|
|
@ -9,24 +9,25 @@ from minigrid.core.actions import Actions |
|
|
|
|
|
|
|
import time |
|
|
|
|
|
|
|
from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig |
|
|
|
from sb3utils import MiniGridSbShieldingWrapper |
|
|
|
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig |
|
|
|
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments |
|
|
|
|
|
|
|
GRID_TO_PRISM_BINARY="/home/spranger/research/tempestpy/Minigrid2PRISM/build/main" |
|
|
|
|
|
|
|
def mask_fn(env: gym.Env): |
|
|
|
return env.create_action_mask() |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
import argparse |
|
|
|
args = parse_arguments(argparse) |
|
|
|
|
|
|
|
args.grid_path = F"{args.grid_path}.txt" |
|
|
|
args.prism_path = F"{args.prism_path}.prism" |
|
|
|
args = parse_sb3_arguments() |
|
|
|
|
|
|
|
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) |
|
|
|
formula = args.formula |
|
|
|
shield_value = args.shield_value |
|
|
|
shield_comparison = args.shield_comparison |
|
|
|
|
|
|
|
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison) |
|
|
|
env = gym.make(args.env, render_mode="rgb_array") |
|
|
|
env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full) |
|
|
|
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) |
|
|
|
env = ActionMasker(env, mask_fn) |
|
|
|
model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args)) |
|
|
|
|
|
|
@ -35,14 +36,16 @@ def main(): |
|
|
|
|
|
|
|
model.learn(steps) |
|
|
|
|
|
|
|
#W mean_reward, std_reward = evaluate_policy(model, model.get_env()) |
|
|
|
|
|
|
|
print("Learning done, hit enter") |
|
|
|
input("") |
|
|
|
vec_env = model.get_env() |
|
|
|
obs = vec_env.reset() |
|
|
|
terminated = truncated = False |
|
|
|
while not terminated and not truncated: |
|
|
|
action_masks = None |
|
|
|
action, _states = model.predict(obs, action_masks=action_masks) |
|
|
|
print(action) |
|
|
|
obs, reward, terminated, truncated, info = env.step(action) |
|
|
|
# action, _states = model.predict(obs, deterministic=True) |
|
|
|
# obs, rewards, dones, info = vec_env.step(action) |
|
|
|