Browse Source

changed shield creation to create shield on reset

refactoring
Thomas Knoll 1 year ago
parent
commit
1c2dbf706e
  1. 36
      examples/shields/rl/11_minigridrl.py
  2. 16
      examples/shields/rl/13_minigridsb.py
  3. 3
      examples/shields/rl/MaskModels.py
  4. 13
      examples/shields/rl/Wrapper.py
  5. 8
      examples/shields/rl/helpers.py

36
examples/shields/rl/11_minigridrl.py

@ -5,7 +5,7 @@ from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID
from ray.rllib.algorithms.algorithm import Algorithm
import gymnasium as gym
@ -29,9 +29,6 @@ from helpers import extract_keys, parse_arguments, create_shield_dict, create_lo
import matplotlib.pyplot as plt
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
@ -50,7 +47,7 @@ class MyCallbacks(DefaultCallbacks):
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0]
#print(env.printGrid())
# print(env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
@ -65,10 +62,9 @@ def env_creater_custom(config):
shield = config.get("shield", {})
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
env = gym.make(name)
keys = extract_keys(env)
env = MiniGridEnvWrapper(env, shield=shield, keys=keys)
env = MiniGridEnvWrapper(env, args=args)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env)
env = OneHotWrapper(env,
@ -76,6 +72,7 @@ def env_creater_custom(config):
framestack=framestack
)
return env
@ -96,12 +93,11 @@ def ppo(args):
register_custom_minigrid_env(args)
shield_dict = create_shield_dict(args)
config = (PPOConfig()
.rollouts(num_rollout_workers=1)
.resources(num_gpus=0)
.environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env})
.environment(env="mini-grid", env_config={"name": args.env, "args": args})
.framework("torch")
.callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False)
@ -111,7 +107,7 @@ def ppo(args):
})
.training(_enable_learner_api=False ,model={
"custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
"custom_model_config" : {"no_masking": args.no_masking}
}))
algo =(
@ -119,11 +115,7 @@ def ppo(args):
config.build()
)
# while not terminated and not truncated:
# action = algo.compute_single_action(obs)
# obs, reward, terminated, truncated = env.step(action)
for i in range(30):
for i in range(args.iterations):
result = algo.train()
print(pretty_print(result))
@ -131,18 +123,24 @@ def ppo(args):
checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}")
# terminated = truncated = False
# while not terminated and not truncated:
# action = algo.compute_single_action(obs)
# obs, reward, terminated, truncated = env.step(action)
ray.shutdown()
def dqn(args):
register_custom_minigrid_env(args)
shield_dict = create_shield_dict(args)
config = DQNConfig()
config = config.resources(num_gpus=0)
config = config.rollouts(num_rollout_workers=1)
config = config.environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env })
config = config.environment(env="mini-grid", env_config={"name": args.env, "args": args })
config = config.framework("torch")
config = config.callbacks(MyCallbacks)
config = config.rl_module(_enable_rl_module_api = False)
@ -152,7 +150,7 @@ def dqn(args):
})
config = config.training(hiddens=[], dueling=False, model={
"custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
"custom_model_config" : {"no_masking": args.no_masking}
})
algo = (

16
examples/shields/rl/13_minigridsb.py

@ -27,13 +27,12 @@ class CustomCallback(BaseCallback):
class MiniGridEnvWrapper(gym.core.Wrapper):
def __init__(self, env, shield={}, keys=[], no_masking=False):
def __init__(self, env, args=None, no_masking=False):
super(MiniGridEnvWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = env.observation_space.spaces["image"]
self.keys = keys
self.shield = shield
self.args = args
self.no_masking = no_masking
def create_action_mask(self):
@ -94,6 +93,12 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
keys = extract_keys(self.env)
shield = create_shield_dict(self.env, self.args)
self.keys = keys
self.shield = shield
return obs["image"], infos
def step(self, action):
@ -116,11 +121,10 @@ def mask_fn(env: gym.Env):
def main():
import argparse
args = parse_arguments(argparse)
shield = create_shield_dict(args)
env = gym.make(args.env, render_mode="rgb_array")
keys = extract_keys(env)
env = MiniGridEnvWrapper(env, shield=shield, keys=keys, no_masking=args.no_masking)
env = MiniGridEnvWrapper(env,args=args, no_masking=args.no_masking)
env = ActionMasker(env, mask_fn)
callback = CustomCallback(1, env)
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args))

3
examples/shields/rl/MaskModels.py

@ -34,9 +34,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
)
nn.Module.__init__(self)
assert("shield" in custom_config)
self.shield = custom_config["shield"]
self.count = 0
self.internal_model = TorchFC(

13
examples/shields/rl/Wrapper.py

@ -7,7 +7,7 @@ from gymnasium.spaces import Dict, Box
from collections import deque
from ray.rllib.utils.numpy import one_hot
from helpers import get_action_index_mapping
from helpers import get_action_index_mapping, create_shield_dict, extract_keys
class OneHotWrapper(gym.core.ObservationWrapper):
@ -86,7 +86,7 @@ class OneHotWrapper(gym.core.ObservationWrapper):
class MiniGridEnvWrapper(gym.core.Wrapper):
def __init__(self, env, shield={}, keys=[]):
def __init__(self, env, args=None):
super(MiniGridEnvWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = Dict(
@ -95,8 +95,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
"action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
}
)
self.keys = keys
self.shield = shield
self.args = args
def create_action_mask(self):
coordinates = self.env.agent_pos
@ -140,8 +139,8 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
if front_tile is not None and front_tile.type == "key":
mask[Actions.pickup] = 1.0
if self.env.carrying:
mask[Actions.drop] = 1.0
# if self.env.carrying:
# mask[Actions.drop] = 1.0
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
@ -150,6 +149,8 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
self.shield = create_shield_dict(self.env, self.args)
self.keys = extract_keys(self.env)
mask = self.create_action_mask()
return {
"data": obs["image"],

8
examples/shields/rl/helpers.py

@ -20,7 +20,7 @@ import os
def extract_keys(env):
env.reset()
keys = []
print(env.grid)
#print(env.grid)
for j in range(env.grid.height):
for i in range(env.grid.width):
obj = env.grid.get(i,j)
@ -113,8 +113,8 @@ def create_shield(grid_to_prism_path, grid_file, prism_path):
program = stormpy.parse_prism_program(prism_path)
# formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
# formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY,
# stormpy.logic.ShieldComparison.ABSOLUTE, 0.9)
@ -150,7 +150,7 @@ def create_shield(grid_to_prism_path, grid_file, prism_path):
return action_dictionary
def create_shield_dict(args):
def create_shield_dict(env, args):
env = create_environment(args)
# print(env.printGrid(init=False))
Loading…
Cancel
Save