From 11d5c3c811bf99a040113cb3fcc1e113d357d7d0 Mon Sep 17 00:00:00 2001 From: sp Date: Mon, 15 Jan 2024 13:33:41 +0100 Subject: [PATCH] store shield files in local tmp dir --- examples/shields/rl/utils.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/shields/rl/utils.py b/examples/shields/rl/utils.py index 9986772..27d1f13 100644 --- a/examples/shields/rl/utils.py +++ b/examples/shields/rl/utils.py @@ -13,6 +13,7 @@ from abc import ABC import re import sys +import tempfile, datetime, shutil import gymnasium as gym @@ -51,21 +52,26 @@ class ShieldHandler(ABC): pass class MiniGridShieldHandler(ShieldHandler): - def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute') -> None: - self.grid_file = grid_file + def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute', cleanup=True) -> None: + self.tmp_dir_name = f"shielding_files_{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{next(tempfile._get_candidate_names())}" + os.mkdir(self.tmp_dir_name) + self.grid_file = self.tmp_dir_name + "/" + grid_file self.grid_to_prism_binary = grid_to_prism_binary - self.prism_path = prism_path + self.prism_path = self.tmp_dir_name + "/" + prism_path self.prism_config = prism_config self.formula = formula shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE self.shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comparison, shield_value) + self.cleanup = cleanup + def __del__(self): + if self.cleanup: + shutil.rmtree(self.tmp_dir_name) def __export_grid_to_text(self, env): - f = open(self.grid_file, "w") - f.write(env.printGrid(init=True)) - f.close() + with open(self.grid_file, "w") as f: + f.write(env.printGrid(init=True)) def __create_prism(self): @@ -162,6 +168,7 @@ def common_parser(): parser.add_argument("--prism_config", default=None) parser.add_argument("--shield_value", default=0.9, type=float) parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute']) + parser.add_argument("--cleanup", action=argparse.BooleanOptionalAction, default=True) return parser class MiniWrapper(gym.Wrapper):