The source code and dockerfile for the GSW2024 AI Lab.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This repo is archived. You can view files and clone it, but cannot push or open issues/pull-requests.

224 lines
8.1 KiB

2 months ago
  1. from enum import Enum
  2. import stormpy.core
  3. class SimulatorObservationMode(Enum):
  4. STATE_LEVEL = 0,
  5. PROGRAM_LEVEL = 1
  6. class SimulatorActionMode(Enum):
  7. INDEX_LEVEL = 0,
  8. GLOBAL_NAMES = 1
  9. class Simulator:
  10. """
  11. Base class for simulators.
  12. """
  13. def __init__(self, seed=None):
  14. self._seed = seed
  15. self._observation_mode = SimulatorObservationMode.STATE_LEVEL
  16. self._action_mode = SimulatorActionMode.INDEX_LEVEL
  17. self._full_observe = False
  18. def available_actions(self):
  19. """
  20. Returns an iterable over the available actions.
  21. The action mode may be used to select how actions are referred to.
  22. :return:
  23. """
  24. raise NotImplementedError("Abstract Superclass")
  25. def step(self, action=None):
  26. """
  27. Do a step taking the passed action.
  28. :param action: The index of the action, for deterministic actions, action may be None.
  29. :return: The observation (on state or program level).
  30. """
  31. raise NotImplementedError("Abstract superclass")
  32. def restart(self):
  33. """
  34. Reset the simulator to the initial state
  35. """
  36. raise NotImplementedError("Abstract superclass")
  37. def is_done(self):
  38. """
  39. Is the simulator in a sink state?
  40. :return: Yes, if the simulator is in a sink state.
  41. """
  42. return False
  43. def set_observation_mode(self, mode):
  44. """
  45. Select the observation mode, that is, how the states are represented
  46. :param mode: STATE_LEVEL or PROGRAM_LEVEL
  47. :type mode:
  48. """
  49. if not isinstance(mode, SimulatorObservationMode):
  50. raise RuntimeError("Observation mode must be a SimulatorObservationMode")
  51. self._observation_mode = mode
  52. def set_action_mode(self, mode):
  53. """
  54. Select the action mode, that is, how the actions are represented
  55. :param mode: SimulatorActionMode.INDEX_LEVEL or SimulatorActionMode.GLOBAL_NAMES
  56. :return:
  57. """
  58. if not isinstance(mode, SimulatorActionMode):
  59. raise RuntimeError("Action mode must be a SimulatorActionMode")
  60. self._action_mode = mode
  61. def set_full_observability(self, value):
  62. """
  63. Sets whether the full state space is observable.
  64. Default inherited from the model, but this method overrides the setting.
  65. :param value:
  66. """
  67. self._full_observe = value
  68. class SparseSimulator(Simulator):
  69. """
  70. Simulator on top of sparse models.
  71. """
  72. def __init__(self, model, seed=None):
  73. super().__init__(seed)
  74. self._model = model
  75. if self._model.is_exact:
  76. self._engine = stormpy.core._DiscreteTimeSparseModelSimulatorExact(model)
  77. else:
  78. self._engine = stormpy.core._DiscreteTimeSparseModelSimulatorDouble(model)
  79. if seed is not None:
  80. self._engine.set_seed(seed)
  81. self._state_valuations = None
  82. self.set_full_observability(self._model.model_type != stormpy.storage.ModelType.POMDP)
  83. def set_seed(self, value):
  84. self._engine.set_seed(value)
  85. def available_actions(self):
  86. if self._action_mode == SimulatorActionMode.INDEX_LEVEL:
  87. return range(self.nr_available_actions())
  88. else:
  89. assert self._action_mode == SimulatorActionMode.GLOBAL_NAMES, "Unknown type of simulator action mode"
  90. if not self._model.has_choice_labeling():
  91. raise RuntimeError("Global names action mode requires model with choice labeling")
  92. av_actions = []
  93. current_state = self._engine.get_current_state()
  94. for action_offset in range(self.nr_available_actions()):
  95. choice_label = self._model.choice_labeling.get_labels_of_choice(self._model.get_choice_index(current_state, action_offset))
  96. if len(choice_label) == 0:
  97. av_actions.append(f"_act_{action_offset}")
  98. elif len(choice_label) == 1:
  99. av_actions.append(list(choice_label)[0])
  100. else:
  101. assert False, "Unknown type of choice label, support not implemented"
  102. return av_actions
  103. def nr_available_actions(self):
  104. if not self._model.is_nondeterministic_model:
  105. return 1
  106. return self._model.get_nr_available_actions(self._engine.get_current_state())
  107. def _report_state(self):
  108. if self._observation_mode == SimulatorObservationMode.STATE_LEVEL:
  109. return self._engine.get_current_state()
  110. elif self._observation_mode == SimulatorObservationMode.PROGRAM_LEVEL:
  111. return self._state_valuations.get_json(self._engine.get_current_state())
  112. assert False, "The observation mode is unexpected"
  113. def _report_observation(self):
  114. """
  115. :return:
  116. """
  117. #TODO this should be ensured earlier
  118. assert self._model.model_type == stormpy.storage.ModelType.POMDP
  119. if self._observation_mode == SimulatorObservationMode.STATE_LEVEL:
  120. return self._model.get_observation(self._engine.get_current_state())
  121. elif self._observation_mode == SimulatorObservationMode.PROGRAM_LEVEL:
  122. raise NotImplementedError("Program level observations are not implemented in storm")
  123. assert False, "The observation mode is unexpected"
  124. def _report_result(self):
  125. if self._full_observe:
  126. return self._report_state(), self._report_rewards()
  127. else:
  128. return self._report_observation(), self._report_rewards()
  129. def _report_rewards(self):
  130. return self._engine.get_last_reward()
  131. def random_step(self):
  132. check = self._engine.random_step()
  133. assert check
  134. return self._report_result()
  135. def step(self, action=None):
  136. if action is None:
  137. if self._model.is_nondeterministic_model and self.nr_available_actions() > 1:
  138. raise RuntimeError("Must specify an action in nondeterministic models.")
  139. check = self._engine.step(0)
  140. assert check
  141. elif type(action) == int and self._action_mode == SimulatorActionMode.INDEX_LEVEL:
  142. if action >= self.nr_available_actions():
  143. raise RuntimeError(f"Only {self.nr_available_actions()} actions available")
  144. check = self._engine.step(action)
  145. assert check
  146. elif self._action_mode == SimulatorActionMode.GLOBAL_NAMES:
  147. action_index = None
  148. av_actions = self.available_actions()
  149. for offset, label in enumerate(av_actions):
  150. if action == label:
  151. action_index = offset
  152. break
  153. if action_index is None:
  154. raise ValueError(f"Could not find action: {action}")
  155. check = self._engine.step(action_index)
  156. assert check
  157. else:
  158. raise ValueError(f"Unrecognized type of action {action}")
  159. return self._report_result()
  160. def restart(self):
  161. self._engine.reset_to_initial_state()
  162. return self._report_result()
  163. def is_done(self):
  164. return self._model.is_sink_state(self._engine.get_current_state())
  165. def set_observation_mode(self, mode):
  166. super().set_observation_mode(mode)
  167. if self._observation_mode == SimulatorObservationMode.PROGRAM_LEVEL:
  168. if not self._model.has_state_valuations():
  169. raise RuntimeError("Program level observations require model with state valuations")
  170. self._state_valuations = self._model.state_valuations
  171. def get_current_state(self):
  172. return self._engine.get_current_state()
  173. def create_simulator(model, seed = None):
  174. """
  175. Factory method for creating a simulator.
  176. :param model: Some form of model
  177. :param seed: A seed for reproducibility. If None (default), the seed is internally generated.
  178. :return: A simulator that can simulate on top of this model
  179. """
  180. if isinstance(model, stormpy.storage._ModelBase):
  181. if model.is_sparse_model:
  182. return SparseSimulator(model, seed)
  183. else:
  184. raise NotImplementedError("Currently, we only support simulators for sparse models.")