Thomas Knoll
1 year ago
5 changed files with 310 additions and 198 deletions
@ -0,0 +1,130 @@ |
from sb3_contrib import MaskablePPO |
from sb3_contrib.common.maskable.evaluation import evaluate_policy |
from sb3_contrib.common.maskable.utils import get_action_masks |
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy |
from sb3_contrib.common.wrappers import ActionMasker |
from stable_baselines3.common.callbacks import BaseCallback |
import gymnasium as gym |
from gymnasium.spaces import Dict, Box |
import numpy as np |
import time |
from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping |
class CustomCallback(BaseCallback): |
def __init__(self, verbose: int = 0, env=None): |
super(CustomCallback, self).__init__(verbose) |
self.env = env |
def _on_step(self) -> bool: |
#print(self.env.printGrid()) |
return super()._on_step() |
class MiniGridEnvWrapper(gym.core.Wrapper): |
def __init__(self, env, shield={}, keys=[], no_masking=False): |
super(MiniGridEnvWrapper, self).__init__(env) |
self.max_available_actions = env.action_space.n |
self.observation_space = env.observation_space.spaces["image"] |
self.keys = keys |
self.shield = shield |
self.no_masking = no_masking |
def create_action_mask(self): |
coordinates = self.env.agent_pos |
view_direction = self.env.agent_dir |
key_text = "" |
# only support one key for now |
if self.keys: |
key_text = F"!Agent_has_{self.keys[0]}_key\t& " |
if self.env.carrying and self.env.carrying.type == "key": |
key_text = F"Agent_has_{self.env.carrying.color}_key\t& " |
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ") |
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" |
allowed_actions = [] |
# Create the mask |
# If shield restricts action mask only valid with 1.0 |
# else set all actions as valid |
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) |
if cur_pos_str in self.shield and self.shield[cur_pos_str]: |
allowed_actions = self.shield[cur_pos_str] |
for allowed_action in allowed_actions: |
index = get_action_index_mapping(allowed_action[1]) |
if index is None: |
assert(False) |
mask[index] = 1.0 |
else: |
# print(F"Not in shield {cur_pos_str}") |
for index, x in enumerate(mask): |
mask[index] = 1.0 |
if self.no_masking: |
return np.array([1.0] * self.max_available_actions, dtype=np.int8) |
return mask |
def reset(self, *, seed=None, options=None): |
obs, infos = self.env.reset(seed=seed, options=options) |
return obs["image"], infos |
def step(self, action): |
# print(F"Performed action in step: {action}") |
orig_obs, rew, done, truncated, info = self.env.step(action) |
#print(F"Original observation is {orig_obs}") |
obs = orig_obs["image"] |
#print(F"Info is {info}") |
return obs, rew, done, truncated, info |
def mask_fn(env: gym.Env): |
return env.create_action_mask() |
def main(): |
import argparse |
args = parse_arguments(argparse) |
shield = create_shield_dict(args) |
env = gym.make(args.env, render_mode="rgb_array") |
keys = extract_keys(env) |
env = MiniGridEnvWrapper(env, shield=shield, keys=keys, no_masking=args.no_masking) |
env = ActionMasker(env, mask_fn) |
callback = CustomCallback(1, env) |
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=args.log_dir) |
model.learn(args.iterations, callback=callback) |
mean_reward, std_reward = evaluate_policy(model, model.get_env(), 10) |
vec_env = model.get_env() |
obs = vec_env.reset() |
terminated = truncated = False |
while not terminated and not truncated: |
action_masks = None |
action, _states = model.predict(obs, action_masks=action_masks) |
obs, reward, terminated, truncated, info = env.step(action) |
# action, _states = model.predict(obs, deterministic=True) |
# obs, rewards, dones, info = vec_env.step(action) |
vec_env.render("human") |
time.sleep(0.2) |
if __name__ == '__main__': |
main() |
Reference in new issue