Browse Source

added argument for predefined prism file

Needed when skipping M2P
refactoring
sp 5 months ago
parent
commit
89675dfe80
  1. 4
      examples/shields/rl/13_minigridsb.py
  2. 7
      examples/shields/rl/utils.py

4
examples/shields/rl/13_minigridsb.py

@ -24,11 +24,9 @@ def mask_fn(env: gym.Env):
def nomask_fn(env: gym.Env):
return [1.0] * 7
def main():
args = parse_sb3_arguments()
formula = args.formula
shield_value = args.shield_value
shield_comparison = args.shield_comparison
@ -38,7 +36,7 @@ def main():
if shield_needed(args.shielding):
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=args.nocleanup)
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=args.nocleanup, prism_file=None)
env = gym.make(args.env, render_mode="rgb_array")

7
examples/shields/rl/utils.py

@ -61,13 +61,14 @@ class ShieldHandler(ABC):
pass
class MiniGridShieldHandler(ShieldHandler):
def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute', nocleanup=False) -> None:
def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute', nocleanup=False, prism_file=None) -> None:
self.tmp_dir_name = f"shielding_files_{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{next(tempfile._get_candidate_names())}"
os.mkdir(self.tmp_dir_name)
self.grid_file = self.tmp_dir_name + "/" + grid_file
self.grid_to_prism_binary = grid_to_prism_binary
self.prism_path = self.tmp_dir_name + "/" + prism_path
self.prism_config = prism_config
self.prism_file = prism_file
self.formula = formula
shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE
@ -84,6 +85,9 @@ class MiniGridShieldHandler(ShieldHandler):
def __create_prism(self):
if self.prism_file is not None:
shutil.copyfile(self.prism_file, self.prism_path)
return
if self.prism_config is None:
result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path}")
else:
@ -188,6 +192,7 @@ def common_parser():
default="MiniGrid-LavaSlipperyCliff-16x13-v0")
parser.add_argument("--grid_file", default="grid.txt")
parser.add_argument("--prism_file", default=None)
parser.add_argument("--prism_output_file", default="grid.prism")
parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--formula", default="Pmax=? [G !AgentIsOnLava]")

Loading…
Cancel
Save