You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

525 lines
20 KiB

from typing import Optional
import re
import stormpy
import stormpy.core
import stormpy.simulator
import stormpy.shields
import stormpy.logic
from dataclasses import dataclass
from enum import Enum
from PIL import Image, ImageDraw
import re
import sys
import tempfile, datetime, shutil
import numpy as np
from colour import Color
import gymnasium as gym
import minigrid
from minigrid.core.state import to_state, State
from minigrid.core.world_object import Lava, Goal
import os
import time
import argparse
from loguru import logger
class StateClassification(Enum):
Safe = "s"
Critical = "c"
Dangerous = "d"
Violated = "u"
class Actions(Enum):
Left = 0
Right = 1
East = 2
South = 3
West = 4
North = 5
Forward = 6
Deadlock = 99
"""
https://highway-env.farama.org/actions/#highway_env.envs.common.action.DiscreteMetaAction.ACTIONS_ALL
"""
class HighwayActions(Enum):
Switch_left = 0
Noop = 1
Switch_right = 2
Accelerate = 3
Break = 4
class TaxinetActions(Enum):
Steer_left = 0
Noop = 1
Steer_right = 2
class TrafficActions(Enum):
EW_1EW_2= 0
EW_1NS_2 = 1
NS_1EW_2 = 2
NS_1NS_2 = 3
@dataclass(frozen=True, eq=True)
class HighwayState:
lane_ego: int
distance_one: int
distance_two: int = 0
velocity_ego: int = 3
lane_one: int = 1
lane_two: int = 2
velocity_one: int = 1
velocity_two: int = 1
previous_action: int = 0
@staticmethod
def feature_space() -> str:
return list(HighwayState.__dict__["__annotations__"].keys())
@staticmethod
def header() -> str:
return ",".join(map(str, HighwayState.__dict__["__annotations__"].keys()))
def csv(self) -> str:
return ",".join(map(str, self.__dict__.values()))
@dataclass(frozen=True, eq=True)
class TaxinetState:
cte: int
he: int
@staticmethod
def feature_space() -> str:
return list(TaxinetState.__dict__["__annotations__"].keys())
@staticmethod
def header() -> str:
return ",".join(map(str, TaxinetState.__dict__["__annotations__"].keys()))
def csv(self) -> str:
return ",".join(map(str, self.__dict__.values()))
@dataclass(frozen=True, eq=True)
class TrafficState:
lE_1: int
lW_1: int
lN_1: int
lS_1: int
lE_2: int
lW_2: int
lN_2: int
lS_2: int
@staticmethod
def feature_space() -> str:
return list(TrafficState.__dict__["__annotations__"].keys())
def header(self) -> str:
return ",".join(map(str, self.__dict__.keys()))
def csv(self) -> str:
return ",".join(map(str, self.__dict__.values()))
NUM_MOVEMENT_ACTIONS = 3
WHITE = (255,255,255,255)
RED = (255,0,0,255)
GREEN = (0,255,0,255)
BLACK = (0,0,0,250)
DARKGREEN = (0,200,0,250)
VIOLET = (70,0,200,250)
GREY = (100,100,100,250)
YELLOW = (230,230,0,250)
ORANGE = (230,150,0,250)
def tic():
#Homemade version of matlab tic and toc functions: https://stackoverflow.com/a/18903019
global startTime_for_tictoc
startTime_for_tictoc = time.time()
def toc():
if 'startTime_for_tictoc' in globals():
print("Elapsed time is " + str(time.time() - startTime_for_tictoc) + " seconds.")
else:
print("Toc: start time not set")
class ShieldingConfig(Enum):
Training = 'training'
Evaluation = 'evaluation'
Disabled = 'none'
Full = 'full'
def __str__(self) -> str:
return self.value
def shield_needed(shielding):
return shielding in [ShieldingConfig.Training, ShieldingConfig.Evaluation, ShieldingConfig.Full]
def shielded_evaluation(shielding):
return shielding in [ShieldingConfig.Evaluation, ShieldingConfig.Full]
def shielded_training(shielding):
return shielding in [ShieldingConfig.Training, ShieldingConfig.Full]
def rectangle_for_overlay(x, y, dir, tile_size, width=2, offset=0, thickness=0):
if dir == 0: return (((x+1)*tile_size-width-thickness,y*tile_size+offset), ((x+1)*tile_size,(y+1)*tile_size-offset))
if dir == 1: return ((x*tile_size+offset,(y+1)*tile_size-width-thickness), ((x+1)*tile_size-offset,(y+1)*tile_size))
if dir == 2: return ((x*tile_size,y*tile_size+offset), (x*tile_size+width+thickness,(y+1)*tile_size-offset))
if dir == 3: return ((x*tile_size+offset,y*tile_size), ((x+1)*tile_size-offset,y*tile_size+width+thickness))
def triangle_for_overlay(x,y, dir, tile_size):
offset = tile_size/2
if dir == 0: return [((x+1)*tile_size,y*tile_size), ((x+1)*tile_size,(y+1)*tile_size), ((x+1)*tile_size-offset, y*tile_size+tile_size/2)]
if dir == 1: return [(x*tile_size,(y+1)*tile_size), ((x+1)*tile_size,(y+1)*tile_size), (x*tile_size+tile_size/2, (y+1)*tile_size-offset)]
if dir == 2: return [(x*tile_size,y*tile_size), (x*tile_size,(y+1)*tile_size), (x*tile_size+offset, y*tile_size+tile_size/2)]
if dir == 3: return [(x*tile_size,y*tile_size), ((x+1)*tile_size,y*tile_size), (x*tile_size+tile_size/2, y*tile_size+offset)]
def create_shield_overlay_image(env, shield, dangerous_states=list()):
env.reset()
img = Image.fromarray(env.render()).convert("RGBA")
overlay = Image.new("RGBA", img.size, (255, 255, 255, 0))
width = env.width if env.width else env.grid.size
ts = img.size[0] // width
draw = ImageDraw.Draw(overlay)
for x in range(1, env.width-1):
for y in range(1, env.height -1):
if isinstance(env.grid.get(x, y), Lava):
for dir in range(0,4):
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=RED)
continue
if isinstance(env.grid.get(x, y), Goal):
for dir in range(0,4):
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=GREEN)
continue
try:
state = State(x, y, '')
mask = shield[state][0:4]
if state in dangerous_states:
for dir in range(0,4):
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=ORANGE)
elif sum(mask) < 4:
for dir in range(0,4):
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=YELLOW)
except KeyError: pass
img = Image.alpha_composite(img, overlay)
#img.show()
return img
def create_critical_dangerous_overlay_image(env, shield, ignore_view: bool = False):
#env.reset()
dangerous = list(shield.dangerous_set_positions.keys())
critical = list(shield.critical_set_positions.keys())
img = Image.fromarray(env.render()).convert("RGBA")
overlay = Image.new("RGBA", img.size, (255, 255, 255, 0))
width = env.width if env.width else env.grid.size
ts = img.size[0] // width
draw = ImageDraw.Draw(overlay)
for x in range(1, env.width-1):
for y in range(1, env.height -1):
if isinstance(env.grid.get(x, y), Lava):
for dir in range(0,4):
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=RED)
continue
if isinstance(env.grid.get(x, y), Goal):
for dir in range(0,4):
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=GREEN)
continue
if ignore_view:
if f"{x},{y}" in dangerous:
for dir in range(0,4):
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=ORANGE)
elif f"{x},{y}" in critical:
for dir in range(0,4):
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=YELLOW)
else:
for dir in range(0,4):
try:
if f"{x},{y},{dir}" in dangerous:
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=ORANGE)
elif f"{x},{y},{dir}" in critical:
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=YELLOW)
except KeyError: pass
img = Image.alpha_composite(img, overlay)
#img.show()
return img
def create_dt_overlay_image(env, decision_tree, translator):
env.reset()
img = Image.fromarray(env.render()).convert("RGBA")
width = env.width if env.width else env.grid.size
ts = img.size[0] // width
overlay = Image.new("RGBA", img.size, (255, 255, 255, 0))
draw = ImageDraw.Draw(overlay)
for x in range(1, env.width - 1):
for y in range(1, env.height - 1):
if isinstance(env.grid.get(x, y), Lava):
for dir in range(0,4):
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=RED)
continue
if isinstance(env.grid.get(x, y), Goal):
for dir in range(0,4):
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=GREEN)
continue
else:
for dir in range(0,4):
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=WHITE)
try:
mask = decision_tree.evaluate(translator.features_given_position(x, y, dir))
mask = get_allowed_actions_mask(mask)[0:3]
if mask == [1.0, 1.0, 0.0]: # turning left or right
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=YELLOW)
elif mask == [1.0, 0.0, 0.0]: # turning only left
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=ORANGE)
elif mask == [1.0, 0.0, 1.0]: # turning left or forward
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=YELLOW)
elif mask == [0.0, 1.0, 0.0]: # turning only right
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=ORANGE)
elif mask == [0.0, 1.0, 1.0]: # turning right or forward
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=YELLOW)
elif mask == [0.0, 0.0, 1.0]: # only forward
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=ORANGE)
elif mask == [0.0, 0.0, 0.0]: # deadlock
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=RED)
#else:
# draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=(0, 200, 0, 96))
except KeyError as e:
pass
img = Image.alpha_composite(img, overlay)
#img.show()
return img
def create_heatmap(env, heatmap_data: dict, high: Optional[int] = None, low: Optional[int] = 0):
env.reset()
img = Image.fromarray(env.render()).convert("RGBA")
width = env.width if env.width else env.grid.size
ts = img.size[0] // width
overlay = Image.new("RGBA", img.size, (255, 255, 255, 0))
draw = ImageDraw.Draw(overlay)
if not high:
high = np.max(list(heatmap_data.values()))
colors = list(Color("lightblue").range_to(Color("darkblue"), high - low))
for x in range(1, env.width - 1):
for y in range(1, env.height - 1):
if isinstance(env.grid.get(x, y), Lava):
for dir in range(0,4):
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=RED)
continue
if isinstance(env.grid.get(x, y), Goal):
for dir in range(0,4):
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=GREEN)
continue
else:
for dir in range(0,4):
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=WHITE)
try:
color = tuple(np.array([int(255 * v) for v in colors[heatmap_data[State(x,y,dir)] - low - 1].rgb]))
draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=color)
except KeyError as e:
logger.trace(f"No Value for {x}, {y}, {dir}.")
img = Image.alpha_composite(img, overlay)
#img.show()
return img
def expname(args, env):
return f"{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{env}_{args.shielding}_{args.shield_comparison}_{args.shield_value}_{args.expname_suffix}"
def create_log_dir(args, env):
log_dir = f"{args.log_dir}/{expname(args, env)}"
os.makedirs(log_dir, exist_ok=True)
return log_dir
def get_allowed_actions_mask_with_view(actions):
action_mask = [0.0] * 7
for action_label in actions:
if "move" in action_label:
action_mask[2] = 1.0
elif "left" in action_label:
action_mask[0] = 1.0
elif "right" in action_label:
action_mask[1] = 1.0
elif "pickup" in action_label:
action_mask[3] = 1.0
elif "drop" in action_label:
action_mask[4] = 1.0
elif "toggle" in action_label:
action_mask[5] = 1.0
elif "done" in action_label:
action_mask[6] = 1.0
if "north" in action_label or "south" in action_label or "east" in action_label or "west" in action_label:
action_mask[2] = 1.0
return action_mask
def get_allowed_actions_mask_with_view(actions):
action_mask = [0.0] * 7
for action_label in actions:
action_label = action_label.lower()
if "move" in action_label:
action_mask[2] = 1.0
elif "left" in action_label:
action_mask[0] = 1.0
elif "right" in action_label:
action_mask[1] = 1.0
elif "pickup" in action_label:
action_mask[3] = 1.0
elif "drop" in action_label:
action_mask[4] = 1.0
elif "toggle" in action_label:
action_mask[5] = 1.0
elif "done" in action_label:
action_mask[6] = 1.0
if "north" in action_label or "south" in action_label or "east" in action_label or "west" in action_label:
action_mask[2] = 1.0
return action_mask
def get_allowed_actions_mask(actions):
action_mask = [0.0] * 4
for action_label in actions:
action_label = action_label.lower()
if "north" in action_label:
action_mask[3] = 1.0
elif "south" in action_label:
action_mask[1] = 1.0
elif "east" in action_label:
action_mask[0] = 1.0
elif "west" in action_label:
action_mask[2] = 1.0
return action_mask
def get_allowed_actions_mask_for_labels(action_labels, ignore_view: bool = False):
actions = [label for labels in action_labels for label in list(labels)]
if ignore_view:
return get_allowed_actions_mask(actions)
else:
return get_allowed_actions_mask_with_view(actions)
def dir_path(string):
if os.path.isdir(string):
return string
else:
raise NotADirectoryError(string)
def common_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--env",
help="gym environment to load",
choices=gym.envs.registry.keys(),
default="MiniGrid-LavaSlipperyCliff-16x13-v0")
parser.add_argument("--grid_file", default="grid.txt")
parser.add_argument("--prism_file", default=None)
parser.add_argument("--prism_output_file", default="grid.prism")
parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--formula", default="Pmax=? [G !AgentIsOnLava]")
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
parser.add_argument("--steps", default=20_000, type=int)
parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
parser.add_argument("--prism_config", default=None)
parser.add_argument("--shield_value", default=0.9, type=float)
parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute'])
parser.add_argument("--nocleanup", action=argparse.BooleanOptionalAction)
parser.add_argument("--expname_suffix", default="")
parser.add_argument("--action_comparison", choices=["absolute", "relative"])
parser.add_argument("--weighted", action="store_true")
parser.add_argument("--min_frequency", type=int, default=0)
return parser
class MiniWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.env = env
def reset(self, *, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options=options)
return obs.transpose(1,0,2), info
def observations(self, obs):
return obs
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return obs.transpose(1,0,2), reward, terminated, truncated, info
def id_to_symbolic(state_id: int, state_valuations) -> State:
return to_symbolic(state_valuations.get_string(state_id))
def to_symbolic(state: str) -> State:
ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state))
booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state)
booleans = {b[1]: False if b[0] == "!" else True for b in booleans}
if int(ints.get("previousActionAgent", 7)) != 7:
return None
if int(ints.get("clock", 0)) != 0:
return None
return to_state(ints, booleans)
def dataclass_from_dict(klass, dikt):
try:
fieldtypes = klass.__annotations__ # this can be also simplified I believe
new_dikt = dict()
for k in list(fieldtypes.keys()):
try:
new_dikt[k] = dikt[k]
except KeyError as e:
logger.trace(f"Could not fetch value for field {k} from input, defaulting to {klass.__dict__[k]=}")
new_dikt[k] = klass.__dict__[k]
return klass(**{f: dataclass_from_dict(fieldtypes[f], new_dikt[f]) for f in new_dikt})
except KeyError as e:
logger.warning(str(e))
except Exception as e:
if isinstance(dikt, (tuple, list)):
return [dataclass_from_dict(klass.__args__[0], f) for f in dikt]
return dikt
def id_to_symbolic_highway_state(state_id, state_valuations):
state_valuation = state_valuations.get_string(state_id)
ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation))
ints = {k:int(v) for k,v in ints.items()}
booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation)
booleans = {b[1]: False if b[0] == "!" else True for b in booleans}
return dataclass_from_dict(HighwayState, {**ints, **booleans})
def id_to_symbolic_taxinet_state(state_id, state_valuations):
state_valuation = state_valuations.get_string(state_id)
ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation))
ints = {k:int(v) for k,v in ints.items()}
booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation)
booleans = {b[1]: False if b[0] == "!" else True for b in booleans}
return dataclass_from_dict(TaxinetState, {**ints, **booleans})
def id_to_symbolic_traffic_state(state_id, state_valuations):
state_valuation = state_valuations.get_string(state_id)
ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation))
ints = {k:int(v) for k,v in ints.items()}
booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation)
booleans = {b[1]: False if b[0] == "!" else True for b in booleans}
return dataclass_from_dict(TrafficState, {**ints, **booleans})
def symbolic_minigrid_to_model_state(model, col, row, view=None):
if view is not None:
for model_state in model.states:
ms = id_to_symbolic(model_state, model.state_valuations)
if ms.colAgent == col and ms.rowAgent == row and ms.viewAgent == view:
return model_state
else:
for model_state in model.states:
ms = id_to_symbolic(model_state, model.state_valuations)
if ms.colAgent == col and ms.rowAgent == row:
return model_state
return None