Browse Source

added dqn handling skeleton

refactoring
Thomas Knoll 1 year ago
parent
commit
e42becef88
  1. 82
      examples/shields/rl/11_minigridrl.py
  2. 4
      examples/shields/rl/MaskEnvironments.py
  3. 10
      examples/shields/rl/MaskModels.py
  4. 73
      examples/shields/rl/Wrapper.py

82
examples/shields/rl/11_minigridrl.py

@ -25,6 +25,7 @@ import numpy as np
import ray import ray
from ray.tune import register_env from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig
from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator
from ray import tune, air from ray import tune, air
from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.algorithms.callbacks import DefaultCallbacks
@ -37,7 +38,7 @@ from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.preprocessors import get_preprocessor
from MaskEnvironments import ParametricActionsMiniGridEnv from MaskEnvironments import ParametricActionsMiniGridEnv
from MaskModels import TorchActionMaskModel from MaskModels import TorchActionMaskModel
from Wrapper import OneHotWrapper, MiniGridEnvWrapper, ImgObsWrapper
from Wrapper import OneHotWrapper, MiniGridEnvWrapper
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -62,7 +63,7 @@ class MyCallbacks(DefaultCallbacks):
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1 episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0] env = base_env.get_sub_environments()[0]
print(env.env.env.printGrid())
#print(env.env.env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}") # print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
@ -83,6 +84,7 @@ def parse_arguments(argparse):
parser.add_argument("--grid_path", default="Grid.txt") parser.add_argument("--grid_path", default="Grid.txt")
parser.add_argument("--prism_path", default="Grid.PRISM") parser.add_argument("--prism_path", default="Grid.PRISM")
parser.add_argument("--no_masking", default=False) parser.add_argument("--no_masking", default=False)
parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"])
args = parser.parse_args() args = parser.parse_args()
@ -108,13 +110,13 @@ def env_creater_custom(config):
framestack=framestack framestack=framestack
) )
obs = env.observation_space.sample()
obs2, infos = env.reset(seed=None, options={})
# obs = env.observation_space.sample()
# obs2, infos = env.reset(seed=None, options={})
print(F"Obs is {obs} before reset. After reset: {obs2}")
# print(F"Obs is {obs} before reset. After reset: {obs2}")
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env) # env = minigrid.wrappers.RGBImgPartialObsWrapper(env)
print(F"Created Custom Minigrid Environment is {env}")
# print(F"Created Custom Minigrid Environment is {env}")
return env return env
@ -194,12 +196,16 @@ def create_environment(args):
return env return env
def main():
args = parse_arguments(argparse)
def register_custom_minigrid_env():
env_name = "mini-grid"
register_env(env_name, env_creater_custom)
ModelCatalog.register_custom_model(
"pa_model",
TorchActionMaskModel
)
def create_shield_dict(args):
env = create_environment(args) env = create_environment(args)
ray.init(num_cpus=3)
# print(env.pprint_grid()) # print(env.pprint_grid())
# print(env.printGrid(init=False)) # print(env.printGrid(init=False))
@ -215,19 +221,21 @@ def main():
# choices = shield.get_choice(state_id) # choices = shield.get_choice(state_id)
# print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") # print(F"Allowed choices in state {state_id}, are {choices.choice_map} ")
env_name = "mini-grid"
register_env(env_name, env_creater_custom)
ModelCatalog.register_custom_model(
"pa_model",
TorchActionMaskModel
)
return shield_dict
def ppo(args):
ray.init(num_cpus=3)
register_custom_minigrid_env()
shield_dict = create_shield_dict(args)
config = (PPOConfig() config = (PPOConfig()
.rollouts(num_rollout_workers=1) .rollouts(num_rollout_workers=1)
.resources(num_gpus=0) .resources(num_gpus=0)
.environment(env="mini-grid", env_config={"shield": shield_dict }) .environment(env="mini-grid", env_config={"shield": shield_dict })
.framework("torch") .framework("torch")
.experimental(_disable_preprocessor_api=False)
.callbacks(MyCallbacks) .callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False) .rl_module(_enable_rl_module_api = False)
.training(_enable_learner_api=False ,model={ .training(_enable_learner_api=False ,model={
@ -256,9 +264,47 @@ def main():
checkpoint_dir = algo.save() checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}") print(f"Checkpoint saved in directory {checkpoint_dir}")
ray.shutdown()
def dqn(args):
config = DQNConfig()
register_custom_minigrid_env()
shield_dict = create_shield_dict(args)
replay_config = config.replay_buffer_config.update(
{
"capacity": 60000,
"prioritized_replay_alpha": 0.5,
"prioritized_replay_beta": 0.5,
"prioritized_replay_eps": 3e-6,
}
)
config = config.training(replay_buffer_config=replay_config, model={
"custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
})
config = config.resources(num_gpus=0)
config = config.rollouts(num_rollout_workers=1)
config = config.framework("torch")
config = config.callbacks(MyCallbacks)
config = config.rl_module(_enable_rl_module_api = False)
config = config.environment(env="mini-grid", env_config={"shield": shield_dict })
def main():
args = parse_arguments(argparse)
if args.algorithm == "ppo":
ppo(args)
elif args.algorithm == "dqn":
dqn(args)
ray.shutdown()
if __name__ == '__main__': if __name__ == '__main__':
main() main()

4
examples/shields/rl/MaskEnvironments.py

@ -56,7 +56,7 @@ class ParametricActionsMiniGridEnv(gym.Env):
return obs, infos return obs, infos
return { return {
"action_mask": self.action_mask, "action_mask": self.action_mask,
"avail_actions": self.action_assignments,
"avail_action": self.action_assignments,
"cart": obs, "cart": obs,
}, infos }, infos
@ -83,7 +83,7 @@ class ParametricActionsMiniGridEnv(gym.Env):
return orig_obs, rew, done, truncated, info return orig_obs, rew, done, truncated, info
obs = { obs = {
"action_mask": self.action_mask, "action_mask": self.action_mask,
"avail_actions": self.action_assignments,
"action_mask": self.action_assignments,
"cart": orig_obs, "cart": orig_obs,
} }
return obs, rew, done, truncated, info return obs, rew, done, truncated, info

10
examples/shields/rl/MaskModels.py

@ -25,10 +25,10 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
): ):
orig_space = getattr(obs_space, "original_space", obs_space) orig_space = getattr(obs_space, "original_space", obs_space)
custom_config = model_config['custom_model_config'] custom_config = model_config['custom_model_config']
print(F"Original Space is: {orig_space}")
# print(F"Original Space is: {orig_space}")
#print(model_config) #print(model_config)
print(F"Observation space in model: {obs_space}")
print(F"Provided action space in model {action_space}")
#print(F"Observation space in model: {obs_space}")
#print(F"Provided action space in model {action_space}")
TorchModelV2.__init__( TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name, **kwargs self, obs_space, action_space, num_outputs, model_config, name, **kwargs
@ -65,7 +65,7 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
# print(F"Caluclated Logits {logits} with size {logits.size()} Count: {self.count}") # print(F"Caluclated Logits {logits} with size {logits.size()} Count: {self.count}")
action_mask = input_dict["obs"]["avail_actions"]
action_mask = input_dict["obs"]["action_mask"]
#print(F"Action mask is {action_mask} with dimension {action_mask.size()}") #print(F"Action mask is {action_mask} with dimension {action_mask.size()}")
# If action masking is disabled, directly return unmasked logits # If action masking is disabled, directly return unmasked logits
@ -77,7 +77,7 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
masked_logits = logits + inf_mask masked_logits = logits + inf_mask
print(F"Infinity mask {inf_mask}, Masked logits {masked_logits}")
# print(F"Infinity mask {inf_mask}, Masked logits {masked_logits}")
# # Return masked logits. # # Return masked logits.
return masked_logits, state return masked_logits, state

73
examples/shields/rl/Wrapper.py

@ -26,12 +26,12 @@ class OneHotWrapper(gym.core.ObservationWrapper):
self.observation_space = Dict( self.observation_space = Dict(
{ {
"data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32), "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),
"avail_actions": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
"action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
} }
) )
print(F"Set obersvation space to {self.observation_space}")
# print(F"Set obersvation space to {self.observation_space}")
def observation(self, obs): def observation(self, obs):
@ -77,7 +77,7 @@ class OneHotWrapper(gym.core.ObservationWrapper):
self.frame_buffer.append(single_frame) self.frame_buffer.append(single_frame)
#obs["one-hot"] = np.concatenate(self.frame_buffer) #obs["one-hot"] = np.concatenate(self.frame_buffer)
tmp = {"data": np.concatenate(self.frame_buffer), "avail_actions": obs["avail_actions"] }
tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] }
return tmp#np.concatenate(self.frame_buffer) return tmp#np.concatenate(self.frame_buffer)
@ -88,7 +88,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
self.observation_space = Dict( self.observation_space = Dict(
{ {
"data": env.observation_space.spaces["image"], "data": env.observation_space.spaces["image"],
"avail_actions" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
"action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
} }
) )
@ -98,7 +98,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
def create_action_mask(self): def create_action_mask(self):
coordinates = self.env.agent_pos coordinates = self.env.agent_pos
view_direction = self.env.agent_dir view_direction = self.env.agent_dir
print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
cur_pos_str = f"[!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" cur_pos_str = f"[!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
allowed_actions = [] allowed_actions = []
@ -109,73 +109,40 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
# else set everything to one # else set everything to one
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
# if cur_pos_str in self.shield:
# allowed_actions = self.shield[cur_pos_str]
# for allowed_action in allowed_actions:
# index = allowed_action[0]
# mask[index] = 1.0
# else:
# for index in len(mask):
# mask[index] = 1.0
if cur_pos_str in self.shield:
allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions:
index = allowed_action[0]
mask[index] = 1.0
else:
for index, x in enumerate(mask):
mask[index] = 1.0
print(F"Allowed actions for position {coordinates} and view {view_direction} are {allowed_actions}")
mask[0] = 1.0
#print(F"Action Mask for position {coordinates} and view {view_direction} is {mask}")
return mask return mask
def reset(self, *, seed=None, options=None): def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset() obs, infos = self.env.reset()
mask = self.create_action_mask()
return { return {
"data": obs["image"], "data": obs["image"],
"avail_actions": np.array([0.0] * self.max_available_actions, dtype=np.int8)
"action_mask": mask
}, infos }, infos
def step(self, action): def step(self, action):
print(F"Performed action in step: {action}")
# print(F"Performed action in step: {action}")
orig_obs, rew, done, truncated, info = self.env.step(action) orig_obs, rew, done, truncated, info = self.env.step(action)
actions = self.create_action_mask()
mask = self.create_action_mask()
#print(F"Original observation is {orig_obs}") #print(F"Original observation is {orig_obs}")
obs = { obs = {
"data": orig_obs["image"], "data": orig_obs["image"],
"avail_actions": actions,
"action_mask": mask,
} }
#print(F"Info is {info}") #print(F"Info is {info}")
return obs, rew, done, truncated, info return obs, rew, done, truncated, info
class ImgObsWrapper(gym.core.ObservationWrapper):
"""
Use the image as the only observation output, no language/mission.
Example:
>>> import gymnasium as gym
>>> from minigrid.wrappers import ImgObsWrapper
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
>>> obs, _ = env.reset()
>>> obs.keys()
dict_keys(['image', 'direction', 'mission'])
>>> env = ImgObsWrapper(env)
>>> obs, _ = env.reset()
>>> obs.shape
(7, 7, 3)
"""
def __init__(self, env):
"""A wrapper that makes image the only observation.
Args:
env: The environment to apply the wrapper
"""
super().__init__(env)
self.observation_space = env.observation_space.spaces["image"]
print(F"Set obersvation space to {self.observation_space}")
def observation(self, obs):
#print(F"obs in img obs wrapper {obs}")
tmp = {"data": obs["image"], "Test": obs["Test"]}
return tmp
Loading…
Cancel
Save