|
@ -25,6 +25,7 @@ import numpy as np |
|
|
import ray |
|
|
import ray |
|
|
from ray.tune import register_env |
|
|
from ray.tune import register_env |
|
|
from ray.rllib.algorithms.ppo import PPOConfig |
|
|
from ray.rllib.algorithms.ppo import PPOConfig |
|
|
|
|
|
from ray.rllib.algorithms.dqn.dqn import DQNConfig |
|
|
from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator |
|
|
from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator |
|
|
from ray import tune, air |
|
|
from ray import tune, air |
|
|
from ray.rllib.algorithms.callbacks import DefaultCallbacks |
|
|
from ray.rllib.algorithms.callbacks import DefaultCallbacks |
|
@ -37,7 +38,7 @@ from ray.rllib.utils.torch_utils import FLOAT_MIN |
|
|
from ray.rllib.models.preprocessors import get_preprocessor |
|
|
from ray.rllib.models.preprocessors import get_preprocessor |
|
|
from MaskEnvironments import ParametricActionsMiniGridEnv |
|
|
from MaskEnvironments import ParametricActionsMiniGridEnv |
|
|
from MaskModels import TorchActionMaskModel |
|
|
from MaskModels import TorchActionMaskModel |
|
|
from Wrapper import OneHotWrapper, MiniGridEnvWrapper, ImgObsWrapper |
|
|
|
|
|
|
|
|
from Wrapper import OneHotWrapper, MiniGridEnvWrapper |
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
@ -62,7 +63,7 @@ class MyCallbacks(DefaultCallbacks): |
|
|
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: |
|
|
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: |
|
|
episode.user_data["count"] = episode.user_data["count"] + 1 |
|
|
episode.user_data["count"] = episode.user_data["count"] + 1 |
|
|
env = base_env.get_sub_environments()[0] |
|
|
env = base_env.get_sub_environments()[0] |
|
|
print(env.env.env.printGrid()) |
|
|
|
|
|
|
|
|
#print(env.env.env.printGrid()) |
|
|
|
|
|
|
|
|
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: |
|
|
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: |
|
|
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}") |
|
|
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}") |
|
@ -83,6 +84,7 @@ def parse_arguments(argparse): |
|
|
parser.add_argument("--grid_path", default="Grid.txt") |
|
|
parser.add_argument("--grid_path", default="Grid.txt") |
|
|
parser.add_argument("--prism_path", default="Grid.PRISM") |
|
|
parser.add_argument("--prism_path", default="Grid.PRISM") |
|
|
parser.add_argument("--no_masking", default=False) |
|
|
parser.add_argument("--no_masking", default=False) |
|
|
|
|
|
parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"]) |
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
@ -108,13 +110,13 @@ def env_creater_custom(config): |
|
|
framestack=framestack |
|
|
framestack=framestack |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
obs = env.observation_space.sample() |
|
|
|
|
|
obs2, infos = env.reset(seed=None, options={}) |
|
|
|
|
|
|
|
|
# obs = env.observation_space.sample() |
|
|
|
|
|
# obs2, infos = env.reset(seed=None, options={}) |
|
|
|
|
|
|
|
|
print(F"Obs is {obs} before reset. After reset: {obs2}") |
|
|
|
|
|
|
|
|
# print(F"Obs is {obs} before reset. After reset: {obs2}") |
|
|
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env) |
|
|
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env) |
|
|
|
|
|
|
|
|
print(F"Created Custom Minigrid Environment is {env}") |
|
|
|
|
|
|
|
|
# print(F"Created Custom Minigrid Environment is {env}") |
|
|
|
|
|
|
|
|
return env |
|
|
return env |
|
|
|
|
|
|
|
@ -194,12 +196,16 @@ def create_environment(args): |
|
|
return env |
|
|
return env |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
args = parse_arguments(argparse) |
|
|
|
|
|
|
|
|
def register_custom_minigrid_env(): |
|
|
|
|
|
env_name = "mini-grid" |
|
|
|
|
|
register_env(env_name, env_creater_custom) |
|
|
|
|
|
ModelCatalog.register_custom_model( |
|
|
|
|
|
"pa_model", |
|
|
|
|
|
TorchActionMaskModel |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def create_shield_dict(args): |
|
|
env = create_environment(args) |
|
|
env = create_environment(args) |
|
|
ray.init(num_cpus=3) |
|
|
|
|
|
|
|
|
|
|
|
# print(env.pprint_grid()) |
|
|
# print(env.pprint_grid()) |
|
|
# print(env.printGrid(init=False)) |
|
|
# print(env.printGrid(init=False)) |
|
|
|
|
|
|
|
@ -214,20 +220,22 @@ def main(): |
|
|
# for state_id in model.states: |
|
|
# for state_id in model.states: |
|
|
# choices = shield.get_choice(state_id) |
|
|
# choices = shield.get_choice(state_id) |
|
|
# print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") |
|
|
# print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") |
|
|
|
|
|
|
|
|
env_name = "mini-grid" |
|
|
|
|
|
register_env(env_name, env_creater_custom) |
|
|
|
|
|
ModelCatalog.register_custom_model( |
|
|
|
|
|
"pa_model", |
|
|
|
|
|
TorchActionMaskModel |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return shield_dict |
|
|
|
|
|
|
|
|
|
|
|
def ppo(args): |
|
|
|
|
|
|
|
|
|
|
|
ray.init(num_cpus=3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
register_custom_minigrid_env() |
|
|
|
|
|
shield_dict = create_shield_dict(args) |
|
|
|
|
|
|
|
|
config = (PPOConfig() |
|
|
config = (PPOConfig() |
|
|
.rollouts(num_rollout_workers=1) |
|
|
.rollouts(num_rollout_workers=1) |
|
|
.resources(num_gpus=0) |
|
|
.resources(num_gpus=0) |
|
|
.environment(env="mini-grid", env_config={"shield": shield_dict }) |
|
|
.environment(env="mini-grid", env_config={"shield": shield_dict }) |
|
|
.framework("torch") |
|
|
.framework("torch") |
|
|
.experimental(_disable_preprocessor_api=False) |
|
|
|
|
|
.callbacks(MyCallbacks) |
|
|
.callbacks(MyCallbacks) |
|
|
.rl_module(_enable_rl_module_api = False) |
|
|
.rl_module(_enable_rl_module_api = False) |
|
|
.training(_enable_learner_api=False ,model={ |
|
|
.training(_enable_learner_api=False ,model={ |
|
@ -255,10 +263,48 @@ def main(): |
|
|
if i % 5 == 0: |
|
|
if i % 5 == 0: |
|
|
checkpoint_dir = algo.save() |
|
|
checkpoint_dir = algo.save() |
|
|
print(f"Checkpoint saved in directory {checkpoint_dir}") |
|
|
print(f"Checkpoint saved in directory {checkpoint_dir}") |
|
|
|
|
|
|
|
|
|
|
|
ray.shutdown() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dqn(args): |
|
|
|
|
|
config = DQNConfig() |
|
|
|
|
|
register_custom_minigrid_env() |
|
|
|
|
|
shield_dict = create_shield_dict(args) |
|
|
|
|
|
replay_config = config.replay_buffer_config.update( |
|
|
|
|
|
{ |
|
|
|
|
|
"capacity": 60000, |
|
|
|
|
|
"prioritized_replay_alpha": 0.5, |
|
|
|
|
|
"prioritized_replay_beta": 0.5, |
|
|
|
|
|
"prioritized_replay_eps": 3e-6, |
|
|
|
|
|
} |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
config = config.training(replay_buffer_config=replay_config, model={ |
|
|
|
|
|
"custom_model": "pa_model", |
|
|
|
|
|
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking} |
|
|
|
|
|
}) |
|
|
|
|
|
config = config.resources(num_gpus=0) |
|
|
|
|
|
config = config.rollouts(num_rollout_workers=1) |
|
|
|
|
|
config = config.framework("torch") |
|
|
|
|
|
config = config.callbacks(MyCallbacks) |
|
|
|
|
|
config = config.rl_module(_enable_rl_module_api = False) |
|
|
|
|
|
|
|
|
|
|
|
config = config.environment(env="mini-grid", env_config={"shield": shield_dict }) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
args = parse_arguments(argparse) |
|
|
|
|
|
|
|
|
|
|
|
if args.algorithm == "ppo": |
|
|
|
|
|
ppo(args) |
|
|
|
|
|
elif args.algorithm == "dqn": |
|
|
|
|
|
dqn(args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ray.shutdown() |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
main() |