Browse Source

init Miniwrapper to switch to WxHxC observations

refactoring
sp 10 months ago
parent
commit
16490a74f1
  1. 21
      examples/shields/rl/utils.py

21
examples/shields/rl/utils.py

@ -14,6 +14,7 @@ from abc import ABC
import re
import sys
import gymnasium as gym
from minigrid.core.actions import Actions
from minigrid.core.state import to_state
@ -128,10 +129,10 @@ class MiniGridShieldHandler(ShieldHandler):
def create_log_dir(args):
return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparison}-env:{args.env}-conf:{args.prism_config}"
return f"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparison}-env:{args.env}-conf:{args.prism_config}"
def test_name(args):
return F"{args.expname}"
return f"{args.expname}"
def get_allowed_actions_mask(actions):
action_mask = [0.0] * 3 + [1.0] * 4
@ -162,3 +163,19 @@ def common_parser():
parser.add_argument("--shield_value", default=0.9, type=float)
parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute'])
return parser
class MiniWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.env = env
def reset(self, *, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options=options)
return obs.transpose(1,0,2), info
def observations(self, obs):
return obs
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return obs.transpose(1,0,2), reward, terminated, truncated, info
Loading…
Cancel
Save