Browse Source

some refactoring as preparation for sb3 example

added sb3 example
refactoring
Thomas Knoll 1 year ago
parent
commit
b1b014dbd6
  1. 164
      examples/shields/rl/11_minigridrl.py
  2. 130
      examples/shields/rl/13_minigridsb.py
  3. 18
      examples/shields/rl/MaskModels.py
  4. 61
      examples/shields/rl/Wrapper.py
  5. 135
      examples/shields/rl/helpers.py

164
examples/shields/rl/11_minigridrl.py

@ -1,21 +1,13 @@
from typing import Dict, Optional, Union
from typing import Dict
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID
import stormpy
import stormpy.core
import stormpy.simulator
from datetime import datetime
import stormpy.shields
import stormpy.logic
import stormpy.examples
import stormpy.examples.files
import os
from datetime import datetime
import gymnasium as gym
@ -26,8 +18,6 @@ import ray
from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig
from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator
from ray import tune, air
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog
@ -37,11 +27,10 @@ from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.models.preprocessors import get_preprocessor
from MaskModels import TorchActionMaskModel
from Wrapper import OneHotWrapper, MiniGridEnvWrapper
from helpers import extract_keys
from helpers import extract_keys, parse_arguments, create_shield_dict
import matplotlib.pyplot as plt
import argparse
@ -58,6 +47,7 @@ class MyCallbacks(DefaultCallbacks):
# img = env.get_frame()
# plt.imshow(img)
# plt.show()
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
@ -70,47 +60,7 @@ class MyCallbacks(DefaultCallbacks):
# print(env.printGrid())
# print(episode.user_data["count"])
def parse_arguments(argparse):
parser = argparse.ArgumentParser()
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
parser.add_argument("--env",
help="gym environment to load",
default="MiniGrid-LavaCrossingS9N1-v0",
choices=[
"MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-DoorKey-8x8-v0",
"MiniGrid-Dynamic-Obstacles-8x8-v0",
"MiniGrid-Empty-Random-6x6-v0",
"MiniGrid-Fetch-6x6-N2-v0",
"MiniGrid-FourRooms-v0",
"MiniGrid-KeyCorridorS6R3-v0",
"MiniGrid-GoToDoor-8x8-v0",
"MiniGrid-LavaGapS7-v0",
"MiniGrid-SimpleCrossingS9N3-v0",
"MiniGrid-BlockedUnlockPickup-v0",
"MiniGrid-LockedRoom-v0",
"MiniGrid-ObstructedMaze-1Dlh-v0",
"MiniGrid-DoorKey-16x16-v0",
"MiniGrid-RedBlueDoors-6x6-v0",])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None)
parser.add_argument("--grid_to_prism_path", default="./main")
parser.add_argument("--grid_path", default="Grid.txt")
parser.add_argument("--prism_path", default="Grid.PRISM")
parser.add_argument("--no_masking", default=False)
parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"])
parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--iterations", type=int, default=30 )
args = parser.parse_args()
return args
def env_creater_custom(config):
framestack = config.get("framestack", 4)
@ -130,86 +80,10 @@ def env_creater_custom(config):
return env
def env_creater_cart(config):
return gym.make("CartPole-v1")
def env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
env = gym.make(name)
env = minigrid.wrappers.ImgObsWrapper(env)
env = OneHotWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
print(F"Created Minigrid Environment is {env}")
return env
def create_log_dir(args):
return F"{args.log_dir}{datetime.now()}-{args.algorithm}-masking:{not args.no_masking}"
def create_shield(grid_to_prism_path, grid_file, prism_path):
os.system(F"{grid_to_prism_path} -v 'agent' -i {grid_file} -o {prism_path}")
f = open(prism_path, "a")
f.write("label \"AgentIsInLava\" = AgentIsInLava;")
f.close()
program = stormpy.parse_prism_program(prism_path)
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
# formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY,
stormpy.logic.ShieldComparison.ABSOLUTE, 0.9)
# shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.9)
formulas = stormpy.parse_properties_for_prism_program(formula_str, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
options.set_build_state_valuations(True)
options.set_build_choice_labels(True)
options.set_build_all_labels()
model = stormpy.build_sparse_model_with_options(program, options)
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
assert result.has_scheduler
assert result.has_shield
shield = result.shield
action_dictionary = {}
shield_scheduler = shield.construct()
for stateID in model.states:
choice = shield_scheduler.get_choice(stateID)
choices = choice.choice_map
state_valuation = model.state_valuations.get_string(stateID)
actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
action_dictionary[state_valuation] = actions_to_be_executed
stormpy.shields.export_shield(model, shield, "Grid.shield")
return action_dictionary
def export_grid_to_text(env, grid_file):
f = open(grid_file, "w")
# print(env)
f.write(env.printGrid(init=True))
f.close()
def create_environment(args):
env_id= args.env
env = gym.make(env_id)
env.reset()
return env
def register_custom_minigrid_env(args):
env_name = "mini-grid"
register_env(env_name, env_creater_custom)
@ -218,25 +92,7 @@ def register_custom_minigrid_env(args):
"pa_model",
TorchActionMaskModel
)
def create_shield_dict(args):
env = create_environment(args)
# print(env.printGrid(init=False))
grid_file = args.grid_path
grid_to_prism_path = args.grid_to_prism_path
export_grid_to_text(env, grid_file)
prism_path = args.prism_path
shield_dict = create_shield(grid_to_prism_path ,grid_file, prism_path)
#shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states}
#print(F"Shield dictionary {shield_dict}")
# for state_id in model.states:
# choices = shield.get_choice(state_id)
# print(F"Allowed choices in state {state_id}, are {choices.choice_map} ")
return shield_dict
def ppo(args):
@ -311,14 +167,16 @@ def dqn(args):
result = algo.train()
print(pretty_print(result))
# if i % 5 == 0:
# checkpoint_dir = algo.save()
# print(f"Checkpoint saved in directory {checkpoint_dir}")
if i % 5 == 0:
print("Saving checkpoint")
checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}")
ray.shutdown()
def main():
import argparse
args = parse_arguments(argparse)
if args.algorithm == "ppo":

