From 372006a1daeefef54bcfdd4afc3479d2d5ca8469 Mon Sep 17 00:00:00 2001 From: sp Date: Sun, 14 Jan 2024 16:28:55 +0100 Subject: [PATCH] major refactor in utils - introduced common_parser for arguments - the shield dict uses minigrid.core.State instead of strings - switched shield query to minigrid get_symbolic_state --- examples/shields/rl/sb3utils.py | 72 +++---- examples/shields/rl/utils.py | 322 +++++++++----------------------- 2 files changed, 109 insertions(+), 285 deletions(-) diff --git a/examples/shields/rl/sb3utils.py b/examples/shields/rl/sb3utils.py index d81cffe..b49c38b 100644 --- a/examples/shields/rl/sb3utils.py +++ b/examples/shields/rl/sb3utils.py @@ -2,69 +2,47 @@ import gymnasium as gym import numpy as np import random -from utils import MiniGridShieldHandler, create_shield_query +from utils import MiniGridShieldHandler, common_parser class MiniGridSbShieldingWrapper(gym.core.Wrapper): - def __init__(self, - env, - shield_creator : MiniGridShieldHandler, - shield_query_creator, + def __init__(self, + env, + shield_handler : MiniGridShieldHandler, create_shield_at_reset = True, mask_actions=True, ): - super(MiniGridSbShieldingWrapper, self).__init__(env) - self.max_available_actions = env.action_space.n + super().__init__(env) self.observation_space = env.observation_space.spaces["image"] - - self.shield_creator = shield_creator - self.mask_actions = mask_actions - self.shield_query_creator = shield_query_creator - - def create_action_mask(self): - if not self.mask_actions: - return np.array([1.0] * self.max_available_actions, dtype=np.int8) - - cur_pos_str = self.shield_query_creator(self.env) - - allowed_actions = [] - # Create the mask - # If shield restricts actions, mask only valid actions with 1.0 - # else set all actions valid - mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) + self.shield_handler = shield_handler + self.mask_actions = mask_actions + self.create_shield_at_reset = create_shield_at_reset - if cur_pos_str in self.shield and self.shield[cur_pos_str]: - allowed_actions = self.shield[cur_pos_str] - for allowed_action in allowed_actions: - index = get_action_index_mapping(allowed_action.labels) - if index is None: - assert(False) - - mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] - else: - for index, x in enumerate(mask): - mask[index] = 1.0 - - front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) + shield = self.shield_handler.create_shield(env=self.env) + self.shield = shield - - if front_tile and front_tile.type == "door": - mask[Actions.toggle] = 1.0 - - return mask - + def create_action_mask(self): + try: + return self.shield[self.env.get_symbolic_state()] + except: + return [1.0] * 3 + [1.0] * 4 def reset(self, *, seed=None, options=None): obs, infos = self.env.reset(seed=seed, options=options) - - shield = self.shield_creator.create_shield(env=self.env) - - self.shield = shield + + if self.create_shield_at_reset and self.mask_actions: + shield = self.shield_handler.create_shield(env=self.env) + self.shield = shield return obs["image"], infos def step(self, action): orig_obs, rew, done, truncated, info = self.env.step(action) obs = orig_obs["image"] - + return obs, rew, done, truncated, info +def parse_sb3_arguments(): + parser = common_parser() + args = parser.parse_args() + + return args diff --git a/examples/shields/rl/utils.py b/examples/shields/rl/utils.py index 283eee2..1c9c8a6 100644 --- a/examples/shields/rl/utils.py +++ b/examples/shields/rl/utils.py @@ -11,16 +11,37 @@ import stormpy.examples.files from enum import Enum from abc import ABC +import re +import sys + from minigrid.core.actions import Actions +from minigrid.core.state import to_state import os import time -class Action(): - def __init__(self, idx, prob=1, labels=[]) -> None: - self.idx = idx - self.prob = prob - self.labels = labels + +import argparse + +def tic(): + #Homemade version of matlab tic and toc functions: https://stackoverflow.com/a/18903019 + global startTime_for_tictoc + startTime_for_tictoc = time.time() + +def toc(): + if 'startTime_for_tictoc' in globals(): + print("Elapsed time is " + str(time.time() - startTime_for_tictoc) + " seconds.") + else: + print("Toc: start time not set") + +class ShieldingConfig(Enum): + Training = 'training' + Evaluation = 'evaluation' + Disabled = 'none' + Full = 'full' + + def __str__(self) -> str: + return self.value class ShieldHandler(ABC): def __init__(self) -> None: @@ -29,290 +50,115 @@ class ShieldHandler(ABC): pass class MiniGridShieldHandler(ShieldHandler): - def __init__(self, grid_file, grid_to_prism_path, prism_path, formula, shield_value=0.9 ,prism_config=None, shield_comparision='relative') -> None: + def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute') -> None: self.grid_file = grid_file - self.grid_to_prism_path = grid_to_prism_path + self.grid_to_prism_binary = grid_to_prism_binary self.prism_path = prism_path - self.formula = formula self.prism_config = prism_config - self.shield_value = shield_value - self.shield_comparision = shield_comparision - + + self.formula = formula + shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE + self.shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comparison, shield_value) + + def __export_grid_to_text(self, env): f = open(self.grid_file, "w") f.write(env.printGrid(init=True)) f.close() - + def __create_prism(self): - # result = os.system(F"{self.grid_to_prism_path} -v 'Agent,Blue' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml") if self.prism_config is None: - result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path}") + result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path}") else: - result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}") - # result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml") + result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}") assert result == 0, "Prism file could not be generated" - f = open(self.prism_path, "a") - f.close() - def __create_shield_dict(self): - print(self.prism_path) program = stormpy.parse_prism_program(self.prism_path) - shield_comp = stormpy.logic.ShieldComparison.RELATIVE - - if self.shield_comparision == 'absolute': - shield_comp = stormpy.logic.ShieldComparison.ABSOLUTE - - shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comp, self.shield_value) - - formulas = stormpy.parse_properties_for_prism_program(self.formula, program) options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) options.set_build_state_valuations(True) options.set_build_choice_labels(True) options.set_build_all_labels() + print(f"LOG: Starting with explicit model creation...") + tic() model = stormpy.build_sparse_model_with_options(program, options) - - result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) - + toc() + + print(f"LOG: Starting with model checking...") + tic() + result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=self.shield_expression) + toc() + assert result.has_shield shield = result.shield - action_dictionary = {} + action_dictionary = dict() shield_scheduler = shield.construct() state_valuations = model.state_valuations choice_labeling = model.choice_labeling - stormpy.shields.export_shield(model, shield, "myshield") - + + #stormpy.shields.export_shield(model, shield, "current.shield") + for stateID in model.states: choice = shield_scheduler.get_choice(stateID) choices = choice.choice_map state_valuation = state_valuations.get_string(stateID) - actions_to_be_executed = [Action(idx= choice[1], prob=choice[0], labels=choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices] - action_dictionary[state_valuation] = actions_to_be_executed - + ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=([a-zA-Z0-9]+)', state_valuation)) + booleans = dict(re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]', state_valuation)) #TODO does not parse everything correctly? + + if int(ints.get("previousActionAgent", 3)) != 3: + continue + if int(ints.get("clock", 0)) != 0: + continue + state = to_state(ints, booleans) + action_dictionary[state] = get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices]) + return action_dictionary - - + + def create_shield(self, **kwargs): env = kwargs["env"] self.__export_grid_to_text(env) self.__create_prism() - - return self.__create_shield_dict() - -def create_shield_query(env): - coordinates = env.env.agent_pos - view_direction = env.env.agent_dir - - keys = extract_keys(env) - doors = extract_doors(env) - adversaries = extract_adversaries(env) - - - if env.carrying: - agent_carrying = F"Agent_is_carrying_object\t" - else: - agent_carrying = "!Agent_is_carrying_object\t" - - key_positions = [] - agent_key_status = [] - - for key in keys: - key_color = key[0].color - key_x = key[1] - key_y = key[2] - if env.carrying and env.carrying.type == "key": - agent_key_text = F"Agent_has_{env.carrying.color}_key\t& " - key_position = F"xKey{key_color}={key_x}\t& yKey{key_color}={key_y}\t" - else: - agent_key_text = F"!Agent_has_{key_color}_key\t& " - key_position = F"xKey{key_color}={key_x}\t& yKey{key_color}={key_y}\t" - - key_positions.append(key_position) - agent_key_status.append(agent_key_text) - - if key_positions: - key_positions[-1] = key_positions[-1].strip() - - door_status = [] - for door in doors: - status = "" - if door.is_open: - status = F"!Door{door.color}locked\t& Door{door.color}open\t&" - elif door.is_locked: - status = F"Door{door.color}locked\t& !Door{door.color}open\t&" - else: - status = F"!Door{door.color}locked\t& !Door{door.color}open\t&" - - door_status.append(status) - - adv_status = [] - adv_positions = [] - - for adversary in adversaries: - status = "" - position = "" - - if adversary.carrying: - carrying = F"{adversary.name}_is_carrying_object\t" - else: - carrying = F"!{adversary.name}_is_carrying_object\t" - - status = F"{carrying}& !{adversary.name}Done\t& " - position = F"x{adversary.name}={adversary.cur_pos[1]}\t& y{adversary.name}={adversary.cur_pos[0]}\t& view{adversary.name}={adversary.adversary_dir}" - adv_status.append(status) - adv_positions.append(position) - - door_status_text = "" - - if door_status: - door_status_text = F"& {''.join(door_status)}\t" - - adv_status_text = "" - - if adv_status: - adv_status_text = F"& {''.join(adv_status)}" - - adv_positions_text = "" - - if adv_positions: - adv_positions_text = F"\t& {''.join(adv_positions)}" - - key_positions_text = "" - - if key_positions: - key_positions_text = F"\t& {''.join(key_positions)}" - - move_text = "" - - if adversaries: - move_text = F"move=0\t& " - - agent_position = F"& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}" - query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]" - - return query - - -class ShieldingConfig(Enum): - Training = 'training' - Evaluation = 'evaluation' - Disabled = 'none' - Full = 'full' - - def __str__(self) -> str: - return self.value + return self.__create_shield_dict() -def extract_keys(env): - keys = [] - for j in range(env.grid.height): - for i in range(env.grid.width): - obj = env.grid.get(i,j) - - if obj and obj.type == "key": - keys.append((obj, i, j)) - - if env.carrying and env.carrying.type == "key": - keys.append((env.carrying, -1, -1)) - # TODO Maybe need to add ordering of keys so it matches the order in the shield - return keys - -def extract_doors(env): - doors = [] - for j in range(env.grid.height): - for i in range(env.grid.width): - obj = env.grid.get(i,j) - - if obj and obj.type == "door": - doors.append(obj) - - return doors - -def extract_adversaries(env): - adv = [] - - if not hasattr(env, "adversaries") or not env.adversaries: - return [] - - for color, adversary in env.adversaries.items(): - adv.append(adversary) - - - return adv def create_log_dir(args): - return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparision}-env:{args.env}-conf:{args.prism_config}" + return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparison}-env:{args.env}-conf:{args.prism_config}" def test_name(args): return F"{args.expname}" -def get_action_index_mapping(actions): - for action_str in actions: - if not "Agent" in action_str: - continue - - if "move" in action_str: - return Actions.forward - elif "left" in action_str: - return Actions.left - elif "right" in action_str: - return Actions.right - elif "pickup" in action_str: - return Actions.pickup - elif "done" in action_str: - return Actions.done - elif "drop" in action_str: - return Actions.drop - elif "toggle" in action_str: - return Actions.toggle - elif "unlock" in action_str: - return Actions.toggle - - raise ValueError("No action mapping found") - - -def parse_arguments(argparse): +def get_allowed_actions_mask(actions): + action_mask = [0.0] * 3 + [1.0] * 4 + actions_labels = [label for labels in actions for label in list(labels)] + for action_label in actions_labels: + if "move" in action_label: + action_mask[2] = 1.0 + elif "left" in action_label: + action_mask[0] = 1.0 + elif "right" in action_label: + action_mask[1] = 1.0 + return action_mask + +def common_parser(): parser = argparse.ArgumentParser() - # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0") parser.add_argument("--env", help="gym environment to load", - default="MiniGrid-LavaSlipperyCliffS12-v2", - choices=[ - "MiniGrid-Adv-8x8-v0", - "MiniGrid-AdvSimple-8x8-v0", - "MiniGrid-LavaCrossingS9N1-v0", - "MiniGrid-LavaCrossingS9N3-v0", - "MiniGrid-LavaSlipperyCliffS12-v0", - "MiniGrid-LavaFaultyS12-30-v0", - ]) - - # parser.add_argument("--seed", type=int, help="seed for environment", default=None) - parser.add_argument("--grid_to_prism_binary_path", default="./main") - parser.add_argument("--grid_path", default="grid") - parser.add_argument("--prism_path", default="grid") - parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"]) + default="MiniGrid-LavaSlipperyCliff-16x12-v0") + + parser.add_argument("--grid_file", default="grid.txt") + parser.add_argument("--prism_output_file", default="grid.prism") parser.add_argument("--log_dir", default="../log_results/") - parser.add_argument("--evaluations", type=int, default=30 ) - parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" - # parser.add_argument("--formula", default="<> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]") - parser.add_argument("--workers", type=int, default=1) - parser.add_argument("--num_gpus", type=float, default=0) + parser.add_argument("--formula", default="Pmax=? [G !AgentIsOnLava]") parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) parser.add_argument("--steps", default=20_000, type=int) - parser.add_argument("--expname", default="exp") parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction) parser.add_argument("--prism_config", default=None) parser.add_argument("--shield_value", default=0.9, type=float) - parser.add_argument("--probability_displacement", default=1/4, type=float) - parser.add_argument("--probability_intended", default=3/4, type=float) - parser.add_argument("--probability_turn_displacement", default=0/4, type=float) - parser.add_argument("--probability_turn_intended", default=4/4, type=float) - parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute']) - # parser.add_argument("--random_starts", default=1, type=int) - args = parser.parse_args() - - return args + parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute']) + return parser