Thomas Knoll
1 year ago
5 changed files with 310 additions and 198 deletions
-
164examples/shields/rl/11_minigridrl.py
-
130examples/shields/rl/13_minigridsb.py
-
18examples/shields/rl/MaskModels.py
-
61examples/shields/rl/Wrapper.py
-
135examples/shields/rl/helpers.py
@ -0,0 +1,130 @@ |
|||
from sb3_contrib import MaskablePPO |
|||
from sb3_contrib.common.maskable.evaluation import evaluate_policy |
|||
from sb3_contrib.common.maskable.utils import get_action_masks |
|||
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy |
|||
from sb3_contrib.common.wrappers import ActionMasker |
|||
from stable_baselines3.common.callbacks import BaseCallback |
|||
|
|||
import gymnasium as gym |
|||
from gymnasium.spaces import Dict, Box |
|||
|
|||
import numpy as np |
|||
import time |
|||
|
|||
from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping |
|||
|
|||
class CustomCallback(BaseCallback): |
|||
def __init__(self, verbose: int = 0, env=None): |
|||
super(CustomCallback, self).__init__(verbose) |
|||
self.env = env |
|||
|
|||
|
|||
def _on_step(self) -> bool: |
|||
#print(self.env.printGrid()) |
|||
return super()._on_step() |
|||
|
|||
|
|||
class MiniGridEnvWrapper(gym.core.Wrapper): |
|||
def __init__(self, env, shield={}, keys=[], no_masking=False): |
|||
super(MiniGridEnvWrapper, self).__init__(env) |
|||
self.max_available_actions = env.action_space.n |
|||
self.observation_space = env.observation_space.spaces["image"] |
|||
|
|||
self.keys = keys |
|||
self.shield = shield |
|||
self.no_masking = no_masking |
|||
|
|||
def create_action_mask(self): |
|||
coordinates = self.env.agent_pos |
|||
view_direction = self.env.agent_dir |
|||
|
|||
key_text = "" |
|||
|
|||
# only support one key for now |
|||
if self.keys: |
|||
key_text = F"!Agent_has_{self.keys[0]}_key\t& " |
|||
|
|||
|
|||
if self.env.carrying and self.env.carrying.type == "key": |
|||
key_text = F"Agent_has_{self.env.carrying.color}_key\t& " |
|||
|
|||
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ") |
|||
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" |
|||
|
|||
allowed_actions = [] |
|||
|
|||
|
|||
# Create the mask |
|||
# If shield restricts action mask only valid with 1.0 |
|||
# else set all actions as valid |
|||
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) |
|||
|
|||
if cur_pos_str in self.shield and self.shield[cur_pos_str]: |
|||
allowed_actions = self.shield[cur_pos_str] |
|||
for allowed_action in allowed_actions: |
|||
index = get_action_index_mapping(allowed_action[1]) |
|||
if index is None: |
|||
assert(False) |
|||
mask[index] = 1.0 |
|||
else: |
|||
# print(F"Not in shield {cur_pos_str}") |
|||
for index, x in enumerate(mask): |
|||
mask[index] = 1.0 |
|||
|
|||
if self.no_masking: |
|||
return np.array([1.0] * self.max_available_actions, dtype=np.int8) |
|||
|
|||
return mask |
|||
|
|||
def reset(self, *, seed=None, options=None): |
|||
obs, infos = self.env.reset(seed=seed, options=options) |
|||
return obs["image"], infos |
|||
|
|||
def step(self, action): |
|||
# print(F"Performed action in step: {action}") |
|||
orig_obs, rew, done, truncated, info = self.env.step(action) |
|||
|
|||
#print(F"Original observation is {orig_obs}") |
|||
obs = orig_obs["image"] |
|||
|
|||
#print(F"Info is {info}") |
|||
return obs, rew, done, truncated, info |
|||
|
|||
|
|||
|
|||
def mask_fn(env: gym.Env): |
|||
return env.create_action_mask() |
|||
|
|||
|
|||
|
|||
def main(): |
|||
import argparse |
|||
args = parse_arguments(argparse) |
|||
shield = create_shield_dict(args) |
|||
|
|||
env = gym.make(args.env, render_mode="rgb_array") |
|||
keys = extract_keys(env) |
|||
env = MiniGridEnvWrapper(env, shield=shield, keys=keys, no_masking=args.no_masking) |
|||
env = ActionMasker(env, mask_fn) |
|||
callback = CustomCallback(1, env) |
|||
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=args.log_dir) |
|||
model.learn(args.iterations, callback=callback) |
|||
|
|||
mean_reward, std_reward = evaluate_policy(model, model.get_env(), 10) |
|||
|
|||
vec_env = model.get_env() |
|||
obs = vec_env.reset() |
|||
terminated = truncated = False |
|||
while not terminated and not truncated: |
|||
action_masks = None |
|||
action, _states = model.predict(obs, action_masks=action_masks) |
|||
obs, reward, terminated, truncated, info = env.step(action) |
|||
# action, _states = model.predict(obs, deterministic=True) |
|||
# obs, rewards, dones, info = vec_env.step(action) |
|||
vec_env.render("human") |
|||
time.sleep(0.2) |
|||
|
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
main() |
Write
Preview
Loading…
Cancel
Save
Reference in new issue