Browse Source

added num_gpus as arg, first try sh info callback

refactoring
sp 12 months ago
parent
commit
618ab6e73c
  1. 24
      examples/shields/rl/15_train_eval_tune.py
  2. 9
      examples/shields/rl/callbacks.py
  3. 1
      examples/shields/rl/helpers.py

24
examples/shields/rl/15_train_eval_tune.py

@ -68,34 +68,34 @@ def trial_name_creator(trial : Trial):
def ppo(args):
register_minigrid_shielding_env(args)
logdir = args.log_dir
config = (PPOConfig()
.rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0)
.resources(num_gpus=args.num_gpus)
.environment( env="mini-grid-shielding",
env_config={"name": args.env,
"args": args,
"args": args,
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training,
},)
.framework("torch")
.callbacks(MyCallbacks)
.evaluation(evaluation_config={
.callbacks(MyCallbacks, ShieldInfoCallback(logdir, [1,12])
.evaluation(evaluation_config={
"evaluation_interval": 1,
"evaluation_duration": 10,
"evaluation_num_workers":1,
"env": "mini-grid-shielding",
"env_config": {"name": args.env,
"args": args,
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})
"env": "mini-grid-shielding",
"env_config": {"name": args.env,
"args": args,
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})
.rl_module(_enable_rl_module_api = False)
.debugging(logger_config={
"type": UnifiedLogger,
"type": UnifiedLogger,
"logdir": logdir
})
.training(_enable_learner_api=False ,model={
"custom_model": "shielding_model"
"custom_model": "shielding_model"
}))
tuner = tune.Tuner("PPO",
tune_config=tune.TuneConfig(
metric="episode_reward_mean",

9
examples/shields/rl/callbacks.py

@ -15,7 +15,16 @@ from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callback
import matplotlib.pyplot as plt
import tensorflow as tf
class ShieldInfoCallback(DefaultCallbacks):
def on_episode_start(self, log_dir, data) -> None:
file_writer = tf.summary.create_file_writer(log_dir)
with file_writer.as_default():
tf.summary.text("first_text", str(data), step=0)
def on_episode_step(self) -> None:
pass
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None:

1
examples/shields/rl/helpers.py

@ -132,6 +132,7 @@ def parse_arguments(argparse):
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# parser.add_argument("--formula", default="<<Agent>> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]")
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--num_gpus", type=float, default=0)
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
parser.add_argument("--steps", default=20_000, type=int)
parser.add_argument("--expname", default="exp")

Loading…
Cancel
Save