Browse Source

added tune example

refactored and evaluation logging
refactoring
Thomas Knoll 2 years ago
parent
commit
138d917fd6
  1. 48
      examples/shields/rl/11_minigridrl.py
  2. 9
      examples/shields/rl/13_minigridsb.py
  3. 42
      examples/shields/rl/14_train_eval.py
  4. 118
      examples/shields/rl/15_train_eval_tune.py
  5. 19
      examples/shields/rl/ShieldHandlers.py
  6. 67
      examples/shields/rl/Wrappers.py
  7. 61
      examples/shields/rl/callbacks.py
  8. 6
      examples/shields/rl/helpers.py

48
examples/shields/rl/11_minigridrl.py

@ -1,10 +1,4 @@
from typing import Dict
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID
import gymnasium as gym import gymnasium as gym
@ -15,7 +9,6 @@ import minigrid
from ray.tune import register_env from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig from ray.rllib.algorithms.dqn.dqn import DQNConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog from ray.rllib.models import ModelCatalog
@ -23,42 +16,13 @@ from ray.rllib.models import ModelCatalog
from TorchActionMaskModel import TorchActionMaskModel from TorchActionMaskModel import TorchActionMaskModel
from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig from helpers import parse_arguments, create_log_dir, ShieldingConfig
from ShieldHandlers import MiniGridShieldHandler
from ShieldHandlers import MiniGridShieldHandler, create_shield_query
from callbacks import MyCallbacks
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from ray.tune.logger import TBXLogger from ray.tune.logger import TBXLogger
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
episode.user_data["count"] = 0
# print("On episode start print")
# print(env.printGrid())
# print(worker)
# print(env.action_space.n)
# print(env.actions)
# print(env.mission)
# print(env.observation_space)
# img = env.get_frame()
# plt.imshow(img)
# plt.show()
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0]
# print(env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
#print("On episode end print")
#print(env.printGrid())
def shielding_env_creater(config): def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4) framestack = config.get("framestack", 4)
@ -69,7 +33,7 @@ def shielding_env_creater(config):
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name) env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query)
# env = minigrid.wrappers.ImgObsWrapper(env) # env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env) # env = ImgObsWrapper(env)
env = OneHotShieldingWrapper(env, env = OneHotShieldingWrapper(env,
@ -98,7 +62,7 @@ def ppo(args):
config = (PPOConfig() config = (PPOConfig()
.rollouts(num_rollout_workers=args.workers) .rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0) .resources(num_gpus=0)
.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Training})
.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
.framework("torch") .framework("torch")
.callbacks(MyCallbacks) .callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False) .rl_module(_enable_rl_module_api = False)
@ -132,7 +96,7 @@ def dqn(args):
config = config.rollouts(num_rollout_workers=args.workers) config = config.rollouts(num_rollout_workers=args.workers)
config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args }) config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
config = config.framework("torch") config = config.framework("torch")
#config = config.callbacks(MyCallbacks)
config = config.callbacks(MyCallbacks)
config = config.rl_module(_enable_rl_module_api = False) config = config.rl_module(_enable_rl_module_api = False)
config = config.debugging(logger_config={ config = config.debugging(logger_config={
"type": TBXLogger, "type": TBXLogger,

9
examples/shields/rl/13_minigridsb.py

@ -11,8 +11,8 @@ from minigrid.core.actions import Actions
import numpy as np import numpy as np
import time import time
from helpers import parse_arguments, extract_keys, get_action_index_mapping, create_log_dir
from ShieldHandlers import MiniGridShieldHandler
from helpers import parse_arguments, create_log_dir, ShieldingConfig
from ShieldHandlers import MiniGridShieldHandler, create_shield_query
from Wrappers import MiniGridSbShieldingWrapper from Wrappers import MiniGridSbShieldingWrapper
class CustomCallback(BaseCallback): class CustomCallback(BaseCallback):
@ -27,6 +27,7 @@ class CustomCallback(BaseCallback):
def mask_fn(env: gym.Env): def mask_fn(env: gym.Env):
return env.create_action_mask() return env.create_action_mask()
@ -42,10 +43,10 @@ def main():
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(args.env, render_mode="rgb_array") env = gym.make(args.env, render_mode="rgb_array")
env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, no_masking=args.no_masking)
env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full)
env = ActionMasker(env, mask_fn) env = ActionMasker(env, mask_fn)
callback = CustomCallback(1, env) callback = CustomCallback(1, env)
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args))
model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args))
iterations = args.iterations iterations = args.iterations

42
examples/shields/rl/14_train_eval.py

