Browse Source

sh info callback, removed ws

refactoring
sp 11 months ago
parent
commit
07841090c6
  1. 2
      examples/shields/rl/15_train_eval_tune.py
  2. 5
      examples/shields/rl/callbacks.py

2
examples/shields/rl/15_train_eval_tune.py

@ -78,7 +78,7 @@ def ppo(args):
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training,
},)
.framework("torch")
.callbacks(MyCallbacks, ShieldInfoCallback(logdir, [1,12]))
.callbacks(MyCallbacks, ShieldInfoCallback)
.evaluation(evaluation_config={
"evaluation_interval": 1,
"evaluation_duration": 10,

5
examples/shields/rl/callbacks.py

@ -18,10 +18,10 @@ import matplotlib.pyplot as plt
import tensorflow as tf
class ShieldInfoCallback(DefaultCallbacks):
def on_episode_start(self, log_dir, data) -> None:
def on_episode_start(self) -> None:
file_writer = tf.summary.create_file_writer(log_dir)
with file_writer.as_default():
tf.summary.text("first_text", str(data), step=0)
tf.summary.text("first_text", "testing", step=0)
def on_episode_step(self) -> None:
pass
@ -92,4 +92,3 @@ class MyCallbacks(DefaultCallbacks):
def on_evaluate_end(self, *, algorithm: Algorithm, evaluation_metrics: dict, **kwargs) -> None:
print("Evaluate End")
Loading…
Cancel
Save