from typing import Optional import re import stormpy import stormpy.core import stormpy.simulator import stormpy.shields import stormpy.logic from dataclasses import dataclass from enum import Enum from PIL import Image, ImageDraw import re import sys import tempfile, datetime, shutil import numpy as np from colour import Color import gymnasium as gym import minigrid from minigrid.core.state import to_state, State from minigrid.core.world_object import Lava, Goal import os import time import argparse from loguru import logger class StateClassification(Enum): Safe = "s" Critical = "c" Dangerous = "d" Violated = "u" class Actions(Enum): Left = 0 Right = 1 East = 2 South = 3 West = 4 North = 5 Forward = 6 Deadlock = 99 """ https://highway-env.farama.org/actions/#highway_env.envs.common.action.DiscreteMetaAction.ACTIONS_ALL """ class HighwayActions(Enum): Switch_left = 0 Noop = 1 Switch_right = 2 Accelerate = 3 Break = 4 class TaxinetActions(Enum): Steer_left = 0 Noop = 1 Steer_right = 2 class TrafficActions(Enum): EW_1EW_2= 0 EW_1NS_2 = 1 NS_1EW_2 = 2 NS_1NS_2 = 3 @dataclass(frozen=True, eq=True) class HighwayState: lane_ego: int distance_one: int distance_two: int = 0 velocity_ego: int = 3 lane_one: int = 1 lane_two: int = 2 velocity_one: int = 1 velocity_two: int = 1 previous_action: int = 0 @staticmethod def feature_space() -> str: return list(HighwayState.__dict__["__annotations__"].keys()) @staticmethod def header() -> str: return ",".join(map(str, HighwayState.__dict__["__annotations__"].keys())) def csv(self) -> str: return ",".join(map(str, self.__dict__.values())) @dataclass(frozen=True, eq=True) class TaxinetState: cte: int he: int @staticmethod def feature_space() -> str: return list(TaxinetState.__dict__["__annotations__"].keys()) @staticmethod def header() -> str: return ",".join(map(str, TaxinetState.__dict__["__annotations__"].keys())) def csv(self) -> str: return ",".join(map(str, self.__dict__.values())) @dataclass(frozen=True, eq=True) class TrafficState: lE_1: int lW_1: int lN_1: int lS_1: int lE_2: int lW_2: int lN_2: int lS_2: int @staticmethod def feature_space() -> str: return list(TrafficState.__dict__["__annotations__"].keys()) def header(self) -> str: return ",".join(map(str, self.__dict__.keys())) def csv(self) -> str: return ",".join(map(str, self.__dict__.values())) NUM_MOVEMENT_ACTIONS = 3 WHITE = (255,255,255,255) RED = (255,0,0,255) GREEN = (0,255,0,255) BLACK = (0,0,0,250) DARKGREEN = (0,200,0,250) VIOLET = (70,0,200,250) GREY = (100,100,100,250) YELLOW = (230,230,0,250) ORANGE = (230,150,0,250) def tic(): #Homemade version of matlab tic and toc functions: https://stackoverflow.com/a/18903019 global startTime_for_tictoc startTime_for_tictoc = time.time() def toc(): if 'startTime_for_tictoc' in globals(): print("Elapsed time is " + str(time.time() - startTime_for_tictoc) + " seconds.") else: print("Toc: start time not set") class ShieldingConfig(Enum): Training = 'training' Evaluation = 'evaluation' Disabled = 'none' Full = 'full' def __str__(self) -> str: return self.value def shield_needed(shielding): return shielding in [ShieldingConfig.Training, ShieldingConfig.Evaluation, ShieldingConfig.Full] def shielded_evaluation(shielding): return shielding in [ShieldingConfig.Evaluation, ShieldingConfig.Full] def shielded_training(shielding): return shielding in [ShieldingConfig.Training, ShieldingConfig.Full] def rectangle_for_overlay(x, y, dir, tile_size, width=2, offset=0, thickness=0): if dir == 0: return (((x+1)*tile_size-width-thickness,y*tile_size+offset), ((x+1)*tile_size,(y+1)*tile_size-offset)) if dir == 1: return ((x*tile_size+offset,(y+1)*tile_size-width-thickness), ((x+1)*tile_size-offset,(y+1)*tile_size)) if dir == 2: return ((x*tile_size,y*tile_size+offset), (x*tile_size+width+thickness,(y+1)*tile_size-offset)) if dir == 3: return ((x*tile_size+offset,y*tile_size), ((x+1)*tile_size-offset,y*tile_size+width+thickness)) def triangle_for_overlay(x,y, dir, tile_size): offset = tile_size/2 if dir == 0: return [((x+1)*tile_size,y*tile_size), ((x+1)*tile_size,(y+1)*tile_size), ((x+1)*tile_size-offset, y*tile_size+tile_size/2)] if dir == 1: return [(x*tile_size,(y+1)*tile_size), ((x+1)*tile_size,(y+1)*tile_size), (x*tile_size+tile_size/2, (y+1)*tile_size-offset)] if dir == 2: return [(x*tile_size,y*tile_size), (x*tile_size,(y+1)*tile_size), (x*tile_size+offset, y*tile_size+tile_size/2)] if dir == 3: return [(x*tile_size,y*tile_size), ((x+1)*tile_size,y*tile_size), (x*tile_size+tile_size/2, y*tile_size+offset)] def create_shield_overlay_image(env, shield, dangerous_states=list()): env.reset() img = Image.fromarray(env.render()).convert("RGBA") overlay = Image.new("RGBA", img.size, (255, 255, 255, 0)) width = env.width if env.width else env.grid.size ts = img.size[0] // width draw = ImageDraw.Draw(overlay) for x in range(1, env.width-1): for y in range(1, env.height -1): if isinstance(env.grid.get(x, y), Lava): for dir in range(0,4): draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=RED) continue if isinstance(env.grid.get(x, y), Goal): for dir in range(0,4): draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=GREEN) continue try: state = State(x, y, '') mask = shield[state][0:4] if state in dangerous_states: for dir in range(0,4): draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=ORANGE) elif sum(mask) < 4: for dir in range(0,4): draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=YELLOW) except KeyError: pass img = Image.alpha_composite(img, overlay) #img.show() return img def create_critical_dangerous_overlay_image(env, shield, ignore_view: bool = False): #env.reset() dangerous = list(shield.dangerous_set_positions.keys()) critical = list(shield.critical_set_positions.keys()) img = Image.fromarray(env.render()).convert("RGBA") overlay = Image.new("RGBA", img.size, (255, 255, 255, 0)) width = env.width if env.width else env.grid.size ts = img.size[0] // width draw = ImageDraw.Draw(overlay) for x in range(1, env.width-1): for y in range(1, env.height -1): if isinstance(env.grid.get(x, y), Lava): for dir in range(0,4): draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=RED) continue if isinstance(env.grid.get(x, y), Goal): for dir in range(0,4): draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=GREEN) continue if ignore_view: if f"{x},{y}" in dangerous: for dir in range(0,4): draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=ORANGE) elif f"{x},{y}" in critical: for dir in range(0,4): draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=YELLOW) else: for dir in range(0,4): try: if f"{x},{y},{dir}" in dangerous: draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=ORANGE) elif f"{x},{y},{dir}" in critical: draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=YELLOW) except KeyError: pass img = Image.alpha_composite(img, overlay) #img.show() return img def create_dt_overlay_image(env, decision_tree, translator): env.reset() img = Image.fromarray(env.render()).convert("RGBA") width = env.width if env.width else env.grid.size ts = img.size[0] // width overlay = Image.new("RGBA", img.size, (255, 255, 255, 0)) draw = ImageDraw.Draw(overlay) for x in range(1, env.width - 1): for y in range(1, env.height - 1): if isinstance(env.grid.get(x, y), Lava): for dir in range(0,4): draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=RED) continue if isinstance(env.grid.get(x, y), Goal): for dir in range(0,4): draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=GREEN) continue else: for dir in range(0,4): draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=WHITE) try: mask = decision_tree.evaluate(translator.features_given_position(x, y, dir)) mask = get_allowed_actions_mask(mask)[0:3] if mask == [1.0, 1.0, 0.0]: # turning left or right draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=YELLOW) elif mask == [1.0, 0.0, 0.0]: # turning only left draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=ORANGE) elif mask == [1.0, 0.0, 1.0]: # turning left or forward draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=YELLOW) elif mask == [0.0, 1.0, 0.0]: # turning only right draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=ORANGE) elif mask == [0.0, 1.0, 1.0]: # turning right or forward draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=YELLOW) elif mask == [0.0, 0.0, 1.0]: # only forward draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=ORANGE) elif mask == [0.0, 0.0, 0.0]: # deadlock draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=RED) #else: # draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=(0, 200, 0, 96)) except KeyError as e: pass img = Image.alpha_composite(img, overlay) #img.show() return img def create_heatmap(env, heatmap_data: dict, high: Optional[int] = None, low: Optional[int] = 0): env.reset() img = Image.fromarray(env.render()).convert("RGBA") width = env.width if env.width else env.grid.size ts = img.size[0] // width overlay = Image.new("RGBA", img.size, (255, 255, 255, 0)) draw = ImageDraw.Draw(overlay) if not high: high = np.max(list(heatmap_data.values())) colors = list(Color("lightblue").range_to(Color("darkblue"), high - low)) for x in range(1, env.width - 1): for y in range(1, env.height - 1): if isinstance(env.grid.get(x, y), Lava): for dir in range(0,4): draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=RED) continue if isinstance(env.grid.get(x, y), Goal): for dir in range(0,4): draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=GREEN) continue else: for dir in range(0,4): draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=WHITE) try: color = tuple(np.array([int(255 * v) for v in colors[heatmap_data[State(x,y,dir)] - low - 1].rgb])) draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=color) except KeyError as e: logger.trace(f"No Value for {x}, {y}, {dir}.") img = Image.alpha_composite(img, overlay) #img.show() return img def expname(args, env): return f"{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{env}_{args.shielding}_{args.shield_comparison}_{args.shield_value}_{args.expname_suffix}" def create_log_dir(args, env): log_dir = f"{args.log_dir}/{expname(args, env)}" os.makedirs(log_dir, exist_ok=True) return log_dir def get_allowed_actions_mask_with_view(actions): action_mask = [0.0] * 7 for action_label in actions: if "move" in action_label: action_mask[2] = 1.0 elif "left" in action_label: action_mask[0] = 1.0 elif "right" in action_label: action_mask[1] = 1.0 elif "pickup" in action_label: action_mask[3] = 1.0 elif "drop" in action_label: action_mask[4] = 1.0 elif "toggle" in action_label: action_mask[5] = 1.0 elif "done" in action_label: action_mask[6] = 1.0 if "north" in action_label or "south" in action_label or "east" in action_label or "west" in action_label: action_mask[2] = 1.0 return action_mask def get_allowed_actions_mask_with_view(actions): action_mask = [0.0] * 7 for action_label in actions: action_label = action_label.lower() if "move" in action_label: action_mask[2] = 1.0 elif "left" in action_label: action_mask[0] = 1.0 elif "right" in action_label: action_mask[1] = 1.0 elif "pickup" in action_label: action_mask[3] = 1.0 elif "drop" in action_label: action_mask[4] = 1.0 elif "toggle" in action_label: action_mask[5] = 1.0 elif "done" in action_label: action_mask[6] = 1.0 if "north" in action_label or "south" in action_label or "east" in action_label or "west" in action_label: action_mask[2] = 1.0 return action_mask def get_allowed_actions_mask(actions): action_mask = [0.0] * 4 for action_label in actions: action_label = action_label.lower() if "north" in action_label: action_mask[3] = 1.0 elif "south" in action_label: action_mask[1] = 1.0 elif "east" in action_label: action_mask[0] = 1.0 elif "west" in action_label: action_mask[2] = 1.0 return action_mask def get_allowed_actions_mask_for_labels(action_labels, ignore_view: bool = False): actions = [label for labels in action_labels for label in list(labels)] if ignore_view: return get_allowed_actions_mask(actions) else: return get_allowed_actions_mask_with_view(actions) def dir_path(string): if os.path.isdir(string): return string else: raise NotADirectoryError(string) def common_parser(): parser = argparse.ArgumentParser() parser.add_argument("--env", help="gym environment to load", choices=gym.envs.registry.keys(), default="MiniGrid-LavaSlipperyCliff-16x13-v0") parser.add_argument("--grid_file", default="grid.txt") parser.add_argument("--prism_file", default=None) parser.add_argument("--prism_output_file", default="grid.prism") parser.add_argument("--log_dir", default="../log_results/") parser.add_argument("--formula", default="Pmax=? [G !AgentIsOnLava]") parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) parser.add_argument("--steps", default=20_000, type=int) parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction) parser.add_argument("--prism_config", default=None) parser.add_argument("--shield_value", default=0.9, type=float) parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute']) parser.add_argument("--nocleanup", action=argparse.BooleanOptionalAction) parser.add_argument("--expname_suffix", default="") parser.add_argument("--action_comparison", choices=["absolute", "relative"]) parser.add_argument("--weighted", action="store_true") parser.add_argument("--min_frequency", type=int, default=0) return parser class MiniWrapper(gym.Wrapper): def __init__(self, env): super().__init__(env) self.env = env def reset(self, *, seed=None, options=None): obs, info = self.env.reset(seed=seed, options=options) return obs.transpose(1,0,2), info def observations(self, obs): return obs def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) return obs.transpose(1,0,2), reward, terminated, truncated, info def id_to_symbolic(state_id: int, state_valuations) -> State: return to_symbolic(state_valuations.get_string(state_id)) def to_symbolic(state: str) -> State: ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state)) booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state) booleans = {b[1]: False if b[0] == "!" else True for b in booleans} if int(ints.get("previousActionAgent", 7)) != 7: return None if int(ints.get("clock", 0)) != 0: return None return to_state(ints, booleans) def dataclass_from_dict(klass, dikt): try: fieldtypes = klass.__annotations__ # this can be also simplified I believe new_dikt = dict() for k in list(fieldtypes.keys()): try: new_dikt[k] = dikt[k] except KeyError as e: logger.trace(f"Could not fetch value for field {k} from input, defaulting to {klass.__dict__[k]=}") new_dikt[k] = klass.__dict__[k] return klass(**{f: dataclass_from_dict(fieldtypes[f], new_dikt[f]) for f in new_dikt}) except KeyError as e: logger.warning(str(e)) except Exception as e: if isinstance(dikt, (tuple, list)): return [dataclass_from_dict(klass.__args__[0], f) for f in dikt] return dikt def id_to_symbolic_highway_state(state_id, state_valuations): state_valuation = state_valuations.get_string(state_id) ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation)) ints = {k:int(v) for k,v in ints.items()} booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation) booleans = {b[1]: False if b[0] == "!" else True for b in booleans} return dataclass_from_dict(HighwayState, {**ints, **booleans}) def id_to_symbolic_taxinet_state(state_id, state_valuations): state_valuation = state_valuations.get_string(state_id) ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation)) ints = {k:int(v) for k,v in ints.items()} booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation) booleans = {b[1]: False if b[0] == "!" else True for b in booleans} return dataclass_from_dict(TaxinetState, {**ints, **booleans}) def id_to_symbolic_traffic_state(state_id, state_valuations): state_valuation = state_valuations.get_string(state_id) ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation)) ints = {k:int(v) for k,v in ints.items()} booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation) booleans = {b[1]: False if b[0] == "!" else True for b in booleans} return dataclass_from_dict(TrafficState, {**ints, **booleans}) def symbolic_minigrid_to_model_state(model, col, row, view=None): if view is not None: for model_state in model.states: ms = id_to_symbolic(model_state, model.state_valuations) if ms.colAgent == col and ms.rowAgent == row and ms.viewAgent == view: return model_state else: for model_state in model.states: ms = id_to_symbolic(model_state, model.state_valuations) if ms.colAgent == col and ms.rowAgent == row: return model_state return None