import stormpy import stormpy.core import stormpy.simulator import stormpy.shields import stormpy.logic import stormpy.examples import stormpy.examples.files from enum import Enum from abc import ABC from PIL import Image, ImageDraw import re import sys import tempfile, datetime, shutil import numpy as np import gymnasium as gym from minigrid.core.actions import Actions from minigrid.core.state import to_state, State import os import time import argparse def tic(): #Homemade version of matlab tic and toc functions: https://stackoverflow.com/a/18903019 global startTime_for_tictoc startTime_for_tictoc = time.time() def toc(): if 'startTime_for_tictoc' in globals(): print("Elapsed time is " + str(time.time() - startTime_for_tictoc) + " seconds.") else: print("Toc: start time not set") class ShieldingConfig(Enum): Training = 'training' Evaluation = 'evaluation' Disabled = 'none' Full = 'full' def __str__(self) -> str: return self.value def shield_needed(shielding): return shielding in [ShieldingConfig.Training, ShieldingConfig.Evaluation, ShieldingConfig.Full] def shielded_evaluation(shielding): return shielding in [ShieldingConfig.Evaluation, ShieldingConfig.Full] def shielded_training(shielding): return shielding in [ShieldingConfig.Training, ShieldingConfig.Full] class ShieldHandler(ABC): def __init__(self) -> None: pass def create_shield(self, **kwargs) -> dict: pass class MiniGridShieldHandler(ShieldHandler): def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute', nocleanup=False, prism_file=None) -> None: self.tmp_dir_name = f"shielding_files_{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{next(tempfile._get_candidate_names())}" os.mkdir(self.tmp_dir_name) self.grid_file = self.tmp_dir_name + "/" + grid_file self.grid_to_prism_binary = grid_to_prism_binary self.prism_path = self.tmp_dir_name + "/" + prism_path self.prism_config = prism_config self.prism_file = prism_file self.action_dictionary = None self.formula = formula shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE self.shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comparison, shield_value) self.nocleanup = nocleanup def __del__(self): if not self.nocleanup: shutil.rmtree(self.tmp_dir_name) def __export_grid_to_text(self, env): with open(self.grid_file, "w") as f: f.write(env.printGrid(init=True)) def __create_prism(self): if self.prism_file is not None: print(self.prism_file) print(self.prism_path) shutil.copyfile(self.prism_file, self.prism_path) return if self.prism_config is None: result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path}") else: result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}") assert result == 0, "Prism file could not be generated" def __create_shield_dict(self): program = stormpy.parse_prism_program(self.prism_path) formulas = stormpy.parse_properties_for_prism_program(self.formula, program) options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) options.set_build_state_valuations(True) options.set_build_choice_labels(True) options.set_build_all_labels() print(f"LOG: Starting with explicit model creation...") tic() model = stormpy.build_sparse_model_with_options(program, options) toc() print(f"LOG: Starting with model checking...") tic() result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=self.shield_expression) toc() assert result.has_shield shield = result.shield action_dictionary = dict() shield_scheduler = shield.construct() state_valuations = model.state_valuations choice_labeling = model.choice_labeling if self.nocleanup: stormpy.shields.export_shield(model, shield, self.tmp_dir_name + "/shield") print(f"LOG: Starting to translate shield...") tic() for stateID in model.states: choice = shield_scheduler.get_choice(stateID) choices = choice.choice_map state_valuation = state_valuations.get_string(stateID) ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation)) booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation) booleans = {b[1]: False if b[0] == "!" else True for b in booleans} if int(ints.get("previousActionAgent", 7)) != 7: continue if int(ints.get("clock", 0)) != 0: continue state = to_state(ints, booleans) #print(f"{state} got added with actions:") #print(get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices])) action_dictionary[state] = get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices]) toc() #print(f"{len(action_dictionary)} states in the shield") self.action_dictionary = action_dictionary # Remove shielding_files_* immediatelly, only to remove clutter for the demo if not self.nocleanup: shutil.rmtree(self.tmp_dir_name) return action_dictionary def create_shield(self, **kwargs): if self.action_dictionary is not None: #print("Returning already calculated shield") return self.action_dictionary env = kwargs["env"] self.__export_grid_to_text(env) self.__create_prism() print("Computing new shield") return self.__create_shield_dict() def rectangle_for_overlay(x, y, dir, tile_size, width=2, offset=0, thickness=0): if dir == 0: return (((x+1)*tile_size-width-thickness,y*tile_size+offset), ((x+1)*tile_size,(y+1)*tile_size-offset)) if dir == 1: return ((x*tile_size+offset,(y+1)*tile_size-width-thickness), ((x+1)*tile_size-offset,(y+1)*tile_size)) if dir == 2: return ((x*tile_size,y*tile_size+offset), (x*tile_size+width+thickness,(y+1)*tile_size-offset)) if dir == 3: return ((x*tile_size+offset,y*tile_size), ((x+1)*tile_size-offset,y*tile_size+width+thickness)) def triangle_for_overlay(x,y, dir, tile_size): offset = tile_size/2 if dir == 0: return [((x+1)*tile_size,y*tile_size), ((x+1)*tile_size,(y+1)*tile_size), ((x+1)*tile_size-offset, y*tile_size+tile_size/2)] if dir == 1: return [(x*tile_size,(y+1)*tile_size), ((x+1)*tile_size,(y+1)*tile_size), (x*tile_size+tile_size/2, (y+1)*tile_size-offset)] if dir == 2: return [(x*tile_size,y*tile_size), (x*tile_size,(y+1)*tile_size), (x*tile_size+offset, y*tile_size+tile_size/2)] if dir == 3: return [(x*tile_size,y*tile_size), ((x+1)*tile_size,y*tile_size), (x*tile_size+tile_size/2, y*tile_size+offset)] def create_shield_overlay_image(env, shield): env.reset() img = Image.fromarray(env.render()).convert("RGBA") ts = env.tile_size overlay = Image.new("RGBA", img.size, (255, 255, 255, 0)) draw = ImageDraw.Draw(overlay) red = (255,0,0,200) for x in range(0, env.width): for y in range(0, env.height): for dir in range(0,4): try: if shield[State(x, y, dir, "")][2] <= 0.0: draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=red) #else: # draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=(0, 200, 0, 96)) except KeyError: pass img = Image.alpha_composite(img, overlay) img.show() def expname(args): return f"{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{args.env}_{args.shielding}_{args.shield_comparison}_{args.shield_value}_{args.expname_suffix}" def create_log_dir(args): log_dir = f"{args.log_dir}/{expname(args)}" os.makedirs(log_dir, exist_ok=True) return log_dir def get_allowed_actions_mask(actions): action_mask = [0.0] * 7 actions_labels = [label for labels in actions for label in list(labels)] for action_label in actions_labels: if "move" in action_label: action_mask[2] = 1.0 elif "left" in action_label: action_mask[0] = 1.0 elif "right" in action_label: action_mask[1] = 1.0 elif "pickup" in action_label: action_mask[3] = 1.0 elif "drop" in action_label: action_mask[4] = 1.0 elif "toggle" in action_label: action_mask[5] = 1.0 elif "done" in action_label: action_mask[6] = 1.0 return action_mask def common_parser(): parser = argparse.ArgumentParser() parser.add_argument("--env", help="gym environment to load", choices=gym.envs.registry.keys(), default="MiniGrid-LavaSlipperyCliff-16x13-v0") parser.add_argument("--grid_file", default="grid.txt") parser.add_argument("--prism_file", default=None) parser.add_argument("--prism_output_file", default="grid.prism") parser.add_argument("--log_dir", default="../log_results/") parser.add_argument("--formula", default="Pmax=? [G !AgentIsOnLava]") parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) parser.add_argument("--steps", default=20_000, type=int) parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction) parser.add_argument("--prism_config", default=None) parser.add_argument("--shield_value", default=0.9, type=float) parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute']) parser.add_argument("--nocleanup", action=argparse.BooleanOptionalAction) parser.add_argument("--expname_suffix", default="") return parser class MiniWrapper(gym.Wrapper): def __init__(self, env): super().__init__(env) self.env = env def reset(self, *, seed=None, options=None): obs, info = self.env.reset(seed=seed, options=options) return obs.transpose(1,0,2), info def observations(self, obs): return obs def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) return obs.transpose(1,0,2), reward, terminated, truncated, info