|
@ -18,6 +18,7 @@ class Simulator: |
|
|
self._seed = seed |
|
|
self._seed = seed |
|
|
self._observation_mode = SimulatorObservationMode.STATE_LEVEL |
|
|
self._observation_mode = SimulatorObservationMode.STATE_LEVEL |
|
|
self._action_mode = SimulatorActionMode.INDEX_LEVEL |
|
|
self._action_mode = SimulatorActionMode.INDEX_LEVEL |
|
|
|
|
|
self._full_observe = False |
|
|
|
|
|
|
|
|
def available_actions(self): |
|
|
def available_actions(self): |
|
|
""" |
|
|
""" |
|
@ -61,6 +62,15 @@ class Simulator: |
|
|
raise RuntimeError("Observation mode must be a SimulatorObservationMode") |
|
|
raise RuntimeError("Observation mode must be a SimulatorObservationMode") |
|
|
self._observation_mode = mode |
|
|
self._observation_mode = mode |
|
|
|
|
|
|
|
|
|
|
|
def set_full_observability(self, value): |
|
|
|
|
|
""" |
|
|
|
|
|
Sets whether the full state space is observable. |
|
|
|
|
|
Default inherited from the model, but this method overrides the setting. |
|
|
|
|
|
|
|
|
|
|
|
:param value: |
|
|
|
|
|
""" |
|
|
|
|
|
self._full_observe = value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SparseSimulator(Simulator): |
|
|
class SparseSimulator(Simulator): |
|
|
""" |
|
|
""" |
|
@ -74,6 +84,7 @@ class SparseSimulator(Simulator): |
|
|
if seed is not None: |
|
|
if seed is not None: |
|
|
self._engine.set_seed(seed) |
|
|
self._engine.set_seed(seed) |
|
|
self._state_valuations = None |
|
|
self._state_valuations = None |
|
|
|
|
|
self.set_full_observability(self._model.model_type != stormpy.storage.ModelType.POMDP) |
|
|
|
|
|
|
|
|
def available_actions(self): |
|
|
def available_actions(self): |
|
|
return range(self.nr_available_actions()) |
|
|
return range(self.nr_available_actions()) |
|
@ -81,11 +92,30 @@ class SparseSimulator(Simulator): |
|
|
def nr_available_actions(self): |
|
|
def nr_available_actions(self): |
|
|
return self._model.get_nr_available_actions(self._engine.get_current_state()) |
|
|
return self._model.get_nr_available_actions(self._engine.get_current_state()) |
|
|
|
|
|
|
|
|
def _report_observation(self): |
|
|
|
|
|
|
|
|
def _report_state(self): |
|
|
if self._observation_mode == SimulatorObservationMode.STATE_LEVEL: |
|
|
if self._observation_mode == SimulatorObservationMode.STATE_LEVEL: |
|
|
return self._engine.get_current_state() |
|
|
return self._engine.get_current_state() |
|
|
elif self._observation_mode == SimulatorObservationMode.PROGRAM_LEVEL: |
|
|
elif self._observation_mode == SimulatorObservationMode.PROGRAM_LEVEL: |
|
|
return self._state_valuations.get_state(self._engine.get_current_state()) |
|
|
return self._state_valuations.get_state(self._engine.get_current_state()) |
|
|
|
|
|
assert False, "The observation mode is unexpected" |
|
|
|
|
|
|
|
|
|
|
|
def _report_observation(self): |
|
|
|
|
|
""" |
|
|
|
|
|
:return: |
|
|
|
|
|
""" |
|
|
|
|
|
#TODO this should be ensured earlier |
|
|
|
|
|
assert self._model.model_type == stormpy.storage.ModelType.POMDP |
|
|
|
|
|
if self._observation_mode == SimulatorObservationMode.STATE_LEVEL: |
|
|
|
|
|
return self._model.get_observation(self._engine.get_current_state()) |
|
|
|
|
|
elif self._observation_mode == SimulatorObservationMode.PROGRAM_LEVEL: |
|
|
|
|
|
raise NotImplementedError("Program level observations are not implemented in storm") |
|
|
|
|
|
assert False, "The observation mode is unexpected" |
|
|
|
|
|
|
|
|
|
|
|
def _report_result(self): |
|
|
|
|
|
if self._full_observe: |
|
|
|
|
|
return self._report_state() |
|
|
|
|
|
else: |
|
|
|
|
|
return self._report_observation() |
|
|
|
|
|
|
|
|
def step(self, action=None): |
|
|
def step(self, action=None): |
|
|
if action is None: |
|
|
if action is None: |
|
@ -98,12 +128,12 @@ class SparseSimulator(Simulator): |
|
|
raise RuntimeError(f"Only {self.nr_available_actions()} actions available") |
|
|
raise RuntimeError(f"Only {self.nr_available_actions()} actions available") |
|
|
check = self._engine.step(action) |
|
|
check = self._engine.step(action) |
|
|
assert check |
|
|
assert check |
|
|
return self._report_observation() |
|
|
|
|
|
|
|
|
return self._report_result() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def restart(self): |
|
|
def restart(self): |
|
|
self._engine.reset_to_initial_state() |
|
|
self._engine.reset_to_initial_state() |
|
|
return self._report_observation() |
|
|
|
|
|
|
|
|
return self._report_result() |
|
|
|
|
|
|
|
|
def is_done(self): |
|
|
def is_done(self): |
|
|
return self._model.is_sink_state(self._engine.get_current_state()) |
|
|
return self._model.is_sink_state(self._engine.get_current_state()) |
|
|