Browse Source

added logdir handling

chnages to action index handling
refactoring
Thomas Knoll 1 year ago
parent
commit
fab1e8f23f
  1. 47
      examples/shields/rl/11_minigridrl.py
  2. 26
      examples/shields/rl/Wrapper.py
  3. 39
      examples/shields/rl/helpers.py

47
examples/shields/rl/11_minigridrl.py

@ -37,6 +37,7 @@ from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.preprocessors import get_preprocessor
from MaskModels import TorchActionMaskModel from MaskModels import TorchActionMaskModel
from Wrapper import OneHotWrapper, MiniGridEnvWrapper from Wrapper import OneHotWrapper, MiniGridEnvWrapper
from helpers import extract_keys
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -61,12 +62,12 @@ class MyCallbacks(DefaultCallbacks):
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1 episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0] env = base_env.get_sub_environments()[0]
#print(env.env.env.printGrid())
#print(env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}") # print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0] env = base_env.get_sub_environments()[0]
# print(env.env.env.printGrid())
# print(env.printGrid())
# print(episode.user_data["count"]) # print(episode.user_data["count"])
@ -96,17 +97,21 @@ def parse_arguments(argparse):
"MiniGrid-RedBlueDoors-6x6-v0",]) "MiniGrid-RedBlueDoors-6x6-v0",])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None) # parser.add_argument("--seed", type=int, help="seed for environment", default=None)
parser.add_argument("--grid_to_prism_path", default="./main")
parser.add_argument("--grid_path", default="Grid.txt") parser.add_argument("--grid_path", default="Grid.txt")
parser.add_argument("--prism_path", default="Grid.PRISM") parser.add_argument("--prism_path", default="Grid.PRISM")
parser.add_argument("--no_masking", default=False) parser.add_argument("--no_masking", default=False)
parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"]) parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"])
parser.add_argument("--log_dir", default="../log_results/") parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--iterations", type=int, default=30 )
args = parser.parse_args() args = parser.parse_args()
return args return args
def env_creater_custom(config): def env_creater_custom(config):
framestack = config.get("framestack", 4) framestack = config.get("framestack", 4)
shield = config.get("shield", {}) shield = config.get("shield", {})
@ -114,7 +119,8 @@ def env_creater_custom(config):
framestack = config.get("framestack", 4) framestack = config.get("framestack", 4)
env = gym.make(name) env = gym.make(name)
env = MiniGridEnvWrapper(env, shield=shield)
keys = extract_keys(env)
env = MiniGridEnvWrapper(env, shield=shield, keys=keys)
# env = minigrid.wrappers.ImgObsWrapper(env) # env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env) # env = ImgObsWrapper(env)
env = OneHotWrapper(env, env = OneHotWrapper(env,
@ -142,10 +148,12 @@ def env_creater(config):
return env return env
def create_log_dir(args):
return F"{args.log_dir}{datetime.now()}-{args.algorithm}-masking:{not args.no_masking}"
def create_shield(grid_file, prism_path):
os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}")
def create_shield(grid_to_prism_path, grid_file, prism_path):
os.system(F"{grid_to_prism_path} -v 'agent' -i {grid_file} -o {prism_path}")
f = open(prism_path, "a") f = open(prism_path, "a")
f.write("label \"AgentIsInLava\" = AgentIsInLava;") f.write("label \"AgentIsInLava\" = AgentIsInLava;")
@ -154,7 +162,12 @@ def create_shield(grid_file, prism_path):
program = stormpy.parse_prism_program(prism_path) program = stormpy.parse_prism_program(prism_path)
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
#formula_str = "Pmax=? [G \"AgentIsInGoalAndNotDone\"]"
# formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY,
stormpy.logic.ShieldComparison.ABSOLUTE, 0.9)
# shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.9)
formulas = stormpy.parse_properties_for_prism_program(formula_str, program) formulas = stormpy.parse_properties_for_prism_program(formula_str, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
@ -163,7 +176,6 @@ def create_shield(grid_file, prism_path):
options.set_build_all_labels() options.set_build_all_labels()
model = stormpy.build_sparse_model_with_options(program, options) model = stormpy.build_sparse_model_with_options(program, options)
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
assert result.has_scheduler assert result.has_scheduler
@ -189,7 +201,6 @@ def export_grid_to_text(env, grid_file):
f = open(grid_file, "w") f = open(grid_file, "w")
# print(env) # print(env)
f.write(env.printGrid(init=True)) f.write(env.printGrid(init=True))
# f.write(env.pprint_grid())
f.close() f.close()
def create_environment(args): def create_environment(args):
@ -210,14 +221,14 @@ def register_custom_minigrid_env(args):
def create_shield_dict(args): def create_shield_dict(args):
env = create_environment(args) env = create_environment(args)
# print(env.pprint_grid())
# print(env.printGrid(init=False)) # print(env.printGrid(init=False))
grid_file = args.grid_path grid_file = args.grid_path
grid_to_prism_path = args.grid_to_prism_path
export_grid_to_text(env, grid_file) export_grid_to_text(env, grid_file)
prism_path = args.prism_path prism_path = args.prism_path
shield_dict = create_shield(grid_file, prism_path)
shield_dict = create_shield(grid_to_prism_path ,grid_file, prism_path)
#shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} #shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states}
#print(F"Shield dictionary {shield_dict}") #print(F"Shield dictionary {shield_dict}")
@ -238,13 +249,13 @@ def ppo(args):
config = (PPOConfig() config = (PPOConfig()
.rollouts(num_rollout_workers=1) .rollouts(num_rollout_workers=1)
.resources(num_gpus=0) .resources(num_gpus=0)
.environment(env="mini-grid", env_config={"shield": shield_dict})
.environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env})
.framework("torch") .framework("torch")
.callbacks(MyCallbacks) .callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False) .rl_module(_enable_rl_module_api = False)
.debugging(logger_config={ .debugging(logger_config={
"type": "ray.tune.logger.TBXLogger", "type": "ray.tune.logger.TBXLogger",
"logdir": F"{args.log_dir}{datetime.now()}-{args.algorithm}"
"logdir": create_log_dir(args)
}) })
.training(_enable_learner_api=False ,model={ .training(_enable_learner_api=False ,model={
"custom_model": "pa_model", "custom_model": "pa_model",
@ -279,13 +290,13 @@ def dqn(args):
config = DQNConfig() config = DQNConfig()
config = config.resources(num_gpus=0) config = config.resources(num_gpus=0)
config = config.rollouts(num_rollout_workers=1) config = config.rollouts(num_rollout_workers=1)
config = config.environment(env="mini-grid", env_config={"shield": shield_dict })
config = config.environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env })
config = config.framework("torch") config = config.framework("torch")
config = config.callbacks(MyCallbacks) config = config.callbacks(MyCallbacks)
config = config.rl_module(_enable_rl_module_api = False) config = config.rl_module(_enable_rl_module_api = False)
config = config.debugging(logger_config={ config = config.debugging(logger_config={
"type": "ray.tune.logger.TBXLogger", "type": "ray.tune.logger.TBXLogger",
"logdir": F"{args.log_dir}{datetime.now()}-{args.algorithm}"
"logdir": create_log_dir(args)
}) })
config = config.training(hiddens=[], dueling=False, model={ config = config.training(hiddens=[], dueling=False, model={
"custom_model": "pa_model", "custom_model": "pa_model",
@ -296,13 +307,13 @@ def dqn(args):
config.build() config.build()
) )
for i in range(30):
for i in range(args.iterations):
result = algo.train() result = algo.train()
print(pretty_print(result)) print(pretty_print(result))
if i % 5 == 0:
checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}")
# if i % 5 == 0:
# checkpoint_dir = algo.save()
# print(f"Checkpoint saved in directory {checkpoint_dir}")
ray.shutdown() ray.shutdown()

26
examples/shields/rl/Wrapper.py

@ -6,6 +6,8 @@ from gymnasium.spaces import Dict, Box
from collections import deque from collections import deque
from ray.rllib.utils.numpy import one_hot from ray.rllib.utils.numpy import one_hot
from helpers import get_action_index_mapping
class OneHotWrapper(gym.core.ObservationWrapper): class OneHotWrapper(gym.core.ObservationWrapper):
def __init__(self, env, vector_index, framestack): def __init__(self, env, vector_index, framestack):
super().__init__(env) super().__init__(env)
@ -82,7 +84,7 @@ class OneHotWrapper(gym.core.ObservationWrapper):
class MiniGridEnvWrapper(gym.core.Wrapper): class MiniGridEnvWrapper(gym.core.Wrapper):
def __init__(self, env, shield={}):
def __init__(self, env, shield={}, keys=[]):
super(MiniGridEnvWrapper, self).__init__(env) super(MiniGridEnvWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n self.max_available_actions = env.action_space.n
self.observation_space = Dict( self.observation_space = Dict(
@ -91,29 +93,43 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
"action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8), "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
} }
) )
self.keys = keys
self.shield = shield self.shield = shield
def create_action_mask(self): def create_action_mask(self):
coordinates = self.env.agent_pos coordinates = self.env.agent_pos
view_direction = self.env.agent_dir view_direction = self.env.agent_dir
key_text = ""
# only support one key for now
if self.keys:
key_text = F"!Agent_has_{self.keys[0]}_key\t& "
if self.env.carrying and self.env.carrying.type == "key":
key_text = F"Agent_has_{self.env.carrying.color}_key\t& "
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ") #print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
cur_pos_str = f"[!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
allowed_actions = [] allowed_actions = []
# Create the mask # Create the mask
# If shield restricts action mask only valid with 1.0 # If shield restricts action mask only valid with 1.0
# else set everything to one
# else set all actions as valid
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]: if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str] allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions: for allowed_action in allowed_actions:
index = allowed_action[0]
index = get_action_index_mapping(allowed_action[1])
if index is None:
assert(False)
mask[index] = 1.0 mask[index] = 1.0
else: else:
print("Not in shield")
for index, x in enumerate(mask): for index, x in enumerate(mask):
mask[index] = 1.0 mask[index] = 1.0

39
examples/shields/rl/helpers.py

@ -0,0 +1,39 @@
import minigrid
from minigrid.core.actions import Actions
def extract_keys(env):
env.reset()
keys = []
print(env.grid)
for j in range(env.grid.height):
for i in range(env.grid.width):
obj = env.grid.get(i,j)
if obj and obj.type == "key":
keys.append(obj.color)
return keys
def get_action_index_mapping(actions):
for action_str in actions:
if "left" in action_str:
return Actions.left
elif "right" in action_str:
return Actions.right
elif "east" in action_str:
return Actions.forward
elif "south" in action_str:
return Actions.forward
elif "west" in action_str:
return Actions.forward
elif "north" in action_str:
return Actions.forward
elif "pickup" in action_str:
return Actions.pickup
elif "done" in action_str:
return Actions.done
raise ValueError(F"Action string {action_str} not supported")
Loading…
Cancel
Save