|
@ -13,6 +13,7 @@ from abc import ABC |
|
|
|
|
|
|
|
|
import re |
|
|
import re |
|
|
import sys |
|
|
import sys |
|
|
|
|
|
import tempfile, datetime, shutil |
|
|
|
|
|
|
|
|
import gymnasium as gym |
|
|
import gymnasium as gym |
|
|
|
|
|
|
|
@ -51,21 +52,26 @@ class ShieldHandler(ABC): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
class MiniGridShieldHandler(ShieldHandler): |
|
|
class MiniGridShieldHandler(ShieldHandler): |
|
|
def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute') -> None: |
|
|
|
|
|
self.grid_file = grid_file |
|
|
|
|
|
|
|
|
def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute', cleanup=True) -> None: |
|
|
|
|
|
self.tmp_dir_name = f"shielding_files_{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{next(tempfile._get_candidate_names())}" |
|
|
|
|
|
os.mkdir(self.tmp_dir_name) |
|
|
|
|
|
self.grid_file = self.tmp_dir_name + "/" + grid_file |
|
|
self.grid_to_prism_binary = grid_to_prism_binary |
|
|
self.grid_to_prism_binary = grid_to_prism_binary |
|
|
self.prism_path = prism_path |
|
|
|
|
|
|
|
|
self.prism_path = self.tmp_dir_name + "/" + prism_path |
|
|
self.prism_config = prism_config |
|
|
self.prism_config = prism_config |
|
|
|
|
|
|
|
|
self.formula = formula |
|
|
self.formula = formula |
|
|
shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE |
|
|
shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE |
|
|
self.shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comparison, shield_value) |
|
|
self.shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comparison, shield_value) |
|
|
|
|
|
|
|
|
|
|
|
self.cleanup = cleanup |
|
|
|
|
|
def __del__(self): |
|
|
|
|
|
if self.cleanup: |
|
|
|
|
|
shutil.rmtree(self.tmp_dir_name) |
|
|
|
|
|
|
|
|
def __export_grid_to_text(self, env): |
|
|
def __export_grid_to_text(self, env): |
|
|
f = open(self.grid_file, "w") |
|
|
|
|
|
|
|
|
with open(self.grid_file, "w") as f: |
|
|
f.write(env.printGrid(init=True)) |
|
|
f.write(env.printGrid(init=True)) |
|
|
f.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __create_prism(self): |
|
|
def __create_prism(self): |
|
@ -162,6 +168,7 @@ def common_parser(): |
|
|
parser.add_argument("--prism_config", default=None) |
|
|
parser.add_argument("--prism_config", default=None) |
|
|
parser.add_argument("--shield_value", default=0.9, type=float) |
|
|
parser.add_argument("--shield_value", default=0.9, type=float) |
|
|
parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute']) |
|
|
parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute']) |
|
|
|
|
|
parser.add_argument("--cleanup", action=argparse.BooleanOptionalAction, default=True) |
|
|
return parser |
|
|
return parser |
|
|
|
|
|
|
|
|
class MiniWrapper(gym.Wrapper): |
|
|
class MiniWrapper(gym.Wrapper): |
|
|