You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

162 lines
5.4 KiB

  1. from enum import Enum
  2. import stormpy.core
  3. class SimulatorObservationMode(Enum):
  4. STATE_LEVEL = 0,
  5. PROGRAM_LEVEL = 1
  6. class SimulatorActionMode(Enum):
  7. INDEX_LEVEL = 0
  8. class Simulator:
  9. """
  10. Base class for simulators.
  11. """
  12. def __init__(self, seed=None):
  13. self._seed = seed
  14. self._observation_mode = SimulatorObservationMode.STATE_LEVEL
  15. self._action_mode = SimulatorActionMode.INDEX_LEVEL
  16. self._full_observe = False
  17. def available_actions(self):
  18. """
  19. Returns an iterable over the available actions. The action mode may be used to select how actions are referred to.
  20. TODO: Support multiple action modes
  21. :return:
  22. """
  23. raise NotImplementedError("Abstract Superclass")
  24. def step(self, action=None):
  25. """
  26. Do a step taking the passed action.
  27. :param action: The index of the action, for deterministic actions, action may be None.
  28. :return: The observation (on state or program level).
  29. """
  30. raise NotImplementedError("Abstract superclass")
  31. def restart(self):
  32. """
  33. Reset the simulator to the initial state
  34. """
  35. raise NotImplementedError("Abstract superclass")
  36. def is_done(self):
  37. """
  38. Is the simulator in a sink state?
  39. :return: Yes, if the simulator is in a sink state.
  40. """
  41. return False
  42. def set_observation_mode(self, mode):
  43. """
  44. :param mode: STATE_LEVEL or PROGRAM_LEVEL
  45. :type mode:
  46. """
  47. if not isinstance(mode, SimulatorObservationMode):
  48. raise RuntimeError("Observation mode must be a SimulatorObservationMode")
  49. self._observation_mode = mode
  50. def set_full_observability(self, value):
  51. """
  52. Sets whether the full state space is observable.
  53. Default inherited from the model, but this method overrides the setting.
  54. :param value:
  55. """
  56. self._full_observe = value
  57. class SparseSimulator(Simulator):
  58. """
  59. Simulator on top of sparse models.
  60. """
  61. def __init__(self, model, seed=None):
  62. super().__init__(seed)
  63. self._model = model
  64. self._engine = stormpy.core._DiscreteTimeSparseModelSimulatorDouble(model)
  65. if seed is not None:
  66. self._engine.set_seed(seed)
  67. self._state_valuations = None
  68. self.set_full_observability(self._model.model_type != stormpy.storage.ModelType.POMDP)
  69. def available_actions(self):
  70. return range(self.nr_available_actions())
  71. def nr_available_actions(self):
  72. return self._model.get_nr_available_actions(self._engine.get_current_state())
  73. def _report_state(self):
  74. if self._observation_mode == SimulatorObservationMode.STATE_LEVEL:
  75. return self._engine.get_current_state()
  76. elif self._observation_mode == SimulatorObservationMode.PROGRAM_LEVEL:
  77. return self._state_valuations.get_state(self._engine.get_current_state())
  78. assert False, "The observation mode is unexpected"
  79. def _report_observation(self):
  80. """
  81. :return:
  82. """
  83. #TODO this should be ensured earlier
  84. assert self._model.model_type == stormpy.storage.ModelType.POMDP
  85. if self._observation_mode == SimulatorObservationMode.STATE_LEVEL:
  86. return self._model.get_observation(self._engine.get_current_state())
  87. elif self._observation_mode == SimulatorObservationMode.PROGRAM_LEVEL:
  88. raise NotImplementedError("Program level observations are not implemented in storm")
  89. assert False, "The observation mode is unexpected"
  90. def _report_result(self):
  91. if self._full_observe:
  92. return self._report_state()
  93. else:
  94. return self._report_observation()
  95. def step(self, action=None):
  96. if action is None:
  97. if self._model.is_nondeterministic_model and self.nr_available_actions() > 1:
  98. raise RuntimeError("Must specify an action in nondeterministic models.")
  99. check = self._engine.step(0)
  100. assert check
  101. else:
  102. if action >= self.nr_available_actions():
  103. raise RuntimeError(f"Only {self.nr_available_actions()} actions available")
  104. check = self._engine.step(action)
  105. assert check
  106. return self._report_result()
  107. def restart(self):
  108. self._engine.reset_to_initial_state()
  109. return self._report_result()
  110. def is_done(self):
  111. return self._model.is_sink_state(self._engine.get_current_state())
  112. def set_observation_mode(self, mode):
  113. super().set_observation_mode(mode)
  114. if self._observation_mode == SimulatorObservationMode.PROGRAM_LEVEL:
  115. if not self._model.has_state_valuations():
  116. raise RuntimeError("Program level observations require model with state valuations")
  117. self._state_valuations = self._model.state_valuations
  118. def create_simulator(model, seed = None):
  119. """
  120. Factory method for creating a simulator.
  121. :param model: Some form of model
  122. :param seed: A seed for reproducibility. If None (default), the seed is internally generated.
  123. :return: A simulator that can simulate on top of this model
  124. """
  125. if isinstance(model, stormpy.storage._ModelBase):
  126. if model.is_sparse_model:
  127. return SparseSimulator(model, seed)
  128. else:
  129. raise NotImplementedError("Currently, we only support simulators for sparse models.")