Browse Source

basic action embedding

refactoring
Thomas Knoll 1 year ago
parent
commit
cf18349819
  1. 255
      examples/shields/rl/11_minigridrl.py
  2. 134
      examples/shields/rl/12_basic_training.py
  3. 91
      examples/shields/rl/MaskEnvironments.py
  4. 81
      examples/shields/rl/MaskModels.py
  5. 152
      examples/shields/rl/Wrapper.py

255
examples/shields/rl/11_minigridrl.py

@ -0,0 +1,255 @@
from typing import Dict, Optional, Union
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID
import stormpy
import stormpy.core
import stormpy.simulator
import stormpy.shields
import stormpy.logic
import stormpy.examples
import stormpy.examples.files
import os
import gymnasium as gym
import minigrid
import numpy as np
import ray
from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator
from ray import tune, air
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print
from ray.rllib.algorithms import ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.models.preprocessors import get_preprocessor
from MaskEnvironments import ParametricActionsMiniGridEnv
from MaskModels import TorchActionMaskModel
from Wrapper import OneHotWrapper, MiniGridEnvWrapper, ImgObsWrapper
import matplotlib.pyplot as plt
import argparse
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
episode.user_data["count"] = 0
# print(env.printGrid())
# print(env.action_space.n)
# print(env.actions)
# print(env.mission)
# print(env.observation_space)
# img = env.get_frame()
# plt.imshow(img)
# plt.show()
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0]
print(env.env.env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
# print(env.env.env.printGrid())
# print(episode.user_data["count"])
def parse_arguments(argparse):
parser = argparse.ArgumentParser()
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
parser.add_argument("--env", help="gym environment to load", default="MiniGrid-LavaCrossingS9N1-v0")
parser.add_argument("--seed", type=int, help="seed for environment", default=1)
parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32)
parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees")
parser.add_argument("--grid_path", default="Grid.txt")
parser.add_argument("--prism_path", default="Grid.PRISM")
args = parser.parse_args()
return args
def env_creater_custom(config):
# name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
# # name = config.get("name", "MiniGrid-Empty-8x8-v0")
framestack = config.get("framestack", 4)
# env = gym.make(name)
# env = ParametricActionsMiniGridEnv(config)
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
env = gym.make(name)
env = MiniGridEnvWrapper(env)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env)
env = OneHotWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
obs = env.observation_space.sample()
obs2, infos = env.reset(seed=None, options={})
print(F"Obs is {obs} before reset. After reset: {obs2}")
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env)
print(F"Created Custom Minigrid Environment is {env}")
return env
def env_creater_cart(config):
return gym.make("CartPole-v1")
def env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
# name = config.get("name", "MiniGrid-Empty-8x8-v0")
framestack = config.get("framestack", 4)
env = gym.make(name)
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env)
env = minigrid.wrappers.ImgObsWrapper(env)
env = OneHotWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
print(F"Created Minigrid Environment is {env}")
return env
def create_shield(grid_file, prism_path):
os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}")
f = open(prism_path, "a")
f.write("label \"AgentIsInLava\" = AgentIsInLava;")
f.close()
program = stormpy.parse_prism_program(prism_path)
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
formulas = stormpy.parse_properties_for_prism_program(formula_str, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
options.set_build_state_valuations(True)
options.set_build_choice_labels(True)
options.set_build_all_labels()
model = stormpy.build_sparse_model_with_options(program, options)
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
assert result.has_scheduler
assert result.has_shield
shield = result.shield
stormpy.shields.export_shield(model, shield, "Grid.shield")
return shield.construct(), model
def export_grid_to_text(env, grid_file):
f = open(grid_file, "w")
# print(env)
f.write(env.printGrid(init=True))
# f.write(env.pprint_grid())
f.close()
def create_environment(args):
env_id= args.env
env = gym.make(env_id)
env.reset()
return env
def main():
args = parse_arguments(argparse)
env = create_environment(args)
ray.init(num_cpus=3)
# print(env.pprint_grid())
# print(env.printGrid(init=False))
grid_file = args.grid_path
export_grid_to_text(env, grid_file)
prism_path = args.prism_path
shield, model = create_shield(grid_file, prism_path)
shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states}
print(shield_dict)
for state_id in model.states:
choices = shield.get_choice(state_id)
print(F"Allowed choices in state {state_id}, are {choices.choice_map} ")
env_name = "mini-grid"
register_env(env_name, env_creater_custom)
ModelCatalog.register_custom_model(
"pa_model",
TorchActionMaskModel
)
config = (PPOConfig()
.rollouts(num_rollout_workers=1)
.resources(num_gpus=0)
.environment(env="mini-grid")
.framework("torch")
.experimental(_disable_preprocessor_api=False)
.callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False)
.training(_enable_learner_api=False ,model={
"custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": True}
# "fcnet_hiddens": [256,256],
# "fcnet_activation": "relu",
}))
algo =(
config.build()
)
episode_reward = 0
terminated = truncated = False
obs, info = env.reset()
# while not terminated and not truncated:
# action = algo.compute_single_action(obs)
# obs, reward, terminated, truncated = env.step(action)
for i in range(30):
result = algo.train()
print(pretty_print(result))
if i % 5 == 0:
checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}")
ray.shutdown()
if __name__ == '__main__':
main()

