|
|
@ -15,7 +15,16 @@ from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callback |
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
import tensorflow as tf |
|
|
|
|
|
|
|
class ShieldInfoCallback(DefaultCallbacks): |
|
|
|
def on_episode_start(self, log_dir, data) -> None: |
|
|
|
file_writer = tf.summary.create_file_writer(log_dir) |
|
|
|
with file_writer.as_default(): |
|
|
|
tf.summary.text("first_text", str(data), step=0) |
|
|
|
|
|
|
|
def on_episode_step(self) -> None: |
|
|
|
pass |
|
|
|
|
|
|
|
class MyCallbacks(DefaultCallbacks): |
|
|
|
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: |
|
|
|