diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index a7cc555..035b195 100644 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -7,8 +7,6 @@ from ray.rllib.policy import Policy from ray.rllib.utils.typing import PolicyID -from datetime import datetime - import gymnasium as gym import minigrid @@ -27,7 +25,7 @@ from ray.rllib.utils.torch_utils import FLOAT_MIN from ray.rllib.models.preprocessors import get_preprocessor from MaskModels import TorchActionMaskModel from Wrapper import OneHotWrapper, MiniGridEnvWrapper -from helpers import extract_keys, parse_arguments, create_shield_dict +from helpers import extract_keys, parse_arguments, create_shield_dict, create_log_dir import matplotlib.pyplot as plt @@ -80,8 +78,6 @@ def env_creater_custom(config): return env -def create_log_dir(args): - return F"{args.log_dir}{datetime.now()}-{args.algorithm}-masking:{not args.no_masking}" def register_custom_minigrid_env(args): @@ -96,7 +92,7 @@ def register_custom_minigrid_env(args): def ppo(args): - ray.init(num_cpus=3) + ray.init(num_cpus=1) register_custom_minigrid_env(args) diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 73299e5..5873ed6 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -8,10 +8,12 @@ from stable_baselines3.common.callbacks import BaseCallback import gymnasium as gym from gymnasium.spaces import Dict, Box +from minigrid.core.actions import Actions + import numpy as np import time -from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping +from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping, create_log_dir class CustomCallback(BaseCallback): def __init__(self, verbose: int = 0, env=None): @@ -35,6 +37,10 @@ class MiniGridEnvWrapper(gym.core.Wrapper): self.no_masking = no_masking def create_action_mask(self): + if self.no_masking: + return np.array([1.0] * self.max_available_actions, dtype=np.int8) + + coordinates = self.env.agent_pos view_direction = self.env.agent_dir @@ -70,9 +76,19 @@ class MiniGridEnvWrapper(gym.core.Wrapper): # print(F"Not in shield {cur_pos_str}") for index, x in enumerate(mask): mask[index] = 1.0 + + front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) + + if front_tile is not None and front_tile.type == "key": + mask[Actions.pickup] = 1.0 + + if self.env.carrying: + mask[Actions.drop] = 1.0 + + if front_tile and front_tile.type == "door": + mask[Actions.toggle] = 1.0 + - if self.no_masking: - return np.array([1.0] * self.max_available_actions, dtype=np.int8) return mask @@ -107,8 +123,14 @@ def main(): env = MiniGridEnvWrapper(env, shield=shield, keys=keys, no_masking=args.no_masking) env = ActionMasker(env, mask_fn) callback = CustomCallback(1, env) - model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=args.log_dir) - model.learn(args.iterations, callback=callback) + model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args)) + + iterations = args.iterations + + if iterations < 10_000: + iterations = 10_000 + + model.learn(iterations, callback=callback) mean_reward, std_reward = evaluate_policy(model, model.get_env(), 10) diff --git a/examples/shields/rl/Wrapper.py b/examples/shields/rl/Wrapper.py index d67abbb..41aaf2a 100644 --- a/examples/shields/rl/Wrapper.py +++ b/examples/shields/rl/Wrapper.py @@ -1,6 +1,7 @@ import gymnasium as gym import numpy as np +from minigrid.core.actions import Actions from gymnasium.spaces import Dict, Box from collections import deque @@ -116,7 +117,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper): allowed_actions = [] - + # Create the mask # If shield restricts action mask only valid with 1.0 # else set all actions as valid @@ -125,7 +126,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper): if cur_pos_str in self.shield and self.shield[cur_pos_str]: allowed_actions = self.shield[cur_pos_str] for allowed_action in allowed_actions: - index = get_action_index_mapping(allowed_action[1]) + index = get_action_index_mapping(allowed_action[1]) # Allowed_action is a set if index is None: assert(False) mask[index] = 1.0 @@ -133,10 +134,18 @@ class MiniGridEnvWrapper(gym.core.Wrapper): # print(F"Not in shield {cur_pos_str}") for index, x in enumerate(mask): mask[index] = 1.0 - - # mask[0] = 1.0 - # print(F"Action Mask for position {coordinates} and view {view_direction} is {mask} Position String: {cur_pos_str})") - + + front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) + + if front_tile is not None and front_tile.type == "key": + mask[Actions.pickup] = 1.0 + + if self.env.carrying: + mask[Actions.drop] = 1.0 + + if front_tile and front_tile.type == "door": + mask[Actions.toggle] = 1.0 + return mask def reset(self, *, seed=None, options=None): diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index 111c3a6..714ac00 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -2,6 +2,8 @@ import minigrid from minigrid.core.actions import Actions import gymnasium as gym +from datetime import datetime + import stormpy import stormpy.core import stormpy.simulator @@ -28,6 +30,9 @@ def extract_keys(env): return keys +def create_log_dir(args): + return F"{args.log_dir}{datetime.now()}-{args.algorithm}-masking:{not args.no_masking}-env:{args.env}" + def get_action_index_mapping(actions): for action_str in actions: @@ -62,19 +67,13 @@ def parse_arguments(argparse): choices=[ "MiniGrid-LavaCrossingS9N1-v0", "MiniGrid-DoorKey-8x8-v0", - "MiniGrid-Dynamic-Obstacles-8x8-v0", - "MiniGrid-Empty-Random-6x6-v0", - "MiniGrid-Fetch-6x6-N2-v0", + "MiniGrid-LockedRoom-v0", "MiniGrid-FourRooms-v0", - "MiniGrid-KeyCorridorS6R3-v0", - "MiniGrid-GoToDoor-8x8-v0", "MiniGrid-LavaGapS7-v0", "MiniGrid-SimpleCrossingS9N3-v0", - "MiniGrid-BlockedUnlockPickup-v0", - "MiniGrid-LockedRoom-v0", - "MiniGrid-ObstructedMaze-1Dlh-v0", "MiniGrid-DoorKey-16x16-v0", - "MiniGrid-RedBlueDoors-6x6-v0",]) + "MiniGrid-Empty-Random-6x6-v0", + ]) # parser.add_argument("--seed", type=int, help="seed for environment", default=None) parser.add_argument("--grid_to_prism_path", default="./main") @@ -114,8 +113,8 @@ def create_shield(grid_to_prism_path, grid_file, prism_path): program = stormpy.parse_prism_program(prism_path) - formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" - # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" + # formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" + formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" # shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, # stormpy.logic.ShieldComparison.ABSOLUTE, 0.9)