Browse Source

simulator: action mode global names for MDPs

refactoring
Sebastian Junges 5 years ago
parent
commit
5d3319dfea
  1. 33
      examples/simulator/02-simulator.py
  2. 41
      lib/stormpy/simulator.py

33
examples/simulator/02-simulator.py

@ -9,7 +9,7 @@ import random
"""
Simulator for nondeterministic models
"""
def example_simulator_01():
def example_simulator_02():
path = stormpy.examples.files.prism_mdp_maze
prism_program = stormpy.parse_prism_program(path)
@ -36,7 +36,36 @@ def example_simulator_01():
for path in paths:
print(" ".join(path))
options = stormpy.BuilderOptions()
options.set_build_state_valuations()
options.set_build_choice_labels(True)
model = stormpy.build_sparse_model_with_options(prism_program, options)
print(model)
simulator = stormpy.simulator.create_simulator(model, seed=42)
simulator.set_observation_mode(stormpy.simulator.SimulatorObservationMode.PROGRAM_LEVEL)
simulator.set_action_mode(stormpy.simulator.SimulatorActionMode.GLOBAL_NAMES)
# 5 paths of at most 20 steps.
paths = []
for m in range(5):
path = []
state = simulator.restart()
path = [f"{state}"]
for n in range(20):
actions = simulator.available_actions()
select_action = random.randint(0,len(actions)-1)
#print(f"Randomly select action nr: {select_action} from actions {actions}")
path.append(f"--act={actions[select_action]}-->")
state = simulator.step(actions[select_action])
#print(state)
path.append(f"{state}")
if simulator.is_done():
#print("Trapped!")
break
paths.append(path)
for path in paths:
print(" ".join(path))
if __name__ == '__main__':
example_simulator_01()
example_simulator_02()

41
lib/stormpy/simulator.py

@ -8,7 +8,8 @@ class SimulatorObservationMode(Enum):
PROGRAM_LEVEL = 1
class SimulatorActionMode(Enum):
INDEX_LEVEL = 0
INDEX_LEVEL = 0,
GLOBAL_NAMES = 1
class Simulator:
"""
@ -62,6 +63,11 @@ class Simulator:
raise RuntimeError("Observation mode must be a SimulatorObservationMode")
self._observation_mode = mode
def set_action_mode(self, mode):
if not isinstance(mode, SimulatorActionMode):
raise RuntimeError("Action mode must be a SimulatorActionMode")
self._action_mode = mode
def set_full_observability(self, value):
"""
Sets whether the full state space is observable.
@ -87,7 +93,22 @@ class SparseSimulator(Simulator):
self.set_full_observability(self._model.model_type != stormpy.storage.ModelType.POMDP)
def available_actions(self):
return range(self.nr_available_actions())
if self._action_mode == SimulatorActionMode.INDEX_LEVEL:
return range(self.nr_available_actions())
else:
assert self._model.has_choice_labeling(), "Global names require choice labeling"
av_actions = []
current_state = self._engine.get_current_state()
for action_offset in range(self.nr_available_actions()):
choice_label = self._model.choice_labeling.get_labels_of_choice(self._model.get_choice_index(current_state, action_offset))
if len(choice_label) == 0:
av_actions.append(f"_act_{action_offset}")
elif len(choice_label) == 1:
av_actions.append(list(choice_label)[0])
else:
assert False, "Unknown type of choice label, support not implemented"
return av_actions
def nr_available_actions(self):
return self._model.get_nr_available_actions(self._engine.get_current_state())
@ -123,11 +144,25 @@ class SparseSimulator(Simulator):
raise RuntimeError("Must specify an action in nondeterministic models.")
check = self._engine.step(0)
assert check
else:
elif type(action) == int and self._action_mode == SimulatorActionMode.INDEX_LEVEL:
if action >= self.nr_available_actions():
raise RuntimeError(f"Only {self.nr_available_actions()} actions available")
check = self._engine.step(action)
assert check
elif self._action_mode == SimulatorActionMode.GLOBAL_NAMES:
current_state = self._engine.get_current_state()
action_index = None
av_actions = self.available_actions()
for offset, label in enumerate(av_actions):
if action == label:
action_index = offset
break
if action_index is None:
raise ValueError("Could not find action: ")
check = self._engine.step(action_index)
assert check
else:
raise ValueError("Unrecognized type of action %s" % action)
return self._report_result()

Loading…
Cancel
Save