130
examples/shields/rl/13_minigridsb.py

@ -0,0 +1,130 @@
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.maskable.utils import get_action_masks
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.callbacks import BaseCallback
import gymnasium as gym
from gymnasium.spaces import Dict, Box
import numpy as np
import time
from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping
class CustomCallback(BaseCallback):
def __init__(self, verbose: int = 0, env=None):
super(CustomCallback, self).__init__(verbose)
self.env = env
def _on_step(self) -> bool:
#print(self.env.printGrid())
return super()._on_step()
class MiniGridEnvWrapper(gym.core.Wrapper):
def __init__(self, env, shield={}, keys=[], no_masking=False):
super(MiniGridEnvWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = env.observation_space.spaces["image"]
self.keys = keys
self.shield = shield
self.no_masking = no_masking
def create_action_mask(self):
coordinates = self.env.agent_pos
view_direction = self.env.agent_dir
key_text = ""
# only support one key for now
if self.keys:
key_text = F"!Agent_has_{self.keys[0]}_key\t& "
if self.env.carrying and self.env.carrying.type == "key":
key_text = F"Agent_has_{self.env.carrying.color}_key\t& "
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
allowed_actions = []
# Create the mask
# If shield restricts action mask only valid with 1.0
# else set all actions as valid
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action[1])
if index is None:
assert(False)
mask[index] = 1.0
else:
# print(F"Not in shield {cur_pos_str}")
for index, x in enumerate(mask):
mask[index] = 1.0
if self.no_masking:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
return obs["image"], infos
def step(self, action):
# print(F"Performed action in step: {action}")
orig_obs, rew, done, truncated, info = self.env.step(action)
#print(F"Original observation is {orig_obs}")
obs = orig_obs["image"]
#print(F"Info is {info}")
return obs, rew, done, truncated, info
def mask_fn(env: gym.Env):
return env.create_action_mask()
def main():
import argparse
args = parse_arguments(argparse)
shield = create_shield_dict(args)
env = gym.make(args.env, render_mode="rgb_array")
keys = extract_keys(env)
env = MiniGridEnvWrapper(env, shield=shield, keys=keys, no_masking=args.no_masking)
env = ActionMasker(env, mask_fn)
callback = CustomCallback(1, env)
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=args.log_dir)
model.learn(args.iterations, callback=callback)
mean_reward, std_reward = evaluate_policy(model, model.get_env(), 10)
vec_env = model.get_env()
obs = vec_env.reset()
terminated = truncated = False
while not terminated and not truncated:
action_masks = None
action, _states = model.predict(obs, action_masks=action_masks)
obs, reward, terminated, truncated, info = env.step(action)
# action, _states = model.predict(obs, deterministic=True)
# obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
time.sleep(0.2)
if __name__ == '__main__':
main()

