From 1c2dbf706ef32ba7d3b982be8bfd91733e4adb64 Mon Sep 17 00:00:00 2001
From: Thomas Knoll <thomas.knolł@student.tugraz.at>
Date: Wed, 30 Aug 2023 15:41:23 +0200
Subject: [PATCH] changed shield creation to create shield on reset

---
 examples/shields/rl/11_minigridrl.py | 36 +++++++++++++---------------
 examples/shields/rl/13_minigridsb.py | 16 ++++++++-----
 examples/shields/rl/MaskModels.py    |  3 ---
 examples/shields/rl/Wrapper.py       | 13 +++++-----
 examples/shields/rl/helpers.py       |  8 +++----
 5 files changed, 38 insertions(+), 38 deletions(-)

diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py
index 035b195..ff7d8e1 100644
--- a/examples/shields/rl/11_minigridrl.py
+++ b/examples/shields/rl/11_minigridrl.py
@@ -5,7 +5,7 @@ from ray.rllib.evaluation.episode import Episode
 from ray.rllib.evaluation.episode_v2 import EpisodeV2
 from ray.rllib.policy import Policy
 from ray.rllib.utils.typing import PolicyID
-
+from ray.rllib.algorithms.algorithm import Algorithm
 
 import gymnasium as gym
 
@@ -29,9 +29,6 @@ from helpers import extract_keys, parse_arguments, create_shield_dict, create_lo
 
 import matplotlib.pyplot as plt
 
-
-
-
 class MyCallbacks(DefaultCallbacks):
     def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
         # print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
@@ -50,7 +47,7 @@ class MyCallbacks(DefaultCallbacks):
     def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
          episode.user_data["count"] = episode.user_data["count"] + 1
          env = base_env.get_sub_environments()[0]
-         #print(env.printGrid())
+        # print(env.printGrid())
     
     def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
         # print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
@@ -65,10 +62,9 @@ def env_creater_custom(config):
     shield = config.get("shield", {})
     name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
     framestack = config.get("framestack", 4)
-    
+    args = config.get("args", None)
     env = gym.make(name)
-    keys = extract_keys(env)
-    env = MiniGridEnvWrapper(env, shield=shield, keys=keys)
+    env = MiniGridEnvWrapper(env, args=args)
     # env = minigrid.wrappers.ImgObsWrapper(env)
     # env = ImgObsWrapper(env)
     env = OneHotWrapper(env,
@@ -76,6 +72,7 @@ def env_creater_custom(config):
                         framestack=framestack
                         )
     
+    
     return env
 
 
@@ -96,12 +93,11 @@ def ppo(args):
 
     
     register_custom_minigrid_env(args)
-    shield_dict = create_shield_dict(args)
     
     config = (PPOConfig()
         .rollouts(num_rollout_workers=1)
         .resources(num_gpus=0)
-        .environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env})
+        .environment(env="mini-grid", env_config={"name": args.env, "args": args})
         .framework("torch")       
         .callbacks(MyCallbacks)
         .rl_module(_enable_rl_module_api = False)
@@ -111,7 +107,7 @@ def ppo(args):
         })
         .training(_enable_learner_api=False ,model={
             "custom_model": "pa_model",
-            "custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}            
+            "custom_model_config" : {"no_masking": args.no_masking}            
         }))
     
     algo =(
@@ -119,11 +115,7 @@ def ppo(args):
         config.build()
     )
     
-    # while not terminated and not truncated:
-    #     action = algo.compute_single_action(obs)
-    #     obs, reward, terminated, truncated = env.step(action)
-    
-    for i in range(30):
+    for i in range(args.iterations):
         result = algo.train()
         print(pretty_print(result))
 
@@ -131,18 +123,24 @@ def ppo(args):
             checkpoint_dir = algo.save()
             print(f"Checkpoint saved in directory {checkpoint_dir}")
             
+    # terminated = truncated = False
+    
+    # while not terminated and not truncated:
+    #      action = algo.compute_single_action(obs)
+    #      obs, reward, terminated, truncated = env.step(action)
+    
+            
     ray.shutdown()
 
 
 def dqn(args):
     register_custom_minigrid_env(args)
-    shield_dict = create_shield_dict(args)
 
     
     config = DQNConfig()
     config = config.resources(num_gpus=0)
     config = config.rollouts(num_rollout_workers=1)
-    config = config.environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env })
+    config = config.environment(env="mini-grid", env_config={"name": args.env, "args": args })
     config = config.framework("torch")
     config = config.callbacks(MyCallbacks)
     config = config.rl_module(_enable_rl_module_api = False)