134
examples/shields/12_basic_training.py → examples/shields/rl/12_basic_training.py

@ -33,11 +33,81 @@ from ray.tune.logger import pretty_print
from ray.rllib.utils.numpy import one_hot
from ray.rllib.algorithms import ppo
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.models.preprocessors import get_preprocessor
import matplotlib.pyplot as plt
import argparse
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from Wrapper import OneHotWrapper
torch, nn = try_import_torch()
class TorchActionMaskModel(TorchModelV2, nn.Module):
"""PyTorch version of above ActionMaskingModel."""
def __init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
**kwargs,
):
orig_space = getattr(obs_space, "original_space", obs_space)
assert (
isinstance(orig_space, Dict)
and "action_mask" in orig_space.spaces
and "observations" in orig_space.spaces
)
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name, **kwargs
)
nn.Module.__init__(self)
self.internal_model = TorchFC(
orig_space["observations"],
action_space,
num_outputs,
model_config,
name + "_internal",
)
# disable action masking --> will likely lead to invalid actions
self.no_masking = False
if "no_masking" in model_config["custom_model_config"]:
self.no_masking = model_config["custom_model_config"]["no_masking"]
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
action_mask = input_dict["obs"]["action_mask"]
# Compute the unmasked logits.
logits, _ = self.internal_model({"obs": input_dict["obs"]["observations"]})
# If action masking is disabled, directly return unmasked logits
if self.no_masking:
return logits, state
# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
masked_logits = logits + inf_mask
# Return masked logits.
return masked_logits, state
def value_function(self):
return self.internal_model.value_function()
class MyCallbacks(DefaultCallbacks):
@ -66,69 +136,7 @@ class MyCallbacks(DefaultCallbacks):
# print(episode.user_data["count"])
class OneHotWrapper(gym.core.ObservationWrapper):
def __init__(self, env, vector_index, framestack):
super().__init__(env)
self.framestack = framestack
# 49=7x7 field of vision; 11=object types; 6=colors; 3=state types.
# +4: Direction.
self.single_frame_dim = 49 * (11 + 6 + 3) + 4
self.init_x = None
self.init_y = None
self.x_positions = []
self.y_positions = []
self.x_y_delta_buffer = deque(maxlen=100)
self.vector_index = vector_index
self.frame_buffer = deque(maxlen=self.framestack)
for _ in range(self.framestack):
self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
self.observation_space = gym.spaces.Box(
0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32
)
def observation(self, obs):
# Debug output: max-x/y positions to watch exploration progress.
if self.step_count == 0:
for _ in range(self.framestack):
self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
if self.vector_index == 0:
if self.x_positions:
max_diff = max(
np.sqrt(
(np.array(self.x_positions) - self.init_x) ** 2
+ (np.array(self.y_positions) - self.init_y) ** 2
)
)
self.x_y_delta_buffer.append(max_diff)
print(
"100-average dist travelled={}".format(
np.mean(self.x_y_delta_buffer)
)
)
self.x_positions = []
self.y_positions = []
self.init_x = self.agent_pos[0]
self.init_y = self.agent_pos[1]
self.x_positions.append(self.agent_pos[0])
self.y_positions.append(self.agent_pos[1])
# One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten.
objects = one_hot(obs[:, :, 0], depth=11)
colors = one_hot(obs[:, :, 1], depth=6)
states = one_hot(obs[:, :, 2], depth=3)
all_ = np.concatenate([objects, colors, states], -1)
all_flat = np.reshape(all_, (-1,))
direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)
single_frame = np.concatenate([all_flat, direction])
self.frame_buffer.append(single_frame)
return np.concatenate(self.frame_buffer)
def parse_arguments(argparse):

