Browse Source

added dqn handling skeleton

refactoring
Thomas Knoll 1 year ago
parent
commit
e42becef88
  1. 82
      examples/shields/rl/11_minigridrl.py
  2. 4
      examples/shields/rl/MaskEnvironments.py
  3. 10
      examples/shields/rl/MaskModels.py
  4. 73
      examples/shields/rl/Wrapper.py

82
examples/shields/rl/11_minigridrl.py

@ -25,6 +25,7 @@ import numpy as np
import ray
from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig
from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator
from ray import tune, air
from ray.rllib.algorithms.callbacks import DefaultCallbacks
@ -37,7 +38,7 @@ from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.models.preprocessors import get_preprocessor
from MaskEnvironments import ParametricActionsMiniGridEnv
from MaskModels import TorchActionMaskModel
from Wrapper import OneHotWrapper, MiniGridEnvWrapper, ImgObsWrapper
from Wrapper import OneHotWrapper, MiniGridEnvWrapper
import matplotlib.pyplot as plt
@ -62,7 +63,7 @@ class MyCallbacks(DefaultCallbacks):
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0]
print(env.env.env.printGrid())
#print(env.env.env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
@ -83,6 +84,7 @@ def parse_arguments(argparse):
parser.add_argument("--grid_path", default="Grid.txt")
parser.add_argument("--prism_path", default="Grid.PRISM")
parser.add_argument("--no_masking", default=False)
parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"])
args = parser.parse_args()
@ -108,13 +110,13 @@ def env_creater_custom(config):
framestack=framestack
)
obs = env.observation_space.sample()
obs2, infos = env.reset(seed=None, options={})
# obs = env.observation_space.sample()
# obs2, infos = env.reset(seed=None, options={})
print(F"Obs is {obs} before reset. After reset: {obs2}")
# print(F"Obs is {obs} before reset. After reset: {obs2}")
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env)
print(F"Created Custom Minigrid Environment is {env}")
# print(F"Created Custom Minigrid Environment is {env}")
return env
@ -194,12 +196,16 @@ def create_environment(args):
return env
def main():
args = parse_arguments(argparse)
def register_custom_minigrid_env():
env_name = "mini-grid"
register_env(env_name, env_creater_custom)
ModelCatalog.register_custom_model(
"pa_model",
TorchActionMaskModel
)
def create_shield_dict(args):
env = create_environment(args)
ray.init(num_cpus=3)
# print(env.pprint_grid())
# print(env.printGrid(init=False))
@ -215,19 +221,21 @@ def main():
# choices = shield.get_choice(state_id)
# print(F"Allowed choices in state {state_id}, are {choices.choice_map} ")
env_name = "mini-grid"
register_env(env_name, env_creater_custom)
ModelCatalog.register_custom_model(
"pa_model",
TorchActionMaskModel
)
return shield_dict
def ppo(args):
ray.init(num_cpus=3)
register_custom_minigrid_env()
shield_dict = create_shield_dict(args)
config = (PPOConfig()
.rollouts(num_rollout_workers=1)
.resources(num_gpus=0)
.environment(env="mini-grid", env_config={"shield": shield_dict })
.framework("torch")
.experimental(_disable_preprocessor_api=False)
.callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False)
.training(_enable_learner_api=False ,model={
@ -256,9 +264,47 @@ def main():
checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}")
ray.shutdown()
def dqn(args):
config = DQNConfig()
register_custom_minigrid_env()
shield_dict = create_shield_dict(args)
replay_config = config.replay_buffer_config.update(
{
"capacity": 60000,
"prioritized_replay_alpha": 0.5,
"prioritized_replay_beta": 0.5,
"prioritized_replay_eps": 3e-6,
}
)
config = config.training(replay_buffer_config=replay_config, model={
"custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
})
config = config.resources(num_gpus=0)
config = config.rollouts(num_rollout_workers=1)
config = config.framework("torch")
config = config.callbacks(MyCallbacks)
config = config.rl_module(_enable_rl_module_api = False)
config = config.environment(env="mini-grid", env_config={"shield": shield_dict })
def main():
args = parse_arguments(argparse)
if args.algorithm == "ppo":
ppo(args)
elif args.algorithm == "dqn":
dqn(args)
ray.shutdown()
if __name__ == '__main__':
main()

4
examples/shields/rl/MaskEnvironments.py

@ -56,7 +56,7 @@ class ParametricActionsMiniGridEnv(gym.Env):
return obs, infos
return {
"action_mask": self.action_mask,
"avail_actions": self.action_assignments,
"avail_action": self.action_assignments,
"cart": obs,
}, infos
@ -83,7 +83,7 @@ class ParametricActionsMiniGridEnv(gym.Env):
return orig_obs, rew, done, truncated, info
obs = {
"action_mask": self.action_mask,
"avail_actions": self.action_assignments,
"action_mask": self.action_assignments,
"cart": orig_obs,
}
return obs, rew, done, truncated, info

10
examples/shields/rl/MaskModels.py

@ -25,10 +25,10 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
):
orig_space = getattr(obs_space, "original_space", obs_space)
custom_config = model_config['custom_model_config']
print(F"Original Space is: {orig_space}")
# print(F"Original Space is: {orig_space}")
#print(model_config)
print(F"Observation space in model: {obs_space}")
print(F"Provided action space in model {action_space}")
#print(F"Observation space in model: {obs_space}")
#print(F"Provided action space in model {action_space}")
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name, **kwargs
@ -65,7 +65,7 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
# print(F"Caluclated Logits {logits} with size {logits.size()} Count: {self.count}")
action_mask = input_dict["obs"]["avail_actions"]
action_mask = input_dict["obs"]["action_mask"]
#print(F"Action mask is {action_mask} with dimension {action_mask.size()}")
# If action masking is disabled, directly return unmasked logits
@ -77,7 +77,7 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
masked_logits = logits + inf_mask
print(F"Infinity mask {inf_mask}, Masked logits {masked_logits}")
# print(F"Infinity mask {inf_mask}, Masked logits {masked_logits}")
# # Return masked logits.
return masked_logits, state

