Browse Source

cleanups

refactoring
Thomas Knoll 1 year ago
parent
commit
1cbaac75cb
  1. 23
      examples/shields/rl/callbacks.py
  2. 149
      examples/shields/rl/helpers.py
  3. 5
      examples/shields/rl/rllibutils.py
  4. 4
      examples/shields/rl/sb3utils.py
  5. 8
      examples/shields/rl/utils.py
  6. 252
      examples/shields/rl/wrappers.py

23
examples/shields/rl/callbacks.py

@ -15,11 +15,9 @@ from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callback
import matplotlib.pyplot as plt
import tensorflow as tf
class MyCallbacks(DefaultCallbacks):
class CustomCallback(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None:
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
episode.user_data["count"] = 0
episode.user_data["ran_into_lava"] = []
@ -29,28 +27,11 @@ class MyCallbacks(DefaultCallbacks):
episode.hist_data["goals_reached"] = []
episode.hist_data["ran_into_adversary"] = []
# print("On episode start print")
# print(env.printGrid())
# print(worker)
# print(env.action_space.n)
# print(env.actions)
# print(env.mission)
# print(env.observation_space)
# plt.imshow(img)
# plt.show()
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0]
# print(env.printGrid())
if hasattr(env, "adversaries"):
for adversary in env.adversaries.values():
if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]:
print(F"Adversary ran into agent. Adversary {adversary.cur_pos}, Agent {env.agent_pos}")
# assert False
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None:

149
examples/shields/rl/helpers.py

@ -1,149 +0,0 @@
import minigrid
from minigrid.core.actions import Actions
from datetime import datetime
from enum import Enum
import os
import stormpy
import stormpy.core
import stormpy.simulator
import stormpy.shields
import stormpy.logic
import stormpy.examples
import stormpy.examples.files
class ShieldingConfig(Enum):
Training = 'training'
Evaluation = 'evaluation'
Disabled = 'none'
Full = 'full'
def __str__(self) -> str:
return self.value
def extract_keys(env):
keys = []
for j in range(env.grid.height):
for i in range(env.grid.width):
obj = env.grid.get(i,j)
if obj and obj.type == "key":
keys.append((obj, i, j))
if env.carrying and env.carrying.type == "key":
keys.append((env.carrying, -1, -1))
# TODO Maybe need to add ordering of keys so it matches the order in the shield
return keys
def extract_doors(env):
doors = []
for j in range(env.grid.height):
for i in range(env.grid.width):
obj = env.grid.get(i,j)
if obj and obj.type == "door":
doors.append(obj)
return doors
def extract_adversaries(env):
adv = []
if not hasattr(env, "adversaries"):
return []
for color, adversary in env.adversaries.items():
adv.append(adversary)
return adv
def create_log_dir(args):
return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparision}-env:{args.env}-conf:{args.prism_config}"
def test_name(args):
return F"{args.expname}"
def get_action_index_mapping(actions):
for action_str in actions:
if not "Agent" in action_str:
continue
if "move" in action_str:
return Actions.forward
elif "left" in action_str:
return Actions.left
elif "right" in action_str:
return Actions.right
elif "pickup" in action_str:
return Actions.pickup
elif "done" in action_str:
return Actions.done
elif "drop" in action_str:
return Actions.drop
elif "toggle" in action_str:
return Actions.toggle
elif "unlock" in action_str:
return Actions.toggle
raise ValueError("No action mapping found")
def parse_arguments(argparse):
parser = argparse.ArgumentParser()
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
parser.add_argument("--env",
help="gym environment to load",
default="MiniGrid-LavaSlipperyS12-v2",
choices=[
"MiniGrid-Adv-8x8-v0",
"MiniGrid-AdvSimple-8x8-v0",
"MiniGrid-SingleDoor-7x6-v0",
"MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-LavaCrossingS9N3-v0",
"MiniGrid-LavaSlipperyS12-v0",
"MiniGrid-LavaSlipperyS12-v1",
"MiniGrid-LavaSlipperyS12-v2",
"MiniGrid-LavaSlipperyS12-v3",
"MiniGrid-DoorKey-8x8-v0",
# "MiniGrid-DoubleDoor-16x16-v0",
# "MiniGrid-DoubleDoor-12x12-v0",
# "MiniGrid-DoubleDoor-10x8-v0",
# "MiniGrid-LockedRoom-v0",
# "MiniGrid-FourRooms-v0",
# "MiniGrid-LavaGapS7-v0",
# "MiniGrid-SimpleCrossingS9N3-v0",
# "MiniGrid-DoorKey-16x16-v0",
# "MiniGrid-Empty-Random-6x6-v0",
])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None)
parser.add_argument("--grid_to_prism_binary_path", default="./main")
parser.add_argument("--grid_path", default="grid")
parser.add_argument("--prism_path", default="grid")
parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"])
parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--evaluations", type=int, default=30 )
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# parser.add_argument("--formula", default="<<Agent>> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]")
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--num_gpus", type=float, default=0)
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
parser.add_argument("--steps", default=20_000, type=int)
parser.add_argument("--expname", default="exp")
parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
parser.add_argument("--prism_config", default=None)
parser.add_argument("--shield_value", default=0.9, type=float)
parser.add_argument("--prob_direct", default=1/4, type=float)
parser.add_argument("--prob_forward", default=3/4, type=float)
parser.add_argument("--prob_next", default=1/8, type=float)
parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute'])
# parser.add_argument("--random_starts", default=1, type=int)
args = parser.parse_args()
return args

