Browse Source

major refactor in utils

- introduced common_parser for arguments
- the shield dict uses minigrid.core.State instead of strings
- switched shield query to minigrid get_symbolic_state
refactoring
sp 10 months ago
parent
commit
372006a1da
  1. 62
      examples/shields/rl/sb3utils.py
  2. 304
      examples/shields/rl/utils.py

62
examples/shields/rl/sb3utils.py

@ -2,64 +2,37 @@ import gymnasium as gym
import numpy as np
import random
from utils import MiniGridShieldHandler, create_shield_query
from utils import MiniGridShieldHandler, common_parser
class MiniGridSbShieldingWrapper(gym.core.Wrapper):
def __init__(self,
env,
shield_creator : MiniGridShieldHandler,
shield_query_creator,
shield_handler : MiniGridShieldHandler,
create_shield_at_reset = True,
mask_actions=True,
):
super(MiniGridSbShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
super().__init__(env)
self.observation_space = env.observation_space.spaces["image"]
self.shield_creator = shield_creator
self.shield_handler = shield_handler
self.mask_actions = mask_actions
self.shield_query_creator = shield_query_creator
self.create_shield_at_reset = create_shield_at_reset
def create_action_mask(self):
if not self.mask_actions:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
cur_pos_str = self.shield_query_creator(self.env)
allowed_actions = []
# Create the mask
# If shield restricts actions, mask only valid actions with 1.0
# else set all actions valid
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action.labels)
if index is None:
assert(False)
mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
else:
for index, x in enumerate(mask):
mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
return mask
shield = self.shield_handler.create_shield(env=self.env)
self.shield = shield
def create_action_mask(self):
try:
return self.shield[self.env.get_symbolic_state()]
except:
return [1.0] * 3 + [1.0] * 4
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
shield = self.shield_creator.create_shield(env=self.env)
self.shield = shield
if self.create_shield_at_reset and self.mask_actions:
shield = self.shield_handler.create_shield(env=self.env)
self.shield = shield
return obs["image"], infos
def step(self, action):
@ -68,3 +41,8 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper):
return obs, rew, done, truncated, info
def parse_sb3_arguments():
parser = common_parser()
args = parser.parse_args()
return args

304
examples/shields/rl/utils.py