73
examples/shields/rl/Wrapper.py

@ -26,12 +26,12 @@ class OneHotWrapper(gym.core.ObservationWrapper):
self.observation_space = Dict(
{
"data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),
"avail_actions": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
"action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
}
)
print(F"Set obersvation space to {self.observation_space}")
# print(F"Set obersvation space to {self.observation_space}")
def observation(self, obs):
@ -77,7 +77,7 @@ class OneHotWrapper(gym.core.ObservationWrapper):
self.frame_buffer.append(single_frame)
#obs["one-hot"] = np.concatenate(self.frame_buffer)
tmp = {"data": np.concatenate(self.frame_buffer), "avail_actions": obs["avail_actions"] }
tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] }
return tmp#np.concatenate(self.frame_buffer)
@ -88,7 +88,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
self.observation_space = Dict(
{
"data": env.observation_space.spaces["image"],
"avail_actions" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
"action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
}
)
@ -98,7 +98,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
def create_action_mask(self):
coordinates = self.env.agent_pos
view_direction = self.env.agent_dir
print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
cur_pos_str = f"[!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
allowed_actions = []
@ -109,73 +109,40 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
# else set everything to one
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
# if cur_pos_str in self.shield:
# allowed_actions = self.shield[cur_pos_str]
# for allowed_action in allowed_actions:
# index = allowed_action[0]
# mask[index] = 1.0
# else:
# for index in len(mask):
# mask[index] = 1.0
if cur_pos_str in self.shield:
allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions:
index = allowed_action[0]
mask[index] = 1.0
else:
for index, x in enumerate(mask):
mask[index] = 1.0
print(F"Allowed actions for position {coordinates} and view {view_direction} are {allowed_actions}")
mask[0] = 1.0
#print(F"Action Mask for position {coordinates} and view {view_direction} is {mask}")
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset()
mask = self.create_action_mask()
return {
"data": obs["image"],
"avail_actions": np.array([0.0] * self.max_available_actions, dtype=np.int8)
"action_mask": mask
}, infos
def step(self, action):
print(F"Performed action in step: {action}")
# print(F"Performed action in step: {action}")
orig_obs, rew, done, truncated, info = self.env.step(action)
actions = self.create_action_mask()
mask = self.create_action_mask()
#print(F"Original observation is {orig_obs}")
obs = {
"data": orig_obs["image"],
"avail_actions": actions,
"action_mask": mask,
}
#print(F"Info is {info}")
return obs, rew, done, truncated, info
class ImgObsWrapper(gym.core.ObservationWrapper):
"""
Use the image as the only observation output, no language/mission.
Example:
>>> import gymnasium as gym
>>> from minigrid.wrappers import ImgObsWrapper
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
>>> obs, _ = env.reset()
>>> obs.keys()
dict_keys(['image', 'direction', 'mission'])
>>> env = ImgObsWrapper(env)
>>> obs, _ = env.reset()
>>> obs.shape
(7, 7, 3)
"""
def __init__(self, env):
"""A wrapper that makes image the only observation.
Args:
env: The environment to apply the wrapper
"""
super().__init__(env)
self.observation_space = env.observation_space.spaces["image"]
print(F"Set obersvation space to {self.observation_space}")
def observation(self, obs):
#print(F"obs in img obs wrapper {obs}")
tmp = {"data": obs["image"], "Test": obs["Test"]}
return tmp
Loading…
Cancel
Save