Browse Source

use shield in evaluation when full shielding

refactoring
sp 10 months ago
parent
commit
d7e7a2411b
  1. 23
      examples/shields/rl/13_minigridsb.py

23
examples/shields/rl/13_minigridsb.py

@ -1,5 +1,5 @@
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat
@ -10,7 +10,7 @@ from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper
import time
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback
from stable_baselines3.common.callbacks import EvalCallback
@ -35,8 +35,11 @@ def main():
log_dir = create_log_dir(args)
new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)])
if args.shielding == ShieldingConfig.Full or args.shielding == ShieldingConfig.Training or args.shielding == ShieldingConfig.Evaluation:
if shield_needed(args.shielding):
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=args.nocleanup)
env = gym.make(args.env, render_mode="rgb_array")
env = RGBImgObsWrapper(env)
env = ImgObsWrapper(env)
@ -63,11 +66,19 @@ def main():
assert(False) # TODO Do something proper
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto")
model.set_logger(new_logger)
steps = args.steps
eval_freq=max(500, int(args.steps/30))
n_eval_episodes=5
if shielded_evaluation(args.shielding):
evalCallback = MaskableEvalCallback(eval_env, best_model_save_path=log_dir,
log_path=log_dir, eval_freq=eval_freq,
deterministic=True, render=False, n_eval_episodes=n_eval_episodes)
else:
evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir,
log_path=log_dir, eval_freq=max(500, int(args.steps/30)),
deterministic=True, render=False, n_eval_episodes=5)
steps = args.steps
log_path=log_dir, eval_freq=eval_freq,
deterministic=True, render=False, n_eval_episodes=n_eval_episodes)
model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback(), evalCallback])

Loading…
Cancel
Save