@ -11,16 +11,37 @@ import stormpy.examples.files
from enum import Enum
from abc import ABC
import re
import sys
from minigrid.core.actions import Actions
from minigrid.core.state import to_state
import os
import time
class Action():
def __init__(self, idx, prob=1, labels=[]) -> None:
self.idx = idx
self.prob = prob
self.labels = labels
import argparse
def tic():
#Homemade version of matlab tic and toc functions: https://stackoverflow.com/a/18903019
global startTime_for_tictoc
startTime_for_tictoc = time.time()
def toc():
if 'startTime_for_tictoc' in globals():
print("Elapsed time is " + str(time.time() - startTime_for_tictoc) + " seconds.")
else:
print("Toc: start time not set")
class ShieldingConfig(Enum):
Training = 'training'
Evaluation = 'evaluation'
Disabled = 'none'
Full = 'full'
def __str__(self) -> str:
return self.value
class ShieldHandler(ABC):
def __init__(self) -> None:
@ -29,14 +50,16 @@ class ShieldHandler(ABC):
pass
class MiniGridShieldHandler(ShieldHandler):
def __init__(self, grid_file, grid_to_prism_path, prism_path, formula, shield_value=0.9 ,prism_config=None, shield_comparision='relative') -> None:
def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute') -> None:
self.grid_file = grid_file
self.grid_to_prism_path = grid_to_prism_path
self.grid_to_prism_binary = grid_to_prism_binary
self.prism_path = prism_path
self.formula = formula
self.prism_config = prism_config
self.shield_value = shield_value
self.shield_comparision = shield_comparision
self.formula = formula
shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE
self.shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comparison, shield_value)
def __export_grid_to_text(self, env):
f = open(self.grid_file, "w")
@ -45,53 +68,53 @@ class MiniGridShieldHandler(ShieldHandler):
def __create_prism(self):
# result = os.system(F"{self.grid_to_prism_path} -v 'Agent,Blue' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
if self.prism_config is None:
result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path}")
result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path}")
else:
result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
# result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
assert result == 0, "Prism file could not be generated"
f = open(self.prism_path, "a")
f.close()
def __create_shield_dict(self):
print(self.prism_path)
program = stormpy.parse_prism_program(self.prism_path)
shield_comp = stormpy.logic.ShieldComparison.RELATIVE
if self.shield_comparision == 'absolute':
shield_comp = stormpy.logic.ShieldComparison.ABSOLUTE
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comp, self.shield_value)
formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
options.set_build_state_valuations(True)
options.set_build_choice_labels(True)
options.set_build_all_labels()
print(f"LOG: Starting with explicit model creation...")
tic()
model = stormpy.build_sparse_model_with_options(program, options)
toc()
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
print(f"LOG: Starting with model checking...")
tic()
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=self.shield_expression)
toc()
assert result.has_shield
shield = result.shield
action_dictionary = {}
action_dictionary = dict()
shield_scheduler = shield.construct()
state_valuations = model.state_valuations
choice_labeling = model.choice_labeling
stormpy.shields.export_shield(model, shield, "myshield")
#stormpy.shields.export_shield(model, shield, "current.shield")
for stateID in model.states:
choice = shield_scheduler.get_choice(stateID)
choices = choice.choice_map
state_valuation = state_valuations.get_string(stateID)
actions_to_be_executed = [Action(idx= choice[1], prob=choice[0], labels=choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
action_dictionary[state_valuation] = actions_to_be_executed
ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=([a-zA-Z0-9]+)', state_valuation))
booleans = dict(re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]', state_valuation)) #TODO does not parse everything correctly?
if int(ints.get("previousActionAgent", 3)) != 3:
continue
if int(ints.get("clock", 0)) != 0:
continue
state = to_state(ints, booleans)
action_dictionary[state] = get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices])
return action_dictionary
@ -103,216 +126,39 @@ class MiniGridShieldHandler(ShieldHandler):
return self.__create_shield_dict()
def create_shield_query(env):
coordinates = env.env.agent_pos
view_direction = env.env.agent_dir
keys = extract_keys(env)
doors = extract_doors(env)
adversaries = extract_adversaries(env)
if env.carrying:
agent_carrying = F"Agent_is_carrying_object\t"
else:
agent_carrying = "!Agent_is_carrying_object\t"
key_positions = []
agent_key_status = []
for key in keys:
key_color = key[0].color
key_x = key[1]
key_y = key[2]
if env.carrying and env.carrying.type == "key":
agent_key_text = F"Agent_has_{env.carrying.color}_key\t& "
key_position = F"xKey{key_color}={key_x}\t& yKey{key_color}={key_y}\t"
else:
agent_key_text = F"!Agent_has_{key_color}_key\t& "
key_position = F"xKey{key_color}={key_x}\t& yKey{key_color}={key_y}\t"
key_positions.append(key_position)
agent_key_status.append(agent_key_text)
if key_positions:
key_positions[-1] = key_positions[-1].strip()
door_status = []
for door in doors:
status = ""
if door.is_open:
status = F"!Door{door.color}locked\t& Door{door.color}open\t&"
elif door.is_locked:
status = F"Door{door.color}locked\t& !Door{door.color}open\t&"
else:
status = F"!Door{door.color}locked\t& !Door{door.color}open\t&"
door_status.append(status)
adv_status = []
adv_positions = []
for adversary in adversaries:
status = ""
position = ""
if adversary.carrying:
carrying = F"{adversary.name}_is_carrying_object\t"
else:
carrying = F"!{adversary.name}_is_carrying_object\t"
status = F"{carrying}& !{adversary.name}Done\t& "
position = F"x{adversary.name}={adversary.cur_pos[1]}\t& y{adversary.name}={adversary.cur_pos[0]}\t& view{adversary.name}={adversary.adversary_dir}"
adv_status.append(status)
adv_positions.append(position)
door_status_text = ""
if door_status:
door_status_text = F"& {''.join(door_status)}\t"
adv_status_text = ""
if adv_status:
adv_status_text = F"& {''.join(adv_status)}"
adv_positions_text = ""
if adv_positions:
adv_positions_text = F"\t& {''.join(adv_positions)}"
key_positions_text = ""
if key_positions:
key_positions_text = F"\t& {''.join(key_positions)}"
move_text = ""
if adversaries:
move_text = F"move=0\t& "
agent_position = F"& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}"
query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]"
return query
class ShieldingConfig(Enum):
Training = 'training'
Evaluation = 'evaluation'
Disabled = 'none'
Full = 'full'
def __str__(self) -> str:
return self.value
def extract_keys(env):
keys = []
for j in range(env.grid.height):
for i in range(env.grid.width):
obj = env.grid.get(i,j)
if obj and obj.type == "key":
keys.append((obj, i, j))
if env.carrying and env.carrying.type == "key":
keys.append((env.carrying, -1, -1))
# TODO Maybe need to add ordering of keys so it matches the order in the shield
return keys
def extract_doors(env):
doors = []
for j in range(env.grid.height):
for i in range(env.grid.width):
obj = env.grid.get(i,j)
if obj and obj.type == "door":
doors.append(obj)
return doors
def extract_adversaries(env):
adv = []
if not hasattr(env, "adversaries") or not env.adversaries:
return []
for color, adversary in env.adversaries.items():
adv.append(adversary)
return adv
def create_log_dir(args):
return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparision}-env:{args.env}-conf:{args.prism_config}"
return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparison}-env:{args.env}-conf:{args.prism_config}"
def test_name(args):
return F"{args.expname}"
def get_action_index_mapping(actions):
for action_str in actions:
if not "Agent" in action_str:
continue
if "move" in action_str:
return Actions.forward
elif "left" in action_str:
return Actions.left
elif "right" in action_str:
return Actions.right
elif "pickup" in action_str:
return Actions.pickup
elif "done" in action_str:
return Actions.done
elif "drop" in action_str:
return Actions.drop
elif "toggle" in action_str:
return Actions.toggle
elif "unlock" in action_str:
return Actions.toggle
raise ValueError("No action mapping found")
def parse_arguments(argparse):
def get_allowed_actions_mask(actions):
action_mask = [0.0] * 3 + [1.0] * 4
actions_labels = [label for labels in actions for label in list(labels)]
for action_label in actions_labels:
if "move" in action_label:
action_mask[2] = 1.0
elif "left" in action_label:
action_mask[0] = 1.0
elif "right" in action_label:
action_mask[1] = 1.0
return action_mask
def common_parser():
parser = argparse.ArgumentParser()
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
parser.add_argument("--env",
help="gym environment to load",
default="MiniGrid-LavaSlipperyCliffS12-v2",
choices=[
"MiniGrid-Adv-8x8-v0",
"MiniGrid-AdvSimple-8x8-v0",
"MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-LavaCrossingS9N3-v0",
"MiniGrid-LavaSlipperyCliffS12-v0",
"MiniGrid-LavaFaultyS12-30-v0",
])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None)
parser.add_argument("--grid_to_prism_binary_path", default="./main")
parser.add_argument("--grid_path", default="grid")
parser.add_argument("--prism_path", default="grid")
parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"])
default="MiniGrid-LavaSlipperyCliff-16x12-v0")
parser.add_argument("--grid_file", default="grid.txt")
parser.add_argument("--prism_output_file", default="grid.prism")
parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--evaluations", type=int, default=30 )
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# parser.add_argument("--formula", default="<<Agent>> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]")
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--num_gpus", type=float, default=0)
parser.add_argument("--formula", default="Pmax=? [G !AgentIsOnLava]")
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
parser.add_argument("--steps", default=20_000, type=int)
parser.add_argument("--expname", default="exp")
parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
parser.add_argument("--prism_config", default=None)
parser.add_argument("--shield_value", default=0.9, type=float)
parser.add_argument("--probability_displacement", default=1/4, type=float)
parser.add_argument("--probability_intended", default=3/4, type=float)
parser.add_argument("--probability_turn_displacement", default=0/4, type=float)
parser.add_argument("--probability_turn_intended", default=4/4, type=float)
parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute'])
# parser.add_argument("--random_starts", default=1, type=int)
args = parser.parse_args()
return args
parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute'])
return parser
Loading…
Cancel
Save