Browse Source

fixed shield generation

worker handling
refactoring
Thomas Knoll 1 year ago
parent
commit
757fbbcc0d
  1. 30
      examples/shields/rl/11_minigridrl.py
  2. 4
      examples/shields/rl/13_minigridsb.py
  3. 10
      examples/shields/rl/MaskModels.py
  4. 11
      examples/shields/rl/Wrapper.py
  5. 38
      examples/shields/rl/helpers.py

30
examples/shields/rl/11_minigridrl.py

@ -22,10 +22,9 @@ from ray.rllib.models import ModelCatalog
from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.models.preprocessors import get_preprocessor
from MaskModels import TorchActionMaskModel
from Wrapper import OneHotWrapper, MiniGridEnvWrapper
from helpers import extract_keys, parse_arguments, create_shield_dict, create_log_dir
from helpers import parse_arguments, create_log_dir
import matplotlib.pyplot as plt
@ -34,7 +33,9 @@ class MyCallbacks(DefaultCallbacks):
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
episode.user_data["count"] = 0
# print("On episode start print")
# print(env.printGrid())
# print(worker)
# print(env.action_space.n)
# print(env.actions)
# print(env.mission)
@ -52,17 +53,19 @@ class MyCallbacks(DefaultCallbacks):
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
#print("On episode end print")
#print(env.printGrid())
# print(episode.user_data["count"])
def env_creater_custom(config):
framestack = config.get("framestack", 4)
shield = config.get("shield", {})
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
env = gym.make(name)
env = MiniGridEnvWrapper(env, args=args)
# env = minigrid.wrappers.ImgObsWrapper(env)
@ -88,14 +91,10 @@ def register_custom_minigrid_env(args):
def ppo(args):
ray.init(num_cpus=1)
register_custom_minigrid_env(args)
config = (PPOConfig()
.rollouts(num_rollout_workers=1)
.rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0)
.environment(env="mini-grid", env_config={"name": args.env, "args": args})
.framework("torch")
@ -123,15 +122,6 @@ def ppo(args):
checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}")
# terminated = truncated = False
# while not terminated and not truncated:
# action = algo.compute_single_action(obs)
# obs, reward, terminated, truncated = env.step(action)
ray.shutdown()
def dqn(args):
register_custom_minigrid_env(args)
@ -139,7 +129,7 @@ def dqn(args):
config = DQNConfig()
config = config.resources(num_gpus=0)
config = config.rollouts(num_rollout_workers=1)
config = config.rollouts(num_rollout_workers=args.workers)
config = config.environment(env="mini-grid", env_config={"name": args.env, "args": args })
config = config.framework("torch")
config = config.callbacks(MyCallbacks)
@ -166,8 +156,6 @@ def dqn(args):
checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}")
ray.shutdown()
def main():
import argparse

4
examples/shields/rl/13_minigridsb.py

@ -102,13 +102,9 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
return obs["image"], infos
def step(self, action):
# print(F"Performed action in step: {action}")
orig_obs, rew, done, truncated, info = self.env.step(action)
#print(F"Original observation is {orig_obs}")
obs = orig_obs["image"]
#print(F"Info is {info}")
return obs, rew, done, truncated, info

10
examples/shields/rl/MaskModels.py

@ -24,10 +24,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
):
orig_space = getattr(obs_space, "original_space", obs_space)
custom_config = model_config['custom_model_config']
# print(F"Original Space is: {orig_space}")
#print(model_config)
#print(F"Observation space in model: {obs_space}")
#print(F"Provided action space in model {action_space}")
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name, **kwargs
@ -44,7 +40,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
name + "_internal",
)
# disable action masking --> will likely lead to invalid actions
self.no_masking = False
if "no_masking" in model_config["custom_model_config"]:
self.no_masking = model_config["custom_model_config"]["no_masking"]
@ -54,20 +49,17 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
# Compute the unmasked logits.
logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]})
action_mask = input_dict["obs"]["action_mask"]
# If action masking is disabled, directly return unmasked logits
if self.no_masking:
return logits, state
# assert(False)
# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
masked_logits = logits + inf_mask
# # Return masked logits.
# Return masked logits.
return masked_logits, state
def value_function(self):

11
examples/shields/rl/Wrapper.py

