|
@ -23,8 +23,8 @@ class Simulator: |
|
|
|
|
|
|
|
|
def available_actions(self): |
|
|
def available_actions(self): |
|
|
""" |
|
|
""" |
|
|
Returns an iterable over the available actions. The action mode may be used to select how actions are referred to. |
|
|
|
|
|
TODO: Support multiple action modes |
|
|
|
|
|
|
|
|
Returns an iterable over the available actions. |
|
|
|
|
|
The action mode may be used to select how actions are referred to. |
|
|
|
|
|
|
|
|
:return: |
|
|
:return: |
|
|
""" |
|
|
""" |
|
@ -55,6 +55,7 @@ class Simulator: |
|
|
|
|
|
|
|
|
def set_observation_mode(self, mode): |
|
|
def set_observation_mode(self, mode): |
|
|
""" |
|
|
""" |
|
|
|
|
|
Select the observation mode, that is, how the states are represented |
|
|
|
|
|
|
|
|
:param mode: STATE_LEVEL or PROGRAM_LEVEL |
|
|
:param mode: STATE_LEVEL or PROGRAM_LEVEL |
|
|
:type mode: |
|
|
:type mode: |
|
@ -64,6 +65,12 @@ class Simulator: |
|
|
self._observation_mode = mode |
|
|
self._observation_mode = mode |
|
|
|
|
|
|
|
|
def set_action_mode(self, mode): |
|
|
def set_action_mode(self, mode): |
|
|
|
|
|
""" |
|
|
|
|
|
Select the action mode, that is, how the actions are represented |
|
|
|
|
|
|
|
|
|
|
|
:param mode: SimulatorActionMode.INDEX_LEVEL or SimulatorActionMode.GLOBAL_NAMES |
|
|
|
|
|
:return: |
|
|
|
|
|
""" |
|
|
if not isinstance(mode, SimulatorActionMode): |
|
|
if not isinstance(mode, SimulatorActionMode): |
|
|
raise RuntimeError("Action mode must be a SimulatorActionMode") |
|
|
raise RuntimeError("Action mode must be a SimulatorActionMode") |
|
|
self._action_mode = mode |
|
|
self._action_mode = mode |
|
@ -96,7 +103,9 @@ class SparseSimulator(Simulator): |
|
|
if self._action_mode == SimulatorActionMode.INDEX_LEVEL: |
|
|
if self._action_mode == SimulatorActionMode.INDEX_LEVEL: |
|
|
return range(self.nr_available_actions()) |
|
|
return range(self.nr_available_actions()) |
|
|
else: |
|
|
else: |
|
|
assert self._model.has_choice_labeling(), "Global names require choice labeling" |
|
|
|
|
|
|
|
|
assert self._action_mode == SimulatorActionMode.GLOBAL_NAMES, "Unknown type of simulator action mode" |
|
|
|
|
|
if not self._model.has_choice_labeling(): |
|
|
|
|
|
raise RuntimeError("Global names action mode requires model with choice labeling") |
|
|
av_actions = [] |
|
|
av_actions = [] |
|
|
current_state = self._engine.get_current_state() |
|
|
current_state = self._engine.get_current_state() |
|
|
for action_offset in range(self.nr_available_actions()): |
|
|
for action_offset in range(self.nr_available_actions()): |
|
@ -150,7 +159,6 @@ class SparseSimulator(Simulator): |
|
|
check = self._engine.step(action) |
|
|
check = self._engine.step(action) |
|
|
assert check |
|
|
assert check |
|
|
elif self._action_mode == SimulatorActionMode.GLOBAL_NAMES: |
|
|
elif self._action_mode == SimulatorActionMode.GLOBAL_NAMES: |
|
|
current_state = self._engine.get_current_state() |
|
|
|
|
|
action_index = None |
|
|
action_index = None |
|
|
av_actions = self.available_actions() |
|
|
av_actions = self.available_actions() |
|
|
for offset, label in enumerate(av_actions): |
|
|
for offset, label in enumerate(av_actions): |
|
@ -158,14 +166,13 @@ class SparseSimulator(Simulator): |
|
|
action_index = offset |
|
|
action_index = offset |
|
|
break |
|
|
break |
|
|
if action_index is None: |
|
|
if action_index is None: |
|
|
raise ValueError("Could not find action: ") |
|
|
|
|
|
|
|
|
raise ValueError(f"Could not find action: {action}") |
|
|
check = self._engine.step(action_index) |
|
|
check = self._engine.step(action_index) |
|
|
assert check |
|
|
assert check |
|
|
else: |
|
|
else: |
|
|
raise ValueError("Unrecognized type of action %s" % action) |
|
|
|
|
|
|
|
|
raise ValueError(f"Unrecognized type of action {action}") |
|
|
return self._report_result() |
|
|
return self._report_result() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def restart(self): |
|
|
def restart(self): |
|
|
self._engine.reset_to_initial_state() |
|
|
self._engine.reset_to_initial_state() |
|
|
return self._report_result() |
|
|
return self._report_result() |
|
|