From 89675dfe80126341831ddaddb06772d5fe5cba99 Mon Sep 17 00:00:00 2001 From: sp Date: Fri, 28 Jun 2024 14:56:47 +0200 Subject: [PATCH] added argument for predefined prism file Needed when skipping M2P --- examples/shields/rl/13_minigridsb.py | 4 +--- examples/shields/rl/utils.py | 7 ++++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 616e48c..1cd0a7e 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -24,11 +24,9 @@ def mask_fn(env: gym.Env): def nomask_fn(env: gym.Env): return [1.0] * 7 - def main(): args = parse_sb3_arguments() - formula = args.formula shield_value = args.shield_value shield_comparison = args.shield_comparison @@ -38,7 +36,7 @@ def main(): if shield_needed(args.shielding): - shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=args.nocleanup) + shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=args.nocleanup, prism_file=None) env = gym.make(args.env, render_mode="rgb_array") diff --git a/examples/shields/rl/utils.py b/examples/shields/rl/utils.py index d74856b..29d9d98 100644 --- a/examples/shields/rl/utils.py +++ b/examples/shields/rl/utils.py @@ -61,13 +61,14 @@ class ShieldHandler(ABC): pass class MiniGridShieldHandler(ShieldHandler): - def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute', nocleanup=False) -> None: + def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute', nocleanup=False, prism_file=None) -> None: self.tmp_dir_name = f"shielding_files_{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{next(tempfile._get_candidate_names())}" os.mkdir(self.tmp_dir_name) self.grid_file = self.tmp_dir_name + "/" + grid_file self.grid_to_prism_binary = grid_to_prism_binary self.prism_path = self.tmp_dir_name + "/" + prism_path self.prism_config = prism_config + self.prism_file = prism_file self.formula = formula shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE @@ -84,6 +85,9 @@ class MiniGridShieldHandler(ShieldHandler): def __create_prism(self): + if self.prism_file is not None: + shutil.copyfile(self.prism_file, self.prism_path) + return if self.prism_config is None: result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path}") else: @@ -188,6 +192,7 @@ def common_parser(): default="MiniGrid-LavaSlipperyCliff-16x13-v0") parser.add_argument("--grid_file", default="grid.txt") + parser.add_argument("--prism_file", default=None) parser.add_argument("--prism_output_file", default="grid.prism") parser.add_argument("--log_dir", default="../log_results/") parser.add_argument("--formula", default="Pmax=? [G !AgentIsOnLava]")