From 1cbaac75cb08355c4011001cc1c0e3fccaa6f2dc Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Tue, 2 Jan 2024 22:45:04 +0100 Subject: [PATCH] cleanups --- examples/shields/rl/callbacks.py | 23 +-- examples/shields/rl/helpers.py | 149 ------------------ examples/shields/rl/rllibutils.py | 5 +- examples/shields/rl/sb3utils.py | 4 +- examples/shields/rl/utils.py | 8 +- examples/shields/rl/wrappers.py | 252 ------------------------------ 6 files changed, 13 insertions(+), 428 deletions(-) delete mode 100644 examples/shields/rl/helpers.py delete mode 100644 examples/shields/rl/wrappers.py diff --git a/examples/shields/rl/callbacks.py b/examples/shields/rl/callbacks.py index 435d33e..4cdbac5 100644 --- a/examples/shields/rl/callbacks.py +++ b/examples/shields/rl/callbacks.py @@ -15,11 +15,9 @@ from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callback import matplotlib.pyplot as plt -import tensorflow as tf -class MyCallbacks(DefaultCallbacks): +class CustomCallback(DefaultCallbacks): def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: - # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") env = base_env.get_sub_environments()[0] episode.user_data["count"] = 0 episode.user_data["ran_into_lava"] = [] @@ -29,28 +27,11 @@ class MyCallbacks(DefaultCallbacks): episode.hist_data["goals_reached"] = [] episode.hist_data["ran_into_adversary"] = [] - # print("On episode start print") - # print(env.printGrid()) - # print(worker) - # print(env.action_space.n) - # print(env.actions) - # print(env.mission) - # print(env.observation_space) - # plt.imshow(img) - # plt.show() - def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: episode.user_data["count"] = episode.user_data["count"] + 1 env = base_env.get_sub_environments()[0] - # print(env.printGrid()) - - if hasattr(env, "adversaries"): - for adversary in env.adversaries.values(): - if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]: - print(F"Adversary ran into agent. Adversary {adversary.cur_pos}, Agent {env.agent_pos}") - # assert False - + def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py deleted file mode 100644 index f06e463..0000000 --- a/examples/shields/rl/helpers.py +++ /dev/null @@ -1,149 +0,0 @@ -import minigrid -from minigrid.core.actions import Actions - -from datetime import datetime -from enum import Enum - -import os - -import stormpy -import stormpy.core -import stormpy.simulator - -import stormpy.shields -import stormpy.logic - -import stormpy.examples -import stormpy.examples.files - -class ShieldingConfig(Enum): - Training = 'training' - Evaluation = 'evaluation' - Disabled = 'none' - Full = 'full' - - def __str__(self) -> str: - return self.value - - -def extract_keys(env): - keys = [] - for j in range(env.grid.height): - for i in range(env.grid.width): - obj = env.grid.get(i,j) - - if obj and obj.type == "key": - keys.append((obj, i, j)) - - if env.carrying and env.carrying.type == "key": - keys.append((env.carrying, -1, -1)) - # TODO Maybe need to add ordering of keys so it matches the order in the shield - return keys - -def extract_doors(env): - doors = [] - for j in range(env.grid.height): - for i in range(env.grid.width): - obj = env.grid.get(i,j) - - if obj and obj.type == "door": - doors.append(obj) - - return doors - -def extract_adversaries(env): - adv = [] - - if not hasattr(env, "adversaries"): - return [] - - for color, adversary in env.adversaries.items(): - adv.append(adversary) - - - return adv - -def create_log_dir(args): - return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparision}-env:{args.env}-conf:{args.prism_config}" - -def test_name(args): - return F"{args.expname}" - -def get_action_index_mapping(actions): - for action_str in actions: - if not "Agent" in action_str: - continue - - if "move" in action_str: - return Actions.forward - elif "left" in action_str: - return Actions.left - elif "right" in action_str: - return Actions.right - elif "pickup" in action_str: - return Actions.pickup - elif "done" in action_str: - return Actions.done - elif "drop" in action_str: - return Actions.drop - elif "toggle" in action_str: - return Actions.toggle - elif "unlock" in action_str: - return Actions.toggle - - raise ValueError("No action mapping found") - - -def parse_arguments(argparse): - parser = argparse.ArgumentParser() - # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0") - parser.add_argument("--env", - help="gym environment to load", - default="MiniGrid-LavaSlipperyS12-v2", - choices=[ - "MiniGrid-Adv-8x8-v0", - "MiniGrid-AdvSimple-8x8-v0", - "MiniGrid-SingleDoor-7x6-v0", - "MiniGrid-LavaCrossingS9N1-v0", - "MiniGrid-LavaCrossingS9N3-v0", - "MiniGrid-LavaSlipperyS12-v0", - "MiniGrid-LavaSlipperyS12-v1", - "MiniGrid-LavaSlipperyS12-v2", - "MiniGrid-LavaSlipperyS12-v3", - "MiniGrid-DoorKey-8x8-v0", - # "MiniGrid-DoubleDoor-16x16-v0", - # "MiniGrid-DoubleDoor-12x12-v0", - # "MiniGrid-DoubleDoor-10x8-v0", - # "MiniGrid-LockedRoom-v0", - # "MiniGrid-FourRooms-v0", - # "MiniGrid-LavaGapS7-v0", - # "MiniGrid-SimpleCrossingS9N3-v0", - # "MiniGrid-DoorKey-16x16-v0", - # "MiniGrid-Empty-Random-6x6-v0", - ]) - - # parser.add_argument("--seed", type=int, help="seed for environment", default=None) - parser.add_argument("--grid_to_prism_binary_path", default="./main") - parser.add_argument("--grid_path", default="grid") - parser.add_argument("--prism_path", default="grid") - parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"]) - parser.add_argument("--log_dir", default="../log_results/") - parser.add_argument("--evaluations", type=int, default=30 ) - parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" - # parser.add_argument("--formula", default="<> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]") - parser.add_argument("--workers", type=int, default=1) - parser.add_argument("--num_gpus", type=float, default=0) - parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) - parser.add_argument("--steps", default=20_000, type=int) - parser.add_argument("--expname", default="exp") - parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction) - parser.add_argument("--prism_config", default=None) - parser.add_argument("--shield_value", default=0.9, type=float) - parser.add_argument("--prob_direct", default=1/4, type=float) - parser.add_argument("--prob_forward", default=3/4, type=float) - parser.add_argument("--prob_next", default=1/8, type=float) - parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute']) - # parser.add_argument("--random_starts", default=1, type=int) - args = parser.parse_args() - - return args diff --git a/examples/shields/rl/rllibutils.py b/examples/shields/rl/rllibutils.py index 03b8253..c5c431d 100644 --- a/examples/shields/rl/rllibutils.py +++ b/examples/shields/rl/rllibutils.py @@ -9,8 +9,7 @@ from gymnasium.spaces import Dict, Box from collections import deque from ray.rllib.utils.numpy import one_hot -from helpers import get_action_index_mapping -from shieldhandlers import ShieldHandler +from utils import get_action_index_mapping, MiniGridShieldHandler, create_shield_query, ShieldingConfig class OneHotShieldingWrapper(gym.core.ObservationWrapper): @@ -85,7 +84,7 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): class MiniGridShieldingWrapper(gym.core.Wrapper): def __init__(self, env, - shield_creator : ShieldHandler, + shield_creator : MiniGridShieldHandler, shield_query_creator, create_shield_at_reset=True, mask_actions=True): diff --git a/examples/shields/rl/sb3utils.py b/examples/shields/rl/sb3utils.py index f0798d2..d81cffe 100644 --- a/examples/shields/rl/sb3utils.py +++ b/examples/shields/rl/sb3utils.py @@ -2,10 +2,12 @@ import gymnasium as gym import numpy as np import random +from utils import MiniGridShieldHandler, create_shield_query + class MiniGridSbShieldingWrapper(gym.core.Wrapper): def __init__(self, env, - shield_creator : ShieldHandler, + shield_creator : MiniGridShieldHandler, shield_query_creator, create_shield_at_reset = True, mask_actions=True, diff --git a/examples/shields/rl/utils.py b/examples/shields/rl/utils.py index cab65be..c0eeb70 100644 --- a/examples/shields/rl/utils.py +++ b/examples/shields/rl/utils.py @@ -8,10 +8,12 @@ import stormpy.logic import stormpy.examples import stormpy.examples.files - -from helpers import extract_doors, extract_keys, extract_adversaries +from enum import Enum from abc import ABC + +from minigrid.core.actions import Actions + import os import time class Action(): @@ -66,6 +68,7 @@ class MiniGridShieldHandler(ShieldHandler): shield_comp = stormpy.logic.ShieldComparison.ABSOLUTE shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comp, self.shield_value) + formulas = stormpy.parse_properties_for_prism_program(self.formula, program) options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) @@ -82,6 +85,7 @@ class MiniGridShieldHandler(ShieldHandler): shield_scheduler = shield.construct() state_valuations = model.state_valuations choice_labeling = model.choice_labeling + stormpy.shields.export_shield(model, shield, "myshield") for stateID in model.states: choice = shield_scheduler.get_choice(stateID) diff --git a/examples/shields/rl/wrappers.py b/examples/shields/rl/wrappers.py deleted file mode 100644 index bbd9bac..0000000 --- a/examples/shields/rl/wrappers.py +++ /dev/null @@ -1,252 +0,0 @@ -import gymnasium as gym -import numpy as np -import random - -from minigrid.core.actions import Actions -from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX - -from gymnasium.spaces import Dict, Box -from collections import deque -from ray.rllib.utils.numpy import one_hot - -from helpers import get_action_index_mapping -from shieldhandlers import ShieldHandler - - -class OneHotShieldingWrapper(gym.core.ObservationWrapper): - def __init__(self, env, vector_index, framestack): - super().__init__(env) - self.framestack = framestack - # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types. - # +4: Direction. - self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4 - self.init_x = None - self.init_y = None - self.x_positions = [] - self.y_positions = [] - self.x_y_delta_buffer = deque(maxlen=100) - self.vector_index = vector_index - self.frame_buffer = deque(maxlen=self.framestack) - for _ in range(self.framestack): - self.frame_buffer.append(np.zeros((self.single_frame_dim,))) - - self.observation_space = Dict( - { - "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32), - "action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int), - } - ) - - def observation(self, obs): - # Debug output: max-x/y positions to watch exploration progress. - # print(F"Initial observation in Wrapper {obs}") - if self.step_count == 0: - for _ in range(self.framestack): - self.frame_buffer.append(np.zeros((self.single_frame_dim,))) - if self.vector_index == 0: - if self.x_positions: - max_diff = max( - np.sqrt( - (np.array(self.x_positions) - self.init_x) ** 2 - + (np.array(self.y_positions) - self.init_y) ** 2 - ) - ) - self.x_y_delta_buffer.append(max_diff) - print( - "100-average dist travelled={}".format( - np.mean(self.x_y_delta_buffer) - ) - ) - self.x_positions = [] - self.y_positions = [] - self.init_x = self.agent_pos[0] - self.init_y = self.agent_pos[1] - - - self.x_positions.append(self.agent_pos[0]) - self.y_positions.append(self.agent_pos[1]) - - image = obs["data"] - # One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten. - objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX)) - colors = one_hot(image[:, :, 1], depth=len(COLORS)) - states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX)) - - all_ = np.concatenate([objects, colors, states], -1) - all_flat = np.reshape(all_, (-1,)) - direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32) - single_frame = np.concatenate([all_flat, direction]) - self.frame_buffer.append(single_frame) - - tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] } - return tmp - - -class MiniGridShieldingWrapper(gym.core.Wrapper): - def __init__(self, - env, - shield_creator : ShieldHandler, - shield_query_creator, - create_shield_at_reset=True, - mask_actions=True): - super(MiniGridShieldingWrapper, self).__init__(env) - self.max_available_actions = env.action_space.n - self.observation_space = Dict( - { - "data": env.observation_space.spaces["image"], - "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8), - } - ) - self.shield_creator = shield_creator - self.create_shield_at_reset = create_shield_at_reset - self.shield = shield_creator.create_shield(env=self.env) - self.mask_actions = mask_actions - self.shield_query_creator = shield_query_creator - print(F"Shielding is {self.mask_actions}") - - def create_action_mask(self): - # print(F"{self.mask_actions} No shielding") - if not self.mask_actions: - # print("No shielding") - ret = np.array([1.0] * self.max_available_actions, dtype=np.int8) - # print(ret) - return ret - - cur_pos_str = self.shield_query_creator(self.env) - # print(F"Pos string {cur_pos_str}") - # print(F"Shield {list(self.shield.keys())[0]}") - # print(F"Is pos str in shield: {cur_pos_str in self.shield}, Position Str {cur_pos_str}") - # Create the mask - # If shield restricts action mask only valid with 1.0 - # else set all actions as valid - allowed_actions = [] - mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) - - if cur_pos_str in self.shield and self.shield[cur_pos_str]: - allowed_actions = self.shield[cur_pos_str] - zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8) - has_allowed_actions = False - - for allowed_action in allowed_actions: - index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set - if index is None: - print(F"No mapping for action {list(allowed_action.labels)}") - print(F"Shield at pos {cur_pos_str}, shield {self.shield[cur_pos_str]}") - assert(False) - - allowed = 1.0 # random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] - if allowed_action.prob == 0 and allowed: - assert False - if allowed: - has_allowed_actions = True - mask[index] = allowed - - # if not has_allowed_actions: - # print(F"No action allowed for pos string {cur_pos_str}") - # assert(False) - - else: - for index, x in enumerate(mask): - mask[index] = 1.0 - - front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) - - if front_tile is not None and front_tile.type == "key": - mask[Actions.pickup] = 1.0 - - - if front_tile and front_tile.type == "door": - mask[Actions.toggle] = 1.0 - # print(F"Mask is {mask} State: {cur_pos_str}") - return mask - - def reset(self, *, seed=None, options=None): - obs, infos = self.env.reset(seed=seed, options=options) - - if self.create_shield_at_reset and self.mask_actions: - self.shield = self.shield_creator.create_shield(env=self.env) - - mask = self.create_action_mask() - return { - "data": obs["image"], - "action_mask": mask - }, infos - - def step(self, action): - orig_obs, rew, done, truncated, info = self.env.step(action) - - mask = self.create_action_mask() - obs = { - "data": orig_obs["image"], - "action_mask": mask, - } - - return obs, rew, done, truncated, info - - - -class MiniGridSbShieldingWrapper(gym.core.Wrapper): - def __init__(self, - env, - shield_creator : ShieldHandler, - shield_query_creator, - create_shield_at_reset = True, - mask_actions=True, - ): - super(MiniGridSbShieldingWrapper, self).__init__(env) - self.max_available_actions = env.action_space.n - self.observation_space = env.observation_space.spaces["image"] - - self.shield_creator = shield_creator - self.mask_actions = mask_actions - self.shield_query_creator = shield_query_creator - print(F"Shielding is {self.mask_actions}") - - def create_action_mask(self): - if not self.mask_actions: - return np.array([1.0] * self.max_available_actions, dtype=np.int8) - - cur_pos_str = self.shield_query_creator(self.env) - - allowed_actions = [] - - # Create the mask - # If shield restricts actions, mask only valid actions with 1.0 - # else set all actions valid - mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) - - if cur_pos_str in self.shield and self.shield[cur_pos_str]: - allowed_actions = self.shield[cur_pos_str] - for allowed_action in allowed_actions: - index = get_action_index_mapping(allowed_action.labels) - if index is None: - assert(False) - - mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] - else: - for index, x in enumerate(mask): - mask[index] = 1.0 - - front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) - - - if front_tile and front_tile.type == "door": - mask[Actions.toggle] = 1.0 - - return mask - - - def reset(self, *, seed=None, options=None): - obs, infos = self.env.reset(seed=seed, options=options) - - shield = self.shield_creator.create_shield(env=self.env) - - self.shield = shield - return obs["image"], infos - - def step(self, action): - orig_obs, rew, done, truncated, info = self.env.step(action) - obs = orig_obs["image"] - - return obs, rew, done, truncated, info -