From 45d110b199e02e414882687c79f644ea2b22170d Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Tue, 19 Dec 2023 09:10:21 +0100 Subject: [PATCH] added probability arguments --- examples/shields/rl/11_minigridrl.py | 16 ++++++++-- examples/shields/rl/15_train_eval_tune.py | 9 ++++-- examples/shields/rl/helpers.py | 4 +++ examples/shields/rl/shieldhandlers.py | 11 +++++-- slippery_prob_08.yaml | 39 +++++++++++++++++++++++ 5 files changed, 72 insertions(+), 7 deletions(-) create mode 100644 slippery_prob_08.yaml diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index 150d7ea..e66cfc1 100755 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -23,9 +23,19 @@ def shielding_env_creater(config): args.grid_path = F"{args.grid_path}_{config.worker_index}_{args.prism_config}.txt" args.prism_path = F"{args.prism_path}_{config.worker_index}_{args.prism_config}.prism" - - shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula, args.shield_value, args.prism_config) - env = gym.make(name, randomize_start=True) + prob_forward = args.prob_forward + prob_direct = args.prob_direct + prob_next = args.prob_next + + shield_creator = MiniGridShieldHandler(args.grid_path, + args.grid_to_prism_binary_path, + args.prism_path, + args.formula, + args.shield_value, + args.prism_config, + shield_comparision=args.shield_comparision) + + env = gym.make(name, randomize_start=True,probability_forward=prob_forward, probability_direct_neighbour=prob_direct, probability_next_neighbour=prob_next) env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding != ShieldingConfig.Disabled, diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index 0158e08..9917e44 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -33,9 +33,14 @@ def shielding_env_creater(config): prism_path=args.prism_path, formula=args.formula, shield_value=args.shield_value, - prism_config=args.prism_config) + prism_config=args.prism_config, + shield_comparision=args.shield_comparision) - env = gym.make(name, randomize_start=True) + prob_forward = args.prob_forward + prob_direct = args.prob_direct + prob_next = args.prob_next + + env = gym.make(name, randomize_start=True,probability_forward=prob_forward, probability_direct_neighbour=prob_direct, probability_next_neighbour=prob_next) env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled) env = OneHotShieldingWrapper(env, diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index 3e8112f..3d943eb 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -138,6 +138,10 @@ def parse_arguments(argparse): parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction) parser.add_argument("--prism_config", default=None) parser.add_argument("--shield_value", default=0.9, type=float) + parser.add_argument("--prob_direct", default=1/4, type=float) + parser.add_argument("--prob_forward", default=3/4, type=float) + parser.add_argument("--prob_next", default=1/8, type=float) + parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute']) # parser.add_argument("--random_starts", default=1, type=int) args = parser.parse_args() diff --git a/examples/shields/rl/shieldhandlers.py b/examples/shields/rl/shieldhandlers.py index 68140a8..95cab72 100644 --- a/examples/shields/rl/shieldhandlers.py +++ b/examples/shields/rl/shieldhandlers.py @@ -27,13 +27,14 @@ class ShieldHandler(ABC): pass class MiniGridShieldHandler(ShieldHandler): - def __init__(self, grid_file, grid_to_prism_path, prism_path, formula, shield_value=0.9 ,prism_config=None) -> None: + def __init__(self, grid_file, grid_to_prism_path, prism_path, formula, shield_value=0.9 ,prism_config=None, shield_comparision='relative') -> None: self.grid_file = grid_file self.grid_to_prism_path = grid_to_prism_path self.prism_path = prism_path self.formula = formula self.prism_config = prism_config self.shield_value = shield_value + self.shield_comparision = shield_comparision def __export_grid_to_text(self, env): f = open(self.grid_file, "w") @@ -58,7 +59,13 @@ class MiniGridShieldHandler(ShieldHandler): def __create_shield_dict(self): print(self.prism_path) program = stormpy.parse_prism_program(self.prism_path) - shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, self.shield_value) + + shield_comp = stormpy.logic.ShieldComparison.RELATIVE + + if self.shield_comparision == 'absolute': + shield_comp = stormpy.logic.ShieldComparison.ABSOLUTE + + shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comp, self.shield_value) formulas = stormpy.parse_properties_for_prism_program(self.formula, program) options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) diff --git a/slippery_prob_08.yaml b/slippery_prob_08.yaml new file mode 100644 index 0000000..90c3449 --- /dev/null +++ b/slippery_prob_08.yaml @@ -0,0 +1,39 @@ +--- +labels: + - label: "AgentIsInGoal" + text: "AgentIsInGoal" + +constants: + - constant: "prop_slippery_turn" + type: "double" + value: "9/9" + overwrite: True + - constant: "prop_next_neighbour_turn" + type: "double" + value: "0/9" + overwrite: True + - constant: "prop_slippery_move_forward" + type: "double" + value: "4/5" + overwrite: True + - constant: "prop_direct_neighbour" + type: "double" + value: "1/5" + overwrite: True + - constant: "prop_next_neighbour" + type: "double" + value: "1/10" + overwrite: True + - constant: "total_prop" + type: "double" + value: "4" + overwrite: True +... + +# const double prop_zero = 0/9; +# const double prop_next_neighbour = 1/9; +# const double prop_slippery_move_forward = 7/9; +# const double prop_slippery_turn = 6/9; +# const double prop_next_neighbour_turn = 1/9; +# const double prop_direct_neighbour = 2/9; +# const double total_prop = 9; \ No newline at end of file