You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

320 lines
11 KiB

import stormpy
import stormpy.core
import stormpy.simulator
import stormpy.shields
import stormpy.logic
import stormpy.examples
import stormpy.examples.files
from enum import Enum
from abc import ABC
from minigrid.core.actions import Actions
import os
import time
class Action():
def __init__(self, idx, prob=1, labels=[]) -> None:
self.idx = idx
self.prob = prob
self.labels = labels
class ShieldHandler(ABC):
def __init__(self) -> None:
pass
def create_shield(self, **kwargs) -> dict:
pass
class MiniGridShieldHandler(ShieldHandler):
def __init__(self, grid_file, grid_to_prism_path, prism_path, formula, shield_value=0.9 ,prism_config=None, shield_comparision='relative') -> None:
self.grid_file = grid_file
self.grid_to_prism_path = grid_to_prism_path
self.prism_path = prism_path
self.formula = formula
self.prism_config = prism_config
self.shield_value = shield_value
self.shield_comparision = shield_comparision
def __export_grid_to_text(self, env):
f = open(self.grid_file, "w")
f.write(env.printGrid(init=True))
f.close()
def __create_prism(self):
# result = os.system(F"{self.grid_to_prism_path} -v 'Agent,Blue' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
if self.prism_config is None:
result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path}")
else:
result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
# result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
assert result == 0, "Prism file could not be generated"
f = open(self.prism_path, "a")
f.write("label \"AgentIsInLava\" = AgentIsInLava;")
f.close()
def __create_shield_dict(self):
print(self.prism_path)
program = stormpy.parse_prism_program(self.prism_path)
shield_comp = stormpy.logic.ShieldComparison.RELATIVE
if self.shield_comparision == 'absolute':
shield_comp = stormpy.logic.ShieldComparison.ABSOLUTE
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comp, self.shield_value)
formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
options.set_build_state_valuations(True)
options.set_build_choice_labels(True)
options.set_build_all_labels()
model = stormpy.build_sparse_model_with_options(program, options)
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
assert result.has_shield
shield = result.shield
action_dictionary = {}
shield_scheduler = shield.construct()
state_valuations = model.state_valuations
choice_labeling = model.choice_labeling
stormpy.shields.export_shield(model, shield, "myshield")
for stateID in model.states:
choice = shield_scheduler.get_choice(stateID)
choices = choice.choice_map
state_valuation = state_valuations.get_string(stateID)
actions_to_be_executed = [Action(idx= choice[1], prob=choice[0], labels=choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
action_dictionary[state_valuation] = actions_to_be_executed
return action_dictionary
def create_shield(self, **kwargs):
env = kwargs["env"]
self.__export_grid_to_text(env)
self.__create_prism()
return self.__create_shield_dict()
def create_shield_query(env):
coordinates = env.env.agent_pos
view_direction = env.env.agent_dir
keys = extract_keys(env)
doors = extract_doors(env)
adversaries = extract_adversaries(env)
if env.carrying:
agent_carrying = F"Agent_is_carrying_object\t"
else:
agent_carrying = "!Agent_is_carrying_object\t"
key_positions = []
agent_key_status = []
for key in keys:
key_color = key[0].color
key_x = key[1]
key_y = key[2]
if env.carrying and env.carrying.type == "key":
agent_key_text = F"Agent_has_{env.carrying.color}_key\t& "
key_position = F"xKey{key_color}={key_x}\t& yKey{key_color}={key_y}\t"
else:
agent_key_text = F"!Agent_has_{key_color}_key\t& "
key_position = F"xKey{key_color}={key_x}\t& yKey{key_color}={key_y}\t"
key_positions.append(key_position)
agent_key_status.append(agent_key_text)
if key_positions:
key_positions[-1] = key_positions[-1].strip()
door_status = []
for door in doors:
status = ""
if door.is_open:
status = F"!Door{door.color}locked\t& Door{door.color}open\t&"
elif door.is_locked:
status = F"Door{door.color}locked\t& !Door{door.color}open\t&"
else:
status = F"!Door{door.color}locked\t& !Door{door.color}open\t&"
door_status.append(status)
adv_status = []
adv_positions = []
for adversary in adversaries:
status = ""
position = ""
if adversary.carrying:
carrying = F"{adversary.name}_is_carrying_object\t"
else:
carrying = F"!{adversary.name}_is_carrying_object\t"
status = F"{carrying}& !{adversary.name}Done\t& "
position = F"x{adversary.name}={adversary.cur_pos[1]}\t& y{adversary.name}={adversary.cur_pos[0]}\t& view{adversary.name}={adversary.adversary_dir}"
adv_status.append(status)
adv_positions.append(position)
door_status_text = ""
if door_status:
door_status_text = F"& {''.join(door_status)}\t"
adv_status_text = ""
if adv_status:
adv_status_text = F"& {''.join(adv_status)}"
adv_positions_text = ""
if adv_positions:
adv_positions_text = F"\t& {''.join(adv_positions)}"
key_positions_text = ""
if key_positions:
key_positions_text = F"\t& {''.join(key_positions)}"
move_text = ""
if adversaries:
move_text = F"move=0\t& "
agent_position = F"& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}"
query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]"
return query
class ShieldingConfig(Enum):
Training = 'training'
Evaluation = 'evaluation'
Disabled = 'none'
Full = 'full'
def __str__(self) -> str:
return self.value
def extract_keys(env):
keys = []
for j in range(env.grid.height):
for i in range(env.grid.width):
obj = env.grid.get(i,j)
if obj and obj.type == "key":
keys.append((obj, i, j))
if env.carrying and env.carrying.type == "key":
keys.append((env.carrying, -1, -1))
# TODO Maybe need to add ordering of keys so it matches the order in the shield
return keys
def extract_doors(env):
doors = []
for j in range(env.grid.height):
for i in range(env.grid.width):
obj = env.grid.get(i,j)
if obj and obj.type == "door":
doors.append(obj)
return doors
def extract_adversaries(env):
adv = []
if not hasattr(env, "adversaries"):
return []
for color, adversary in env.adversaries.items():
adv.append(adversary)
return adv
def create_log_dir(args):
return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparision}-env:{args.env}-conf:{args.prism_config}"
def test_name(args):
return F"{args.expname}"
def get_action_index_mapping(actions):
for action_str in actions:
if not "Agent" in action_str:
continue
if "move" in action_str:
return Actions.forward
elif "left" in action_str:
return Actions.left
elif "right" in action_str:
return Actions.right
elif "pickup" in action_str:
return Actions.pickup
elif "done" in action_str:
return Actions.done
elif "drop" in action_str:
return Actions.drop
elif "toggle" in action_str:
return Actions.toggle
elif "unlock" in action_str:
return Actions.toggle
raise ValueError("No action mapping found")
def parse_arguments(argparse):
parser = argparse.ArgumentParser()
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
parser.add_argument("--env",
help="gym environment to load",
default="MiniGrid-LavaSlipperyS12-v2",
choices=[
"MiniGrid-Adv-8x8-v0",
"MiniGrid-AdvSimple-8x8-v0",
"MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-LavaCrossingS9N3-v0",
"MiniGrid-LavaSlipperyS12-v0",
"MiniGrid-LavaSlipperyS12-v1",
"MiniGrid-LavaSlipperyS12-v2",
"MiniGrid-LavaSlipperyS12-v3",
])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None)
parser.add_argument("--grid_to_prism_binary_path", default="./main")
parser.add_argument("--grid_path", default="grid")
parser.add_argument("--prism_path", default="grid")
parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"])
parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--evaluations", type=int, default=30 )
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# parser.add_argument("--formula", default="<<Agent>> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]")
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--num_gpus", type=float, default=0)
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
parser.add_argument("--steps", default=20_000, type=int)
parser.add_argument("--expname", default="exp")
parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
parser.add_argument("--prism_config", default=None)
parser.add_argument("--shield_value", default=0.9, type=float)
parser.add_argument("--probability_displacement", default=1/4, type=float)
parser.add_argument("--probability_intended", default=3/4, type=float)
parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute'])
# parser.add_argument("--random_starts", default=1, type=int)
args = parser.parse_args()
return args