|
@ -10,7 +10,7 @@ from ray.tune.logger import UnifiedLogger |
|
|
from ray.rllib.models import ModelCatalog |
|
|
from ray.rllib.models import ModelCatalog |
|
|
from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger |
|
|
from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger |
|
|
from ray.rllib.algorithms.algorithm import Algorithm |
|
|
from ray.rllib.algorithms.algorithm import Algorithm |
|
|
|
|
|
|
|
|
|
|
|
from ray.air import session |
|
|
|
|
|
|
|
|
from torch_action_mask_model import TorchActionMaskModel |
|
|
from torch_action_mask_model import TorchActionMaskModel |
|
|
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper |
|
|
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper |
|
@ -25,11 +25,13 @@ def shielding_env_creater(config): |
|
|
name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0") |
|
|
name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0") |
|
|
framestack = config.get("framestack", 4) |
|
|
framestack = config.get("framestack", 4) |
|
|
args = config.get("args", None) |
|
|
args = config.get("args", None) |
|
|
args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" |
|
|
|
|
|
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt" |
|
|
|
|
|
args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism" |
|
|
shielding = config.get("shielding", False) |
|
|
shielding = config.get("shielding", False) |
|
|
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) |
|
|
|
|
|
|
|
|
shield_creator = MiniGridShieldHandler(grid_file=args.grid_path, |
|
|
|
|
|
grid_to_prism_path=args.grid_to_prism_binary_path, |
|
|
|
|
|
prism_path=args.prism_path, |
|
|
|
|
|
formula=args.formula) |
|
|
|
|
|
|
|
|
env = gym.make(name) |
|
|
env = gym.make(name) |
|
|
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding) |
|
|
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding) |
|
@ -64,7 +66,10 @@ def ppo(args): |
|
|
.rollouts(num_rollout_workers=args.workers) |
|
|
.rollouts(num_rollout_workers=args.workers) |
|
|
.resources(num_gpus=0) |
|
|
.resources(num_gpus=0) |
|
|
.environment( env="mini-grid-shielding", |
|
|
.environment( env="mini-grid-shielding", |
|
|
env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) |
|
|
|
|
|
|
|
|
env_config={"name": args.env, |
|
|
|
|
|
"args": args, |
|
|
|
|
|
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, |
|
|
|
|
|
},) |
|
|
.framework("torch") |
|
|
.framework("torch") |
|
|
.callbacks(MyCallbacks) |
|
|
.callbacks(MyCallbacks) |
|
|
.evaluation(evaluation_config={ |
|
|
.evaluation(evaluation_config={ |
|
@ -72,7 +77,9 @@ def ppo(args): |
|
|
"evaluation_duration": 10, |
|
|
"evaluation_duration": 10, |
|
|
"evaluation_num_workers":1, |
|
|
"evaluation_num_workers":1, |
|
|
"env": "mini-grid-shielding", |
|
|
"env": "mini-grid-shielding", |
|
|
"env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}}) |
|
|
|
|
|
|
|
|
"env_config": {"name": args.env, |
|
|
|
|
|
"args": args, |
|
|
|
|
|
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}}) |
|
|
.rl_module(_enable_rl_module_api = False) |
|
|
.rl_module(_enable_rl_module_api = False) |
|
|
.debugging(logger_config={ |
|
|
.debugging(logger_config={ |
|
|
"type": UnifiedLogger, |
|
|
"type": UnifiedLogger, |
|
@ -87,7 +94,8 @@ def ppo(args): |
|
|
metric="episode_reward_mean", |
|
|
metric="episode_reward_mean", |
|
|
mode="max", |
|
|
mode="max", |
|
|
num_samples=1, |
|
|
num_samples=1, |
|
|
trial_name_creator=trial_name_creator, |
|
|
|
|
|
|
|
|
trial_name_creator=trial_name_creator, |
|
|
|
|
|
|
|
|
), |
|
|
), |
|
|
run_config=air.RunConfig( |
|
|
run_config=air.RunConfig( |
|
|
stop = {"episode_reward_mean": 94, |
|
|
stop = {"episode_reward_mean": 94, |
|
@ -144,7 +152,7 @@ def ppo(args): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
def main(): |
|
|
ray.init(num_cpus=4) |
|
|
|
|
|
|
|
|
ray.init(num_cpus=3) |
|
|
import argparse |
|
|
import argparse |
|
|
args = parse_arguments(argparse) |
|
|
args = parse_arguments(argparse) |
|
|
|
|
|
|
|
|