Browse Source

major refactor in utils

- introduced common_parser for arguments
- the shield dict uses minigrid.core.State instead of strings
- switched shield query to minigrid get_symbolic_state
refactoring
sp 10 months ago
parent
commit
372006a1da
  1. 60
      examples/shields/rl/sb3utils.py
  2. 304
      examples/shields/rl/utils.py

60
examples/shields/rl/sb3utils.py

@ -2,63 +2,36 @@ import gymnasium as gym
import numpy as np import numpy as np
import random import random
from utils import MiniGridShieldHandler, create_shield_query
from utils import MiniGridShieldHandler, common_parser
class MiniGridSbShieldingWrapper(gym.core.Wrapper): class MiniGridSbShieldingWrapper(gym.core.Wrapper):
def __init__(self, def __init__(self,
env, env,
shield_creator : MiniGridShieldHandler,
shield_query_creator,
shield_handler : MiniGridShieldHandler,
create_shield_at_reset = True, create_shield_at_reset = True,
mask_actions=True, mask_actions=True,
): ):
super(MiniGridSbShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
super().__init__(env)
self.observation_space = env.observation_space.spaces["image"] self.observation_space = env.observation_space.spaces["image"]
self.shield_creator = shield_creator
self.shield_handler = shield_handler
self.mask_actions = mask_actions self.mask_actions = mask_actions
self.shield_query_creator = shield_query_creator
self.create_shield_at_reset = create_shield_at_reset
def create_action_mask(self):
if not self.mask_actions:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
cur_pos_str = self.shield_query_creator(self.env)
allowed_actions = []
# Create the mask
# If shield restricts actions, mask only valid actions with 1.0
# else set all actions valid
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action.labels)
if index is None:
assert(False)
mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
else:
for index, x in enumerate(mask):
mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
return mask
shield = self.shield_handler.create_shield(env=self.env)
self.shield = shield
def create_action_mask(self):
try:
return self.shield[self.env.get_symbolic_state()]
except:
return [1.0] * 3 + [1.0] * 4
def reset(self, *, seed=None, options=None): def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options) obs, infos = self.env.reset(seed=seed, options=options)
shield = self.shield_creator.create_shield(env=self.env)
if self.create_shield_at_reset and self.mask_actions:
shield = self.shield_handler.create_shield(env=self.env)
self.shield = shield self.shield = shield
return obs["image"], infos return obs["image"], infos
@ -68,3 +41,8 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper):
return obs, rew, done, truncated, info return obs, rew, done, truncated, info
def parse_sb3_arguments():
parser = common_parser()
args = parser.parse_args()
return args

304
examples/shields/rl/utils.py