91
examples/shields/rl/MaskEnvironments.py

@ -0,0 +1,91 @@
import random
import minigrid
import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box, Dict, Discrete
from Wrapper import OneHotWrapper
class ParametricActionsMiniGridEnv(gym.Env):
"""Parametric action version of MiniGrid.
"""
def __init__(self, config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
self.left_action_embed = np.random.randn(2)
self.right_action_embed = np.random.randn(2)
framestack = config.get("framestack", 4)
# env = gym.make(name)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = OneHotWrapper(env,
# config.vector_index if hasattr(config, "vector_index") else 0,
# framestack=framestack
# )
self.wrapped = gym.make(name)
# self.observation_space = Dict(
# {
# "action_mask": None,
# "avail_actions": None,
# "cart": self.wrapped.observation_space,
# }
# )
print(F"Wrapped environment is {self.wrapped}")
self.step_count = 0
self.action_space = self.wrapped.action_space
self.observation_space = self.wrapped.observation_space
def update_avail_actions(self):
self.action_assignments = np.array(
[[0.0, 0.0]] * self.action_space.n, dtype=np.float32
)
self.action_mask = np.array([0.0] * self.action_space.n, dtype=np.int8)
self.left_idx, self.right_idx = random.sample(range(self.action_space.n), 2)
self.action_assignments[self.left_idx] = self.left_action_embed
self.action_assignments[self.right_idx] = self.right_action_embed
self.action_mask[self.left_idx] = 1
self.action_mask[self.right_idx] = 1
def reset(self, *, seed=None, options=None):
self.update_avail_actions()
obs, infos = self.wrapped.reset()
return obs, infos
return {
"action_mask": self.action_mask,
"avail_actions": self.action_assignments,
"cart": obs,
}, infos
def step(self, action):
if action == self.left_idx:
actual_action = 0
elif action == self.right_idx:
actual_action = 1
else:
actual_action = 0
# raise ValueError(
# "Chosen action was not one of the non-zero action embeddings",
# action,
# self.action_assignments,
# self.action_mask,
# self.left_idx,
# self.right_idx,
# )
orig_obs, rew, done, truncated, info = self.wrapped.step(actual_action)
self.update_avail_actions()
self.action_mask = self.action_mask.astype(np.int8)
print(F"Info is {info}")
info["Hello" : "Ich kenn mich nix aus"]
return orig_obs, rew, done, truncated, info
obs = {
"action_mask": self.action_mask,
"avail_actions": self.action_assignments,
"cart": orig_obs,
}
return obs, rew, done, truncated, info

81
examples/shields/rl/MaskModels.py

@ -0,0 +1,81 @@
from typing import Dict, Optional, Union
from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MIN, FLOAT_MAX
torch, nn = try_import_torch()
class TorchActionMaskModel(TorchModelV2, nn.Module):
"""PyTorch version of above ActionMaskingModel."""
def __init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
**kwargs,
):
orig_space = getattr(obs_space, "original_space", obs_space)
custom_config = model_config['custom_model_config']
print(F"Original Space is: {orig_space}")
#print(model_config)
print(F"Observation space in model: {obs_space}")
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name, **kwargs
)
nn.Module.__init__(self)
assert("shield" in custom_config)
self.shield = custom_config["shield"]
self.internal_model = TorchFC(
orig_space["data"],
action_space,
num_outputs,
model_config,
name + "_internal",
)
# disable action masking --> will likely lead to invalid actions
self.no_masking = False
if "no_masking" in model_config["custom_model_config"]:
self.no_masking = model_config["custom_model_config"]["no_masking"]
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
# print(F"Input dict is {input_dict} at obs: {input_dict['obs']}")
# print(F"State is {state}")
action_mask = []
# print(input_dict["env"])
# Compute the unmasked logits.
logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]})
# If action masking is disabled, directly return unmasked logits
if self.no_masking:
return logits, state
assert(False)
return logits, state
# Convert action_mask into a [0.0 || -inf]-type mask.
# inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
# masked_logits = logits + inf_mask
# # Return masked logits.
# return masked_logits, state
def value_function(self):
return self.internal_model.value_function()

