diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index 73e58a5..3bcec40 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -39,28 +39,21 @@ def extract_keys(env): return keys def create_log_dir(args): - return F"{args.log_dir}{args.algorithm}-shielding:{args.shielding}-evaluations:{args.evaluations}-steps:{args.steps}" + return F"{args.log_dir}{args.algorithm}-shielding:{args.shielding}-evaluations:{args.evaluations}-steps:{args.steps}-env:{args.env}" def get_action_index_mapping(actions): for action_str in actions: - if "left" in action_str: + if "move" in action_str: + return Actions.forward + elif "left" in action_str: return Actions.left elif "right" in action_str: return Actions.right - elif "east" in action_str: - return Actions.forward - elif "south" in action_str: - return Actions.forward - elif "west" in action_str: - return Actions.forward - elif "north" in action_str: - return Actions.forward elif "pickup" in action_str: return Actions.pickup elif "done" in action_str: - return Actions.done - + return Actions.done raise ValueError(F"Action string {action_str} not supported") @@ -75,6 +68,10 @@ def parse_arguments(argparse): choices=[ "MiniGrid-LavaCrossingS9N1-v0", "MiniGrid-LavaCrossingS9N3-v0", + "MiniGrid-LavaSlipperyS12-v0", + "MiniGrid-LavaSlipperyS12-v1", + "MiniGrid-LavaSlipperyS12-v2", + "MiniGrid-LavaSlipperyS12-v3", # "MiniGrid-DoorKey-8x8-v0", # "MiniGrid-LockedRoom-v0", # "MiniGrid-FourRooms-v0", diff --git a/examples/shields/rl/shieldhandlers.py b/examples/shields/rl/shieldhandlers.py index 5ca4b54..4feefb8 100644 --- a/examples/shields/rl/shieldhandlers.py +++ b/examples/shields/rl/shieldhandlers.py @@ -12,6 +12,12 @@ from abc import ABC import os +class Action(): + def __init__(self, idx, prob=1, labels=[]) -> None: + self.idx = idx + self.prob = prob + self.labels = labels + class ShieldHandler(ABC): def __init__(self) -> None: pass @@ -41,6 +47,7 @@ class MiniGridShieldHandler(ShieldHandler): f.close() def __create_shield_dict(self): + print(self.prism_path) program = stormpy.parse_prism_program(self.prism_path) shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) @@ -56,6 +63,7 @@ class MiniGridShieldHandler(ShieldHandler): assert result.has_scheduler assert result.has_shield shield = result.shield + stormpy.shields.export_shield(model, shield, "Grid.shield") action_dictionary = {} shield_scheduler = shield.construct() @@ -65,12 +73,11 @@ class MiniGridShieldHandler(ShieldHandler): choices = choice.choice_map state_valuation = model.state_valuations.get_string(stateID) - actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices] + #actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices] + actions_to_be_executed = [Action(idx= choice[1], prob=choice[0], labels=model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices] action_dictionary[state_valuation] = actions_to_be_executed - # stormpy.shields.export_shield(model, shield, "Grid.shield") - return action_dictionary diff --git a/examples/shields/rl/wrappers.py b/examples/shields/rl/wrappers.py index 7bdc5e8..3f0c2f8 100644 --- a/examples/shields/rl/wrappers.py +++ b/examples/shields/rl/wrappers.py @@ -1,5 +1,6 @@ import gymnasium as gym import numpy as np +import random from minigrid.core.actions import Actions @@ -15,9 +16,9 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): def __init__(self, env, vector_index, framestack): super().__init__(env) self.framestack = framestack - # 49=7x7 field of vision; 11=object types; 6=colors; 3=state types. + # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types. # +4: Direction. - self.single_frame_dim = 49 * (11 + 6 + 3) + 4 + self.single_frame_dim = 49 * (16 + 6 + 3) + 4 self.init_x = None self.init_y = None self.x_positions = [] @@ -66,8 +67,8 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): image = obs["data"] - # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten. - objects = one_hot(image[:, :, 0], depth=11) + # One-hot the last dim into 12, 6, 3 one-hot vectors, then flatten. + objects = one_hot(image[:, :, 0], depth=16) colors = one_hot(image[:, :, 1], depth=6) states = one_hot(image[:, :, 2], depth=3) @@ -115,12 +116,15 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) if cur_pos_str in self.shield and self.shield[cur_pos_str]: - allowed_actions = self.shield[cur_pos_str] - for allowed_action in allowed_actions: - index = get_action_index_mapping(allowed_action[1]) # Allowed_action is a set - if index is None: - assert(False) - mask[index] = 1.0 + allowed_actions = self.shield[cur_pos_str] + for allowed_action in allowed_actions: + index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set + if index is None: + assert(False) + + allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] + mask[index] = allowed + else: for index, x in enumerate(mask): mask[index] = 1.0 @@ -195,13 +199,13 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) if cur_pos_str in self.shield and self.shield[cur_pos_str]: - allowed_actions = self.shield[cur_pos_str] - for allowed_action in allowed_actions: - index = get_action_index_mapping(allowed_action[1]) - if index is None: + allowed_actions = self.shield[cur_pos_str] + for allowed_action in allowed_actions: + index = get_action_index_mapping(allowed_action.labels) + if index is None: assert(False) - - mask[index] = 1.0 + + mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] else: for index, x in enumerate(mask): mask[index] = 1.0