diff --git a/examples/shields/rl/utils.py b/examples/shields/rl/utils.py index 1c9c8a6..1dd01cf 100644 --- a/examples/shields/rl/utils.py +++ b/examples/shields/rl/utils.py @@ -14,6 +14,7 @@ from abc import ABC import re import sys +import gymnasium as gym from minigrid.core.actions import Actions from minigrid.core.state import to_state @@ -128,10 +129,10 @@ class MiniGridShieldHandler(ShieldHandler): def create_log_dir(args): - return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparison}-env:{args.env}-conf:{args.prism_config}" + return f"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparison}-env:{args.env}-conf:{args.prism_config}" def test_name(args): - return F"{args.expname}" + return f"{args.expname}" def get_allowed_actions_mask(actions): action_mask = [0.0] * 3 + [1.0] * 4 @@ -162,3 +163,19 @@ def common_parser(): parser.add_argument("--shield_value", default=0.9, type=float) parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute']) return parser + +class MiniWrapper(gym.Wrapper): + def __init__(self, env): + super().__init__(env) + self.env = env + + def reset(self, *, seed=None, options=None): + obs, info = self.env.reset(seed=seed, options=options) + return obs.transpose(1,0,2), info + + def observations(self, obs): + return obs + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + return obs.transpose(1,0,2), reward, terminated, truncated, info