You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
128 lines
4.3 KiB
128 lines
4.3 KiB
import gymnasium as gym
|
|
import numpy as np
|
|
import random
|
|
from moviepy.editor import ImageSequenceClip
|
|
|
|
from utils import MiniGridShieldHandler, common_parser
|
|
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
|
|
from stable_baselines3.common.logger import Image
|
|
|
|
class MiniGridSbShieldingWrapper(gym.core.Wrapper):
|
|
def __init__(self,
|
|
env,
|
|
shield_handler : MiniGridShieldHandler,
|
|
create_shield_at_reset = False,
|
|
):
|
|
super().__init__(env)
|
|
self.shield_handler = shield_handler
|
|
self.create_shield_at_reset = create_shield_at_reset
|
|
|
|
shield = self.shield_handler.create_shield(env=self.env)
|
|
self.shield = shield
|
|
|
|
def create_action_mask(self):
|
|
try:
|
|
return self.shield[self.env.get_symbolic_state()]
|
|
except:
|
|
return [0.0] * 3 + [1.0] * 4
|
|
|
|
def reset(self, *, seed=None, options=None):
|
|
obs, infos = self.env.reset(seed=seed, options=options)
|
|
|
|
if self.create_shield_at_reset:
|
|
shield = self.shield_handler.create_shield(env=self.env)
|
|
self.shield = shield
|
|
return obs, infos
|
|
|
|
def step(self, action):
|
|
obs, rew, done, truncated, info = self.env.step(action)
|
|
|
|
return obs, rew, done, truncated, info
|
|
|
|
def parse_sb3_arguments():
|
|
parser = common_parser()
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
class ImageRecorderCallback(BaseCallback):
|
|
def __init__(self, eval_env, render_freq, n_eval_episodes, evaluation_method, log_dir, deterministic=True, verbose=0):
|
|
super().__init__(verbose)
|
|
|
|
self._eval_env = eval_env
|
|
self._render_freq = render_freq
|
|
self._n_eval_episodes = n_eval_episodes
|
|
self._deterministic = deterministic
|
|
self._evaluation_method = evaluation_method
|
|
self._log_dir = log_dir
|
|
|
|
def _on_training_start(self):
|
|
image = self.training_env.render(mode="rgb_array")
|
|
self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))
|
|
|
|
def _on_step(self) -> bool:
|
|
#if self.n_calls % self._render_freq == 0:
|
|
# self.record_video()
|
|
return True
|
|
|
|
def _on_training_end(self) -> None:
|
|
self.record_video()
|
|
|
|
def record_video(self) -> bool:
|
|
screens = []
|
|
def grab_screens(_locals, _globals) -> None:
|
|
"""
|
|
Renders the environment in its current state, recording the screen in the captured `screens` list
|
|
|
|
:param _locals: A dictionary containing all local variables of the callback's scope
|
|
:param _globals: A dictionary containing all global variables of the callback's scope
|
|
"""
|
|
screen = self._eval_env.render()
|
|
screens.append(screen)
|
|
self._evaluation_method(
|
|
self.model,
|
|
self._eval_env,
|
|
callback=grab_screens,
|
|
n_eval_episodes=self._n_eval_episodes,
|
|
deterministic=self._deterministic,
|
|
)
|
|
|
|
clip = ImageSequenceClip(list(screens), fps=3)
|
|
clip.write_gif(f"{self._log_dir}/{self.n_calls}.gif", fps=3)
|
|
return True
|
|
|
|
|
|
class InfoCallback(BaseCallback):
|
|
"""
|
|
Custom callback for plotting additional values in tensorboard.
|
|
"""
|
|
|
|
def __init__(self, verbose=0):
|
|
super().__init__(verbose)
|
|
self.sum_goal = 0
|
|
self.sum_lava = 0
|
|
self.sum_collisions = 0
|
|
self.sum_opened_door = 0
|
|
self.sum_picked_up = 0
|
|
|
|
def _on_step(self) -> bool:
|
|
infos = self.locals["infos"][0]
|
|
if infos["reached_goal"]:
|
|
self.sum_goal += 1
|
|
if infos["ran_into_lava"]:
|
|
self.sum_lava += 1
|
|
self.logger.record("info/sum_reached_goal", self.sum_goal)
|
|
self.logger.record("info/sum_ran_into_lava", self.sum_lava)
|
|
if "collision" in infos:
|
|
if infos["collision"]:
|
|
self.sum_collisions += 1
|
|
self.logger.record("info/sum_collision", self.sum_collisions)
|
|
if "opened_door" in infos:
|
|
if infos["opened_door"]:
|
|
self.sum_opened_door += 1
|
|
self.logger.record("info/sum_opened_door", self.sum_opened_door)
|
|
if "picked_up" in infos:
|
|
if infos["picked_up"]:
|
|
self.sum_picked_up += 1
|
|
self.logger.record("info/sum_picked_up", self.sum_picked_up)
|
|
return True
|