5
examples/shields/rl/rllibutils.py

@ -9,8 +9,7 @@ from gymnasium.spaces import Dict, Box
from collections import deque
from ray.rllib.utils.numpy import one_hot
from helpers import get_action_index_mapping
from shieldhandlers import ShieldHandler
from utils import get_action_index_mapping, MiniGridShieldHandler, create_shield_query, ShieldingConfig
class OneHotShieldingWrapper(gym.core.ObservationWrapper):
@ -85,7 +84,7 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper):
class MiniGridShieldingWrapper(gym.core.Wrapper):
def __init__(self,
env,
shield_creator : ShieldHandler,
shield_creator : MiniGridShieldHandler,
shield_query_creator,
create_shield_at_reset=True,
mask_actions=True):

4
examples/shields/rl/sb3utils.py

@ -2,10 +2,12 @@ import gymnasium as gym
import numpy as np
import random
from utils import MiniGridShieldHandler, create_shield_query
class MiniGridSbShieldingWrapper(gym.core.Wrapper):
def __init__(self,
env,
shield_creator : ShieldHandler,
shield_creator : MiniGridShieldHandler,
shield_query_creator,
create_shield_at_reset = True,
mask_actions=True,

8
examples/shields/rl/utils.py

@ -8,10 +8,12 @@ import stormpy.logic
import stormpy.examples
import stormpy.examples.files
from helpers import extract_doors, extract_keys, extract_adversaries
from enum import Enum
from abc import ABC
from minigrid.core.actions import Actions
import os
import time
class Action():
@ -66,6 +68,7 @@ class MiniGridShieldHandler(ShieldHandler):
shield_comp = stormpy.logic.ShieldComparison.ABSOLUTE
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comp, self.shield_value)
formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
@ -82,6 +85,7 @@ class MiniGridShieldHandler(ShieldHandler):
shield_scheduler = shield.construct()
state_valuations = model.state_valuations
choice_labeling = model.choice_labeling
stormpy.shields.export_shield(model, shield, "myshield")
for stateID in model.states:
choice = shield_scheduler.get_choice(stateID)

252
examples/shields/rl/wrappers.py

