|  |  | @ -18,6 +18,7 @@ class Simulator: | 
			
		
	
		
			
				
					|  |  |  |         self._seed = seed | 
			
		
	
		
			
				
					|  |  |  |         self._observation_mode = SimulatorObservationMode.STATE_LEVEL | 
			
		
	
		
			
				
					|  |  |  |         self._action_mode = SimulatorActionMode.INDEX_LEVEL | 
			
		
	
		
			
				
					|  |  |  |         self._full_observe = False | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def available_actions(self): | 
			
		
	
		
			
				
					|  |  |  |         """ | 
			
		
	
	
		
			
				
					|  |  | @ -61,6 +62,15 @@ class Simulator: | 
			
		
	
		
			
				
					|  |  |  |             raise RuntimeError("Observation mode must be a SimulatorObservationMode") | 
			
		
	
		
			
				
					|  |  |  |         self._observation_mode = mode | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def set_full_observability(self, value): | 
			
		
	
		
			
				
					|  |  |  |         """ | 
			
		
	
		
			
				
					|  |  |  |         Sets whether the full state space is observable. | 
			
		
	
		
			
				
					|  |  |  |         Default inherited from the model, but this method overrides the setting. | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         :param value: | 
			
		
	
		
			
				
					|  |  |  |         """ | 
			
		
	
		
			
				
					|  |  |  |         self._full_observe = value | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | class SparseSimulator(Simulator): | 
			
		
	
		
			
				
					|  |  |  |     """ | 
			
		
	
	
		
			
				
					|  |  | @ -74,6 +84,7 @@ class SparseSimulator(Simulator): | 
			
		
	
		
			
				
					|  |  |  |         if seed is not None: | 
			
		
	
		
			
				
					|  |  |  |             self._engine.set_seed(seed) | 
			
		
	
		
			
				
					|  |  |  |         self._state_valuations = None | 
			
		
	
		
			
				
					|  |  |  |         self.set_full_observability(self._model.model_type != stormpy.storage.ModelType.POMDP) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def available_actions(self): | 
			
		
	
		
			
				
					|  |  |  |         return range(self.nr_available_actions()) | 
			
		
	
	
		
			
				
					|  |  | @ -81,11 +92,30 @@ class SparseSimulator(Simulator): | 
			
		
	
		
			
				
					|  |  |  |     def nr_available_actions(self): | 
			
		
	
		
			
				
					|  |  |  |         return self._model.get_nr_available_actions(self._engine.get_current_state()) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def _report_observation(self): | 
			
		
	
		
			
				
					|  |  |  |     def _report_state(self): | 
			
		
	
		
			
				
					|  |  |  |         if self._observation_mode == SimulatorObservationMode.STATE_LEVEL: | 
			
		
	
		
			
				
					|  |  |  |             return self._engine.get_current_state() | 
			
		
	
		
			
				
					|  |  |  |         elif self._observation_mode == SimulatorObservationMode.PROGRAM_LEVEL: | 
			
		
	
		
			
				
					|  |  |  |             return self._state_valuations.get_state(self._engine.get_current_state()) | 
			
		
	
		
			
				
					|  |  |  |         assert False, "The observation mode is unexpected" | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def _report_observation(self): | 
			
		
	
		
			
				
					|  |  |  |         """ | 
			
		
	
		
			
				
					|  |  |  |         :return: | 
			
		
	
		
			
				
					|  |  |  |         """ | 
			
		
	
		
			
				
					|  |  |  |         #TODO this should be ensured earlier | 
			
		
	
		
			
				
					|  |  |  |         assert self._model.model_type == stormpy.storage.ModelType.POMDP | 
			
		
	
		
			
				
					|  |  |  |         if self._observation_mode == SimulatorObservationMode.STATE_LEVEL: | 
			
		
	
		
			
				
					|  |  |  |             return self._model.get_observation(self._engine.get_current_state()) | 
			
		
	
		
			
				
					|  |  |  |         elif self._observation_mode == SimulatorObservationMode.PROGRAM_LEVEL: | 
			
		
	
		
			
				
					|  |  |  |             raise NotImplementedError("Program level observations are not implemented in storm") | 
			
		
	
		
			
				
					|  |  |  |         assert False, "The observation mode is unexpected" | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def _report_result(self): | 
			
		
	
		
			
				
					|  |  |  |         if self._full_observe: | 
			
		
	
		
			
				
					|  |  |  |             return self._report_state() | 
			
		
	
		
			
				
					|  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |             return self._report_observation() | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def step(self, action=None): | 
			
		
	
		
			
				
					|  |  |  |         if action is None: | 
			
		
	
	
		
			
				
					|  |  | @ -98,12 +128,12 @@ class SparseSimulator(Simulator): | 
			
		
	
		
			
				
					|  |  |  |                 raise RuntimeError(f"Only {self.nr_available_actions()} actions available") | 
			
		
	
		
			
				
					|  |  |  |             check = self._engine.step(action) | 
			
		
	
		
			
				
					|  |  |  |             assert check | 
			
		
	
		
			
				
					|  |  |  |         return self._report_observation() | 
			
		
	
		
			
				
					|  |  |  |         return self._report_result() | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def restart(self): | 
			
		
	
		
			
				
					|  |  |  |         self._engine.reset_to_initial_state() | 
			
		
	
		
			
				
					|  |  |  |         return self._report_observation() | 
			
		
	
		
			
				
					|  |  |  |         return self._report_result() | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def is_done(self): | 
			
		
	
		
			
				
					|  |  |  |         return self._model.is_sink_state(self._engine.get_current_state()) | 
			
		
	
	
		
			
				
					|  |  | 
 |