From b1b014dbd6721c8b129584fb17eb8946d1d39921 Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Mon, 28 Aug 2023 16:04:15 +0200 Subject: [PATCH] some refactoring as preparation for sb3 example added sb3 example --- examples/shields/rl/11_minigridrl.py | 164 ++------------------------- examples/shields/rl/13_minigridsb.py | 130 +++++++++++++++++++++ examples/shields/rl/MaskModels.py | 18 +-- examples/shields/rl/Wrapper.py | 61 +++++----- examples/shields/rl/helpers.py | 135 +++++++++++++++++++++- 5 files changed, 310 insertions(+), 198 deletions(-) create mode 100644 examples/shields/rl/13_minigridsb.py diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index 497f48b..a7cc555 100644 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -1,21 +1,13 @@ -from typing import Dict, Optional, Union +from typing import Dict from ray.rllib.env.base_env import BaseEnv from ray.rllib.evaluation import RolloutWorker from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.episode_v2 import EpisodeV2 from ray.rllib.policy import Policy from ray.rllib.utils.typing import PolicyID -import stormpy -import stormpy.core -import stormpy.simulator -from datetime import datetime -import stormpy.shields -import stormpy.logic -import stormpy.examples -import stormpy.examples.files -import os +from datetime import datetime import gymnasium as gym @@ -26,8 +18,6 @@ import ray from ray.tune import register_env from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.dqn.dqn import DQNConfig -from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator -from ray import tune, air from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.tune.logger import pretty_print from ray.rllib.models import ModelCatalog @@ -37,11 +27,10 @@ from ray.rllib.utils.torch_utils import FLOAT_MIN from ray.rllib.models.preprocessors import get_preprocessor from MaskModels import TorchActionMaskModel from Wrapper import OneHotWrapper, MiniGridEnvWrapper -from helpers import extract_keys +from helpers import extract_keys, parse_arguments, create_shield_dict import matplotlib.pyplot as plt -import argparse @@ -58,6 +47,7 @@ class MyCallbacks(DefaultCallbacks): # img = env.get_frame() # plt.imshow(img) # plt.show() + def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: episode.user_data["count"] = episode.user_data["count"] + 1 @@ -70,47 +60,7 @@ class MyCallbacks(DefaultCallbacks): # print(env.printGrid()) # print(episode.user_data["count"]) - - - -def parse_arguments(argparse): - parser = argparse.ArgumentParser() - # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0") - parser.add_argument("--env", - help="gym environment to load", - default="MiniGrid-LavaCrossingS9N1-v0", - choices=[ - "MiniGrid-LavaCrossingS9N1-v0", - "MiniGrid-DoorKey-8x8-v0", - "MiniGrid-Dynamic-Obstacles-8x8-v0", - "MiniGrid-Empty-Random-6x6-v0", - "MiniGrid-Fetch-6x6-N2-v0", - "MiniGrid-FourRooms-v0", - "MiniGrid-KeyCorridorS6R3-v0", - "MiniGrid-GoToDoor-8x8-v0", - "MiniGrid-LavaGapS7-v0", - "MiniGrid-SimpleCrossingS9N3-v0", - "MiniGrid-BlockedUnlockPickup-v0", - "MiniGrid-LockedRoom-v0", - "MiniGrid-ObstructedMaze-1Dlh-v0", - "MiniGrid-DoorKey-16x16-v0", - "MiniGrid-RedBlueDoors-6x6-v0",]) - - # parser.add_argument("--seed", type=int, help="seed for environment", default=None) - parser.add_argument("--grid_to_prism_path", default="./main") - parser.add_argument("--grid_path", default="Grid.txt") - parser.add_argument("--prism_path", default="Grid.PRISM") - parser.add_argument("--no_masking", default=False) - parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"]) - parser.add_argument("--log_dir", default="../log_results/") - parser.add_argument("--iterations", type=int, default=30 ) - - args = parser.parse_args() - - return args - - - + def env_creater_custom(config): framestack = config.get("framestack", 4) @@ -130,86 +80,10 @@ def env_creater_custom(config): return env -def env_creater_cart(config): - return gym.make("CartPole-v1") - -def env_creater(config): - name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") - framestack = config.get("framestack", 4) - - env = gym.make(name) - env = minigrid.wrappers.ImgObsWrapper(env) - env = OneHotWrapper(env, - config.vector_index if hasattr(config, "vector_index") else 0, - framestack=framestack - ) - - print(F"Created Minigrid Environment is {env}") - - return env - def create_log_dir(args): return F"{args.log_dir}{datetime.now()}-{args.algorithm}-masking:{not args.no_masking}" -def create_shield(grid_to_prism_path, grid_file, prism_path): - os.system(F"{grid_to_prism_path} -v 'agent' -i {grid_file} -o {prism_path}") - - f = open(prism_path, "a") - f.write("label \"AgentIsInLava\" = AgentIsInLava;") - f.close() - - - program = stormpy.parse_prism_program(prism_path) - formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" - # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" - shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, - stormpy.logic.ShieldComparison.ABSOLUTE, 0.9) - - - # shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.9) - - formulas = stormpy.parse_properties_for_prism_program(formula_str, program) - options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) - options.set_build_state_valuations(True) - options.set_build_choice_labels(True) - options.set_build_all_labels() - model = stormpy.build_sparse_model_with_options(program, options) - - result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) - - assert result.has_scheduler - assert result.has_shield - shield = result.shield - - action_dictionary = {} - shield_scheduler = shield.construct() - - for stateID in model.states: - choice = shield_scheduler.get_choice(stateID) - choices = choice.choice_map - state_valuation = model.state_valuations.get_string(stateID) - - actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices] - - action_dictionary[state_valuation] = actions_to_be_executed - - stormpy.shields.export_shield(model, shield, "Grid.shield") - return action_dictionary - -def export_grid_to_text(env, grid_file): - f = open(grid_file, "w") - # print(env) - f.write(env.printGrid(init=True)) - f.close() - -def create_environment(args): - env_id= args.env - env = gym.make(env_id) - env.reset() - return env - - def register_custom_minigrid_env(args): env_name = "mini-grid" register_env(env_name, env_creater_custom) @@ -218,25 +92,7 @@ def register_custom_minigrid_env(args): "pa_model", TorchActionMaskModel ) - -def create_shield_dict(args): - env = create_environment(args) - # print(env.printGrid(init=False)) - - grid_file = args.grid_path - grid_to_prism_path = args.grid_to_prism_path - export_grid_to_text(env, grid_file) - - prism_path = args.prism_path - shield_dict = create_shield(grid_to_prism_path ,grid_file, prism_path) - #shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} - - #print(F"Shield dictionary {shield_dict}") - # for state_id in model.states: - # choices = shield.get_choice(state_id) - # print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") - - return shield_dict + def ppo(args): @@ -311,14 +167,16 @@ def dqn(args): result = algo.train() print(pretty_print(result)) - # if i % 5 == 0: - # checkpoint_dir = algo.save() - # print(f"Checkpoint saved in directory {checkpoint_dir}") + if i % 5 == 0: + print("Saving checkpoint") + checkpoint_dir = algo.save() + print(f"Checkpoint saved in directory {checkpoint_dir}") ray.shutdown() def main(): + import argparse args = parse_arguments(argparse) if args.algorithm == "ppo": diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py new file mode 100644 index 0000000..73299e5 --- /dev/null +++ b/examples/shields/rl/13_minigridsb.py @@ -0,0 +1,130 @@ +from sb3_contrib import MaskablePPO +from sb3_contrib.common.maskable.evaluation import evaluate_policy +from sb3_contrib.common.maskable.utils import get_action_masks +from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy +from sb3_contrib.common.wrappers import ActionMasker +from stable_baselines3.common.callbacks import BaseCallback + +import gymnasium as gym +from gymnasium.spaces import Dict, Box + +import numpy as np +import time + +from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping + +class CustomCallback(BaseCallback): + def __init__(self, verbose: int = 0, env=None): + super(CustomCallback, self).__init__(verbose) + self.env = env + + + def _on_step(self) -> bool: + #print(self.env.printGrid()) + return super()._on_step() + + +class MiniGridEnvWrapper(gym.core.Wrapper): + def __init__(self, env, shield={}, keys=[], no_masking=False): + super(MiniGridEnvWrapper, self).__init__(env) + self.max_available_actions = env.action_space.n + self.observation_space = env.observation_space.spaces["image"] + + self.keys = keys + self.shield = shield + self.no_masking = no_masking + + def create_action_mask(self): + coordinates = self.env.agent_pos + view_direction = self.env.agent_dir + + key_text = "" + + # only support one key for now + if self.keys: + key_text = F"!Agent_has_{self.keys[0]}_key\t& " + + + if self.env.carrying and self.env.carrying.type == "key": + key_text = F"Agent_has_{self.env.carrying.color}_key\t& " + + #print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ") + cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" + + allowed_actions = [] + + + # Create the mask + # If shield restricts action mask only valid with 1.0 + # else set all actions as valid + mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) + + if cur_pos_str in self.shield and self.shield[cur_pos_str]: + allowed_actions = self.shield[cur_pos_str] + for allowed_action in allowed_actions: + index = get_action_index_mapping(allowed_action[1]) + if index is None: + assert(False) + mask[index] = 1.0 + else: + # print(F"Not in shield {cur_pos_str}") + for index, x in enumerate(mask): + mask[index] = 1.0 + + if self.no_masking: + return np.array([1.0] * self.max_available_actions, dtype=np.int8) + + return mask + + def reset(self, *, seed=None, options=None): + obs, infos = self.env.reset(seed=seed, options=options) + return obs["image"], infos + + def step(self, action): + # print(F"Performed action in step: {action}") + orig_obs, rew, done, truncated, info = self.env.step(action) + + #print(F"Original observation is {orig_obs}") + obs = orig_obs["image"] + + #print(F"Info is {info}") + return obs, rew, done, truncated, info + + + +def mask_fn(env: gym.Env): + return env.create_action_mask() + + + +def main(): + import argparse + args = parse_arguments(argparse) + shield = create_shield_dict(args) + + env = gym.make(args.env, render_mode="rgb_array") + keys = extract_keys(env) + env = MiniGridEnvWrapper(env, shield=shield, keys=keys, no_masking=args.no_masking) + env = ActionMasker(env, mask_fn) + callback = CustomCallback(1, env) + model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=args.log_dir) + model.learn(args.iterations, callback=callback) + + mean_reward, std_reward = evaluate_policy(model, model.get_env(), 10) + + vec_env = model.get_env() + obs = vec_env.reset() + terminated = truncated = False + while not terminated and not truncated: + action_masks = None + action, _states = model.predict(obs, action_masks=action_masks) + obs, reward, terminated, truncated, info = env.step(action) + # action, _states = model.predict(obs, deterministic=True) + # obs, rewards, dones, info = vec_env.step(action) + vec_env.render("human") + time.sleep(0.2) + + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/shields/rl/MaskModels.py b/examples/shields/rl/MaskModels.py index e36c6da..607b529 100644 --- a/examples/shields/rl/MaskModels.py +++ b/examples/shields/rl/MaskModels.py @@ -54,19 +54,12 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): def forward(self, input_dict, state, seq_lens): # Extract the available actions tensor from the observation. - # print(F"Input dict is {input_dict} at obs: {input_dict['obs']}") - # print(F"State is {state}") - - # print(input_dict["env"]) - - # Compute the unmasked logits. + # Compute the unmasked logits. logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]}) - # print(F"Caluclated Logits {logits} with size {logits.size()} Count: {self.count}") - + action_mask = input_dict["obs"]["action_mask"] - #print(F"Action mask is {action_mask} with dimension {action_mask.size()}") - + # If action masking is disabled, directly return unmasked logits if self.no_masking: return logits, state @@ -74,12 +67,9 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): # assert(False) # Convert action_mask into a [0.0 || -inf]-type mask. inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) - # print(F"Logits Size: {logits.size()} Inf-Mask Size: {inf_mask.size()}") - # print(F"Logits:{logits} Inf-Mask: {inf_mask}") masked_logits = logits + inf_mask - # print(F"Infinity mask {inf_mask}, Masked logits {masked_logits}") - + # # Return masked logits. return masked_logits, state diff --git a/examples/shields/rl/Wrapper.py b/examples/shields/rl/Wrapper.py index d67e364..d67abbb 100644 --- a/examples/shields/rl/Wrapper.py +++ b/examples/shields/rl/Wrapper.py @@ -8,6 +8,7 @@ from ray.rllib.utils.numpy import one_hot from helpers import get_action_index_mapping + class OneHotWrapper(gym.core.ObservationWrapper): def __init__(self, env, vector_index, framestack): super().__init__(env) @@ -30,11 +31,11 @@ class OneHotWrapper(gym.core.ObservationWrapper): "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32), "action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int), } - ) - - + ) + + # print(F"Set obersvation space to {self.observation_space}") - + def observation(self, obs): # Debug output: max-x/y positions to watch exploration progress. @@ -61,23 +62,23 @@ class OneHotWrapper(gym.core.ObservationWrapper): self.init_x = self.agent_pos[0] self.init_y = self.agent_pos[1] - + self.x_positions.append(self.agent_pos[0]) self.y_positions.append(self.agent_pos[1]) - + image = obs["data"] # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten. objects = one_hot(image[:, :, 0], depth=11) colors = one_hot(image[:, :, 1], depth=6) states = one_hot(image[:, :, 2], depth=3) - + all_ = np.concatenate([objects, colors, states], -1) all_flat = np.reshape(all_, (-1,)) direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32) single_frame = np.concatenate([all_flat, direction]) self.frame_buffer.append(single_frame) - + #obs["one-hot"] = np.concatenate(self.frame_buffer) tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] } return tmp#np.concatenate(self.frame_buffer) @@ -95,49 +96,49 @@ class MiniGridEnvWrapper(gym.core.Wrapper): ) self.keys = keys self.shield = shield - + def create_action_mask(self): coordinates = self.env.agent_pos view_direction = self.env.agent_dir - + key_text = "" - + # only support one key for now if self.keys: key_text = F"!Agent_has_{self.keys[0]}_key\t& " - - + + if self.env.carrying and self.env.carrying.type == "key": key_text = F"Agent_has_{self.env.carrying.color}_key\t& " #print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ") cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" - + allowed_actions = [] - - + + # Create the mask # If shield restricts action mask only valid with 1.0 # else set all actions as valid mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) - + if cur_pos_str in self.shield and self.shield[cur_pos_str]: allowed_actions = self.shield[cur_pos_str] for allowed_action in allowed_actions: - index = get_action_index_mapping(allowed_action[1]) + index = get_action_index_mapping(allowed_action[1]) if index is None: - assert(False) + assert(False) mask[index] = 1.0 else: - print("Not in shield") + # print(F"Not in shield {cur_pos_str}") for index, x in enumerate(mask): mask[index] = 1.0 - - - #print(F"Action Mask for position {coordinates} and view {view_direction} is {mask} Position String: {cur_pos_str})") - + + # mask[0] = 1.0 + # print(F"Action Mask for position {coordinates} and view {view_direction} is {mask} Position String: {cur_pos_str})") + return mask - + def reset(self, *, seed=None, options=None): obs, infos = self.env.reset(seed=seed, options=options) mask = self.create_action_mask() @@ -145,19 +146,19 @@ class MiniGridEnvWrapper(gym.core.Wrapper): "data": obs["image"], "action_mask": mask }, infos - + def step(self, action): # print(F"Performed action in step: {action}") orig_obs, rew, done, truncated, info = self.env.step(action) - + mask = self.create_action_mask() #print(F"Original observation is {orig_obs}") obs = { "data": orig_obs["image"], "action_mask": mask, } - + #print(F"Info is {info}") return obs, rew, done, truncated, info - - + + diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index aa8738b..111c3a6 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -1,5 +1,18 @@ import minigrid from minigrid.core.actions import Actions +import gymnasium as gym + +import stormpy +import stormpy.core +import stormpy.simulator + +import stormpy.shields +import stormpy.logic + +import stormpy.examples +import stormpy.examples.files + +import os def extract_keys(env): @@ -36,4 +49,124 @@ def get_action_index_mapping(actions): return Actions.done - raise ValueError(F"Action string {action_str} not supported") \ No newline at end of file + raise ValueError(F"Action string {action_str} not supported") + + + +def parse_arguments(argparse): + parser = argparse.ArgumentParser() + # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0") + parser.add_argument("--env", + help="gym environment to load", + default="MiniGrid-LavaCrossingS9N1-v0", + choices=[ + "MiniGrid-LavaCrossingS9N1-v0", + "MiniGrid-DoorKey-8x8-v0", + "MiniGrid-Dynamic-Obstacles-8x8-v0", + "MiniGrid-Empty-Random-6x6-v0", + "MiniGrid-Fetch-6x6-N2-v0", + "MiniGrid-FourRooms-v0", + "MiniGrid-KeyCorridorS6R3-v0", + "MiniGrid-GoToDoor-8x8-v0", + "MiniGrid-LavaGapS7-v0", + "MiniGrid-SimpleCrossingS9N3-v0", + "MiniGrid-BlockedUnlockPickup-v0", + "MiniGrid-LockedRoom-v0", + "MiniGrid-ObstructedMaze-1Dlh-v0", + "MiniGrid-DoorKey-16x16-v0", + "MiniGrid-RedBlueDoors-6x6-v0",]) + + # parser.add_argument("--seed", type=int, help="seed for environment", default=None) + parser.add_argument("--grid_to_prism_path", default="./main") + parser.add_argument("--grid_path", default="Grid.txt") + parser.add_argument("--prism_path", default="Grid.PRISM") + parser.add_argument("--no_masking", default=False) + parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"]) + parser.add_argument("--log_dir", default="../log_results/") + parser.add_argument("--iterations", type=int, default=30 ) + + args = parser.parse_args() + + return args + + + +def create_environment(args): + env_id= args.env + env = gym.make(env_id) + env.reset() + return env + + +def export_grid_to_text(env, grid_file): + f = open(grid_file, "w") + # print(env) + f.write(env.printGrid(init=True)) + f.close() + + +def create_shield(grid_to_prism_path, grid_file, prism_path): + os.system(F"{grid_to_prism_path} -v 'agent' -i {grid_file} -o {prism_path}") + + f = open(prism_path, "a") + f.write("label \"AgentIsInLava\" = AgentIsInLava;") + f.close() + + + program = stormpy.parse_prism_program(prism_path) + formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" + # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" + # shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, + # stormpy.logic.ShieldComparison.ABSOLUTE, 0.9) + + shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) + # shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.9) + + formulas = stormpy.parse_properties_for_prism_program(formula_str, program) + options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) + options.set_build_state_valuations(True) + options.set_build_choice_labels(True) + options.set_build_all_labels() + model = stormpy.build_sparse_model_with_options(program, options) + + result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) + + assert result.has_scheduler + assert result.has_shield + shield = result.shield + + action_dictionary = {} + shield_scheduler = shield.construct() + + for stateID in model.states: + choice = shield_scheduler.get_choice(stateID) + choices = choice.choice_map + state_valuation = model.state_valuations.get_string(stateID) + + actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices] + + action_dictionary[state_valuation] = actions_to_be_executed + + stormpy.shields.export_shield(model, shield, "Grid.shield") + return action_dictionary + + +def create_shield_dict(args): + env = create_environment(args) + # print(env.printGrid(init=False)) + + grid_file = args.grid_path + grid_to_prism_path = args.grid_to_prism_path + export_grid_to_text(env, grid_file) + + prism_path = args.prism_path + shield_dict = create_shield(grid_to_prism_path ,grid_file, prism_path) + #shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} + + #print(F"Shield dictionary {shield_dict}") + # for state_id in model.states: + # choices = shield.get_choice(state_id) + # print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") + + return shield_dict +