Browse Source

added probability arguments

refactoring
Thomas Knoll 11 months ago
parent
commit
45d110b199
  1. 16
      examples/shields/rl/11_minigridrl.py
  2. 9
      examples/shields/rl/15_train_eval_tune.py
  3. 4
      examples/shields/rl/helpers.py
  4. 11
      examples/shields/rl/shieldhandlers.py
  5. 39
      slippery_prob_08.yaml

16
examples/shields/rl/11_minigridrl.py

@ -23,9 +23,19 @@ def shielding_env_creater(config):
args.grid_path = F"{args.grid_path}_{config.worker_index}_{args.prism_config}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}_{args.prism_config}.prism"
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula, args.shield_value, args.prism_config)
env = gym.make(name, randomize_start=True)
prob_forward = args.prob_forward
prob_direct = args.prob_direct
prob_next = args.prob_next
shield_creator = MiniGridShieldHandler(args.grid_path,
args.grid_to_prism_binary_path,
args.prism_path,
args.formula,
args.shield_value,
args.prism_config,
shield_comparision=args.shield_comparision)
env = gym.make(name, randomize_start=True,probability_forward=prob_forward, probability_direct_neighbour=prob_direct, probability_next_neighbour=prob_next)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator,
shield_query_creator=create_shield_query,
mask_actions=args.shielding != ShieldingConfig.Disabled,

9
examples/shields/rl/15_train_eval_tune.py

@ -33,9 +33,14 @@ def shielding_env_creater(config):
prism_path=args.prism_path,
formula=args.formula,
shield_value=args.shield_value,
prism_config=args.prism_config)
prism_config=args.prism_config,
shield_comparision=args.shield_comparision)
env = gym.make(name, randomize_start=True)
prob_forward = args.prob_forward
prob_direct = args.prob_direct
prob_next = args.prob_next
env = gym.make(name, randomize_start=True,probability_forward=prob_forward, probability_direct_neighbour=prob_direct, probability_next_neighbour=prob_next)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled)
env = OneHotShieldingWrapper(env,

4
examples/shields/rl/helpers.py

@ -138,6 +138,10 @@ def parse_arguments(argparse):
parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
parser.add_argument("--prism_config", default=None)
parser.add_argument("--shield_value", default=0.9, type=float)
parser.add_argument("--prob_direct", default=1/4, type=float)
parser.add_argument("--prob_forward", default=3/4, type=float)
parser.add_argument("--prob_next", default=1/8, type=float)
parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute'])
# parser.add_argument("--random_starts", default=1, type=int)
args = parser.parse_args()

11
examples/shields/rl/shieldhandlers.py

@ -27,13 +27,14 @@ class ShieldHandler(ABC):
pass
class MiniGridShieldHandler(ShieldHandler):
def __init__(self, grid_file, grid_to_prism_path, prism_path, formula, shield_value=0.9 ,prism_config=None) -> None:
def __init__(self, grid_file, grid_to_prism_path, prism_path, formula, shield_value=0.9 ,prism_config=None, shield_comparision='relative') -> None:
self.grid_file = grid_file
self.grid_to_prism_path = grid_to_prism_path
self.prism_path = prism_path
self.formula = formula
self.prism_config = prism_config
self.shield_value = shield_value
self.shield_comparision = shield_comparision
def __export_grid_to_text(self, env):
f = open(self.grid_file, "w")
@ -58,7 +59,13 @@ class MiniGridShieldHandler(ShieldHandler):
def __create_shield_dict(self):
print(self.prism_path)
program = stormpy.parse_prism_program(self.prism_path)
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, self.shield_value)
shield_comp = stormpy.logic.ShieldComparison.RELATIVE
if self.shield_comparision == 'absolute':
shield_comp = stormpy.logic.ShieldComparison.ABSOLUTE
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comp, self.shield_value)
formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])

39
slippery_prob_08.yaml

@ -0,0 +1,39 @@
---
labels:
- label: "AgentIsInGoal"
text: "AgentIsInGoal"
constants:
- constant: "prop_slippery_turn"
type: "double"
value: "9/9"
overwrite: True
- constant: "prop_next_neighbour_turn"
type: "double"
value: "0/9"
overwrite: True
- constant: "prop_slippery_move_forward"
type: "double"
value: "4/5"
overwrite: True
- constant: "prop_direct_neighbour"
type: "double"
value: "1/5"
overwrite: True
- constant: "prop_next_neighbour"
type: "double"
value: "1/10"
overwrite: True
- constant: "total_prop"
type: "double"
value: "4"
overwrite: True
...
# const double prop_zero = 0/9;
# const double prop_next_neighbour = 1/9;
# const double prop_slippery_move_forward = 7/9;
# const double prop_slippery_turn = 6/9;
# const double prop_next_neighbour_turn = 1/9;
# const double prop_direct_neighbour = 2/9;
# const double total_prop = 9;
Loading…
Cancel
Save