152
examples/shields/rl/Wrapper.py

@ -0,0 +1,152 @@
import gymnasium as gym
import numpy as np
from gymnasium.spaces import Dict, Box
from collections import deque
from ray.rllib.utils.numpy import one_hot
class OneHotWrapper(gym.core.ObservationWrapper):
def __init__(self, env, vector_index, framestack):
super().__init__(env)
self.framestack = framestack
# 49=7x7 field of vision; 11=object types; 6=colors; 3=state types.
# +4: Direction.
self.single_frame_dim = 49 * (11 + 6 + 3) + 4
self.init_x = None
self.init_y = None
self.x_positions = []
self.y_positions = []
self.x_y_delta_buffer = deque(maxlen=100)
self.vector_index = vector_index
self.frame_buffer = deque(maxlen=self.framestack)
for _ in range(self.framestack):
self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
self.observation_space = Dict(
{
"data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),
"avail_actions": gym.spaces.Box(0, 10, shape=(10,), dtype=int),
}
)
print(F"Set obersvation space to {self.observation_space}")
def observation(self, obs):
# Debug output: max-x/y positions to watch exploration progress.
# print(F"Initial observation in Wrapper {obs}")
if self.step_count == 0:
for _ in range(self.framestack):
self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
if self.vector_index == 0:
if self.x_positions:
max_diff = max(
np.sqrt(
(np.array(self.x_positions) - self.init_x) ** 2
+ (np.array(self.y_positions) - self.init_y) ** 2
)
)
self.x_y_delta_buffer.append(max_diff)
print(
"100-average dist travelled={}".format(
np.mean(self.x_y_delta_buffer)
)
)
self.x_positions = []
self.y_positions = []
self.init_x = self.agent_pos[0]
self.init_y = self.agent_pos[1]
self.x_positions.append(self.agent_pos[0])
self.y_positions.append(self.agent_pos[1])
image = obs["data"]
# One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten.
objects = one_hot(image[:, :, 0], depth=11)
colors = one_hot(image[:, :, 1], depth=6)
states = one_hot(image[:, :, 2], depth=3)
all_ = np.concatenate([objects, colors, states], -1)
all_flat = np.reshape(all_, (-1,))
direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)
single_frame = np.concatenate([all_flat, direction])
self.frame_buffer.append(single_frame)
#obs["one-hot"] = np.concatenate(self.frame_buffer)
tmp = {"data": np.concatenate(self.frame_buffer), "avail_actions": obs["avail_actions"] }
return tmp#np.concatenate(self.frame_buffer)
class MiniGridEnvWrapper(gym.core.Wrapper):
def __init__(self, env):
super(MiniGridEnvWrapper, self).__init__(env)
self.observation_space = Dict(
{
"data": env.observation_space.spaces["image"],
"avail_actions" : Box(0, 10, shape=(10,), dtype=np.int8),
}
)
def test(self):
print("Testing some stuff")
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset()
return {
"data": obs["image"],
"avail_actions": np.array([0.0] * 10, dtype=np.int8)
}, infos
def step(self, action):
orig_obs, rew, done, truncated, info = self.env.step(action)
self.test()
#print(F"Original observation is {orig_obs}")
obs = {
"data": orig_obs["image"],
"avail_actions": np.array([0.0] * 10, dtype=np.int8),
}
#print(F"Info is {info}")
return obs, rew, done, truncated, info
class ImgObsWrapper(gym.core.ObservationWrapper):
"""
Use the image as the only observation output, no language/mission.
Example:
>>> import gymnasium as gym
>>> from minigrid.wrappers import ImgObsWrapper
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
>>> obs, _ = env.reset()
>>> obs.keys()
dict_keys(['image', 'direction', 'mission'])
>>> env = ImgObsWrapper(env)
>>> obs, _ = env.reset()
>>> obs.shape
(7, 7, 3)
"""
def __init__(self, env):
"""A wrapper that makes image the only observation.
Args:
env: The environment to apply the wrapper
"""
super().__init__(env)
self.observation_space = env.observation_space.spaces["image"]
print(F"Set obersvation space to {self.observation_space}")
def observation(self, obs):
#print(F"obs in img obs wrapper {obs}")
tmp = {"data": obs["image"], "Test": obs["Test"]}
return tmp
Loading…
Cancel
Save