You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
232 lines
9.7 KiB
232 lines
9.7 KiB
from __future__ import annotations
|
|
from typing import Callable, Optional
|
|
|
|
from re import A
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
|
|
import os, subprocess, sys, re
|
|
|
|
from minigrid_shield_handler import MiniGridShieldHandler
|
|
from utils import common_parser, create_dt_overlay_image
|
|
from utils import NUM_MOVEMENT_ACTIONS
|
|
from minigrid.core.state import to_state, State
|
|
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback, EvalCallback
|
|
from stable_baselines3.common.logger import Image
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from loguru import logger
|
|
|
|
|
|
|
|
class ShieldInterferenceTracker(gym.core.Wrapper):
|
|
def __init__(self, env, interference_delta: Optional[Callable] = None):
|
|
super().__init__(env)
|
|
|
|
self.interference_tracker = dict()
|
|
self.visited_tracker = dict()
|
|
self.interference_delta = interference_delta
|
|
|
|
def step(self, action):
|
|
obs, rew, done, truncated, info = self.env.step(action)
|
|
try:
|
|
symbolic_state = self.env.get_symbolic_state()
|
|
mask = self.env.get_wrapper_attr("shield")[symbolic_state]
|
|
if sum(mask) < NUM_MOVEMENT_ACTIONS:
|
|
self.interference_tracker[symbolic_state] = self.interference_tracker.get(symbolic_state, 0) + 1
|
|
if self.interference_delta and callable(self.interference_delta):
|
|
for state in self.interference_delta(symbolic_state):
|
|
self.interference_tracker[state] = self.interference_tracker.get(state, 1)
|
|
|
|
self.visited_tracker[symbolic_state] = self.visited_tracker.get(symbolic_state, 0) + 1
|
|
except Exception as e:
|
|
pass
|
|
#logger.info(self.env.get_wrapper_attr("shield"))
|
|
#logger.error(f"Exception: {e}")
|
|
#assert False
|
|
return obs, rew, done, truncated, info
|
|
|
|
def tracking_statistics(self):
|
|
return self.visited_tracker, self.interference_tracker
|
|
|
|
def clear_statistics(self) -> None:
|
|
self.visited_tracker.clear()
|
|
self.interference_tracker.clear()
|
|
|
|
class MiniGridSbShieldingWrapper(gym.core.Wrapper):
|
|
def __init__(self,
|
|
env,
|
|
shield_handler : MiniGridShieldHandler,
|
|
create_shield_at_reset = False,
|
|
):
|
|
super().__init__(env)
|
|
self.shield_handler = shield_handler
|
|
self.create_shield_at_reset = create_shield_at_reset
|
|
|
|
shield = self.shield_handler.create_shield(env=self.env)
|
|
self.shield = shield
|
|
|
|
def create_action_mask(self):
|
|
try:
|
|
return self.shield[self.env.get_symbolic_state()]
|
|
except:
|
|
return [0.0] * 4
|
|
|
|
def reset(self, *, seed=None, options=None):
|
|
obs, infos = self.env.reset(seed=seed, options=options)
|
|
|
|
if self.create_shield_at_reset:
|
|
shield = self.shield_handler.create_shield(env=self.env)
|
|
self.shield = shield
|
|
return obs, infos
|
|
|
|
def step(self, action):
|
|
obs, rew, done, truncated, info = self.env.step(action)
|
|
info["state_not_modelled"] = not self.shield.__contains__(self.env.get_symbolic_state())
|
|
return obs, rew, done, truncated, info
|
|
|
|
def parse_sb3_arguments():
|
|
parser = common_parser()
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
class DecisionTreeAnalysis(BaseCallback):
|
|
def __init__(self, env, envName, shield_handler, isodate, weighted, action_comparison, complete_shield_num_nodes, saveimg : bool = False, clear: bool = False, verbose=0):
|
|
logger.info(f"Init DecisionTreeAnalysis")
|
|
super().__init__(verbose)
|
|
self.env = env
|
|
self.envName = envName
|
|
self.shield_handler = shield_handler
|
|
self.isodate = isodate
|
|
self.weighted = weighted
|
|
self.action_comparison = action_comparison
|
|
self.WORKSPACE = "./"
|
|
|
|
self.clear_statistics = clear
|
|
self.saveimg = saveimg
|
|
self.complete_shield_num_nodes = complete_shield_num_nodes
|
|
|
|
def _on_step(self) -> bool:
|
|
visited, interference = self.env.tracking_statistics()
|
|
self.shield = self.env.shield
|
|
#for min_frequency in [0, 10, 50, 100]:
|
|
for min_frequency in [0]:
|
|
for complement in [False]:#[False, True]:
|
|
csv_file_name = self.__write_csv(visited, "visited", complement, min_frequency)
|
|
self.__run_analysis(csv_file_name, ignore_mask=[1.0,1.0,1.0])
|
|
#csv_file_name = self.__write_csv(interference, "interference", complement, min_frequency)
|
|
#self.__run_analysis(csv_file_name, ignore_mask=[1.0,1.0,1.0])
|
|
if self.clear_statistics: self.env.get_wrapper_attr("clear_statistics")()
|
|
return True
|
|
|
|
def __run_analysis(self, csv_file_name, ignore_mask) -> None:
|
|
logger.info(f"{self.num_timesteps} | Running analysis")
|
|
self.__create_dt(csv_file_name)
|
|
self.__run_dot(csv_file_name)
|
|
self.__parse_modified_dot_and_create_dt(csv_file_name)
|
|
if self.saveimg:
|
|
img = self.__create_image()
|
|
if img:
|
|
experiment_name = csv_file_name.split("/")[-1]
|
|
img.save(f"./results/{experiment_name.split('.')[0]}.png")
|
|
else:
|
|
self.decision_tree.evaluate_restrictions(ignore_mask=ignore_mask)
|
|
self.__compare_to_shield()
|
|
|
|
|
|
|
|
def __write_csv(self, statistics: dict, tracking: str, complement: bool, min_frequency: int):
|
|
csv_file_name = f"./results/{self.isodate}_{self.envName}{'_' if not complement else '_complemented_'}{tracking}_shield_{self.action_comparison}_{['noweight','weighted_min'+str(min_frequency)][self.weighted]}_steps_{self.num_timesteps:08}.csv"
|
|
write_csv(csv_file_name, self.shield, statistics, self.shield_handler.num_features, self.action_comparison, complement, min_frequency, self.weighted)
|
|
return csv_file_name
|
|
|
|
def __create_dt(self, csvFile):
|
|
logger.info(f"{self.num_timesteps} | Running DTcontrol")
|
|
_, num_nodes = run_dtcontrol(csvFile, self.WORKSPACE, self.shield_handler.feature_space)
|
|
self.logger.record("decision_tree/num_nodes", num_nodes)
|
|
self.logger.record("decision_tree/complete_shield", self.complete_shield_num_nodes)
|
|
#graphdf = createDTs(os.path.join(self.WORKSPACE, csvFile), self.WORKSPACE)
|
|
#prefix = csvFile.replace('.csv', '')
|
|
#plot_dts(graphdf, prefix)
|
|
|
|
def __parse_modified_dot_and_create_dt(self, csvFile):
|
|
logger.info(f"{self.num_timesteps} | Creating decision trees in memory")
|
|
PATH = dt_dir_for_experiment(csvFile)
|
|
self.decision_tree = create_decision_tree_from_csv(f"{PATH}/default_modified.dot", {feature: i for i, feature in enumerate(self.shield_handler.feature_space)})
|
|
|
|
def __create_image(self):
|
|
logger.info(f"{self.num_timesteps} | Rendering decision trees onto image")
|
|
try:
|
|
#self.decision_tree.leaves_to_action_dict
|
|
img = create_dt_overlay_image(self.env, self.decision_tree)
|
|
#img.show()
|
|
return img
|
|
except UnboundLocalError as e:
|
|
logger.warning(f"{e}. The tree is likely a single node")
|
|
return None
|
|
|
|
def __run_dot(self, csvFile):
|
|
logger.info(f"{self.num_timesteps} | Plotting decision trees using dot")
|
|
PATH = dt_dir_for_experiment(csvFile)
|
|
run_dot(os.path.join(PATH, "default_modified.dot"))
|
|
|
|
def __compare_to_shield(self):
|
|
logger.info(f"{self.num_timesteps} | Comparing decision trees with complete shield")
|
|
state_valuations = self.shield_handler.model.state_valuations
|
|
minigrid_states = list()
|
|
for state_id in self.shield_handler.model.states:
|
|
state_valuation = state_valuations.get_string(state_id)
|
|
ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation))
|
|
booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation)
|
|
booleans = {b[1]: False if b[0] == "!" else True for b in booleans}
|
|
if int(ints.get("previousActionAgent", 7)) != 7:
|
|
continue
|
|
if int(ints.get("clock", 0)) != 0:
|
|
continue
|
|
minigrid_states.append(to_state(ints, booleans))
|
|
comparison = self.decision_tree.compare_to_shield(self.env, minigrid_states, self.shield_handler.create_shield(env=self.env))
|
|
self.logger.record("decision_tree/comparison", comparison)
|
|
|
|
|
|
class InfoCallback(BaseCallback):
|
|
"""
|
|
Custom callback for plotting additional values in tensorboard.
|
|
"""
|
|
|
|
def __init__(self, verbose=0):
|
|
super().__init__(verbose)
|
|
self.sum_goal = 0
|
|
self.sum_lava = 0
|
|
self.sum_collisions = 0
|
|
self.sum_opened_door = 0
|
|
self.sum_picked_up = 0
|
|
self.no_shield_action = 0
|
|
|
|
def _on_step(self) -> bool:
|
|
infos = self.locals["infos"][0]
|
|
if infos["reached_goal"]:
|
|
self.sum_goal += 1
|
|
if infos["ran_into_lava"]:
|
|
self.sum_lava += 1
|
|
self.logger.record("info/sum_reached_goal", self.sum_goal)
|
|
self.logger.record("info/sum_ran_into_lava", self.sum_lava)
|
|
if "collision" in infos:
|
|
if infos["collision"]:
|
|
self.sum_collisions += 1
|
|
self.logger.record("info/sum_collision", self.sum_collisions)
|
|
if "opened_door" in infos:
|
|
if infos["opened_door"]:
|
|
self.sum_opened_door += 1
|
|
self.logger.record("info/sum_opened_door", self.sum_opened_door)
|
|
if "picked_up" in infos:
|
|
if infos["picked_up"]:
|
|
self.sum_picked_up += 1
|
|
self.logger.record("info/sum_picked_up", self.sum_picked_up)
|
|
if "no_shield_action" in infos:
|
|
if infos["no_shield_action"]:
|
|
self.no_shield_action += 1
|
|
self.logger.record("info/no_shield_action", self.no_shield_action)
|
|
return True
|