| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -23,8 +23,8 @@ class Simulator: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def available_actions(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        Returns an iterable over the available actions. The action mode may be used to select how actions are referred to. | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        TODO: Support multiple action modes | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        Returns an iterable over the available actions. | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        The action mode may be used to select how actions are referred to. | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        :return: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -55,6 +55,7 @@ class Simulator: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def set_observation_mode(self, mode): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        Select the observation mode, that is, how the states are represented | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        :param mode: STATE_LEVEL or PROGRAM_LEVEL | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        :type mode: | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -64,6 +65,12 @@ class Simulator: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._observation_mode = mode | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def set_action_mode(self, mode): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        Select the action mode, that is, how the actions are represented | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        :param mode: SimulatorActionMode.INDEX_LEVEL or SimulatorActionMode.GLOBAL_NAMES | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        :return: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if not isinstance(mode, SimulatorActionMode): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise RuntimeError("Action mode must be a SimulatorActionMode") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._action_mode = mode | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -96,7 +103,9 @@ class SparseSimulator(Simulator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if self._action_mode == SimulatorActionMode.INDEX_LEVEL: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return range(self.nr_available_actions()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            assert self._model.has_choice_labeling(), "Global names require choice labeling" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            assert self._action_mode == SimulatorActionMode.GLOBAL_NAMES, "Unknown type of simulator action mode" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if not self._model.has_choice_labeling(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                raise RuntimeError("Global names action mode requires model with choice labeling") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            av_actions = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            current_state = self._engine.get_current_state() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for action_offset in range(self.nr_available_actions()): | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -150,7 +159,6 @@ class SparseSimulator(Simulator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            check = self._engine.step(action) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            assert check | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        elif self._action_mode == SimulatorActionMode.GLOBAL_NAMES: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            current_state = self._engine.get_current_state() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            action_index = None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            av_actions = self.available_actions() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for offset, label in enumerate(av_actions): | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -158,14 +166,13 @@ class SparseSimulator(Simulator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    action_index = offset | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    break | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if action_index is None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                raise ValueError("Could not find action: ") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                raise ValueError(f"Could not find action: {action}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            check = self._engine.step(action_index) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            assert check | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise ValueError("Unrecognized type of action %s" % action) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise ValueError(f"Unrecognized type of action {action}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return self._report_result() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def restart(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._engine.reset_to_initial_state() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return self._report_result() | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |