|
|
@ -8,7 +8,8 @@ class SimulatorObservationMode(Enum): |
|
|
|
PROGRAM_LEVEL = 1 |
|
|
|
|
|
|
|
class SimulatorActionMode(Enum): |
|
|
|
INDEX_LEVEL = 0 |
|
|
|
INDEX_LEVEL = 0, |
|
|
|
GLOBAL_NAMES = 1 |
|
|
|
|
|
|
|
class Simulator: |
|
|
|
""" |
|
|
@ -62,6 +63,11 @@ class Simulator: |
|
|
|
raise RuntimeError("Observation mode must be a SimulatorObservationMode") |
|
|
|
self._observation_mode = mode |
|
|
|
|
|
|
|
def set_action_mode(self, mode): |
|
|
|
if not isinstance(mode, SimulatorActionMode): |
|
|
|
raise RuntimeError("Action mode must be a SimulatorActionMode") |
|
|
|
self._action_mode = mode |
|
|
|
|
|
|
|
def set_full_observability(self, value): |
|
|
|
""" |
|
|
|
Sets whether the full state space is observable. |
|
|
@ -87,7 +93,22 @@ class SparseSimulator(Simulator): |
|
|
|
self.set_full_observability(self._model.model_type != stormpy.storage.ModelType.POMDP) |
|
|
|
|
|
|
|
def available_actions(self): |
|
|
|
if self._action_mode == SimulatorActionMode.INDEX_LEVEL: |
|
|
|
return range(self.nr_available_actions()) |
|
|
|
else: |
|
|
|
assert self._model.has_choice_labeling(), "Global names require choice labeling" |
|
|
|
av_actions = [] |
|
|
|
current_state = self._engine.get_current_state() |
|
|
|
for action_offset in range(self.nr_available_actions()): |
|
|
|
choice_label = self._model.choice_labeling.get_labels_of_choice(self._model.get_choice_index(current_state, action_offset)) |
|
|
|
if len(choice_label) == 0: |
|
|
|
av_actions.append(f"_act_{action_offset}") |
|
|
|
elif len(choice_label) == 1: |
|
|
|
av_actions.append(list(choice_label)[0]) |
|
|
|
else: |
|
|
|
assert False, "Unknown type of choice label, support not implemented" |
|
|
|
|
|
|
|
return av_actions |
|
|
|
|
|
|
|
def nr_available_actions(self): |
|
|
|
return self._model.get_nr_available_actions(self._engine.get_current_state()) |
|
|
@ -123,11 +144,25 @@ class SparseSimulator(Simulator): |
|
|
|
raise RuntimeError("Must specify an action in nondeterministic models.") |
|
|
|
check = self._engine.step(0) |
|
|
|
assert check |
|
|
|
else: |
|
|
|
elif type(action) == int and self._action_mode == SimulatorActionMode.INDEX_LEVEL: |
|
|
|
if action >= self.nr_available_actions(): |
|
|
|
raise RuntimeError(f"Only {self.nr_available_actions()} actions available") |
|
|
|
check = self._engine.step(action) |
|
|
|
assert check |
|
|
|
elif self._action_mode == SimulatorActionMode.GLOBAL_NAMES: |
|
|
|
current_state = self._engine.get_current_state() |
|
|
|
action_index = None |
|
|
|
av_actions = self.available_actions() |
|
|
|
for offset, label in enumerate(av_actions): |
|
|
|
if action == label: |
|
|
|
action_index = offset |
|
|
|
break |
|
|
|
if action_index is None: |
|
|
|
raise ValueError("Could not find action: ") |
|
|
|
check = self._engine.step(action_index) |
|
|
|
assert check |
|
|
|
else: |
|
|
|
raise ValueError("Unrecognized type of action %s" % action) |
|
|
|
return self._report_result() |
|
|
|
|
|
|
|
|
|
|
|