@ -9,18 +9,20 @@ from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig from ray.rllib.algorithms.dqn.dqn import DQNConfig
# from ray.rllib.algorithms.callbacks import DefaultCallbacks # from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print
from ray.tune.logger import pretty_print, TBXLogger, TBXLoggerCallback, DEFAULT_LOGGERS, UnifiedLogger, CSVLogger
from ray.rllib.models import ModelCatalog from ray.rllib.models import ModelCatalog
from TorchActionMaskModel import TorchActionMaskModel from TorchActionMaskModel import TorchActionMaskModel
from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig from helpers import parse_arguments, create_log_dir, ShieldingConfig
from ShieldHandlers import MiniGridShieldHandler
from ShieldHandlers import MiniGridShieldHandler, create_shield_query
from callbacks import MyCallbacks
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from ray.tune.logger import TBXLogger
@ -39,7 +41,7 @@ def shielding_env_creater(config):
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name) env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, mask_actions=shielding)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding)
env = OneHotShieldingWrapper(env, env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0, config.vector_index if hasattr(config, "vector_index") else 0,
@ -67,16 +69,18 @@ def ppo(args):
.rollouts(num_rollout_workers=args.workers) .rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0) .resources(num_gpus=0)
.environment( env="mini-grid-shielding", .environment( env="mini-grid-shielding",
env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Training})
env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
.framework("torch") .framework("torch")
.evaluation(evaluation_config={ "evaluation_interval": 1,
"evaluation_parallel_to_training": False,
.callbacks(MyCallbacks)
.evaluation(evaluation_config={
"evaluation_interval": 1,
"evaluation_duration": 10,
"evaluation_num_workers":1,
"env": "mini-grid-shielding", "env": "mini-grid-shielding",
"env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Evaluation}})
#.callbacks(MyCallbacks)
"env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})
.rl_module(_enable_rl_module_api = False) .rl_module(_enable_rl_module_api = False)
.debugging(logger_config={ .debugging(logger_config={
"type": TBXLogger,
"type": UnifiedLogger,
"logdir": create_log_dir(args) "logdir": create_log_dir(args)
}) })
.training(_enable_learner_api=False ,model={ .training(_enable_learner_api=False ,model={
@ -90,16 +94,34 @@ def ppo(args):
iterations = args.iterations iterations = args.iterations
for i in range(iterations): for i in range(iterations):
algo.train() algo.train()
if i % 5 == 0: if i % 5 == 0:
algo.save() algo.save()
writer = SummaryWriter(log_dir=F"{create_log_dir(args)}-eval")
csv_logger = CSVLogger()
for i in range(iterations): for i in range(iterations):
eval_result = algo.evaluate() eval_result = algo.evaluate()
print(pretty_print(eval_result)) print(pretty_print(eval_result))
print(eval_result)
# logger.on_result(eval_result)
evaluation = eval_result['evaluation']
epsiode_reward_mean = evaluation['episode_reward_mean']
episode_len_mean = evaluation['episode_len_mean']
print(epsiode_reward_mean)
writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i)
writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i)
writer.close()
def main(): def main():

118
examples/shields/rl/15_train_eval_tune.py

@ -0,0 +1,118 @@
import gymnasium as gym
import minigrid
# import numpy as np
# import ray
from ray.tune import register_env
from ray import tune, air
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig
# from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print, TBXLogger, TBXLoggerCallback, DEFAULT_LOGGERS, UnifiedLogger
from ray.rllib.models import ModelCatalog
from TorchActionMaskModel import TorchActionMaskModel
from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig
from ShieldHandlers import MiniGridShieldHandler, create_shield_query
from callbacks import MyCallbacks
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
shielding = config.get("shielding", False)
# if shielding:
# assert(False)
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
def register_minigrid_shielding_env(args):
env_name = "mini-grid-shielding"
register_env(env_name, shielding_env_creater)
ModelCatalog.register_custom_model(
"shielding_model",
TorchActionMaskModel
)
def ppo(args):
register_minigrid_shielding_env(args)
config = (PPOConfig()
.rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0)
.environment( env="mini-grid-shielding",
env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
.framework("torch")
.callbacks(MyCallbacks)
.evaluation(evaluation_config={
"evaluation_interval": 1,
"evaluation_duration": 10,
"evaluation_num_workers":1,
"env": "mini-grid-shielding",
"env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})
.rl_module(_enable_rl_module_api = False)
.debugging(logger_config={
"type": UnifiedLogger,
"logdir": create_log_dir(args)
})
.training(_enable_learner_api=False ,model={
"custom_model": "shielding_model"
}))
tuner = tune.Tuner("PPO",
run_config=air.RunConfig(
stop = {"episode_reward_mean": 50},
checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True),
storage_path=F"{create_log_dir(args)}-tuner"
),
param_space=config,)
tuner.fit()
iterations = args.iterations
print(config.to_dict())
tune.run("PPO", config=config)
# print(epsiode_reward_mean)
# writer.add_scalar("evaluation/episode_reward", epsiode_reward_mean, i)
def main():
import argparse
args = parse_arguments(argparse)
ppo(args)
if __name__ == '__main__':
main()

19
examples/shields/rl/ShieldHandlers.py

@ -15,7 +15,7 @@ import os
class ShieldHandler(ABC): class ShieldHandler(ABC):
def __init__(self) -> None: def __init__(self) -> None:
pass pass
def create_shield(self, **kwargs):
def create_shield(self, **kwargs) -> dict:
pass pass
class MiniGridShieldHandler(ShieldHandler): class MiniGridShieldHandler(ShieldHandler):
@ -32,7 +32,9 @@ class MiniGridShieldHandler(ShieldHandler):
def __create_prism(self): def __create_prism(self):
os.system(F"{self.grid_to_prism_path} -v 'agent' -i {self.grid_file} -o {self.prism_path}")
result = os.system(F"{self.grid_to_prism_path} -v 'agent' -i {self.grid_file} -o {self.prism_path}")
assert result == 0, "Prism file could not be generated"
f = open(self.prism_path, "a") f = open(self.prism_path, "a")
f.write("label \"AgentIsInLava\" = AgentIsInLava;") f.write("label \"AgentIsInLava\" = AgentIsInLava;")
@ -79,3 +81,16 @@ class MiniGridShieldHandler(ShieldHandler):
return self.__create_shield_dict() return self.__create_shield_dict()
def create_shield_query(env):
coordinates = env.env.agent_pos
view_direction = env.env.agent_dir
key_text = ""
# only support one key for now
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
return cur_pos_str

67
examples/shields/rl/Wrappers.py

@ -82,7 +82,12 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper):
class MiniGridShieldingWrapper(gym.core.Wrapper): class MiniGridShieldingWrapper(gym.core.Wrapper):
def __init__(self, env, shield_creator : ShieldHandler, create_shield_at_reset=True, mask_actions=True):
def __init__(self,
env,
shield_creator : ShieldHandler,
shield_query_creator,
create_shield_at_reset=True,
mask_actions=True):
super(MiniGridShieldingWrapper, self).__init__(env) super(MiniGridShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n self.max_available_actions = env.action_space.n
self.observation_space = Dict( self.observation_space = Dict(
@ -95,32 +100,18 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
self.create_shield_at_reset = create_shield_at_reset self.create_shield_at_reset = create_shield_at_reset
self.shield = shield_creator.create_shield(env=self.env) self.shield = shield_creator.create_shield(env=self.env)
self.mask_actions = mask_actions self.mask_actions = mask_actions
self.shield_query_creator = shield_query_creator
def create_action_mask(self): def create_action_mask(self):
if not self.mask_actions: if not self.mask_actions:
return np.array([1.0] * self.max_available_actions, dtype=np.int8) return np.array([1.0] * self.max_available_actions, dtype=np.int8)
coordinates = self.env.agent_pos
view_direction = self.env.agent_dir
key_text = ""
# only support one key for now
if self.keys:
key_text = F"!Agent_has_{self.keys[0]}_key\t& "
if self.env.carrying and self.env.carrying.type == "key":
key_text = F"Agent_has_{self.env.carrying.color}_key\t& "
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
allowed_actions = []
cur_pos_str = self.shield_query_creator(self.env)
# Create the mask # Create the mask
# If shield restricts action mask only valid with 1.0 # If shield restricts action mask only valid with 1.0
# else set all actions as valid # else set all actions as valid
allowed_actions = []
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]: if cur_pos_str in self.shield and self.shield[cur_pos_str]:
@ -175,38 +166,32 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
class MiniGridSbShieldingWrapper(gym.core.Wrapper): class MiniGridSbShieldingWrapper(gym.core.Wrapper):
def __init__(self, env, shield_creator : ShieldHandler, no_masking=False):
def __init__(self,
env,
shield_creator : ShieldHandler,
shield_query_creator,
create_shield_at_reset = True,
mask_actions=True,
):
super(MiniGridSbShieldingWrapper, self).__init__(env) super(MiniGridSbShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n self.max_available_actions = env.action_space.n
self.observation_space = env.observation_space.spaces["image"] self.observation_space = env.observation_space.spaces["image"]
self.shield_creator = shield_creator self.shield_creator = shield_creator
self.no_masking = no_masking
self.mask_actions = mask_actions
self.shield_query_creator = shield_query_creator
def create_action_mask(self): def create_action_mask(self):
if self.no_masking:
if not self.mask_actions:
return np.array([1.0] * self.max_available_actions, dtype=np.int8) return np.array([1.0] * self.max_available_actions, dtype=np.int8)
coordinates = self.env.agent_pos
view_direction = self.env.agent_dir
cur_pos_str = self.shield_query_creator(self.env)
key_text = ""
# only support one key for now
if self.keys:
key_text = F"!Agent_has_{self.keys[0]}_key\t& "
if self.env.carrying and self.env.carrying.type == "key":
key_text = F"Agent_has_{self.env.carrying.color}_key\t& "
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
allowed_actions = [] allowed_actions = []
# Create the mask # Create the mask
# If shield restricts action mask only valid with 1.0
# else set all actions as valid
# If shield restricts actions, mask only valid actions with 1.0
# else set all actions valid
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]: if cur_pos_str in self.shield and self.shield[cur_pos_str]:
@ -215,25 +200,21 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper):
index = get_action_index_mapping(allowed_action[1]) index = get_action_index_mapping(allowed_action[1])
if index is None: if index is None:
assert(False) assert(False)
mask[index] = 1.0 mask[index] = 1.0
else: else:
# print(F"Not in shield {cur_pos_str}")
for index, x in enumerate(mask): for index, x in enumerate(mask):
mask[index] = 1.0 mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
# if front_tile is not None and front_tile.type == "key":
# mask[Actions.pickup] = 1.0
# if self.env.carrying:
# mask[Actions.drop] = 1.0
if front_tile and front_tile.type == "door": if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0 mask[Actions.toggle] = 1.0
return mask return mask
def reset(self, *, seed=None, options=None): def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options) obs, infos = self.env.reset(seed=seed, options=options)

61
examples/shields/rl/callbacks.py

@ -0,0 +1,61 @@
from typing import Dict
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
episode.user_data["count"] = 0
episode.user_data["ran_into_lava"] = []
episode.user_data["goals_reached"] = []
episode.hist_data["ran_into_lava"] = []
episode.hist_data["goals_reached"] = []
# print("On episode start print")
# print(env.printGrid())
# print(worker)
# print(env.action_space.n)
# print(env.actions)
# print(env.mission)
# print(env.observation_space)
# img = env.get_frame()
# plt.imshow(img)
# plt.show()
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0]
# print(env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1])
episode.user_data["goals_reached"].append(agent_tile is not None and agent_tile.type == "goal")
episode.user_data["ran_into_lava"].append(agent_tile is not None and agent_tile.type == "lava")
episode.custom_metrics["reached_goal"] = agent_tile is not None and agent_tile.type == "goal"
episode.custom_metrics["ran_into_lava"] = agent_tile is not None and agent_tile.type == "lava"
#print("On episode end print")
#print(env.printGrid())
episode.hist_data["goals_reached"] = episode.user_data["goals_reached"]
episode.hist_data["ran_into_lava"] = episode.user_data["ran_into_lava"]
def on_evaluate_start(self, *, algorithm: Algorithm, **kwargs) -> None:
print("Evaluate Start")
def on_evaluate_end(self, *, algorithm: Algorithm, evaluation_metrics: dict, **kwargs) -> None:
print("Evaluate End")

6
examples/shields/rl/helpers.py

@ -20,7 +20,7 @@ class ShieldingConfig(Enum):
Training = 'training' Training = 'training'
Evaluation = 'evaluation' Evaluation = 'evaluation'
Disabled = 'none' Disabled = 'none'
Enabled = 'full'
Full = 'full'
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value
@ -39,7 +39,7 @@ def extract_keys(env):
return keys return keys
def create_log_dir(args): def create_log_dir(args):
return F"{args.log_dir}{datetime.now()}-{args.algorithm}-shielding:{args.shielding}-env:{args.env}"
return F"{args.log_dir}{datetime.now()}-{args.algorithm}-shielding:{args.shielding}-env:{args.env}-iterations:{args.iterations}"
def get_action_index_mapping(actions): def get_action_index_mapping(actions):
@ -93,7 +93,7 @@ def parse_arguments(argparse):
parser.add_argument("--iterations", type=int, default=30 ) parser.add_argument("--iterations", type=int, default=30 )
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
parser.add_argument("--workers", type=int, default=1) parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Enabled)
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
args = parser.parse_args() args = parser.parse_args()

Loading…
Cancel
Save