You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
152 lines
4.9 KiB
152 lines
4.9 KiB
from sb3_contrib import MaskablePPO
|
|
from sb3_contrib.common.maskable.evaluation import evaluate_policy
|
|
from sb3_contrib.common.maskable.utils import get_action_masks
|
|
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
|
|
from sb3_contrib.common.wrappers import ActionMasker
|
|
from stable_baselines3.common.callbacks import BaseCallback
|
|
|
|
import gymnasium as gym
|
|
from gymnasium.spaces import Dict, Box
|
|
|
|
from minigrid.core.actions import Actions
|
|
|
|
import numpy as np
|
|
import time
|
|
|
|
from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping, create_log_dir
|
|
|
|
class CustomCallback(BaseCallback):
|
|
def __init__(self, verbose: int = 0, env=None):
|
|
super(CustomCallback, self).__init__(verbose)
|
|
self.env = env
|
|
|
|
|
|
def _on_step(self) -> bool:
|
|
#print(self.env.printGrid())
|
|
return super()._on_step()
|
|
|
|
|
|
class MiniGridEnvWrapper(gym.core.Wrapper):
|
|
def __init__(self, env, shield={}, keys=[], no_masking=False):
|
|
super(MiniGridEnvWrapper, self).__init__(env)
|
|
self.max_available_actions = env.action_space.n
|
|
self.observation_space = env.observation_space.spaces["image"]
|
|
|
|
self.keys = keys
|
|
self.shield = shield
|
|
self.no_masking = no_masking
|
|
|
|
def create_action_mask(self):
|
|
if self.no_masking:
|
|
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
|
|
|
|
|
|
coordinates = self.env.agent_pos
|
|
view_direction = self.env.agent_dir
|
|
|
|
key_text = ""
|
|
|
|
# only support one key for now
|
|
if self.keys:
|
|
key_text = F"!Agent_has_{self.keys[0]}_key\t& "
|
|
|
|
|
|
if self.env.carrying and self.env.carrying.type == "key":
|
|
key_text = F"Agent_has_{self.env.carrying.color}_key\t& "
|
|
|
|
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
|
|
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
|
|
|
|
allowed_actions = []
|
|
|
|
|
|
# Create the mask
|
|
# If shield restricts action mask only valid with 1.0
|
|
# else set all actions as valid
|
|
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
|
|
|
|
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
|
|
allowed_actions = self.shield[cur_pos_str]
|
|
for allowed_action in allowed_actions:
|
|
index = get_action_index_mapping(allowed_action[1])
|
|
if index is None:
|
|
assert(False)
|
|
mask[index] = 1.0
|
|
else:
|
|
# print(F"Not in shield {cur_pos_str}")
|
|
for index, x in enumerate(mask):
|
|
mask[index] = 1.0
|
|
|
|
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
|
|
|
|
if front_tile is not None and front_tile.type == "key":
|
|
mask[Actions.pickup] = 1.0
|
|
|
|
if self.env.carrying:
|
|
mask[Actions.drop] = 1.0
|
|
|
|
if front_tile and front_tile.type == "door":
|
|
mask[Actions.toggle] = 1.0
|
|
|
|
|
|
|
|
return mask
|
|
|
|
def reset(self, *, seed=None, options=None):
|
|
obs, infos = self.env.reset(seed=seed, options=options)
|
|
return obs["image"], infos
|
|
|
|
def step(self, action):
|
|
# print(F"Performed action in step: {action}")
|
|
orig_obs, rew, done, truncated, info = self.env.step(action)
|
|
|
|
#print(F"Original observation is {orig_obs}")
|
|
obs = orig_obs["image"]
|
|
|
|
#print(F"Info is {info}")
|
|
return obs, rew, done, truncated, info
|
|
|
|
|
|
|
|
def mask_fn(env: gym.Env):
|
|
return env.create_action_mask()
|
|
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
args = parse_arguments(argparse)
|
|
shield = create_shield_dict(args)
|
|
|
|
env = gym.make(args.env, render_mode="rgb_array")
|
|
keys = extract_keys(env)
|
|
env = MiniGridEnvWrapper(env, shield=shield, keys=keys, no_masking=args.no_masking)
|
|
env = ActionMasker(env, mask_fn)
|
|
callback = CustomCallback(1, env)
|
|
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args))
|
|
|
|
iterations = args.iterations
|
|
|
|
if iterations < 10_000:
|
|
iterations = 10_000
|
|
|
|
model.learn(iterations, callback=callback)
|
|
|
|
mean_reward, std_reward = evaluate_policy(model, model.get_env(), 10)
|
|
|
|
vec_env = model.get_env()
|
|
obs = vec_env.reset()
|
|
terminated = truncated = False
|
|
while not terminated and not truncated:
|
|
action_masks = None
|
|
action, _states = model.predict(obs, action_masks=action_masks)
|
|
obs, reward, terminated, truncated, info = env.step(action)
|
|
# action, _states = model.predict(obs, deterministic=True)
|
|
# obs, rewards, dones, info = vec_env.step(action)
|
|
vec_env.render("human")
|
|
time.sleep(0.2)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|