Browse Source

added rudimental key / door masking

refactoring
Thomas Knoll 2 years ago
parent
commit
97f7d23cda
  1. 8
      examples/shields/rl/11_minigridrl.py
  2. 32
      examples/shields/rl/13_minigridsb.py
  3. 21
      examples/shields/rl/Wrapper.py
  4. 21
      examples/shields/rl/helpers.py

8
examples/shields/rl/11_minigridrl.py

@ -7,8 +7,6 @@ from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID from ray.rllib.utils.typing import PolicyID
from datetime import datetime
import gymnasium as gym import gymnasium as gym
import minigrid import minigrid
@ -27,7 +25,7 @@ from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.preprocessors import get_preprocessor
from MaskModels import TorchActionMaskModel from MaskModels import TorchActionMaskModel
from Wrapper import OneHotWrapper, MiniGridEnvWrapper from Wrapper import OneHotWrapper, MiniGridEnvWrapper
from helpers import extract_keys, parse_arguments, create_shield_dict from helpers import extract_keys, parse_arguments, create_shield_dict, create_log_dir
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -80,8 +78,6 @@ def env_creater_custom(config):
return env return env
def create_log_dir(args):
return F"{args.log_dir}{datetime.now()}-{args.algorithm}-masking:{not args.no_masking}"
def register_custom_minigrid_env(args): def register_custom_minigrid_env(args):
@ -96,7 +92,7 @@ def register_custom_minigrid_env(args):
def ppo(args): def ppo(args):
ray.init(num_cpus=3) ray.init(num_cpus=1)
register_custom_minigrid_env(args) register_custom_minigrid_env(args)

32
examples/shields/rl/13_minigridsb.py

@ -8,10 +8,12 @@ from stable_baselines3.common.callbacks import BaseCallback
import gymnasium as gym import gymnasium as gym
from gymnasium.spaces import Dict, Box from gymnasium.spaces import Dict, Box
from minigrid.core.actions import Actions
import numpy as np import numpy as np
import time import time
from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping, create_log_dir
class CustomCallback(BaseCallback): class CustomCallback(BaseCallback):
def __init__(self, verbose: int = 0, env=None): def __init__(self, verbose: int = 0, env=None):
@ -35,6 +37,10 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
self.no_masking = no_masking self.no_masking = no_masking
def create_action_mask(self): def create_action_mask(self):
if self.no_masking:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
coordinates = self.env.agent_pos coordinates = self.env.agent_pos
view_direction = self.env.agent_dir view_direction = self.env.agent_dir
@ -70,9 +76,19 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
# print(F"Not in shield {cur_pos_str}") # print(F"Not in shield {cur_pos_str}")
for index, x in enumerate(mask): for index, x in enumerate(mask):
mask[index] = 1.0 mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
if front_tile is not None and front_tile.type == "key":
mask[Actions.pickup] = 1.0
if self.env.carrying:
mask[Actions.drop] = 1.0
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
if self.no_masking:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
return mask return mask
@ -107,8 +123,14 @@ def main():
env = MiniGridEnvWrapper(env, shield=shield, keys=keys, no_masking=args.no_masking) env = MiniGridEnvWrapper(env, shield=shield, keys=keys, no_masking=args.no_masking)
env = ActionMasker(env, mask_fn) env = ActionMasker(env, mask_fn)
callback = CustomCallback(1, env) callback = CustomCallback(1, env)
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=args.log_dir) model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args))
model.learn(args.iterations, callback=callback) iterations = args.iterations
if iterations < 10_000:
iterations = 10_000
model.learn(iterations, callback=callback)
mean_reward, std_reward = evaluate_policy(model, model.get_env(), 10) mean_reward, std_reward = evaluate_policy(model, model.get_env(), 10)

21
examples/shields/rl/Wrapper.py

