|
|
@ -14,6 +14,7 @@ from abc import ABC |
|
|
|
import re |
|
|
|
import sys |
|
|
|
|
|
|
|
import gymnasium as gym |
|
|
|
|
|
|
|
from minigrid.core.actions import Actions |
|
|
|
from minigrid.core.state import to_state |
|
|
@ -128,10 +129,10 @@ class MiniGridShieldHandler(ShieldHandler): |
|
|
|
|
|
|
|
|
|
|
|
def create_log_dir(args): |
|
|
|
return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparison}-env:{args.env}-conf:{args.prism_config}" |
|
|
|
return f"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparison}-env:{args.env}-conf:{args.prism_config}" |
|
|
|
|
|
|
|
def test_name(args): |
|
|
|
return F"{args.expname}" |
|
|
|
return f"{args.expname}" |
|
|
|
|
|
|
|
def get_allowed_actions_mask(actions): |
|
|
|
action_mask = [0.0] * 3 + [1.0] * 4 |
|
|
@ -162,3 +163,19 @@ def common_parser(): |
|
|
|
parser.add_argument("--shield_value", default=0.9, type=float) |
|
|
|
parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute']) |
|
|
|
return parser |
|
|
|
|
|
|
|
class MiniWrapper(gym.Wrapper): |
|
|
|
def __init__(self, env): |
|
|
|
super().__init__(env) |
|
|
|
self.env = env |
|
|
|
|
|
|
|
def reset(self, *, seed=None, options=None): |
|
|
|
obs, info = self.env.reset(seed=seed, options=options) |
|
|
|
return obs.transpose(1,0,2), info |
|
|
|
|
|
|
|
def observations(self, obs): |
|
|
|
return obs |
|
|
|
|
|
|
|
def step(self, action): |
|
|
|
obs, reward, terminated, truncated, info = self.env.step(action) |
|
|
|
return obs.transpose(1,0,2), reward, terminated, truncated, info |