Browse Source

added config / some adversary fixes

refactoring
Thomas Knoll 1 year ago
parent
commit
afc9f5bc4d
  1. 12
      adv_config.yaml
  2. 25
      examples/shields/rl/callbacks.py
  3. 12
      examples/shields/rl/helpers.py
  4. 11
      examples/shields/rl/shieldhandlers.py
  5. 6
      examples/shields/rl/wrappers.py

12
adv_config.yaml

@ -0,0 +1,12 @@
---
labels:
- label: "AgentIsInGoal"
text: "AgentIsInGoal"
- label: "AgentRanIntoAdversary"
text: "AgentRanIntoAdversary"
formulas:
- formula: "AgentRanIntoAdversary"
content: "(xAgent=xBlue) & (yAgent=yBlue)"
...

25
examples/shields/rl/callbacks.py

@ -1,8 +1,9 @@
from typing import Dict
from typing import Dict, Optional
from ray.rllib.env.env_context import EnvContext
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID
from ray.rllib.utils.typing import EnvType, PolicyID
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.env.base_env import BaseEnv
@ -12,6 +13,10 @@ from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks
import matplotlib.pyplot as plt
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None:
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
@ -19,8 +24,11 @@ class MyCallbacks(DefaultCallbacks):
episode.user_data["count"] = 0
episode.user_data["ran_into_lava"] = []
episode.user_data["goals_reached"] = []
episode.user_data["ran_into_adversary"] = []
episode.hist_data["ran_into_lava"] = []
episode.hist_data["goals_reached"] = []
episode.hist_data["ran_into_adversary"] = []
# print("On episode start print")
# print(env.printGrid())
# print(worker)
@ -28,7 +36,6 @@ class MyCallbacks(DefaultCallbacks):
# print(env.actions)
# print(env.mission)
# print(env.observation_space)
# img = env.get_frame()
# plt.imshow(img)
# plt.show()
@ -42,16 +49,26 @@ class MyCallbacks(DefaultCallbacks):
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1])
ran_into_adversary = False
if hasattr(env, "adversaries"):
adversaries = env.adversaries.values()
for adversary in adversaries:
if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]:
ran_into_adversary = True
break
episode.user_data["goals_reached"].append(agent_tile is not None and agent_tile.type == "goal")
episode.user_data["ran_into_lava"].append(agent_tile is not None and agent_tile.type == "lava")
episode.user_data["ran_into_adversary"].append(ran_into_adversary)
episode.custom_metrics["reached_goal"] = agent_tile is not None and agent_tile.type == "goal"
episode.custom_metrics["ran_into_lava"] = agent_tile is not None and agent_tile.type == "lava"
episode.custom_metrics["ran_into_adversary"] = ran_into_adversary
#print("On episode end print")
# print(env.printGrid())
episode.hist_data["goals_reached"] = episode.user_data["goals_reached"]
episode.hist_data["ran_into_lava"] = episode.user_data["ran_into_lava"]
episode.hist_data["ran_into_adversary"] = episode.user_data["ran_into_adversary"]
def on_evaluate_start(self, *, algorithm: Algorithm, **kwargs) -> None:
print("Evaluate Start")

12
examples/shields/rl/helpers.py

@ -71,6 +71,9 @@ def test_name(args):
def get_action_index_mapping(actions):
for action_str in actions:
if not "Agent" in action_str:
continue
if "move" in action_str:
return Actions.forward
elif "left" in action_str:
@ -88,8 +91,7 @@ def get_action_index_mapping(actions):
elif "unlock" in action_str:
return Actions.toggle
return Actions.done
raise ValueError("No action mapping found")
def parse_arguments(argparse):
@ -100,6 +102,8 @@ def parse_arguments(argparse):
default="MiniGrid-LavaCrossingS9N1-v0",
choices=[
"MiniGrid-Adv-8x8-v0",
"MiniGrid-AdvSimple-8x8-v0",
"MiniGrid-SingleDoor-7x6-v0",
"MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-LavaCrossingS9N3-v0",
"MiniGrid-LavaSlipperyS12-v0",
@ -110,7 +114,6 @@ def parse_arguments(argparse):
# "MiniGrid-DoubleDoor-16x16-v0",
# "MiniGrid-DoubleDoor-12x12-v0",
# "MiniGrid-DoubleDoor-10x8-v0",
# "MiniGrid-SingleDoor-7x6-v0",
# "MiniGrid-LockedRoom-v0",
# "MiniGrid-FourRooms-v0",
# "MiniGrid-LavaGapS7-v0",
@ -126,7 +129,8 @@ def parse_arguments(argparse):
parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"])
parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--evaluations", type=int, default=10 )
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
parser.add_argument("--formula", default="Pmax=? [G !\"AgentRanIntoAdversary\"]")
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
parser.add_argument("--steps", default=20_000, type=int)

11
examples/shields/rl/shieldhandlers.py

@ -40,13 +40,12 @@ class MiniGridShieldHandler(ShieldHandler):
def __create_prism(self):
result = os.system(F"{self.grid_to_prism_path} -v 'Agent,Blue' -i {self.grid_file} -o {self.prism_path}")
result = os.system(F"{self.grid_to_prism_path} -v 'Agent,Blue' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
assert result == 0, "Prism file could not be generated"
f = open(self.prism_path, "a")
f.write("label \"AgentIsInLava\" = AgentIsInLava;")
f.write("label \"AgentIsInGoal\" = AgentIsInGoal;")
f.close()
def __create_shield_dict(self):
@ -63,7 +62,6 @@ class MiniGridShieldHandler(ShieldHandler):
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
assert result.has_scheduler
assert result.has_shield
shield = result.shield
stormpy.shields.export_shield(model, shield, "Grid.shield")
@ -172,8 +170,13 @@ def create_shield_query(env):
if key_positions:
key_positions_text = F"\t& {''.join(key_positions)}"
move_text = ""
if adversaries:
move_text = F"move=0\t& "
agent_position = F"xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}"
query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]"
query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]"
return query

6
examples/shields/rl/wrappers.py

@ -110,7 +110,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
cur_pos_str = self.shield_query_creator(self.env)
# print(F"Pos string {cur_pos_str}")
# print(F"Shield {list(self.shield.keys())[0]}")
# print(cur_pos_str in self.shield)
# print(F"Is pos str in shield: {cur_pos_str in self.shield}")
# Create the mask
# If shield restricts action mask only valid with 1.0
# else set all actions as valid
@ -127,6 +127,8 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
assert(False)
allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
if allowed_action.prob == 0 and allowed:
assert False
mask[index] = allowed
else:
@ -141,7 +143,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
# print(F"Mask is {mask} State: {cur_pos_str}")
return mask
def reset(self, *, seed=None, options=None):

Loading…
Cancel
Save