Browse Source

added rudimental key / door masking

refactoring
Thomas Knoll 1 year ago
parent
commit
97f7d23cda
  1. 8
      examples/shields/rl/11_minigridrl.py
  2. 32
      examples/shields/rl/13_minigridsb.py
  3. 21
      examples/shields/rl/Wrapper.py
  4. 21
      examples/shields/rl/helpers.py

8
examples/shields/rl/11_minigridrl.py

@ -7,8 +7,6 @@ from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID
from datetime import datetime
import gymnasium as gym
import minigrid
@ -27,7 +25,7 @@ from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.models.preprocessors import get_preprocessor
from MaskModels import TorchActionMaskModel
from Wrapper import OneHotWrapper, MiniGridEnvWrapper
from helpers import extract_keys, parse_arguments, create_shield_dict
from helpers import extract_keys, parse_arguments, create_shield_dict, create_log_dir
import matplotlib.pyplot as plt
@ -80,8 +78,6 @@ def env_creater_custom(config):
return env
def create_log_dir(args):
return F"{args.log_dir}{datetime.now()}-{args.algorithm}-masking:{not args.no_masking}"
def register_custom_minigrid_env(args):
@ -96,7 +92,7 @@ def register_custom_minigrid_env(args):
def ppo(args):
ray.init(num_cpus=3)
ray.init(num_cpus=1)
register_custom_minigrid_env(args)

32
examples/shields/rl/13_minigridsb.py

@ -8,10 +8,12 @@ from stable_baselines3.common.callbacks import BaseCallback
import gymnasium as gym
from gymnasium.spaces import Dict, Box
from minigrid.core.actions import Actions
import numpy as np
import time
from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping
from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping, create_log_dir
class CustomCallback(BaseCallback):
def __init__(self, verbose: int = 0, env=None):
@ -35,6 +37,10 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
self.no_masking = no_masking
def create_action_mask(self):
if self.no_masking:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
coordinates = self.env.agent_pos
view_direction = self.env.agent_dir
@ -70,9 +76,19 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
# print(F"Not in shield {cur_pos_str}")
for index, x in enumerate(mask):
mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
if front_tile is not None and front_tile.type == "key":
mask[Actions.pickup] = 1.0
if self.env.carrying:
mask[Actions.drop] = 1.0
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
if self.no_masking:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
return mask
@ -107,8 +123,14 @@ def main():
env = MiniGridEnvWrapper(env, shield=shield, keys=keys, no_masking=args.no_masking)
env = ActionMasker(env, mask_fn)
callback = CustomCallback(1, env)
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=args.log_dir)
model.learn(args.iterations, callback=callback)
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args))
iterations = args.iterations
if iterations < 10_000:
iterations = 10_000
model.learn(iterations, callback=callback)
mean_reward, std_reward = evaluate_policy(model, model.get_env(), 10)

21
examples/shields/rl/Wrapper.py

@ -1,6 +1,7 @@
import gymnasium as gym
import numpy as np
from minigrid.core.actions import Actions
from gymnasium.spaces import Dict, Box
from collections import deque
@ -116,7 +117,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
allowed_actions = []
# Create the mask
# If shield restricts action mask only valid with 1.0
# else set all actions as valid
@ -125,7 +126,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action[1])
index = get_action_index_mapping(allowed_action[1]) # Allowed_action is a set
if index is None:
assert(False)
mask[index] = 1.0
@ -133,10 +134,18 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
# print(F"Not in shield {cur_pos_str}")
for index, x in enumerate(mask):
mask[index] = 1.0
# mask[0] = 1.0
# print(F"Action Mask for position {coordinates} and view {view_direction} is {mask} Position String: {cur_pos_str})")
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
if front_tile is not None and front_tile.type == "key":
mask[Actions.pickup] = 1.0
if self.env.carrying:
mask[Actions.drop] = 1.0
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
return mask
def reset(self, *, seed=None, options=None):

21
examples/shields/rl/helpers.py

@ -2,6 +2,8 @@ import minigrid
from minigrid.core.actions import Actions
import gymnasium as gym
from datetime import datetime
import stormpy
import stormpy.core
import stormpy.simulator
@ -28,6 +30,9 @@ def extract_keys(env):
return keys
def create_log_dir(args):
return F"{args.log_dir}{datetime.now()}-{args.algorithm}-masking:{not args.no_masking}-env:{args.env}"
def get_action_index_mapping(actions):
for action_str in actions:
@ -62,19 +67,13 @@ def parse_arguments(argparse):
choices=[
"MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-DoorKey-8x8-v0",
"MiniGrid-Dynamic-Obstacles-8x8-v0",
"MiniGrid-Empty-Random-6x6-v0",
"MiniGrid-Fetch-6x6-N2-v0",
"MiniGrid-LockedRoom-v0",
"MiniGrid-FourRooms-v0",
"MiniGrid-KeyCorridorS6R3-v0",
"MiniGrid-GoToDoor-8x8-v0",
"MiniGrid-LavaGapS7-v0",
"MiniGrid-SimpleCrossingS9N3-v0",
"MiniGrid-BlockedUnlockPickup-v0",
"MiniGrid-LockedRoom-v0",
"MiniGrid-ObstructedMaze-1Dlh-v0",
"MiniGrid-DoorKey-16x16-v0",
"MiniGrid-RedBlueDoors-6x6-v0",])
"MiniGrid-Empty-Random-6x6-v0",
])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None)
parser.add_argument("--grid_to_prism_path", default="./main")
@ -114,8 +113,8 @@ def create_shield(grid_to_prism_path, grid_file, prism_path):
program = stormpy.parse_prism_program(prism_path)
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
# formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY,
# stormpy.logic.ShieldComparison.ABSOLUTE, 0.9)

Loading…
Cancel
Save