diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index 9917e44..315f7b1 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -68,34 +68,34 @@ def trial_name_creator(trial : Trial): def ppo(args): register_minigrid_shielding_env(args) logdir = args.log_dir - + config = (PPOConfig() .rollouts(num_rollout_workers=args.workers) - .resources(num_gpus=0) + .resources(num_gpus=args.num_gpus) .environment( env="mini-grid-shielding", env_config={"name": args.env, - "args": args, + "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, },) .framework("torch") - .callbacks(MyCallbacks) - .evaluation(evaluation_config={ + .callbacks(MyCallbacks, ShieldInfoCallback(logdir, [1,12]) + .evaluation(evaluation_config={ "evaluation_interval": 1, "evaluation_duration": 10, "evaluation_num_workers":1, - "env": "mini-grid-shielding", - "env_config": {"name": args.env, - "args": args, - "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}}) + "env": "mini-grid-shielding", + "env_config": {"name": args.env, + "args": args, + "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}}) .rl_module(_enable_rl_module_api = False) .debugging(logger_config={ - "type": UnifiedLogger, + "type": UnifiedLogger, "logdir": logdir }) .training(_enable_learner_api=False ,model={ - "custom_model": "shielding_model" + "custom_model": "shielding_model" })) - + tuner = tune.Tuner("PPO", tune_config=tune.TuneConfig( metric="episode_reward_mean", diff --git a/examples/shields/rl/callbacks.py b/examples/shields/rl/callbacks.py index 240fec7..1be3ecc 100644 --- a/examples/shields/rl/callbacks.py +++ b/examples/shields/rl/callbacks.py @@ -15,7 +15,16 @@ from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callback import matplotlib.pyplot as plt +import tensorflow as tf +class ShieldInfoCallback(DefaultCallbacks): + def on_episode_start(self, log_dir, data) -> None: + file_writer = tf.summary.create_file_writer(log_dir) + with file_writer.as_default(): + tf.summary.text("first_text", str(data), step=0) + + def on_episode_step(self) -> None: + pass class MyCallbacks(DefaultCallbacks): def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index af59256..f06e463 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -132,6 +132,7 @@ def parse_arguments(argparse): parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" # parser.add_argument("--formula", default="<> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]") parser.add_argument("--workers", type=int, default=1) + parser.add_argument("--num_gpus", type=float, default=0) parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) parser.add_argument("--steps", default=20_000, type=int) parser.add_argument("--expname", default="exp")