You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

149 lines
5.6 KiB

11 months ago
12 months ago
12 months ago
12 months ago
12 months ago
  1. import minigrid
  2. from minigrid.core.actions import Actions
  3. from datetime import datetime
  4. from enum import Enum
  5. import os
  6. import stormpy
  7. import stormpy.core
  8. import stormpy.simulator
  9. import stormpy.shields
  10. import stormpy.logic
  11. import stormpy.examples
  12. import stormpy.examples.files
  13. class ShieldingConfig(Enum):
  14. Training = 'training'
  15. Evaluation = 'evaluation'
  16. Disabled = 'none'
  17. Full = 'full'
  18. def __str__(self) -> str:
  19. return self.value
  20. def extract_keys(env):
  21. keys = []
  22. for j in range(env.grid.height):
  23. for i in range(env.grid.width):
  24. obj = env.grid.get(i,j)
  25. if obj and obj.type == "key":
  26. keys.append((obj, i, j))
  27. if env.carrying and env.carrying.type == "key":
  28. keys.append((env.carrying, -1, -1))
  29. # TODO Maybe need to add ordering of keys so it matches the order in the shield
  30. return keys
  31. def extract_doors(env):
  32. doors = []
  33. for j in range(env.grid.height):
  34. for i in range(env.grid.width):
  35. obj = env.grid.get(i,j)
  36. if obj and obj.type == "door":
  37. doors.append(obj)
  38. return doors
  39. def extract_adversaries(env):
  40. adv = []
  41. if not hasattr(env, "adversaries"):
  42. return []
  43. for color, adversary in env.adversaries.items():
  44. adv.append(adversary)
  45. return adv
  46. def create_log_dir(args):
  47. return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparision}-env:{args.env}-conf:{args.prism_config}"
  48. def test_name(args):
  49. return F"{args.expname}"
  50. def get_action_index_mapping(actions):
  51. for action_str in actions:
  52. if not "Agent" in action_str:
  53. continue
  54. if "move" in action_str:
  55. return Actions.forward
  56. elif "left" in action_str:
  57. return Actions.left
  58. elif "right" in action_str:
  59. return Actions.right
  60. elif "pickup" in action_str:
  61. return Actions.pickup
  62. elif "done" in action_str:
  63. return Actions.done
  64. elif "drop" in action_str:
  65. return Actions.drop
  66. elif "toggle" in action_str:
  67. return Actions.toggle
  68. elif "unlock" in action_str:
  69. return Actions.toggle
  70. raise ValueError("No action mapping found")
  71. def parse_arguments(argparse):
  72. parser = argparse.ArgumentParser()
  73. # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
  74. parser.add_argument("--env",
  75. help="gym environment to load",
  76. default="MiniGrid-LavaSlipperyS12-v2",
  77. choices=[
  78. "MiniGrid-Adv-8x8-v0",
  79. "MiniGrid-AdvSimple-8x8-v0",
  80. "MiniGrid-SingleDoor-7x6-v0",
  81. "MiniGrid-LavaCrossingS9N1-v0",
  82. "MiniGrid-LavaCrossingS9N3-v0",
  83. "MiniGrid-LavaSlipperyS12-v0",
  84. "MiniGrid-LavaSlipperyS12-v1",
  85. "MiniGrid-LavaSlipperyS12-v2",
  86. "MiniGrid-LavaSlipperyS12-v3",
  87. "MiniGrid-DoorKey-8x8-v0",
  88. # "MiniGrid-DoubleDoor-16x16-v0",
  89. # "MiniGrid-DoubleDoor-12x12-v0",
  90. # "MiniGrid-DoubleDoor-10x8-v0",
  91. # "MiniGrid-LockedRoom-v0",
  92. # "MiniGrid-FourRooms-v0",
  93. # "MiniGrid-LavaGapS7-v0",
  94. # "MiniGrid-SimpleCrossingS9N3-v0",
  95. # "MiniGrid-DoorKey-16x16-v0",
  96. # "MiniGrid-Empty-Random-6x6-v0",
  97. ])
  98. # parser.add_argument("--seed", type=int, help="seed for environment", default=None)
  99. parser.add_argument("--grid_to_prism_binary_path", default="./main")
  100. parser.add_argument("--grid_path", default="grid")
  101. parser.add_argument("--prism_path", default="grid")
  102. parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"])
  103. parser.add_argument("--log_dir", default="../log_results/")
  104. parser.add_argument("--evaluations", type=int, default=30 )
  105. parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
  106. # parser.add_argument("--formula", default="<<Agent>> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]")
  107. parser.add_argument("--workers", type=int, default=1)
  108. parser.add_argument("--num_gpus", type=float, default=0)
  109. parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
  110. parser.add_argument("--steps", default=20_000, type=int)
  111. parser.add_argument("--expname", default="exp")
  112. parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
  113. parser.add_argument("--prism_config", default=None)
  114. parser.add_argument("--shield_value", default=0.9, type=float)
  115. parser.add_argument("--prob_direct", default=1/4, type=float)
  116. parser.add_argument("--prob_forward", default=3/4, type=float)
  117. parser.add_argument("--prob_next", default=1/8, type=float)
  118. parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute'])
  119. # parser.add_argument("--random_starts", default=1, type=int)
  120. args = parser.parse_args()
  121. return args