18
examples/shields/rl/MaskModels.py

@ -54,19 +54,12 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
# print(F"Input dict is {input_dict} at obs: {input_dict['obs']}")
# print(F"State is {state}")
# print(input_dict["env"])
# Compute the unmasked logits.
# Compute the unmasked logits.
logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]})
# print(F"Caluclated Logits {logits} with size {logits.size()} Count: {self.count}")
action_mask = input_dict["obs"]["action_mask"]
#print(F"Action mask is {action_mask} with dimension {action_mask.size()}")
# If action masking is disabled, directly return unmasked logits
if self.no_masking:
return logits, state
@ -74,12 +67,9 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
# assert(False)
# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
# print(F"Logits Size: {logits.size()} Inf-Mask Size: {inf_mask.size()}")
# print(F"Logits:{logits} Inf-Mask: {inf_mask}")
masked_logits = logits + inf_mask
# print(F"Infinity mask {inf_mask}, Masked logits {masked_logits}")
# # Return masked logits.
return masked_logits, state

61
examples/shields/rl/Wrapper.py

@ -8,6 +8,7 @@ from ray.rllib.utils.numpy import one_hot
from helpers import get_action_index_mapping
class OneHotWrapper(gym.core.ObservationWrapper):
def __init__(self, env, vector_index, framestack):
super().__init__(env)
@ -30,11 +31,11 @@ class OneHotWrapper(gym.core.ObservationWrapper):
"data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),
"action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
}
)
)
# print(F"Set obersvation space to {self.observation_space}")
def observation(self, obs):
# Debug output: max-x/y positions to watch exploration progress.
@ -61,23 +62,23 @@ class OneHotWrapper(gym.core.ObservationWrapper):
self.init_x = self.agent_pos[0]
self.init_y = self.agent_pos[1]
self.x_positions.append(self.agent_pos[0])
self.y_positions.append(self.agent_pos[1])
image = obs["data"]
# One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten.
objects = one_hot(image[:, :, 0], depth=11)
colors = one_hot(image[:, :, 1], depth=6)
states = one_hot(image[:, :, 2], depth=3)
all_ = np.concatenate([objects, colors, states], -1)
all_flat = np.reshape(all_, (-1,))
direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)
single_frame = np.concatenate([all_flat, direction])
self.frame_buffer.append(single_frame)
#obs["one-hot"] = np.concatenate(self.frame_buffer)
tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] }
return tmp#np.concatenate(self.frame_buffer)
@ -95,49 +96,49 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
)
self.keys = keys
self.shield = shield
def create_action_mask(self):
coordinates = self.env.agent_pos
view_direction = self.env.agent_dir
key_text = ""
# only support one key for now
if self.keys:
key_text = F"!Agent_has_{self.keys[0]}_key\t& "
if self.env.carrying and self.env.carrying.type == "key":
key_text = F"Agent_has_{self.env.carrying.color}_key\t& "
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
allowed_actions = []
# Create the mask
# If shield restricts action mask only valid with 1.0
# else set all actions as valid
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action[1])
index = get_action_index_mapping(allowed_action[1])
if index is None:
assert(False)
assert(False)
mask[index] = 1.0
else:
print("Not in shield")
# print(F"Not in shield {cur_pos_str}")
for index, x in enumerate(mask):
mask[index] = 1.0
#print(F"Action Mask for position {coordinates} and view {view_direction} is {mask} Position String: {cur_pos_str})")
# mask[0] = 1.0
# print(F"Action Mask for position {coordinates} and view {view_direction} is {mask} Position String: {cur_pos_str})")
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
mask = self.create_action_mask()
@ -145,19 +146,19 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
"data": obs["image"],
"action_mask": mask
}, infos
def step(self, action):
# print(F"Performed action in step: {action}")
orig_obs, rew, done, truncated, info = self.env.step(action)
mask = self.create_action_mask()
#print(F"Original observation is {orig_obs}")
obs = {
"data": orig_obs["image"],
"action_mask": mask,
}
#print(F"Info is {info}")
return obs, rew, done, truncated, info

