|
@ -1,19 +1,20 @@ |
|
|
from sb3_contrib import MaskablePPO |
|
|
from sb3_contrib import MaskablePPO |
|
|
from sb3_contrib.common.maskable.evaluation import evaluate_policy |
|
|
from sb3_contrib.common.maskable.evaluation import evaluate_policy |
|
|
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy |
|
|
|
|
|
from sb3_contrib.common.wrappers import ActionMasker |
|
|
from sb3_contrib.common.wrappers import ActionMasker |
|
|
|
|
|
|
|
|
import gymnasium as gym |
|
|
import gymnasium as gym |
|
|
|
|
|
|
|
|
from minigrid.core.actions import Actions |
|
|
from minigrid.core.actions import Actions |
|
|
|
|
|
from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper |
|
|
|
|
|
|
|
|
import time |
|
|
import time |
|
|
|
|
|
|
|
|
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig |
|
|
|
|
|
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments |
|
|
|
|
|
|
|
|
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper |
|
|
|
|
|
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback |
|
|
|
|
|
|
|
|
GRID_TO_PRISM_BINARY="/home/spranger/research/tempestpy/Minigrid2PRISM/build/main" |
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY") |
|
|
def mask_fn(env: gym.Env): |
|
|
def mask_fn(env: gym.Env): |
|
|
return env.create_action_mask() |
|
|
return env.create_action_mask() |
|
|
|
|
|
|
|
@ -27,14 +28,16 @@ def main(): |
|
|
|
|
|
|
|
|
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison) |
|
|
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison) |
|
|
env = gym.make(args.env, render_mode="rgb_array") |
|
|
env = gym.make(args.env, render_mode="rgb_array") |
|
|
|
|
|
env = RGBImgObsWrapper(env) # Get pixel observations |
|
|
|
|
|
env = ImgObsWrapper(env) # Get rid of the 'mission' field |
|
|
|
|
|
env = MiniWrapper(env) |
|
|
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) |
|
|
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) |
|
|
env = ActionMasker(env, mask_fn) |
|
|
env = ActionMasker(env, mask_fn) |
|
|
model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args)) |
|
|
|
|
|
|
|
|
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=create_log_dir(args)) |
|
|
|
|
|
|
|
|
steps = args.steps |
|
|
steps = args.steps |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.learn(steps) |
|
|
|
|
|
|
|
|
model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback()], log_interval=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Learning done, hit enter") |
|
|
print("Learning done, hit enter") |
|
|