|
|
@ -8,10 +8,12 @@ from stable_baselines3.common.callbacks import BaseCallback |
|
|
|
import gymnasium as gym |
|
|
|
from gymnasium.spaces import Dict, Box |
|
|
|
|
|
|
|
from minigrid.core.actions import Actions |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import time |
|
|
|
|
|
|
|
from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping |
|
|
|
from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping, create_log_dir |
|
|
|
|
|
|
|
class CustomCallback(BaseCallback): |
|
|
|
def __init__(self, verbose: int = 0, env=None): |
|
|
@ -35,6 +37,10 @@ class MiniGridEnvWrapper(gym.core.Wrapper): |
|
|
|
self.no_masking = no_masking |
|
|
|
|
|
|
|
def create_action_mask(self): |
|
|
|
if self.no_masking: |
|
|
|
return np.array([1.0] * self.max_available_actions, dtype=np.int8) |
|
|
|
|
|
|
|
|
|
|
|
coordinates = self.env.agent_pos |
|
|
|
view_direction = self.env.agent_dir |
|
|
|
|
|
|
@ -70,9 +76,19 @@ class MiniGridEnvWrapper(gym.core.Wrapper): |
|
|
|
# print(F"Not in shield {cur_pos_str}") |
|
|
|
for index, x in enumerate(mask): |
|
|
|
mask[index] = 1.0 |
|
|
|
|
|
|
|
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) |
|
|
|
|
|
|
|
if front_tile is not None and front_tile.type == "key": |
|
|
|
mask[Actions.pickup] = 1.0 |
|
|
|
|
|
|
|
if self.env.carrying: |
|
|
|
mask[Actions.drop] = 1.0 |
|
|
|
|
|
|
|
if front_tile and front_tile.type == "door": |
|
|
|
mask[Actions.toggle] = 1.0 |
|
|
|
|
|
|
|
|
|
|
|
if self.no_masking: |
|
|
|
return np.array([1.0] * self.max_available_actions, dtype=np.int8) |
|
|
|
|
|
|
|
return mask |
|
|
|
|
|
|
@ -107,8 +123,14 @@ def main(): |
|
|
|
env = MiniGridEnvWrapper(env, shield=shield, keys=keys, no_masking=args.no_masking) |
|
|
|
env = ActionMasker(env, mask_fn) |
|
|
|
callback = CustomCallback(1, env) |
|
|
|
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=args.log_dir) |
|
|
|
model.learn(args.iterations, callback=callback) |
|
|
|
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args)) |
|
|
|
|
|
|
|
iterations = args.iterations |
|
|
|
|
|
|
|
if iterations < 10_000: |
|
|
|
iterations = 10_000 |
|
|
|
|
|
|
|
model.learn(iterations, callback=callback) |
|
|
|
|
|
|
|
mean_reward, std_reward = evaluate_policy(model, model.get_env(), 10) |
|
|
|
|
|
|
|