@@ -152,7 +150,7 @@ def dqn(args):
         })
     config = config.training(hiddens=[], dueling=False, model={    
             "custom_model": "pa_model",
-            "custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
+            "custom_model_config" : {"no_masking": args.no_masking}
     })
     
     algo = (
diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py
index 5873ed6..6a2fb20 100644
--- a/examples/shields/rl/13_minigridsb.py
+++ b/examples/shields/rl/13_minigridsb.py
@@ -27,13 +27,12 @@ class CustomCallback(BaseCallback):
 
 
 class MiniGridEnvWrapper(gym.core.Wrapper):
-    def __init__(self, env, shield={}, keys=[], no_masking=False):
+    def __init__(self, env, args=None, no_masking=False):
         super(MiniGridEnvWrapper, self).__init__(env)
         self.max_available_actions = env.action_space.n
         self.observation_space = env.observation_space.spaces["image"]
         
-        self.keys = keys
-        self.shield = shield
+        self.args = args
         self.no_masking = no_masking
 
     def create_action_mask(self):
@@ -94,6 +93,12 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
 
     def reset(self, *, seed=None, options=None):
         obs, infos = self.env.reset(seed=seed, options=options)
+        
+        keys = extract_keys(self.env)
+        shield = create_shield_dict(self.env, self.args)
+        
+        self.keys = keys
+        self.shield = shield
         return obs["image"], infos
 
     def step(self, action):
@@ -116,11 +121,10 @@ def mask_fn(env: gym.Env):
 def main():
     import argparse
     args = parse_arguments(argparse)
-    shield = create_shield_dict(args)
+    
     
     env = gym.make(args.env, render_mode="rgb_array")
-    keys = extract_keys(env)
-    env = MiniGridEnvWrapper(env, shield=shield, keys=keys, no_masking=args.no_masking)
+    env = MiniGridEnvWrapper(env,args=args, no_masking=args.no_masking)
     env = ActionMasker(env, mask_fn)
     callback = CustomCallback(1, env)
     model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args))
diff --git a/examples/shields/rl/MaskModels.py b/examples/shields/rl/MaskModels.py
index 607b529..0ee4154 100644
--- a/examples/shields/rl/MaskModels.py
+++ b/examples/shields/rl/MaskModels.py
@@ -34,9 +34,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
         )
         nn.Module.__init__(self)
         
-        assert("shield" in custom_config)
-        
-        self.shield = custom_config["shield"]
         self.count = 0
 
         self.internal_model = TorchFC(
diff --git a/examples/shields/rl/Wrapper.py b/examples/shields/rl/Wrapper.py
index 41aaf2a..6650369 100644
--- a/examples/shields/rl/Wrapper.py
+++ b/examples/shields/rl/Wrapper.py
@@ -7,7 +7,7 @@ from gymnasium.spaces import Dict, Box
 from collections import deque
 from ray.rllib.utils.numpy import one_hot
 
-from helpers import get_action_index_mapping
+from helpers import get_action_index_mapping, create_shield_dict, extract_keys
 
 
 class OneHotWrapper(gym.core.ObservationWrapper):
@@ -86,7 +86,7 @@ class OneHotWrapper(gym.core.ObservationWrapper):
 
 
 class MiniGridEnvWrapper(gym.core.Wrapper):
-    def __init__(self, env, shield={}, keys=[]):
+    def __init__(self, env, args=None):
         super(MiniGridEnvWrapper, self).__init__(env)
         self.max_available_actions = env.action_space.n
         self.observation_space = Dict(
@@ -95,8 +95,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
                 "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
             }
         )
-        self.keys = keys
-        self.shield = shield
+        self.args = args
 
     def create_action_mask(self):
         coordinates = self.env.agent_pos
@@ -140,8 +139,8 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
         if front_tile is not None and front_tile.type == "key":
             mask[Actions.pickup] = 1.0
             
-        if self.env.carrying:
-            mask[Actions.drop] = 1.0
+        # if self.env.carrying:
+        #     mask[Actions.drop] = 1.0
             
         if front_tile and front_tile.type == "door":
             mask[Actions.toggle] = 1.0
@@ -150,6 +149,8 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
 
     def reset(self, *, seed=None, options=None):
         obs, infos = self.env.reset(seed=seed, options=options)
+        self.shield = create_shield_dict(self.env, self.args)
+        self.keys = extract_keys(self.env)
         mask = self.create_action_mask()
         return {
             "data": obs["image"],
diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py
index 714ac00..ca15f14 100644
--- a/examples/shields/rl/helpers.py
+++ b/examples/shields/rl/helpers.py
@@ -20,7 +20,7 @@ import os
 def extract_keys(env):
     env.reset()
     keys = []
-    print(env.grid)
+    #print(env.grid)
     for j in range(env.grid.height):
         for i in range(env.grid.width):
             obj = env.grid.get(i,j)
@@ -113,8 +113,8 @@ def create_shield(grid_to_prism_path, grid_file, prism_path):
     
     
     program = stormpy.parse_prism_program(prism_path)
-    # formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
-    formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
+    formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
+    # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
     # shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY,
     #                                                       stormpy.logic.ShieldComparison.ABSOLUTE, 0.9) 
  
@@ -150,7 +150,7 @@ def create_shield(grid_to_prism_path, grid_file, prism_path):
     return action_dictionary
 
         
-def create_shield_dict(args):
+def create_shield_dict(env, args):
     env = create_environment(args)
     # print(env.printGrid(init=False))