from __future__ import annotations from typing import Callable, Optional from re import A import gymnasium as gym import numpy as np import os, subprocess, sys, re from minigrid_shield_handler import MiniGridShieldHandler from utils import common_parser, create_dt_overlay_image from utils import NUM_MOVEMENT_ACTIONS from minigrid.core.state import to_state, State from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback, EvalCallback from stable_baselines3.common.logger import Image from dataclasses import dataclass from loguru import logger class ShieldInterferenceTracker(gym.core.Wrapper): def __init__(self, env, interference_delta: Optional[Callable] = None): super().__init__(env) self.interference_tracker = dict() self.visited_tracker = dict() self.interference_delta = interference_delta def step(self, action): obs, rew, done, truncated, info = self.env.step(action) try: symbolic_state = self.env.get_symbolic_state() mask = self.env.get_wrapper_attr("shield")[symbolic_state] if sum(mask) < NUM_MOVEMENT_ACTIONS: self.interference_tracker[symbolic_state] = self.interference_tracker.get(symbolic_state, 0) + 1 if self.interference_delta and callable(self.interference_delta): for state in self.interference_delta(symbolic_state): self.interference_tracker[state] = self.interference_tracker.get(state, 1) self.visited_tracker[symbolic_state] = self.visited_tracker.get(symbolic_state, 0) + 1 except Exception as e: pass #logger.info(self.env.get_wrapper_attr("shield")) #logger.error(f"Exception: {e}") #assert False return obs, rew, done, truncated, info def tracking_statistics(self): return self.visited_tracker, self.interference_tracker def clear_statistics(self) -> None: self.visited_tracker.clear() self.interference_tracker.clear() class MiniGridSbShieldingWrapper(gym.core.Wrapper): def __init__(self, env, shield_handler : MiniGridShieldHandler, create_shield_at_reset = False, ): super().__init__(env) self.shield_handler = shield_handler self.create_shield_at_reset = create_shield_at_reset shield = self.shield_handler.create_shield(env=self.env) self.shield = shield def create_action_mask(self): try: return self.shield[self.env.get_symbolic_state()] except: return [0.0] * 4 def reset(self, *, seed=None, options=None): obs, infos = self.env.reset(seed=seed, options=options) if self.create_shield_at_reset: shield = self.shield_handler.create_shield(env=self.env) self.shield = shield return obs, infos def step(self, action): obs, rew, done, truncated, info = self.env.step(action) info["state_not_modelled"] = not self.shield.__contains__(self.env.get_symbolic_state()) return obs, rew, done, truncated, info def parse_sb3_arguments(): parser = common_parser() args = parser.parse_args() return args class DecisionTreeAnalysis(BaseCallback): def __init__(self, env, envName, shield_handler, isodate, weighted, action_comparison, complete_shield_num_nodes, saveimg : bool = False, clear: bool = False, verbose=0): logger.info(f"Init DecisionTreeAnalysis") super().__init__(verbose) self.env = env self.envName = envName self.shield_handler = shield_handler self.isodate = isodate self.weighted = weighted self.action_comparison = action_comparison self.WORKSPACE = "./" self.clear_statistics = clear self.saveimg = saveimg self.complete_shield_num_nodes = complete_shield_num_nodes def _on_step(self) -> bool: visited, interference = self.env.tracking_statistics() self.shield = self.env.shield #for min_frequency in [0, 10, 50, 100]: for min_frequency in [0]: for complement in [False]:#[False, True]: csv_file_name = self.__write_csv(visited, "visited", complement, min_frequency) self.__run_analysis(csv_file_name, ignore_mask=[1.0,1.0,1.0]) #csv_file_name = self.__write_csv(interference, "interference", complement, min_frequency) #self.__run_analysis(csv_file_name, ignore_mask=[1.0,1.0,1.0]) if self.clear_statistics: self.env.get_wrapper_attr("clear_statistics")() return True def __run_analysis(self, csv_file_name, ignore_mask) -> None: logger.info(f"{self.num_timesteps} | Running analysis") self.__create_dt(csv_file_name) self.__run_dot(csv_file_name) self.__parse_modified_dot_and_create_dt(csv_file_name) if self.saveimg: img = self.__create_image() if img: experiment_name = csv_file_name.split("/")[-1] img.save(f"./results/{experiment_name.split('.')[0]}.png") else: self.decision_tree.evaluate_restrictions(ignore_mask=ignore_mask) self.__compare_to_shield() def __write_csv(self, statistics: dict, tracking: str, complement: bool, min_frequency: int): csv_file_name = f"./results/{self.isodate}_{self.envName}{'_' if not complement else '_complemented_'}{tracking}_shield_{self.action_comparison}_{['noweight','weighted_min'+str(min_frequency)][self.weighted]}_steps_{self.num_timesteps:08}.csv" write_csv(csv_file_name, self.shield, statistics, self.shield_handler.num_features, self.action_comparison, complement, min_frequency, self.weighted) return csv_file_name def __create_dt(self, csvFile): logger.info(f"{self.num_timesteps} | Running DTcontrol") _, num_nodes = run_dtcontrol(csvFile, self.WORKSPACE, self.shield_handler.feature_space) self.logger.record("decision_tree/num_nodes", num_nodes) self.logger.record("decision_tree/complete_shield", self.complete_shield_num_nodes) #graphdf = createDTs(os.path.join(self.WORKSPACE, csvFile), self.WORKSPACE) #prefix = csvFile.replace('.csv', '') #plot_dts(graphdf, prefix) def __parse_modified_dot_and_create_dt(self, csvFile): logger.info(f"{self.num_timesteps} | Creating decision trees in memory") PATH = dt_dir_for_experiment(csvFile) self.decision_tree = create_decision_tree_from_csv(f"{PATH}/default_modified.dot", {feature: i for i, feature in enumerate(self.shield_handler.feature_space)}) def __create_image(self): logger.info(f"{self.num_timesteps} | Rendering decision trees onto image") try: #self.decision_tree.leaves_to_action_dict img = create_dt_overlay_image(self.env, self.decision_tree) #img.show() return img except UnboundLocalError as e: logger.warning(f"{e}. The tree is likely a single node") return None def __run_dot(self, csvFile): logger.info(f"{self.num_timesteps} | Plotting decision trees using dot") PATH = dt_dir_for_experiment(csvFile) run_dot(os.path.join(PATH, "default_modified.dot")) def __compare_to_shield(self): logger.info(f"{self.num_timesteps} | Comparing decision trees with complete shield") state_valuations = self.shield_handler.model.state_valuations minigrid_states = list() for state_id in self.shield_handler.model.states: state_valuation = state_valuations.get_string(state_id) ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation)) booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation) booleans = {b[1]: False if b[0] == "!" else True for b in booleans} if int(ints.get("previousActionAgent", 7)) != 7: continue if int(ints.get("clock", 0)) != 0: continue minigrid_states.append(to_state(ints, booleans)) comparison = self.decision_tree.compare_to_shield(self.env, minigrid_states, self.shield_handler.create_shield(env=self.env)) self.logger.record("decision_tree/comparison", comparison) class InfoCallback(BaseCallback): """ Custom callback for plotting additional values in tensorboard. """ def __init__(self, verbose=0): super().__init__(verbose) self.sum_goal = 0 self.sum_lava = 0 self.sum_collisions = 0 self.sum_opened_door = 0 self.sum_picked_up = 0 self.no_shield_action = 0 def _on_step(self) -> bool: infos = self.locals["infos"][0] if infos["reached_goal"]: self.sum_goal += 1 if infos["ran_into_lava"]: self.sum_lava += 1 self.logger.record("info/sum_reached_goal", self.sum_goal) self.logger.record("info/sum_ran_into_lava", self.sum_lava) if "collision" in infos: if infos["collision"]: self.sum_collisions += 1 self.logger.record("info/sum_collision", self.sum_collisions) if "opened_door" in infos: if infos["opened_door"]: self.sum_opened_door += 1 self.logger.record("info/sum_opened_door", self.sum_opened_door) if "picked_up" in infos: if infos["picked_up"]: self.sum_picked_up += 1 self.logger.record("info/sum_picked_up", self.sum_picked_up) if "no_shield_action" in infos: if infos["no_shield_action"]: self.no_shield_action += 1 self.logger.record("info/no_shield_action", self.no_shield_action) return True