@ -34,10 +34,6 @@ class OneHotWrapper(gym.core.ObservationWrapper):
}
)
# print(F"Set obersvation space to {self.observation_space}")
def observation(self, obs):
# Debug output: max-x/y positions to watch exploration progress.
# print(F"Initial observation in Wrapper {obs}")
@ -80,9 +76,8 @@ class OneHotWrapper(gym.core.ObservationWrapper):
single_frame = np.concatenate([all_flat, direction])
self.frame_buffer.append(single_frame)
#obs["one-hot"] = np.concatenate(self.frame_buffer)
tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] }
return tmp#np.concatenate(self.frame_buffer)
return tmp
class MiniGridEnvWrapper(gym.core.Wrapper):
@ -111,7 +106,6 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
if self.env.carrying and self.env.carrying.type == "key":
key_text = F"Agent_has_{self.env.carrying.color}_key\t& "
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
allowed_actions = []
@ -130,7 +124,6 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
assert(False)
mask[index] = 1.0
else:
# print(F"Not in shield {cur_pos_str}")
for index, x in enumerate(mask):
mask[index] = 1.0
@ -158,11 +151,9 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
}, infos
def step(self, action):
# print(F"Performed action in step: {action}")
orig_obs, rew, done, truncated, info = self.env.step(action)
mask = self.create_action_mask()
#print(F"Original observation is {orig_obs}")
obs = {
"data": orig_obs["image"],
"action_mask": mask,

38
examples/shields/rl/helpers.py

@ -18,7 +18,6 @@ import os
def extract_keys(env):
env.reset()
keys = []
#print(env.grid)
for j in range(env.grid.height):
@ -66,6 +65,7 @@ def parse_arguments(argparse):
default="MiniGrid-LavaCrossingS9N1-v0",
choices=[
"MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-LavaCrossingS9N3-v0",
"MiniGrid-DoorKey-8x8-v0",
"MiniGrid-LockedRoom-v0",
"MiniGrid-FourRooms-v0",
@ -77,26 +77,21 @@ def parse_arguments(argparse):
# parser.add_argument("--seed", type=int, help="seed for environment", default=None)
parser.add_argument("--grid_to_prism_path", default="./main")
parser.add_argument("--grid_path", default="Grid.txt")
parser.add_argument("--prism_path", default="Grid.PRISM")
parser.add_argument("--grid_path", default="grid")
parser.add_argument("--prism_path", default="grid")
parser.add_argument("--no_masking", default=False)
parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"])
parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--iterations", type=int, default=30 )
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
parser.add_argument("--workers", type=int, default=1)
args = parser.parse_args()
return args
def create_environment(args):
env_id= args.env
env = gym.make(env_id)
env.reset()
return env
def export_grid_to_text(env, grid_file):
f = open(grid_file, "w")
# print(env)
@ -104,24 +99,18 @@ def export_grid_to_text(env, grid_file):
f.close()
def create_shield(grid_to_prism_path, grid_file, prism_path):
def create_shield(grid_to_prism_path, grid_file, prism_path, formula):
os.system(F"{grid_to_prism_path} -v 'agent' -i {grid_file} -o {prism_path}")
f = open(prism_path, "a")
f.write("label \"AgentIsInLava\" = AgentIsInLava;")
f.close()
program = stormpy.parse_prism_program(prism_path)
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
# formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY,
# stormpy.logic.ShieldComparison.ABSOLUTE, 0.9)
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)
# shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.9)
formulas = stormpy.parse_properties_for_prism_program(formula_str, program)
formulas = stormpy.parse_properties_for_prism_program(formula, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
options.set_build_state_valuations(True)
options.set_build_choice_labels(True)
@ -151,21 +140,12 @@ def create_shield(grid_to_prism_path, grid_file, prism_path):
def create_shield_dict(env, args):
env = create_environment(args)
# print(env.printGrid(init=False))
grid_file = args.grid_path
grid_to_prism_path = args.grid_to_prism_path
export_grid_to_text(env, grid_file)
prism_path = args.prism_path
shield_dict = create_shield(grid_to_prism_path ,grid_file, prism_path)
#shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states}
#print(F"Shield dictionary {shield_dict}")
# for state_id in model.states:
# choices = shield.get_choice(state_id)
# print(F"Allowed choices in state {state_id}, are {choices.choice_map} ")
shield_dict = create_shield(grid_to_prism_path ,grid_file, prism_path, args.formula)
return shield_dict
Loading…
Cancel
Save