Browse Source

changes in sb3 rl training

- included callbacks for initial image and info plotting
- switched to CnnPolicy
- changed GRID_TO_PRISM_BINARY to environment var M2P_BINARY
refactoring
sp 10 months ago
parent
commit
59c795348e
  1. 17
      examples/shields/rl/13_minigridsb.py

17
examples/shields/rl/13_minigridsb.py

@ -1,19 +1,20 @@
from sb3_contrib import MaskablePPO from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.evaluation import evaluate_policy from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker from sb3_contrib.common.wrappers import ActionMasker
import gymnasium as gym import gymnasium as gym
from minigrid.core.actions import Actions from minigrid.core.actions import Actions
from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper
import time import time
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback
GRID_TO_PRISM_BINARY="/home/spranger/research/tempestpy/Minigrid2PRISM/build/main"
import os
GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY")
def mask_fn(env: gym.Env): def mask_fn(env: gym.Env):
return env.create_action_mask() return env.create_action_mask()
@ -27,14 +28,16 @@ def main():
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison) shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison)
env = gym.make(args.env, render_mode="rgb_array") env = gym.make(args.env, render_mode="rgb_array")
env = RGBImgObsWrapper(env) # Get pixel observations
env = ImgObsWrapper(env) # Get rid of the 'mission' field
env = MiniWrapper(env)
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full)
env = ActionMasker(env, mask_fn) env = ActionMasker(env, mask_fn)
model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args))
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=create_log_dir(args))
steps = args.steps steps = args.steps
model.learn(steps)
model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback()], log_interval=1)
print("Learning done, hit enter") print("Learning done, hit enter")

Loading…
Cancel
Save