135
examples/shields/rl/helpers.py

@ -1,5 +1,18 @@
import minigrid
from minigrid.core.actions import Actions
import gymnasium as gym
import stormpy
import stormpy.core
import stormpy.simulator
import stormpy.shields
import stormpy.logic
import stormpy.examples
import stormpy.examples.files
import os
def extract_keys(env):
@ -36,4 +49,124 @@ def get_action_index_mapping(actions):
return Actions.done
raise ValueError(F"Action string {action_str} not supported")
raise ValueError(F"Action string {action_str} not supported")
def parse_arguments(argparse):
parser = argparse.ArgumentParser()
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
parser.add_argument("--env",
help="gym environment to load",
default="MiniGrid-LavaCrossingS9N1-v0",
choices=[
"MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-DoorKey-8x8-v0",
"MiniGrid-Dynamic-Obstacles-8x8-v0",
"MiniGrid-Empty-Random-6x6-v0",
"MiniGrid-Fetch-6x6-N2-v0",
"MiniGrid-FourRooms-v0",
"MiniGrid-KeyCorridorS6R3-v0",
"MiniGrid-GoToDoor-8x8-v0",
"MiniGrid-LavaGapS7-v0",
"MiniGrid-SimpleCrossingS9N3-v0",
"MiniGrid-BlockedUnlockPickup-v0",
"MiniGrid-LockedRoom-v0",
"MiniGrid-ObstructedMaze-1Dlh-v0",
"MiniGrid-DoorKey-16x16-v0",
"MiniGrid-RedBlueDoors-6x6-v0",])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None)
parser.add_argument("--grid_to_prism_path", default="./main")
parser.add_argument("--grid_path", default="Grid.txt")
parser.add_argument("--prism_path", default="Grid.PRISM")
parser.add_argument("--no_masking", default=False)
parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"])
parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--iterations", type=int, default=30 )
args = parser.parse_args()
return args
def create_environment(args):
env_id= args.env
env = gym.make(env_id)
env.reset()
return env
def export_grid_to_text(env, grid_file):
f = open(grid_file, "w")
# print(env)
f.write(env.printGrid(init=True))
f.close()
def create_shield(grid_to_prism_path, grid_file, prism_path):
os.system(F"{grid_to_prism_path} -v 'agent' -i {grid_file} -o {prism_path}")
f = open(prism_path, "a")
f.write("label \"AgentIsInLava\" = AgentIsInLava;")
f.close()
program = stormpy.parse_prism_program(prism_path)
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
# formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY,
# stormpy.logic.ShieldComparison.ABSOLUTE, 0.9)
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)
# shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.9)
formulas = stormpy.parse_properties_for_prism_program(formula_str, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
options.set_build_state_valuations(True)
options.set_build_choice_labels(True)
options.set_build_all_labels()
model = stormpy.build_sparse_model_with_options(program, options)
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
assert result.has_scheduler
assert result.has_shield
shield = result.shield
action_dictionary = {}
shield_scheduler = shield.construct()
for stateID in model.states:
choice = shield_scheduler.get_choice(stateID)
choices = choice.choice_map
state_valuation = model.state_valuations.get_string(stateID)
actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
action_dictionary[state_valuation] = actions_to_be_executed
stormpy.shields.export_shield(model, shield, "Grid.shield")
return action_dictionary
def create_shield_dict(args):
env = create_environment(args)
# print(env.printGrid(init=False))
grid_file = args.grid_path
grid_to_prism_path = args.grid_to_prism_path
export_grid_to_text(env, grid_file)
prism_path = args.prism_path
shield_dict = create_shield(grid_to_prism_path ,grid_file, prism_path)
#shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states}
#print(F"Shield dictionary {shield_dict}")
# for state_id in model.states:
# choices = shield.get_choice(state_id)
# print(F"Allowed choices in state {state_id}, are {choices.choice_map} ")
return shield_dict
Loading…
Cancel
Save