@ -1,6 +1,7 @@
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
from minigrid.core.actions import Actions
from gymnasium.spaces import Dict, Box from gymnasium.spaces import Dict, Box
from collections import deque from collections import deque
@ -116,7 +117,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
allowed_actions = [] allowed_actions = []
# Create the mask # Create the mask
# If shield restricts action mask only valid with 1.0 # If shield restricts action mask only valid with 1.0
# else set all actions as valid # else set all actions as valid
@ -125,7 +126,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
if cur_pos_str in self.shield and self.shield[cur_pos_str]: if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str] allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions: for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action[1]) index = get_action_index_mapping(allowed_action[1]) # Allowed_action is a set
if index is None: if index is None:
assert(False) assert(False)
mask[index] = 1.0 mask[index] = 1.0
@ -133,10 +134,18 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
# print(F"Not in shield {cur_pos_str}") # print(F"Not in shield {cur_pos_str}")
for index, x in enumerate(mask): for index, x in enumerate(mask):
mask[index] = 1.0 mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
# mask[0] = 1.0 if front_tile is not None and front_tile.type == "key":
# print(F"Action Mask for position {coordinates} and view {view_direction} is {mask} Position String: {cur_pos_str})") mask[Actions.pickup] = 1.0
if self.env.carrying:
mask[Actions.drop] = 1.0
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
return mask return mask
def reset(self, *, seed=None, options=None): def reset(self, *, seed=None, options=None):

21
examples/shields/rl/helpers.py

@ -2,6 +2,8 @@ import minigrid
from minigrid.core.actions import Actions from minigrid.core.actions import Actions
import gymnasium as gym import gymnasium as gym
from datetime import datetime
import stormpy import stormpy
import stormpy.core import stormpy.core
import stormpy.simulator import stormpy.simulator
@ -28,6 +30,9 @@ def extract_keys(env):
return keys return keys
def create_log_dir(args):
return F"{args.log_dir}{datetime.now()}-{args.algorithm}-masking:{not args.no_masking}-env:{args.env}"
def get_action_index_mapping(actions): def get_action_index_mapping(actions):
for action_str in actions: for action_str in actions:
@ -62,19 +67,13 @@ def parse_arguments(argparse):
choices=[ choices=[
"MiniGrid-LavaCrossingS9N1-v0", "MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-DoorKey-8x8-v0", "MiniGrid-DoorKey-8x8-v0",
"MiniGrid-Dynamic-Obstacles-8x8-v0", "MiniGrid-LockedRoom-v0",
"MiniGrid-Empty-Random-6x6-v0",
"MiniGrid-Fetch-6x6-N2-v0",
"MiniGrid-FourRooms-v0", "MiniGrid-FourRooms-v0",
"MiniGrid-KeyCorridorS6R3-v0",
"MiniGrid-GoToDoor-8x8-v0",
"MiniGrid-LavaGapS7-v0", "MiniGrid-LavaGapS7-v0",
"MiniGrid-SimpleCrossingS9N3-v0", "MiniGrid-SimpleCrossingS9N3-v0",
"MiniGrid-BlockedUnlockPickup-v0",
"MiniGrid-LockedRoom-v0",
"MiniGrid-ObstructedMaze-1Dlh-v0",
"MiniGrid-DoorKey-16x16-v0", "MiniGrid-DoorKey-16x16-v0",
"MiniGrid-RedBlueDoors-6x6-v0",]) "MiniGrid-Empty-Random-6x6-v0",
])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None) # parser.add_argument("--seed", type=int, help="seed for environment", default=None)
parser.add_argument("--grid_to_prism_path", default="./main") parser.add_argument("--grid_to_prism_path", default="./main")
@ -114,8 +113,8 @@ def create_shield(grid_to_prism_path, grid_file, prism_path):
program = stormpy.parse_prism_program(prism_path) program = stormpy.parse_prism_program(prism_path)
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" # formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
# formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, # shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY,
# stormpy.logic.ShieldComparison.ABSOLUTE, 0.9) # stormpy.logic.ShieldComparison.ABSOLUTE, 0.9)

|||||||
100:0
Loading…
Cancel
Save