@ -1,252 +0,0 @@
import gymnasium as gym
import numpy as np
import random
from minigrid.core.actions import Actions
from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX
from gymnasium.spaces import Dict, Box
from collections import deque
from ray.rllib.utils.numpy import one_hot
from helpers import get_action_index_mapping
from shieldhandlers import ShieldHandler
class OneHotShieldingWrapper(gym.core.ObservationWrapper):
def __init__(self, env, vector_index, framestack):
super().__init__(env)
self.framestack = framestack
# 49=7x7 field of vision; 16=object types; 6=colors; 3=state types.
# +4: Direction.
self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4
self.init_x = None
self.init_y = None
self.x_positions = []
self.y_positions = []
self.x_y_delta_buffer = deque(maxlen=100)
self.vector_index = vector_index
self.frame_buffer = deque(maxlen=self.framestack)
for _ in range(self.framestack):
self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
self.observation_space = Dict(
{
"data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),
"action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
}
)
def observation(self, obs):
# Debug output: max-x/y positions to watch exploration progress.
# print(F"Initial observation in Wrapper {obs}")
if self.step_count == 0:
for _ in range(self.framestack):
self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
if self.vector_index == 0:
if self.x_positions:
max_diff = max(
np.sqrt(
(np.array(self.x_positions) - self.init_x) ** 2
+ (np.array(self.y_positions) - self.init_y) ** 2
)
)
self.x_y_delta_buffer.append(max_diff)
print(
"100-average dist travelled={}".format(
np.mean(self.x_y_delta_buffer)
)
)
self.x_positions = []
self.y_positions = []
self.init_x = self.agent_pos[0]
self.init_y = self.agent_pos[1]
self.x_positions.append(self.agent_pos[0])
self.y_positions.append(self.agent_pos[1])
image = obs["data"]
# One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten.
objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX))
colors = one_hot(image[:, :, 1], depth=len(COLORS))
states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX))
all_ = np.concatenate([objects, colors, states], -1)
all_flat = np.reshape(all_, (-1,))
direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)
single_frame = np.concatenate([all_flat, direction])
self.frame_buffer.append(single_frame)
tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] }
return tmp
class MiniGridShieldingWrapper(gym.core.Wrapper):
def __init__(self,
env,
shield_creator : ShieldHandler,
shield_query_creator,
create_shield_at_reset=True,
mask_actions=True):
super(MiniGridShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = Dict(
{
"data": env.observation_space.spaces["image"],
"action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
}
)
self.shield_creator = shield_creator
self.create_shield_at_reset = create_shield_at_reset
self.shield = shield_creator.create_shield(env=self.env)
self.mask_actions = mask_actions
self.shield_query_creator = shield_query_creator
print(F"Shielding is {self.mask_actions}")
def create_action_mask(self):
# print(F"{self.mask_actions} No shielding")
if not self.mask_actions:
# print("No shielding")
ret = np.array([1.0] * self.max_available_actions, dtype=np.int8)
# print(ret)
return ret
cur_pos_str = self.shield_query_creator(self.env)
# print(F"Pos string {cur_pos_str}")
# print(F"Shield {list(self.shield.keys())[0]}")
# print(F"Is pos str in shield: {cur_pos_str in self.shield}, Position Str {cur_pos_str}")
# Create the mask
# If shield restricts action mask only valid with 1.0
# else set all actions as valid
allowed_actions = []
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8)
has_allowed_actions = False
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set
if index is None:
print(F"No mapping for action {list(allowed_action.labels)}")
print(F"Shield at pos {cur_pos_str}, shield {self.shield[cur_pos_str]}")
assert(False)
allowed = 1.0 # random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
if allowed_action.prob == 0 and allowed:
assert False
if allowed:
has_allowed_actions = True
mask[index] = allowed
# if not has_allowed_actions:
# print(F"No action allowed for pos string {cur_pos_str}")
# assert(False)
else:
for index, x in enumerate(mask):
mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
if front_tile is not None and front_tile.type == "key":
mask[Actions.pickup] = 1.0
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
# print(F"Mask is {mask} State: {cur_pos_str}")
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
if self.create_shield_at_reset and self.mask_actions:
self.shield = self.shield_creator.create_shield(env=self.env)
mask = self.create_action_mask()
return {
"data": obs["image"],
"action_mask": mask
}, infos
def step(self, action):
orig_obs, rew, done, truncated, info = self.env.step(action)
mask = self.create_action_mask()
obs = {
"data": orig_obs["image"],
"action_mask": mask,
}
return obs, rew, done, truncated, info
class MiniGridSbShieldingWrapper(gym.core.Wrapper):
def __init__(self,
env,
shield_creator : ShieldHandler,
shield_query_creator,
create_shield_at_reset = True,
mask_actions=True,
):
super(MiniGridSbShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = env.observation_space.spaces["image"]
self.shield_creator = shield_creator
self.mask_actions = mask_actions
self.shield_query_creator = shield_query_creator
print(F"Shielding is {self.mask_actions}")
def create_action_mask(self):
if not self.mask_actions:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
cur_pos_str = self.shield_query_creator(self.env)
allowed_actions = []
# Create the mask
# If shield restricts actions, mask only valid actions with 1.0
# else set all actions valid
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action.labels)
if index is None:
assert(False)
mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
else:
for index, x in enumerate(mask):
mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
shield = self.shield_creator.create_shield(env=self.env)
self.shield = shield
return obs["image"], infos
def step(self, action):
orig_obs, rew, done, truncated, info = self.env.step(action)
obs = orig_obs["image"]
return obs, rew, done, truncated, info
Loading…
Cancel
Save