Browse Source

changes according to refactoring of utils

refactoring
sp 10 months ago
parent
commit
7ccbe8f9bc
  1. 23
      examples/shields/rl/13_minigridsb.py

23
examples/shields/rl/13_minigridsb.py

@ -9,24 +9,25 @@ from minigrid.core.actions import Actions
import time
from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
from sb3utils import MiniGridSbShieldingWrapper
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments
GRID_TO_PRISM_BINARY="/home/spranger/research/tempestpy/Minigrid2PRISM/build/main"
def mask_fn(env: gym.Env):
return env.create_action_mask()
def main():
import argparse
args = parse_arguments(argparse)
args.grid_path = F"{args.grid_path}.txt"
args.prism_path = F"{args.prism_path}.prism"
args = parse_sb3_arguments()
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
formula = args.formula
shield_value = args.shield_value
shield_comparison = args.shield_comparison
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison)
env = gym.make(args.env, render_mode="rgb_array")
env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full)
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full)
env = ActionMasker(env, mask_fn)
model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args))
@ -35,14 +36,16 @@ def main():
model.learn(steps)
#W mean_reward, std_reward = evaluate_policy(model, model.get_env())
print("Learning done, hit enter")
input("")
vec_env = model.get_env()
obs = vec_env.reset()
terminated = truncated = False
while not terminated and not truncated:
action_masks = None
action, _states = model.predict(obs, action_masks=action_masks)
print(action)
obs, reward, terminated, truncated, info = env.step(action)
# action, _states = model.predict(obs, deterministic=True)
# obs, rewards, dones, info = vec_env.step(action)

Loading…
Cancel
Save