Browse Source

added expname to grid path

refactoring
Thomas Knoll 1 year ago
parent
commit
41f94bf92e
  1. 24
      examples/shields/rl/15_train_eval_tune.py

24
examples/shields/rl/15_train_eval_tune.py

@ -10,7 +10,7 @@ from ray.tune.logger import UnifiedLogger
from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger
from ray.rllib.algorithms.algorithm import Algorithm
from ray.air import session
from torch_action_mask_model import TorchActionMaskModel
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
@ -25,11 +25,13 @@ def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism"
shielding = config.get("shielding", False)
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
shield_creator = MiniGridShieldHandler(grid_file=args.grid_path,
grid_to_prism_path=args.grid_to_prism_binary_path,
prism_path=args.prism_path,
formula=args.formula)
env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding)
@ -64,7 +66,10 @@ def ppo(args):
.rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0)
.environment( env="mini-grid-shielding",
env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
env_config={"name": args.env,
"args": args,
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training,
},)
.framework("torch")
.callbacks(MyCallbacks)
.evaluation(evaluation_config={
@ -72,7 +77,9 @@ def ppo(args):
"evaluation_duration": 10,
"evaluation_num_workers":1,
"env": "mini-grid-shielding",
"env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})
"env_config": {"name": args.env,
"args": args,
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})
.rl_module(_enable_rl_module_api = False)
.debugging(logger_config={
"type": UnifiedLogger,
@ -88,6 +95,7 @@ def ppo(args):
mode="max",
num_samples=1,
trial_name_creator=trial_name_creator,
),
run_config=air.RunConfig(
stop = {"episode_reward_mean": 94,
@ -144,7 +152,7 @@ def ppo(args):
def main():
ray.init(num_cpus=4)
ray.init(num_cpus=3)
import argparse
args = parse_arguments(argparse)

Loading…
Cancel
Save