From 16490a74f103306687eed30102b4ea21b64df1ed Mon Sep 17 00:00:00 2001
From: sp <stefan.pranger@iaik.tugraz.at>
Date: Mon, 15 Jan 2024 09:34:40 +0100
Subject: [PATCH] init Miniwrapper to switch to WxHxC observations

---
 examples/shields/rl/utils.py | 21 +++++++++++++++++++--
 1 file changed, 19 insertions(+), 2 deletions(-)

diff --git a/examples/shields/rl/utils.py b/examples/shields/rl/utils.py
index 1c9c8a6..1dd01cf 100644
--- a/examples/shields/rl/utils.py
+++ b/examples/shields/rl/utils.py
@@ -14,6 +14,7 @@ from abc import ABC
 import re
 import sys
 
+import gymnasium as gym
 
 from minigrid.core.actions import Actions
 from minigrid.core.state import to_state
@@ -128,10 +129,10 @@ class MiniGridShieldHandler(ShieldHandler):
 
 
 def create_log_dir(args):
-    return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparison}-env:{args.env}-conf:{args.prism_config}"
+    return f"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparison}-env:{args.env}-conf:{args.prism_config}"
 
 def test_name(args):
-    return F"{args.expname}"
+    return f"{args.expname}"
 
 def get_allowed_actions_mask(actions):
     action_mask = [0.0] * 3 + [1.0] * 4
@@ -162,3 +163,19 @@ def common_parser():
     parser.add_argument("--shield_value", default=0.9, type=float)
     parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute'])
     return parser
+
+class MiniWrapper(gym.Wrapper):
+    def __init__(self, env):
+        super().__init__(env)
+        self.env = env
+
+    def reset(self, *, seed=None, options=None):
+        obs, info = self.env.reset(seed=seed, options=options)
+        return obs.transpose(1,0,2), info
+
+    def observations(self, obs):
+        return obs
+
+    def step(self, action):
+        obs, reward, terminated, truncated, info = self.env.step(action)
+        return obs.transpose(1,0,2), reward, terminated, truncated, info