|
@ -37,6 +37,7 @@ from ray.rllib.utils.torch_utils import FLOAT_MIN |
|
|
from ray.rllib.models.preprocessors import get_preprocessor |
|
|
from ray.rllib.models.preprocessors import get_preprocessor |
|
|
from MaskModels import TorchActionMaskModel |
|
|
from MaskModels import TorchActionMaskModel |
|
|
from Wrapper import OneHotWrapper, MiniGridEnvWrapper |
|
|
from Wrapper import OneHotWrapper, MiniGridEnvWrapper |
|
|
|
|
|
from helpers import extract_keys |
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
@ -61,12 +62,12 @@ class MyCallbacks(DefaultCallbacks): |
|
|
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: |
|
|
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: |
|
|
episode.user_data["count"] = episode.user_data["count"] + 1 |
|
|
episode.user_data["count"] = episode.user_data["count"] + 1 |
|
|
env = base_env.get_sub_environments()[0] |
|
|
env = base_env.get_sub_environments()[0] |
|
|
#print(env.env.env.printGrid()) |
|
|
|
|
|
|
|
|
#print(env.printGrid()) |
|
|
|
|
|
|
|
|
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: |
|
|
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: |
|
|
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}") |
|
|
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}") |
|
|
env = base_env.get_sub_environments()[0] |
|
|
env = base_env.get_sub_environments()[0] |
|
|
# print(env.env.env.printGrid()) |
|
|
|
|
|
|
|
|
# print(env.printGrid()) |
|
|
# print(episode.user_data["count"]) |
|
|
# print(episode.user_data["count"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -96,17 +97,21 @@ def parse_arguments(argparse): |
|
|
"MiniGrid-RedBlueDoors-6x6-v0",]) |
|
|
"MiniGrid-RedBlueDoors-6x6-v0",]) |
|
|
|
|
|
|
|
|
# parser.add_argument("--seed", type=int, help="seed for environment", default=None) |
|
|
# parser.add_argument("--seed", type=int, help="seed for environment", default=None) |
|
|
|
|
|
parser.add_argument("--grid_to_prism_path", default="./main") |
|
|
parser.add_argument("--grid_path", default="Grid.txt") |
|
|
parser.add_argument("--grid_path", default="Grid.txt") |
|
|
parser.add_argument("--prism_path", default="Grid.PRISM") |
|
|
parser.add_argument("--prism_path", default="Grid.PRISM") |
|
|
parser.add_argument("--no_masking", default=False) |
|
|
parser.add_argument("--no_masking", default=False) |
|
|
parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"]) |
|
|
parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"]) |
|
|
parser.add_argument("--log_dir", default="../log_results/") |
|
|
parser.add_argument("--log_dir", default="../log_results/") |
|
|
|
|
|
parser.add_argument("--iterations", type=int, default=30 ) |
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
return args |
|
|
return args |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def env_creater_custom(config): |
|
|
def env_creater_custom(config): |
|
|
framestack = config.get("framestack", 4) |
|
|
framestack = config.get("framestack", 4) |
|
|
shield = config.get("shield", {}) |
|
|
shield = config.get("shield", {}) |
|
@ -114,7 +119,8 @@ def env_creater_custom(config): |
|
|
framestack = config.get("framestack", 4) |
|
|
framestack = config.get("framestack", 4) |
|
|
|
|
|
|
|
|
env = gym.make(name) |
|
|
env = gym.make(name) |
|
|
env = MiniGridEnvWrapper(env, shield=shield) |
|
|
|
|
|
|
|
|
keys = extract_keys(env) |
|
|
|
|
|
env = MiniGridEnvWrapper(env, shield=shield, keys=keys) |
|
|
# env = minigrid.wrappers.ImgObsWrapper(env) |
|
|
# env = minigrid.wrappers.ImgObsWrapper(env) |
|
|
# env = ImgObsWrapper(env) |
|
|
# env = ImgObsWrapper(env) |
|
|
env = OneHotWrapper(env, |
|
|
env = OneHotWrapper(env, |
|
@ -142,10 +148,12 @@ def env_creater(config): |
|
|
|
|
|
|
|
|
return env |
|
|
return env |
|
|
|
|
|
|
|
|
|
|
|
def create_log_dir(args): |
|
|
|
|
|
return F"{args.log_dir}{datetime.now()}-{args.algorithm}-masking:{not args.no_masking}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_shield(grid_file, prism_path): |
|
|
|
|
|
os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}") |
|
|
|
|
|
|
|
|
def create_shield(grid_to_prism_path, grid_file, prism_path): |
|
|
|
|
|
os.system(F"{grid_to_prism_path} -v 'agent' -i {grid_file} -o {prism_path}") |
|
|
|
|
|
|
|
|
f = open(prism_path, "a") |
|
|
f = open(prism_path, "a") |
|
|
f.write("label \"AgentIsInLava\" = AgentIsInLava;") |
|
|
f.write("label \"AgentIsInLava\" = AgentIsInLava;") |
|
@ -154,7 +162,12 @@ def create_shield(grid_file, prism_path): |
|
|
|
|
|
|
|
|
program = stormpy.parse_prism_program(prism_path) |
|
|
program = stormpy.parse_prism_program(prism_path) |
|
|
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" |
|
|
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" |
|
|
#formula_str = "Pmax=? [G \"AgentIsInGoalAndNotDone\"]" |
|
|
|
|
|
|
|
|
# formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" |
|
|
|
|
|
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, |
|
|
|
|
|
stormpy.logic.ShieldComparison.ABSOLUTE, 0.9) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.9) |
|
|
|
|
|
|
|
|
formulas = stormpy.parse_properties_for_prism_program(formula_str, program) |
|
|
formulas = stormpy.parse_properties_for_prism_program(formula_str, program) |
|
|
options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) |
|
|
options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) |
|
@ -163,7 +176,6 @@ def create_shield(grid_file, prism_path): |
|
|
options.set_build_all_labels() |
|
|
options.set_build_all_labels() |
|
|
model = stormpy.build_sparse_model_with_options(program, options) |
|
|
model = stormpy.build_sparse_model_with_options(program, options) |
|
|
|
|
|
|
|
|
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) |
|
|
|
|
|
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) |
|
|
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) |
|
|
|
|
|
|
|
|
assert result.has_scheduler |
|
|
assert result.has_scheduler |
|
@ -189,7 +201,6 @@ def export_grid_to_text(env, grid_file): |
|
|
f = open(grid_file, "w") |
|
|
f = open(grid_file, "w") |
|
|
# print(env) |
|
|
# print(env) |
|
|
f.write(env.printGrid(init=True)) |
|
|
f.write(env.printGrid(init=True)) |
|
|
# f.write(env.pprint_grid()) |
|
|
|
|
|
f.close() |
|
|
f.close() |
|
|
|
|
|
|
|
|
def create_environment(args): |
|
|
def create_environment(args): |
|
@ -210,14 +221,14 @@ def register_custom_minigrid_env(args): |
|
|
|
|
|
|
|
|
def create_shield_dict(args): |
|
|
def create_shield_dict(args): |
|
|
env = create_environment(args) |
|
|
env = create_environment(args) |
|
|
# print(env.pprint_grid()) |
|
|
|
|
|
# print(env.printGrid(init=False)) |
|
|
# print(env.printGrid(init=False)) |
|
|
|
|
|
|
|
|
grid_file = args.grid_path |
|
|
grid_file = args.grid_path |
|
|
|
|
|
grid_to_prism_path = args.grid_to_prism_path |
|
|
export_grid_to_text(env, grid_file) |
|
|
export_grid_to_text(env, grid_file) |
|
|
|
|
|
|
|
|
prism_path = args.prism_path |
|
|
prism_path = args.prism_path |
|
|
shield_dict = create_shield(grid_file, prism_path) |
|
|
|
|
|
|
|
|
shield_dict = create_shield(grid_to_prism_path ,grid_file, prism_path) |
|
|
#shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} |
|
|
#shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} |
|
|
|
|
|
|
|
|
#print(F"Shield dictionary {shield_dict}") |
|
|
#print(F"Shield dictionary {shield_dict}") |
|
@ -238,13 +249,13 @@ def ppo(args): |
|
|
config = (PPOConfig() |
|
|
config = (PPOConfig() |
|
|
.rollouts(num_rollout_workers=1) |
|
|
.rollouts(num_rollout_workers=1) |
|
|
.resources(num_gpus=0) |
|
|
.resources(num_gpus=0) |
|
|
.environment(env="mini-grid", env_config={"shield": shield_dict}) |
|
|
|
|
|
|
|
|
.environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env}) |
|
|
.framework("torch") |
|
|
.framework("torch") |
|
|
.callbacks(MyCallbacks) |
|
|
.callbacks(MyCallbacks) |
|
|
.rl_module(_enable_rl_module_api = False) |
|
|
.rl_module(_enable_rl_module_api = False) |
|
|
.debugging(logger_config={ |
|
|
.debugging(logger_config={ |
|
|
"type": "ray.tune.logger.TBXLogger", |
|
|
"type": "ray.tune.logger.TBXLogger", |
|
|
"logdir": F"{args.log_dir}{datetime.now()}-{args.algorithm}" |
|
|
|
|
|
|
|
|
"logdir": create_log_dir(args) |
|
|
}) |
|
|
}) |
|
|
.training(_enable_learner_api=False ,model={ |
|
|
.training(_enable_learner_api=False ,model={ |
|
|
"custom_model": "pa_model", |
|
|
"custom_model": "pa_model", |
|
@ -279,13 +290,13 @@ def dqn(args): |
|
|
config = DQNConfig() |
|
|
config = DQNConfig() |
|
|
config = config.resources(num_gpus=0) |
|
|
config = config.resources(num_gpus=0) |
|
|
config = config.rollouts(num_rollout_workers=1) |
|
|
config = config.rollouts(num_rollout_workers=1) |
|
|
config = config.environment(env="mini-grid", env_config={"shield": shield_dict }) |
|
|
|
|
|
|
|
|
config = config.environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env }) |
|
|
config = config.framework("torch") |
|
|
config = config.framework("torch") |
|
|
config = config.callbacks(MyCallbacks) |
|
|
config = config.callbacks(MyCallbacks) |
|
|
config = config.rl_module(_enable_rl_module_api = False) |
|
|
config = config.rl_module(_enable_rl_module_api = False) |
|
|
config = config.debugging(logger_config={ |
|
|
config = config.debugging(logger_config={ |
|
|
"type": "ray.tune.logger.TBXLogger", |
|
|
"type": "ray.tune.logger.TBXLogger", |
|
|
"logdir": F"{args.log_dir}{datetime.now()}-{args.algorithm}" |
|
|
|
|
|
|
|
|
"logdir": create_log_dir(args) |
|
|
}) |
|
|
}) |
|
|
config = config.training(hiddens=[], dueling=False, model={ |
|
|
config = config.training(hiddens=[], dueling=False, model={ |
|
|
"custom_model": "pa_model", |
|
|
"custom_model": "pa_model", |
|
@ -296,13 +307,13 @@ def dqn(args): |
|
|
config.build() |
|
|
config.build() |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
for i in range(30): |
|
|
|
|
|
|
|
|
for i in range(args.iterations): |
|
|
result = algo.train() |
|
|
result = algo.train() |
|
|
print(pretty_print(result)) |
|
|
print(pretty_print(result)) |
|
|
|
|
|
|
|
|
if i % 5 == 0: |
|
|
|
|
|
checkpoint_dir = algo.save() |
|
|
|
|
|
print(f"Checkpoint saved in directory {checkpoint_dir}") |
|
|
|
|
|
|
|
|
# if i % 5 == 0: |
|
|
|
|
|
# checkpoint_dir = algo.save() |
|
|
|
|
|
# print(f"Checkpoint saved in directory {checkpoint_dir}") |
|
|
|
|
|
|
|
|
ray.shutdown() |
|
|
ray.shutdown() |
|
|
|
|
|
|
|
|