Browse Source

moved text writing to standard callback

refactoring
sp 12 months ago
parent
commit
de5f35dd5f
  1. 4
      examples/shields/rl/15_train_eval_tune.py
  2. 4
      examples/shields/rl/callbacks.py

4
examples/shields/rl/15_train_eval_tune.py

@ -19,7 +19,7 @@ from helpers import parse_arguments, create_log_dir, ShieldingConfig, test_name
from shieldhandlers import MiniGridShieldHandler, create_shield_query from shieldhandlers import MiniGridShieldHandler, create_shield_query
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from callbacks import MyCallbacks, ShieldInfoCallback
from callbacks import MyCallbacks
def shielding_env_creater(config): def shielding_env_creater(config):
@ -79,7 +79,7 @@ def ppo(args):
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training,
},) },)
.framework("torch") .framework("torch")
.callbacks(make_multi_callbacks([MyCallbacks, ShieldInfoCallback]))
.callbacks(MyCallbacks)
.evaluation(evaluation_config={ .evaluation(evaluation_config={
"evaluation_interval": 1, "evaluation_interval": 1,
"evaluation_duration": 10, "evaluation_duration": 10,

4
examples/shields/rl/callbacks.py

@ -28,6 +28,10 @@ class ShieldInfoCallback(DefaultCallbacks):
class MyCallbacks(DefaultCallbacks): class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None:
file_writer = tf.summary.create_file_writer(log_dir)
with file_writer.as_default():
tf.summary.text("first_text", "testing", step=0)
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}") # print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0] env = base_env.get_sub_environments()[0]
episode.user_data["count"] = 0 episode.user_data["count"] = 0

Loading…
Cancel
Save