@ -11,16 +11,37 @@ import stormpy.examples.files
from enum import Enum from enum import Enum
from abc import ABC from abc import ABC
import re
import sys
from minigrid.core.actions import Actions from minigrid.core.actions import Actions
from minigrid.core.state import to_state
import os import os
import time import time
class Action():
def __init__(self, idx, prob=1, labels=[]) -> None:
self.idx = idx
self.prob = prob
self.labels = labels
import argparse
def tic():
#Homemade version of matlab tic and toc functions: https://stackoverflow.com/a/18903019
global startTime_for_tictoc
startTime_for_tictoc = time.time()
def toc():
if 'startTime_for_tictoc' in globals():
print("Elapsed time is " + str(time.time() - startTime_for_tictoc) + " seconds.")
else:
print("Toc: start time not set")
class ShieldingConfig(Enum):
Training = 'training'
Evaluation = 'evaluation'
Disabled = 'none'
Full = 'full'
def __str__(self) -> str:
return self.value
class ShieldHandler(ABC): class ShieldHandler(ABC):
def __init__(self) -> None: def __init__(self) -> None:
@ -29,14 +50,16 @@ class ShieldHandler(ABC):
pass pass
class MiniGridShieldHandler(ShieldHandler): class MiniGridShieldHandler(ShieldHandler):
def __init__(self, grid_file, grid_to_prism_path, prism_path, formula, shield_value=0.9 ,prism_config=None, shield_comparision='relative') -> None:
def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute') -> None:
self.grid_file = grid_file self.grid_file = grid_file
self.grid_to_prism_path = grid_to_prism_path
self.grid_to_prism_binary = grid_to_prism_binary
self.prism_path = prism_path self.prism_path = prism_path
self.formula = formula
self.prism_config = prism_config self.prism_config = prism_config
self.shield_value = shield_value
self.shield_comparision = shield_comparision
self.formula = formula
shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE
self.shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comparison, shield_value)
def __export_grid_to_text(self, env): def __export_grid_to_text(self, env):
f = open(self.grid_file, "w") f = open(self.grid_file, "w")
@ -45,53 +68,53 @@ class MiniGridShieldHandler(ShieldHandler):
def __create_prism(self): def __create_prism(self):
# result = os.system(F"{self.grid_to_prism_path} -v 'Agent,Blue' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
if self.prism_config is None: if self.prism_config is None:
result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path}")
result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path}")
else: else:
result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
# result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
assert result == 0, "Prism file could not be generated" assert result == 0, "Prism file could not be generated"
f = open(self.prism_path, "a")
f.close()
def __create_shield_dict(self): def __create_shield_dict(self):
print(self.prism_path)
program = stormpy.parse_prism_program(self.prism_path) program = stormpy.parse_prism_program(self.prism_path)
shield_comp = stormpy.logic.ShieldComparison.RELATIVE
if self.shield_comparision == 'absolute':
shield_comp = stormpy.logic.ShieldComparison.ABSOLUTE
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comp, self.shield_value)
formulas = stormpy.parse_properties_for_prism_program(self.formula, program) formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
options.set_build_state_valuations(True) options.set_build_state_valuations(True)
options.set_build_choice_labels(True) options.set_build_choice_labels(True)
options.set_build_all_labels() options.set_build_all_labels()
print(f"LOG: Starting with explicit model creation...")
tic()
model = stormpy.build_sparse_model_with_options(program, options) model = stormpy.build_sparse_model_with_options(program, options)
toc()
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
print(f"LOG: Starting with model checking...")
tic()
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=self.shield_expression)
toc()
assert result.has_shield assert result.has_shield
shield = result.shield shield = result.shield
action_dictionary = {}
action_dictionary = dict()
shield_scheduler = shield.construct() shield_scheduler = shield.construct()
state_valuations = model.state_valuations state_valuations = model.state_valuations
choice_labeling = model.choice_labeling choice_labeling = model.choice_labeling
stormpy.shields.export_shield(model, shield, "myshield")
#stormpy.shields.export_shield(model, shield, "current.shield")
for stateID in model.states: for stateID in model.states:
choice = shield_scheduler.get_choice(stateID) choice = shield_scheduler.get_choice(stateID)
choices = choice.choice_map choices = choice.choice_map
state_valuation = state_valuations.get_string(stateID) state_valuation = state_valuations.get_string(stateID)
actions_to_be_executed = [Action(idx= choice[1], prob=choice[0], labels=choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
action_dictionary[state_valuation] = actions_to_be_executed
ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=([a-zA-Z0-9]+)', state_valuation))
booleans = dict(re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]', state_valuation)) #TODO does not parse everything correctly?
if int(ints.get("previousActionAgent", 3)) != 3:
continue
if int(ints.get("clock", 0)) != 0:
continue
state = to_state(ints, booleans)
action_dictionary[state] = get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices])
return action_dictionary return action_dictionary
@ -103,216 +126,39 @@ class MiniGridShieldHandler(ShieldHandler):
return self.__create_shield_dict() return self.__create_shield_dict()
def create_shield_query(env):
coordinates = env.env.agent_pos
view_direction = env.env.agent_dir
keys = extract_keys(env)
doors = extract_doors(env)
adversaries = extract_adversaries(env)
if env.carrying:
agent_carrying = F"Agent_is_carrying_object\t"
else:
agent_carrying = "!Agent_is_carrying_object\t"
key_positions = []
agent_key_status = []
for key in keys:
key_color = key[0].color
key_x = key[1]
key_y = key[2]
if env.carrying and env.carrying.type == "key":
agent_key_text = F"Agent_has_{env.carrying.color}_key\t& "
key_position = F"xKey{key_color}={key_x}\t& yKey{key_color}={key_y}\t"
else:
agent_key_text = F"!Agent_has_{key_color}_key\t& "
key_position = F"xKey{key_color}={key_x}\t& yKey{key_color}={key_y}\t"
key_positions.append(key_position)
agent_key_status.append(agent_key_text)
if key_positions:
key_positions[-1] = key_positions[-1].strip()
door_status = []
for door in doors:
status = ""
if door.is_open:
status = F"!Door{door.color}locked\t& Door{door.color}open\t&"
elif door.is_locked:
status = F"Door{door.color}locked\t& !Door{door.color}open\t&"
else:
status = F"!Door{door.color}locked\t& !Door{door.color}open\t&"
door_status.append(status)
adv_status = []
adv_positions = []
for adversary in adversaries:
status = ""
position = ""
if adversary.carrying:
carrying = F"{adversary.name}_is_carrying_object\t"
else:
carrying = F"!{adversary.name}_is_carrying_object\t"
status = F"{carrying}& !{adversary.name}Done\t& "
position = F"x{adversary.name}={adversary.cur_pos[1]}\t& y{adversary.name}={adversary.cur_pos[0]}\t& view{adversary.name}={adversary.adversary_dir}"
adv_status.append(status)
adv_positions.append(position)
door_status_text = ""
if door_status:
door_status_text = F"& {''.join(door_status)}\t"
adv_status_text = ""
if adv_status:
adv_status_text = F"& {''.join(adv_status)}"
adv_positions_text = ""
if adv_positions:
adv_positions_text = F"\t& {''.join(adv_positions)}"
key_positions_text = ""
if key_positions:
key_positions_text = F"\t& {''.join(key_positions)}"
move_text = ""
if adversaries:
move_text = F"move=0\t& "
agent_position = F"& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}"
query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]"
return query
class ShieldingConfig(Enum):
Training = 'training'
Evaluation = 'evaluation'
Disabled = 'none'
Full = 'full'
def __str__(self) -> str:
return self.value
def extract_keys(env):
keys = []
for j in range(env.grid.height):
for i in range(env.grid.width):
obj = env.grid.get(i,j)
if obj and obj.type == "key":
keys.append((obj, i, j))
if env.carrying and env.carrying.type == "key":
keys.append((env.carrying, -1, -1))
# TODO Maybe need to add ordering of keys so it matches the order in the shield
return keys
def extract_doors(env):
doors = []
for j in range(env.grid.height):
for i in range(env.grid.width):
obj = env.grid.get(i,j)
if obj and obj.type == "door":
doors.append(obj)
return doors
def extract_adversaries(env):
adv = []
if not hasattr(env, "adversaries") or not env.adversaries:
return []
for color, adversary in env.adversaries.items():
adv.append(adversary)
return adv
def create_log_dir(args): def create_log_dir(args):
return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparision}-env:{args.env}-conf:{args.prism_config}"
return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparison}-env:{args.env}-conf:{args.prism_config}"
def test_name(args): def test_name(args):
return F"{args.expname}" return F"{args.expname}"
def get_action_index_mapping(actions):
for action_str in actions:
if not "Agent" in action_str:
continue
if "move" in action_str:
return Actions.forward
elif "left" in action_str:
return Actions.left
elif "right" in action_str:
return Actions.right
elif "pickup" in action_str:
return Actions.pickup
elif "done" in action_str:
return Actions.done
elif "drop" in action_str:
return Actions.drop
elif "toggle" in action_str:
return Actions.toggle
elif "unlock" in action_str:
return Actions.toggle
raise ValueError("No action mapping found")
def parse_arguments(argparse):
def get_allowed_actions_mask(actions):
action_mask = [0.0] * 3 + [1.0] * 4
actions_labels = [label for labels in actions for label in list(labels)]
for action_label in actions_labels:
if "move" in action_label:
action_mask[2] = 1.0
elif "left" in action_label:
action_mask[0] = 1.0
elif "right" in action_label:
action_mask[1] = 1.0
return action_mask
def common_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
parser.add_argument("--env", parser.add_argument("--env",
help="gym environment to load", help="gym environment to load",
default="MiniGrid-LavaSlipperyCliffS12-v2",
choices=[
"MiniGrid-Adv-8x8-v0",
"MiniGrid-AdvSimple-8x8-v0",
"MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-LavaCrossingS9N3-v0",
"MiniGrid-LavaSlipperyCliffS12-v0",
"MiniGrid-LavaFaultyS12-30-v0",
])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None)
parser.add_argument("--grid_to_prism_binary_path", default="./main")
parser.add_argument("--grid_path", default="grid")
parser.add_argument("--prism_path", default="grid")
parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"])
default="MiniGrid-LavaSlipperyCliff-16x12-v0")
parser.add_argument("--grid_file", default="grid.txt")
parser.add_argument("--prism_output_file", default="grid.prism")
parser.add_argument("--log_dir", default="../log_results/") parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--evaluations", type=int, default=30 )
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# parser.add_argument("--formula", default="<<Agent>> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]")
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--num_gpus", type=float, default=0)
parser.add_argument("--formula", default="Pmax=? [G !AgentIsOnLava]")
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
parser.add_argument("--steps", default=20_000, type=int) parser.add_argument("--steps", default=20_000, type=int)
parser.add_argument("--expname", default="exp")
parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction) parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
parser.add_argument("--prism_config", default=None) parser.add_argument("--prism_config", default=None)
parser.add_argument("--shield_value", default=0.9, type=float) parser.add_argument("--shield_value", default=0.9, type=float)
parser.add_argument("--probability_displacement", default=1/4, type=float)
parser.add_argument("--probability_intended", default=3/4, type=float)
parser.add_argument("--probability_turn_displacement", default=0/4, type=float)
parser.add_argument("--probability_turn_intended", default=4/4, type=float)
parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute'])
# parser.add_argument("--random_starts", default=1, type=int)
args = parser.parse_args()
return args
parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute'])
return parser
Loading…
Cancel
Save