You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

195 lines
6.8 KiB

12 months ago
12 months ago
12 months ago
  1. import stormpy
  2. import stormpy.core
  3. import stormpy.simulator
  4. import stormpy.shields
  5. import stormpy.logic
  6. import stormpy.examples
  7. import stormpy.examples.files
  8. from helpers import extract_doors, extract_keys, extract_adversaries
  9. from abc import ABC
  10. import os
  11. import time
  12. class Action():
  13. def __init__(self, idx, prob=1, labels=[]) -> None:
  14. self.idx = idx
  15. self.prob = prob
  16. self.labels = labels
  17. class ShieldHandler(ABC):
  18. def __init__(self) -> None:
  19. pass
  20. def create_shield(self, **kwargs) -> dict:
  21. pass
  22. class MiniGridShieldHandler(ShieldHandler):
  23. def __init__(self, grid_file, grid_to_prism_path, prism_path, formula, shield_value=0.9 ,prism_config=None, shield_comparision='relative') -> None:
  24. self.grid_file = grid_file
  25. self.grid_to_prism_path = grid_to_prism_path
  26. self.prism_path = prism_path
  27. self.formula = formula
  28. self.prism_config = prism_config
  29. self.shield_value = shield_value
  30. self.shield_comparision = shield_comparision
  31. def __export_grid_to_text(self, env):
  32. f = open(self.grid_file, "w")
  33. f.write(env.printGrid(init=True))
  34. f.close()
  35. def __create_prism(self):
  36. # result = os.system(F"{self.grid_to_prism_path} -v 'Agent,Blue' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
  37. if self.prism_config is None:
  38. result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path}")
  39. else:
  40. result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
  41. # result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
  42. assert result == 0, "Prism file could not be generated"
  43. f = open(self.prism_path, "a")
  44. f.write("label \"AgentIsInLava\" = AgentIsInLava;")
  45. f.close()
  46. def __create_shield_dict(self):
  47. print(self.prism_path)
  48. program = stormpy.parse_prism_program(self.prism_path)
  49. shield_comp = stormpy.logic.ShieldComparison.RELATIVE
  50. if self.shield_comparision == 'absolute':
  51. shield_comp = stormpy.logic.ShieldComparison.ABSOLUTE
  52. shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comp, self.shield_value)
  53. formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
  54. options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
  55. options.set_build_state_valuations(True)
  56. options.set_build_choice_labels(True)
  57. options.set_build_all_labels()
  58. model = stormpy.build_sparse_model_with_options(program, options)
  59. result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
  60. assert result.has_shield
  61. shield = result.shield
  62. stormpy.shields.export_shield(model, shield, "Grid.shield")
  63. action_dictionary = {}
  64. shield_scheduler = shield.construct()
  65. state_valuations = model.state_valuations
  66. choice_labeling = model.choice_labeling
  67. for stateID in model.states:
  68. choice = shield_scheduler.get_choice(stateID)
  69. choices = choice.choice_map
  70. state_valuation = state_valuations.get_string(stateID)
  71. actions_to_be_executed = [Action(idx= choice[1], prob=choice[0], labels=choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
  72. action_dictionary[state_valuation] = actions_to_be_executed
  73. return action_dictionary
  74. def create_shield(self, **kwargs):
  75. env = kwargs["env"]
  76. self.__export_grid_to_text(env)
  77. self.__create_prism()
  78. return self.__create_shield_dict()
  79. def create_shield_query(env):
  80. coordinates = env.env.agent_pos
  81. view_direction = env.env.agent_dir
  82. keys = extract_keys(env)
  83. doors = extract_doors(env)
  84. adversaries = extract_adversaries(env)
  85. if env.carrying:
  86. agent_carrying = F"Agent_is_carrying_object\t"
  87. else:
  88. agent_carrying = "!Agent_is_carrying_object\t"
  89. key_positions = []
  90. agent_key_status = []
  91. for key in keys:
  92. key_color = key[0].color
  93. key_x = key[1]
  94. key_y = key[2]
  95. if env.carrying and env.carrying.type == "key":
  96. agent_key_text = F"Agent_has_{env.carrying.color}_key\t& "
  97. key_position = F"xKey{key_color}={key_x}\t& yKey{key_color}={key_y}\t"
  98. else:
  99. agent_key_text = F"!Agent_has_{key_color}_key\t& "
  100. key_position = F"xKey{key_color}={key_x}\t& yKey{key_color}={key_y}\t"
  101. key_positions.append(key_position)
  102. agent_key_status.append(agent_key_text)
  103. if key_positions:
  104. key_positions[-1] = key_positions[-1].strip()
  105. door_status = []
  106. for door in doors:
  107. status = ""
  108. if door.is_open:
  109. status = F"!Door{door.color}locked\t& Door{door.color}open\t&"
  110. elif door.is_locked:
  111. status = F"Door{door.color}locked\t& !Door{door.color}open\t&"
  112. else:
  113. status = F"!Door{door.color}locked\t& !Door{door.color}open\t&"
  114. door_status.append(status)
  115. adv_status = []
  116. adv_positions = []
  117. for adversary in adversaries:
  118. status = ""
  119. position = ""
  120. if adversary.carrying:
  121. carrying = F"{adversary.name}_is_carrying_object\t"
  122. else:
  123. carrying = F"!{adversary.name}_is_carrying_object\t"
  124. status = F"{carrying}& !{adversary.name}Done\t& "
  125. position = F"x{adversary.name}={adversary.cur_pos[1]}\t& y{adversary.name}={adversary.cur_pos[0]}\t& view{adversary.name}={adversary.adversary_dir}"
  126. adv_status.append(status)
  127. adv_positions.append(position)
  128. door_status_text = ""
  129. if door_status:
  130. door_status_text = F"& {''.join(door_status)}\t"
  131. adv_status_text = ""
  132. if adv_status:
  133. adv_status_text = F"& {''.join(adv_status)}"
  134. adv_positions_text = ""
  135. if adv_positions:
  136. adv_positions_text = F"\t& {''.join(adv_positions)}"
  137. key_positions_text = ""
  138. if key_positions:
  139. key_positions_text = F"\t& {''.join(key_positions)}"
  140. move_text = ""
  141. if adversaries:
  142. move_text = F"move=0\t& "
  143. agent_position = F"& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}"
  144. query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]"
  145. return query