Browse Source

arguments and log dir

refactoring
Thomas Knoll 1 year ago
parent
commit
7f20c3f909
  1. 60
      examples/shields/rl/11_minigridrl.py
  2. 91
      examples/shields/rl/MaskEnvironments.py
  3. 5
      examples/shields/rl/Wrapper.py

60
examples/shields/rl/11_minigridrl.py

@ -8,7 +8,7 @@ from ray.rllib.utils.typing import PolicyID
import stormpy
import stormpy.core
import stormpy.simulator
from datetime import datetime
import stormpy.shields
import stormpy.logic
@ -35,7 +35,6 @@ from ray.rllib.models import ModelCatalog
from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.models.preprocessors import get_preprocessor
from MaskEnvironments import ParametricActionsMiniGridEnv
from MaskModels import TorchActionMaskModel
from Wrapper import OneHotWrapper, MiniGridEnvWrapper
@ -76,14 +75,32 @@ class MyCallbacks(DefaultCallbacks):
def parse_arguments(argparse):
parser = argparse.ArgumentParser()
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
parser.add_argument("--env", help="gym environment to load", default="MiniGrid-LavaCrossingS9N1-v0")
parser.add_argument("--seed", type=int, help="seed for environment", default=1)
parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32)
parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees")
parser.add_argument("--env",
help="gym environment to load",
default="MiniGrid-LavaCrossingS9N1-v0",
choices=[
"MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-DoorKey-8x8-v0",
"MiniGrid-Dynamic-Obstacles-8x8-v0",
"MiniGrid-Empty-Random-6x6-v0",
"MiniGrid-Fetch-6x6-N2-v0",
"MiniGrid-FourRooms-v0",
"MiniGrid-KeyCorridorS6R3-v0",
"MiniGrid-GoToDoor-8x8-v0",
"MiniGrid-LavaGapS7-v0",
"MiniGrid-SimpleCrossingS9N3-v0",
"MiniGrid-BlockedUnlockPickup-v0",
"MiniGrid-LockedRoom-v0",
"MiniGrid-ObstructedMaze-1Dlh-v0",
"MiniGrid-DoorKey-16x16-v0",
"MiniGrid-RedBlueDoors-6x6-v0",])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None)
parser.add_argument("--grid_path", default="Grid.txt")
parser.add_argument("--prism_path", default="Grid.PRISM")
parser.add_argument("--no_masking", default=False)
parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"])
parser.add_argument("--log_dir", default="../log_results/")
args = parser.parse_args()
@ -91,12 +108,8 @@ def parse_arguments(argparse):
def env_creater_custom(config):
# name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
# # name = config.get("name", "MiniGrid-Empty-8x8-v0")
framestack = config.get("framestack", 4)
shield = config.get("shield", {})
# env = gym.make(name)
# env = ParametricActionsMiniGridEnv(config)
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
@ -109,14 +122,6 @@ def env_creater_custom(config):
framestack=framestack
)
# obs = env.observation_space.sample()
# obs2, infos = env.reset(seed=None, options={})
# print(F"Obs is {obs} before reset. After reset: {obs2}")
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env)
# print(F"Created Custom Minigrid Environment is {env}")
return env
def env_creater_cart(config):
@ -124,11 +129,9 @@ def env_creater_cart(config):
def env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
# name = config.get("name", "MiniGrid-Empty-8x8-v0")
framestack = config.get("framestack", 4)
env = gym.make(name)
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env)
env = minigrid.wrappers.ImgObsWrapper(env)
env = OneHotWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
@ -151,6 +154,7 @@ def create_shield(grid_file, prism_path):
program = stormpy.parse_prism_program(prism_path)
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
#formula_str = "Pmax=? [G \"AgentIsInGoalAndNotDone\"]"
formulas = stormpy.parse_properties_for_prism_program(formula_str, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
@ -234,18 +238,18 @@ def ppo(args):
config = (PPOConfig()
.rollouts(num_rollout_workers=1)
.resources(num_gpus=0)
.environment(env="mini-grid", env_config={"shield": shield_dict })
.environment(env="mini-grid", env_config={"shield": shield_dict})
.framework("torch")
.callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False)
.debugging(logger_config={
"type": "ray.tune.logger.TBXLogger",
"logdir": F"{args.log_dir}{datetime.now()}-{args.algorithm}"
})
.training(_enable_learner_api=False ,model={
"custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
# "fcnet_hiddens": [256,256],
# "fcnet_activation": "relu",
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
}))
algo =(
@ -279,6 +283,10 @@ def dqn(args):
config = config.framework("torch")
config = config.callbacks(MyCallbacks)
config = config.rl_module(_enable_rl_module_api = False)
config = config.debugging(logger_config={
"type": "ray.tune.logger.TBXLogger",
"logdir": F"{args.log_dir}{datetime.now()}-{args.algorithm}"
})
config = config.training(hiddens=[], dueling=False, model={
"custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}

91
examples/shields/rl/MaskEnvironments.py

@ -1,91 +0,0 @@
import random
import minigrid
import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box, Dict, Discrete
from Wrapper import OneHotWrapper
class ParametricActionsMiniGridEnv(gym.Env):
"""Parametric action version of MiniGrid.
"""
def __init__(self, config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
self.left_action_embed = np.random.randn(2)
self.right_action_embed = np.random.randn(2)
framestack = config.get("framestack", 4)
# env = gym.make(name)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = OneHotWrapper(env,
# config.vector_index if hasattr(config, "vector_index") else 0,
# framestack=framestack
# )
self.wrapped = gym.make(name)
# self.observation_space = Dict(
# {
# "action_mask": None,
# "avail_actions": None,
# "cart": self.wrapped.observation_space,
# }
# )
print(F"Wrapped environment is {self.wrapped}")
self.step_count = 0
self.action_space = self.wrapped.action_space
self.observation_space = self.wrapped.observation_space
def update_avail_actions(self):
self.action_assignments = np.array(
[[0.0, 0.0]] * self.action_space.n, dtype=np.float32
)
self.action_mask = np.array([0.0] * self.action_space.n, dtype=np.int8)
self.left_idx, self.right_idx = random.sample(range(self.action_space.n), 2)
self.action_assignments[self.left_idx] = self.left_action_embed
self.action_assignments[self.right_idx] = self.right_action_embed
self.action_mask[self.left_idx] = 1
self.action_mask[self.right_idx] = 1
def reset(self, *, seed=None, options=None):
self.update_avail_actions()
obs, infos = self.wrapped.reset()
return obs, infos
return {
"action_mask": self.action_mask,
"avail_action": self.action_assignments,
"cart": obs,
}, infos
def step(self, action):
if action == self.left_idx:
actual_action = 0
elif action == self.right_idx:
actual_action = 1
else:
actual_action = 0
# raise ValueError(
# "Chosen action was not one of the non-zero action embeddings",
# action,
# self.action_assignments,
# self.action_mask,
# self.left_idx,
# self.right_idx,
# )
orig_obs, rew, done, truncated, info = self.wrapped.step(actual_action)
self.update_avail_actions()
self.action_mask = self.action_mask.astype(np.int8)
print(F"Info is {info}")
info["Hello" : "Ich kenn mich nix aus"]
return orig_obs, rew, done, truncated, info
obs = {
"action_mask": self.action_mask,
"action_mask": self.action_assignments,
"cart": orig_obs,
}
return obs, rew, done, truncated, info

5
examples/shields/rl/Wrapper.py

@ -82,7 +82,7 @@ class OneHotWrapper(gym.core.ObservationWrapper):
class MiniGridEnvWrapper(gym.core.Wrapper):
def __init__(self, env, shield):
def __init__(self, env, shield={}):
super(MiniGridEnvWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = Dict(
@ -91,7 +91,6 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
"action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
}
)
self.shield = shield
@ -124,7 +123,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset()
obs, infos = self.env.reset(seed=seed, options=options)
mask = self.create_action_mask()
return {
"data": obs["image"],

Loading…
Cancel
Save