Thomas Knoll
1 year ago
5 changed files with 310 additions and 198 deletions
-
164examples/shields/rl/11_minigridrl.py
-
130examples/shields/rl/13_minigridsb.py
-
18examples/shields/rl/MaskModels.py
-
61examples/shields/rl/Wrapper.py
-
135examples/shields/rl/helpers.py
@ -0,0 +1,130 @@ |
|||||
|
from sb3_contrib import MaskablePPO |
||||
|
from sb3_contrib.common.maskable.evaluation import evaluate_policy |
||||
|
from sb3_contrib.common.maskable.utils import get_action_masks |
||||
|
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy |
||||
|
from sb3_contrib.common.wrappers import ActionMasker |
||||
|
from stable_baselines3.common.callbacks import BaseCallback |
||||
|
|
||||
|
import gymnasium as gym |
||||
|
from gymnasium.spaces import Dict, Box |
||||
|
|
||||
|
import numpy as np |
||||
|
import time |
||||
|
|
||||
|
from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping |
||||
|
|
||||
|
class CustomCallback(BaseCallback): |
||||
|
def __init__(self, verbose: int = 0, env=None): |
||||
|
super(CustomCallback, self).__init__(verbose) |
||||
|
self.env = env |
||||
|
|
||||
|
|
||||
|
def _on_step(self) -> bool: |
||||
|
#print(self.env.printGrid()) |
||||
|
return super()._on_step() |
||||
|
|
||||
|
|
||||
|
class MiniGridEnvWrapper(gym.core.Wrapper): |
||||
|
def __init__(self, env, shield={}, keys=[], no_masking=False): |
||||
|
super(MiniGridEnvWrapper, self).__init__(env) |
||||
|
self.max_available_actions = env.action_space.n |
||||
|
self.observation_space = env.observation_space.spaces["image"] |
||||
|
|
||||
|
self.keys = keys |
||||
|
self.shield = shield |
||||
|
self.no_masking = no_masking |
||||
|
|
||||
|
def create_action_mask(self): |
||||
|
coordinates = self.env.agent_pos |
||||
|
view_direction = self.env.agent_dir |
||||
|
|
||||
|
key_text = "" |
||||
|
|
||||
|
# only support one key for now |
||||
|
if self.keys: |
||||
|
key_text = F"!Agent_has_{self.keys[0]}_key\t& " |
||||
|
|
||||
|
|
||||
|
if self.env.carrying and self.env.carrying.type == "key": |
||||
|
key_text = F"Agent_has_{self.env.carrying.color}_key\t& " |
||||
|
|
||||
|
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ") |
||||
|
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" |
||||
|
|
||||
|
allowed_actions = [] |
||||
|
|
||||
|
|
||||
|
# Create the mask |
||||
|
# If shield restricts action mask only valid with 1.0 |
||||
|
# else set all actions as valid |
||||
|
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) |
||||
|
|
||||
|
if cur_pos_str in self.shield and self.shield[cur_pos_str]: |
||||
|
allowed_actions = self.shield[cur_pos_str] |
||||
|
for allowed_action in allowed_actions: |
||||
|
index = get_action_index_mapping(allowed_action[1]) |
||||
|
if index is None: |
||||
|
assert(False) |
||||
|
mask[index] = 1.0 |
||||
|
else: |
||||
|
# print(F"Not in shield {cur_pos_str}") |
||||
|
for index, x in enumerate(mask): |
||||
|
mask[index] = 1.0 |
||||
|
|
||||
|
if self.no_masking: |
||||
|
return np.array([1.0] * self.max_available_actions, dtype=np.int8) |
||||
|
|
||||
|
return mask |
||||
|
|
||||
|
def reset(self, *, seed=None, options=None): |
||||
|
obs, infos = self.env.reset(seed=seed, options=options) |
||||
|
return obs["image"], infos |
||||
|
|
||||
|
def step(self, action): |
||||
|
# print(F"Performed action in step: {action}") |
||||
|
orig_obs, rew, done, truncated, info = self.env.step(action) |
||||
|
|
||||
|
#print(F"Original observation is {orig_obs}") |
||||
|
obs = orig_obs["image"] |
||||
|
|
||||
|
#print(F"Info is {info}") |
||||
|
return obs, rew, done, truncated, info |
||||
|
|
||||
|
|
||||
|
|
||||
|
def mask_fn(env: gym.Env): |
||||
|
return env.create_action_mask() |
||||
|
|
||||
|
|
||||
|
|
||||
|
def main(): |
||||
|
import argparse |
||||
|
args = parse_arguments(argparse) |
||||
|
shield = create_shield_dict(args) |
||||
|
|
||||
|
env = gym.make(args.env, render_mode="rgb_array") |
||||
|
keys = extract_keys(env) |
||||
|
env = MiniGridEnvWrapper(env, shield=shield, keys=keys, no_masking=args.no_masking) |
||||
|
env = ActionMasker(env, mask_fn) |
||||
|
callback = CustomCallback(1, env) |
||||
|
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=args.log_dir) |
||||
|
model.learn(args.iterations, callback=callback) |
||||
|
|
||||
|
mean_reward, std_reward = evaluate_policy(model, model.get_env(), 10) |
||||
|
|
||||
|
vec_env = model.get_env() |
||||
|
obs = vec_env.reset() |
||||
|
terminated = truncated = False |
||||
|
while not terminated and not truncated: |
||||
|
action_masks = None |
||||
|
action, _states = model.predict(obs, action_masks=action_masks) |
||||
|
obs, reward, terminated, truncated, info = env.step(action) |
||||
|
# action, _states = model.predict(obs, deterministic=True) |
||||
|
# obs, rewards, dones, info = vec_env.step(action) |
||||
|
vec_env.render("human") |
||||
|
time.sleep(0.2) |
||||
|
|
||||
|
|
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
main() |
Write
Preview
Loading…
Cancel
Save
Reference in new issue