You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

232 lines
9.7 KiB

from __future__ import annotations
from typing import Callable, Optional
from re import A
import gymnasium as gym
import numpy as np
import os, subprocess, sys, re
from minigrid_shield_handler import MiniGridShieldHandler
from utils import common_parser, create_dt_overlay_image
from utils import NUM_MOVEMENT_ACTIONS
from minigrid.core.state import to_state, State
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback, EvalCallback
from stable_baselines3.common.logger import Image
from dataclasses import dataclass
from loguru import logger
class ShieldInterferenceTracker(gym.core.Wrapper):
def __init__(self, env, interference_delta: Optional[Callable] = None):
super().__init__(env)
self.interference_tracker = dict()
self.visited_tracker = dict()
self.interference_delta = interference_delta
def step(self, action):
obs, rew, done, truncated, info = self.env.step(action)
try:
symbolic_state = self.env.get_symbolic_state()
mask = self.env.get_wrapper_attr("shield")[symbolic_state]
if sum(mask) < NUM_MOVEMENT_ACTIONS:
self.interference_tracker[symbolic_state] = self.interference_tracker.get(symbolic_state, 0) + 1
if self.interference_delta and callable(self.interference_delta):
for state in self.interference_delta(symbolic_state):
self.interference_tracker[state] = self.interference_tracker.get(state, 1)
self.visited_tracker[symbolic_state] = self.visited_tracker.get(symbolic_state, 0) + 1
except Exception as e:
pass
#logger.info(self.env.get_wrapper_attr("shield"))
#logger.error(f"Exception: {e}")
#assert False
return obs, rew, done, truncated, info
def tracking_statistics(self):
return self.visited_tracker, self.interference_tracker
def clear_statistics(self) -> None:
self.visited_tracker.clear()
self.interference_tracker.clear()
class MiniGridSbShieldingWrapper(gym.core.Wrapper):
def __init__(self,
env,
shield_handler : MiniGridShieldHandler,
create_shield_at_reset = False,
):
super().__init__(env)
self.shield_handler = shield_handler
self.create_shield_at_reset = create_shield_at_reset
shield = self.shield_handler.create_shield(env=self.env)
self.shield = shield
def create_action_mask(self):
try:
return self.shield[self.env.get_symbolic_state()]
except:
return [0.0] * 4
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
if self.create_shield_at_reset:
shield = self.shield_handler.create_shield(env=self.env)
self.shield = shield
return obs, infos
def step(self, action):
obs, rew, done, truncated, info = self.env.step(action)
info["state_not_modelled"] = not self.shield.__contains__(self.env.get_symbolic_state())
return obs, rew, done, truncated, info
def parse_sb3_arguments():
parser = common_parser()
args = parser.parse_args()
return args
class DecisionTreeAnalysis(BaseCallback):
def __init__(self, env, envName, shield_handler, isodate, weighted, action_comparison, complete_shield_num_nodes, saveimg : bool = False, clear: bool = False, verbose=0):
logger.info(f"Init DecisionTreeAnalysis")
super().__init__(verbose)
self.env = env
self.envName = envName
self.shield_handler = shield_handler
self.isodate = isodate
self.weighted = weighted
self.action_comparison = action_comparison
self.WORKSPACE = "./"
self.clear_statistics = clear
self.saveimg = saveimg
self.complete_shield_num_nodes = complete_shield_num_nodes
def _on_step(self) -> bool:
visited, interference = self.env.tracking_statistics()
self.shield = self.env.shield
#for min_frequency in [0, 10, 50, 100]:
for min_frequency in [0]:
for complement in [False]:#[False, True]:
csv_file_name = self.__write_csv(visited, "visited", complement, min_frequency)
self.__run_analysis(csv_file_name, ignore_mask=[1.0,1.0,1.0])
#csv_file_name = self.__write_csv(interference, "interference", complement, min_frequency)
#self.__run_analysis(csv_file_name, ignore_mask=[1.0,1.0,1.0])
if self.clear_statistics: self.env.get_wrapper_attr("clear_statistics")()
return True
def __run_analysis(self, csv_file_name, ignore_mask) -> None:
logger.info(f"{self.num_timesteps} | Running analysis")
self.__create_dt(csv_file_name)
self.__run_dot(csv_file_name)
self.__parse_modified_dot_and_create_dt(csv_file_name)
if self.saveimg:
img = self.__create_image()
if img:
experiment_name = csv_file_name.split("/")[-1]
img.save(f"./results/{experiment_name.split('.')[0]}.png")
else:
self.decision_tree.evaluate_restrictions(ignore_mask=ignore_mask)
self.__compare_to_shield()
def __write_csv(self, statistics: dict, tracking: str, complement: bool, min_frequency: int):
csv_file_name = f"./results/{self.isodate}_{self.envName}{'_' if not complement else '_complemented_'}{tracking}_shield_{self.action_comparison}_{['noweight','weighted_min'+str(min_frequency)][self.weighted]}_steps_{self.num_timesteps:08}.csv"
write_csv(csv_file_name, self.shield, statistics, self.shield_handler.num_features, self.action_comparison, complement, min_frequency, self.weighted)
return csv_file_name
def __create_dt(self, csvFile):
logger.info(f"{self.num_timesteps} | Running DTcontrol")
_, num_nodes = run_dtcontrol(csvFile, self.WORKSPACE, self.shield_handler.feature_space)
self.logger.record("decision_tree/num_nodes", num_nodes)
self.logger.record("decision_tree/complete_shield", self.complete_shield_num_nodes)
#graphdf = createDTs(os.path.join(self.WORKSPACE, csvFile), self.WORKSPACE)
#prefix = csvFile.replace('.csv', '')
#plot_dts(graphdf, prefix)
def __parse_modified_dot_and_create_dt(self, csvFile):
logger.info(f"{self.num_timesteps} | Creating decision trees in memory")
PATH = dt_dir_for_experiment(csvFile)
self.decision_tree = create_decision_tree_from_csv(f"{PATH}/default_modified.dot", {feature: i for i, feature in enumerate(self.shield_handler.feature_space)})
def __create_image(self):
logger.info(f"{self.num_timesteps} | Rendering decision trees onto image")
try:
#self.decision_tree.leaves_to_action_dict
img = create_dt_overlay_image(self.env, self.decision_tree)
#img.show()
return img
except UnboundLocalError as e:
logger.warning(f"{e}. The tree is likely a single node")
return None
def __run_dot(self, csvFile):
logger.info(f"{self.num_timesteps} | Plotting decision trees using dot")
PATH = dt_dir_for_experiment(csvFile)
run_dot(os.path.join(PATH, "default_modified.dot"))
def __compare_to_shield(self):
logger.info(f"{self.num_timesteps} | Comparing decision trees with complete shield")
state_valuations = self.shield_handler.model.state_valuations
minigrid_states = list()
for state_id in self.shield_handler.model.states:
state_valuation = state_valuations.get_string(state_id)
ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation))
booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation)
booleans = {b[1]: False if b[0] == "!" else True for b in booleans}
if int(ints.get("previousActionAgent", 7)) != 7:
continue
if int(ints.get("clock", 0)) != 0:
continue
minigrid_states.append(to_state(ints, booleans))
comparison = self.decision_tree.compare_to_shield(self.env, minigrid_states, self.shield_handler.create_shield(env=self.env))
self.logger.record("decision_tree/comparison", comparison)
class InfoCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard.
"""
def __init__(self, verbose=0):
super().__init__(verbose)
self.sum_goal = 0
self.sum_lava = 0
self.sum_collisions = 0
self.sum_opened_door = 0
self.sum_picked_up = 0
self.no_shield_action = 0
def _on_step(self) -> bool:
infos = self.locals["infos"][0]
if infos["reached_goal"]:
self.sum_goal += 1
if infos["ran_into_lava"]:
self.sum_lava += 1
self.logger.record("info/sum_reached_goal", self.sum_goal)
self.logger.record("info/sum_ran_into_lava", self.sum_lava)
if "collision" in infos:
if infos["collision"]:
self.sum_collisions += 1
self.logger.record("info/sum_collision", self.sum_collisions)
if "opened_door" in infos:
if infos["opened_door"]:
self.sum_opened_door += 1
self.logger.record("info/sum_opened_door", self.sum_opened_door)
if "picked_up" in infos:
if infos["picked_up"]:
self.sum_picked_up += 1
self.logger.record("info/sum_picked_up", self.sum_picked_up)
if "no_shield_action" in infos:
if infos["no_shield_action"]:
self.no_shield_action += 1
self.logger.record("info/no_shield_action", self.no_shield_action)
return True