Browse Source

added num_gpus as arg, first try sh info callback

refactoring
sp 11 months ago
parent
commit
618ab6e73c
  1. 4
      examples/shields/rl/15_train_eval_tune.py
  2. 9
      examples/shields/rl/callbacks.py
  3. 1
      examples/shields/rl/helpers.py

4
examples/shields/rl/15_train_eval_tune.py

@ -71,14 +71,14 @@ def ppo(args):
config = (PPOConfig() config = (PPOConfig()
.rollouts(num_rollout_workers=args.workers) .rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0)
.resources(num_gpus=args.num_gpus)
.environment( env="mini-grid-shielding", .environment( env="mini-grid-shielding",
env_config={"name": args.env, env_config={"name": args.env,
"args": args, "args": args,
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training,
},) },)
.framework("torch") .framework("torch")
.callbacks(MyCallbacks)
.callbacks(MyCallbacks, ShieldInfoCallback(logdir, [1,12])
.evaluation(evaluation_config={ .evaluation(evaluation_config={
"evaluation_interval": 1, "evaluation_interval": 1,
"evaluation_duration": 10, "evaluation_duration": 10,

9
examples/shields/rl/callbacks.py

@ -15,7 +15,16 @@ from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callback
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import tensorflow as tf
class ShieldInfoCallback(DefaultCallbacks):
def on_episode_start(self, log_dir, data) -> None:
file_writer = tf.summary.create_file_writer(log_dir)
with file_writer.as_default():
tf.summary.text("first_text", str(data), step=0)
def on_episode_step(self) -> None:
pass
class MyCallbacks(DefaultCallbacks): class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None:

1
examples/shields/rl/helpers.py

@ -132,6 +132,7 @@ def parse_arguments(argparse):
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# parser.add_argument("--formula", default="<<Agent>> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]") # parser.add_argument("--formula", default="<<Agent>> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]")
parser.add_argument("--workers", type=int, default=1) parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--num_gpus", type=float, default=0)
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
parser.add_argument("--steps", default=20_000, type=int) parser.add_argument("--steps", default=20_000, type=int)
parser.add_argument("--expname", default="exp") parser.add_argument("--expname", default="exp")

Loading…
Cancel
Save