diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index f8b5008..d2f125c 100644 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -71,7 +71,7 @@ def ppo(args): config.build() ) - for i in range(args.iterations): + for i in range(args.evaluations): result = algo.train() print(pretty_print(result)) @@ -103,7 +103,7 @@ def dqn(args): config.build() ) - for i in range(args.iterations): + for i in range(args.evaluations): result = algo.train() print(pretty_print(result)) diff --git a/examples/shields/rl/12_minigridrl_tune.py b/examples/shields/rl/12_minigridrl_tune.py index 57beb01..e0ff945 100644 --- a/examples/shields/rl/12_minigridrl_tune.py +++ b/examples/shields/rl/12_minigridrl_tune.py @@ -116,8 +116,7 @@ def main(): ), run_config=air.RunConfig( stop = {"episode_reward_mean": 94, - "timesteps_total": 12000, - "training_iteration": args.iterations}, + "timesteps_total": 12000,}, checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ), storage_path=F"{logdir}" ), diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index c6757da..43575dc 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -47,12 +47,10 @@ def main(): callback = CustomCallback(1, env) model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args)) - iterations = args.iterations + steps = args.steps - if iterations < 10_000: - iterations = 10_000 - model.learn(iterations, callback=callback) + model.learn(steps, callback=callback) #W mean_reward, std_reward = evaluate_policy(model, model.get_env()) diff --git a/examples/shields/rl/14_train_eval.py b/examples/shields/rl/14_train_eval.py index 1591075..56fa8ee 100644 --- a/examples/shields/rl/14_train_eval.py +++ b/examples/shields/rl/14_train_eval.py @@ -81,11 +81,11 @@ def ppo(args): config.build() ) - iterations = args.iterations + evaluations = args.evaluations - for i in range(iterations): + for i in range(evaluations): algo.train() if i % 5 == 0: @@ -96,7 +96,7 @@ def ppo(args): writer = SummaryWriter(log_dir=eval_log_dir) csv_logger = CSVLogger(config=config, logdir=eval_log_dir) - for i in range(iterations): + for i in range(evaluations): eval_result = algo.evaluate() print(pretty_print(eval_result)) print(eval_result) diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index 808b431..9706fdd 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -86,9 +86,12 @@ def ppo(args): ), run_config=air.RunConfig( stop = {"episode_reward_mean": 94, - "timesteps_total": args.steps, - "training_iteration": args.iterations}, - checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ), + "timesteps_total": args.steps,}, + checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, + num_to_keep=1, + checkpoint_score_attribute="episode_reward_mean", + ), + storage_path=F"{logdir}" ) , @@ -116,7 +119,7 @@ def ppo(args): csv_logger = CSVLogger(config=config, logdir=eval_log_dir) - for i in range(args.iterations): + for i in range(args.evaluations): eval_result = algo.evaluate() print(pretty_print(eval_result)) print(eval_result) diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index 3e67c81..73e58a5 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -39,7 +39,7 @@ def extract_keys(env): return keys def create_log_dir(args): - return F"{args.log_dir}{args.algorithm}-shielding:{args.shielding}-iterations:{args.iterations}" + return F"{args.log_dir}{args.algorithm}-shielding:{args.shielding}-evaluations:{args.evaluations}-steps:{args.steps}" def get_action_index_mapping(actions): @@ -90,7 +90,7 @@ def parse_arguments(argparse): parser.add_argument("--prism_path", default="grid") parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"]) parser.add_argument("--log_dir", default="../log_results/") - parser.add_argument("--iterations", type=int, default=10 ) + parser.add_argument("--evaluations", type=int, default=10 ) parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" parser.add_argument("--workers", type=int, default=1) parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) diff --git a/examples/shields/rl/ppo_sb.ipynb b/examples/shields/rl/ppo_sb.ipynb index 400f1db..aa76f87 100644 --- a/examples/shields/rl/ppo_sb.ipynb +++ b/examples/shields/rl/ppo_sb.ipynb @@ -9,38 +9,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "pygame 2.5.1 (SDL 2.28.2, Python 3.10.12)\n", - "Hello from the pygame community. https://www.pygame.org/contribute.html\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-09-08 10:00:46.717621: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", - "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "2023-09-08 10:00:47.771352: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" - ] - }, - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'examples'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[1], line 9\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mstable_baselines3\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mcommon\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mcallbacks\u001b[39;00m \u001b[39mimport\u001b[39;00m BaseCallback\n\u001b[1;32m 7\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mgymnasium\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mgym\u001b[39;00m\n\u001b[0;32m----> 9\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mexamples\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mshields\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mrl\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mshieldhandlers\u001b[39;00m \u001b[39mimport\u001b[39;00m MiniGridShieldHandler, create_shield_query\n\u001b[1;32m 10\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mexamples\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mshields\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mrl\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mwrappers\u001b[39;00m \u001b[39mimport\u001b[39;00m MiniGridSbShieldingWrapper\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'examples'" - ] - } - ], + "outputs": [], "source": [ "from sb3_contrib import MaskablePPO\n", "from sb3_contrib.common.maskable.evaluation import evaluate_policy\n", @@ -56,7 +27,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -66,239 +37,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using cpu device\n", - "Wrapping the env with a `Monitor` wrapper\n", - "Wrapping the env in a DummyVecEnv.\n", - "Wrapping the env in a VecTransposeImage.\n", - "\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Reading :\tWGXR WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWGVR VRVRVRVRVRWG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG GGWG\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "\n", - "Write to file Grid.shield.\n", - "\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Reading :\tWGXR WG\n", - "Reading :\tWG VR WG\n", - "Reading :\tWG VR WG\n", - "Reading :\tWG VR WG\n", - "Reading :\tWG VR WG\n", - "Reading :\tWG VR WG\n", - "Reading :\tWG VRGGWG\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "\n", - "Write to file Grid.shield.\n", - "\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Reading :\tWGXR WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG VRVRVRVRVRVRWG\n", - "Reading :\tWG GGWG\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "\n", - "Write to file Grid.shield.\n", - "\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Reading :\tWGXR WG\n", - "Reading :\tWGVRVRVR VRVRVRWG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG GGWG\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "\n", - "Write to file Grid.shield.\n", - "\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Reading :\tWGXR VR WG\n", - "Reading :\tWG VR WG\n", - "Reading :\tWG VR WG\n", - "Reading :\tWG VR WG\n", - "Reading :\tWG VR WG\n", - "Reading :\tWG VR WG\n", - "Reading :\tWG GGWG\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "\n", - "Write to file Grid.shield.\n", - "\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Reading :\tWGXR WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWGVRVRVR VRVRVRWG\n", - "Reading :\tWG GGWG\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "\n", - "Write to file Grid.shield.\n", - "\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Reading :\tWGXR VR WG\n", - "Reading :\tWG VR WG\n", - "Reading :\tWG VR WG\n", - "Reading :\tWG VR WG\n", - "Reading :\tWG VR WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG VRGGWG\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "\n", - "Write to file Grid.shield.\n", - "\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Reading :\tWGXR WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWGVRVRVRVRVRVR WG\n", - "Reading :\tWG GGWG\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "\n", - "Write to file Grid.shield.\n", - "---------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 283 |\n", - "| ep_rew_mean | 0.157 |\n", - "| time/ | |\n", - "| fps | 165 |\n", - "| iterations | 1 |\n", - "| time_elapsed | 12 |\n", - "| total_timesteps | 2048 |\n", - "---------------------------------\n", - "\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Reading :\tWGXR WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWGVRVRVRVRVRVR WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG WG\n", - "Reading :\tWG GGWG\n", - "Reading :\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWG WG\n", - "Background:\tWGWGWGWGWGWGWGWGWG\n", - "\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[3], line 9\u001b[0m\n\u001b[1;32m 5\u001b[0m env \u001b[39m=\u001b[39m ActionMasker(env, mask_fn)\n\u001b[1;32m 6\u001b[0m model \u001b[39m=\u001b[39m MaskablePPO(MaskableActorCriticPolicy, env, verbose\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[0;32m----> 9\u001b[0m model\u001b[39m.\u001b[39;49mlearn(\u001b[39m10_000\u001b[39;49m)\n", - "File \u001b[0;32m~/Documents/Projects/tempestpy/env/lib/python3.10/site-packages/sb3_contrib/ppo_mask/ppo_mask.py:526\u001b[0m, in \u001b[0;36mMaskablePPO.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, use_masking, progress_bar)\u001b[0m\n\u001b[1;32m 523\u001b[0m callback\u001b[39m.\u001b[39mon_training_start(\u001b[39mlocals\u001b[39m(), \u001b[39mglobals\u001b[39m())\n\u001b[1;32m 525\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_timesteps \u001b[39m<\u001b[39m total_timesteps:\n\u001b[0;32m--> 526\u001b[0m continue_training \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcollect_rollouts(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49menv, callback, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mrollout_buffer, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mn_steps, use_masking)\n\u001b[1;32m 528\u001b[0m \u001b[39mif\u001b[39;00m continue_training \u001b[39mis\u001b[39;00m \u001b[39mFalse\u001b[39;00m:\n\u001b[1;32m 529\u001b[0m \u001b[39mbreak\u001b[39;00m\n", - "File \u001b[0;32m~/Documents/Projects/tempestpy/env/lib/python3.10/site-packages/sb3_contrib/ppo_mask/ppo_mask.py:306\u001b[0m, in \u001b[0;36mMaskablePPO.collect_rollouts\u001b[0;34m(self, env, callback, rollout_buffer, n_rollout_steps, use_masking)\u001b[0m\n\u001b[1;32m 303\u001b[0m actions, values, log_probs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpolicy(obs_tensor, action_masks\u001b[39m=\u001b[39maction_masks)\n\u001b[1;32m 305\u001b[0m actions \u001b[39m=\u001b[39m actions\u001b[39m.\u001b[39mcpu()\u001b[39m.\u001b[39mnumpy()\n\u001b[0;32m--> 306\u001b[0m new_obs, rewards, dones, infos \u001b[39m=\u001b[39m env\u001b[39m.\u001b[39;49mstep(actions)\n\u001b[1;32m 308\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_timesteps \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m env\u001b[39m.\u001b[39mnum_envs\n\u001b[1;32m 310\u001b[0m \u001b[39m# Give access to local variables\u001b[39;00m\n", - "File \u001b[0;32m~/Documents/Projects/tempestpy/env/lib/python3.10/site-packages/stable_baselines3/common/vec_env/base_vec_env.py:197\u001b[0m, in \u001b[0;36mVecEnv.step\u001b[0;34m(self, actions)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 191\u001b[0m \u001b[39mStep the environments with the given action\u001b[39;00m\n\u001b[1;32m 192\u001b[0m \n\u001b[1;32m 193\u001b[0m \u001b[39m:param actions: the action\u001b[39;00m\n\u001b[1;32m 194\u001b[0m \u001b[39m:return: observation, reward, done, information\u001b[39;00m\n\u001b[1;32m 195\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 196\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstep_async(actions)\n\u001b[0;32m--> 197\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstep_wait()\n", - "File \u001b[0;32m~/Documents/Projects/tempestpy/env/lib/python3.10/site-packages/stable_baselines3/common/vec_env/vec_transpose.py:95\u001b[0m, in \u001b[0;36mVecTransposeImage.step_wait\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mstep_wait\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m VecEnvStepReturn:\n\u001b[0;32m---> 95\u001b[0m observations, rewards, dones, infos \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mvenv\u001b[39m.\u001b[39;49mstep_wait()\n\u001b[1;32m 97\u001b[0m \u001b[39m# Transpose the terminal observations\u001b[39;00m\n\u001b[1;32m 98\u001b[0m \u001b[39mfor\u001b[39;00m idx, done \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(dones):\n", - "File \u001b[0;32m~/Documents/Projects/tempestpy/env/lib/python3.10/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py:70\u001b[0m, in \u001b[0;36mDummyVecEnv.step_wait\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuf_dones[env_idx]:\n\u001b[1;32m 68\u001b[0m \u001b[39m# save final observation where user can get it, then reset\u001b[39;00m\n\u001b[1;32m 69\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuf_infos[env_idx][\u001b[39m\"\u001b[39m\u001b[39mterminal_observation\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m obs\n\u001b[0;32m---> 70\u001b[0m obs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreset_infos[env_idx] \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49menvs[env_idx]\u001b[39m.\u001b[39;49mreset()\n\u001b[1;32m 71\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_save_obs(env_idx, obs)\n\u001b[1;32m 72\u001b[0m \u001b[39mreturn\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_obs_from_buf(), np\u001b[39m.\u001b[39mcopy(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuf_rews), np\u001b[39m.\u001b[39mcopy(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuf_dones), deepcopy(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuf_infos))\n", - "File \u001b[0;32m~/Documents/Projects/tempestpy/env/lib/python3.10/site-packages/stable_baselines3/common/monitor.py:83\u001b[0m, in \u001b[0;36mMonitor.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mExpected you to pass keyword argument \u001b[39m\u001b[39m{\u001b[39;00mkey\u001b[39m}\u001b[39;00m\u001b[39m into reset\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 82\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcurrent_reset_info[key] \u001b[39m=\u001b[39m value\n\u001b[0;32m---> 83\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49menv\u001b[39m.\u001b[39;49mreset(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", - "File \u001b[0;32m~/Documents/Projects/tempestpy/env/lib/python3.10/site-packages/gymnasium/core.py:414\u001b[0m, in \u001b[0;36mWrapper.reset\u001b[0;34m(self, seed, options)\u001b[0m\n\u001b[1;32m 410\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mreset\u001b[39m(\n\u001b[1;32m 411\u001b[0m \u001b[39mself\u001b[39m, \u001b[39m*\u001b[39m, seed: \u001b[39mint\u001b[39m \u001b[39m|\u001b[39m \u001b[39mNone\u001b[39;00m \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m, options: \u001b[39mdict\u001b[39m[\u001b[39mstr\u001b[39m, Any] \u001b[39m|\u001b[39m \u001b[39mNone\u001b[39;00m \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 412\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mtuple\u001b[39m[WrapperObsType, \u001b[39mdict\u001b[39m[\u001b[39mstr\u001b[39m, Any]]:\n\u001b[1;32m 413\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 414\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49menv\u001b[39m.\u001b[39;49mreset(seed\u001b[39m=\u001b[39;49mseed, options\u001b[39m=\u001b[39;49moptions)\n", - "File \u001b[0;32m~/Documents/Projects/tempestpy/examples/shields/rl/Wrappers.py:222\u001b[0m, in \u001b[0;36mMiniGridSbShieldingWrapper.reset\u001b[0;34m(self, seed, options)\u001b[0m\n\u001b[1;32m 219\u001b[0m obs, infos \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39menv\u001b[39m.\u001b[39mreset(seed\u001b[39m=\u001b[39mseed, options\u001b[39m=\u001b[39moptions)\n\u001b[1;32m 221\u001b[0m keys \u001b[39m=\u001b[39m extract_keys(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39menv)\n\u001b[0;32m--> 222\u001b[0m shield \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mshield_creator\u001b[39m.\u001b[39;49mcreate_shield(env\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49menv)\n\u001b[1;32m 224\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkeys \u001b[39m=\u001b[39m keys\n\u001b[1;32m 225\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mshield \u001b[39m=\u001b[39m shield\n", - "File \u001b[0;32m~/Documents/Projects/tempestpy/examples/shields/rl/ShieldHandlers.py:82\u001b[0m, in \u001b[0;36mMiniGridShieldHandler.create_shield\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__export_grid_to_text(env)\n\u001b[1;32m 80\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__create_prism()\n\u001b[0;32m---> 82\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__create_shield_dict()\n", - "File \u001b[0;32m~/Documents/Projects/tempestpy/examples/shields/rl/ShieldHandlers.py:66\u001b[0m, in \u001b[0;36mMiniGridShieldHandler.__create_shield_dict\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 64\u001b[0m choice \u001b[39m=\u001b[39m shield_scheduler\u001b[39m.\u001b[39mget_choice(stateID)\n\u001b[1;32m 65\u001b[0m choices \u001b[39m=\u001b[39m choice\u001b[39m.\u001b[39mchoice_map\n\u001b[0;32m---> 66\u001b[0m state_valuation \u001b[39m=\u001b[39m model\u001b[39m.\u001b[39;49mstate_valuations\u001b[39m.\u001b[39;49mget_string(stateID)\n\u001b[1;32m 68\u001b[0m actions_to_be_executed \u001b[39m=\u001b[39m [(choice[\u001b[39m1\u001b[39m] ,model\u001b[39m.\u001b[39mchoice_labeling\u001b[39m.\u001b[39mget_labels_of_choice(model\u001b[39m.\u001b[39mget_choice_index(stateID, choice[\u001b[39m1\u001b[39m]))) \u001b[39mfor\u001b[39;00m choice \u001b[39min\u001b[39;00m choices]\n\u001b[1;32m 70\u001b[0m action_dictionary[state_valuation] \u001b[39m=\u001b[39m actions_to_be_executed\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], + "outputs": [], "source": [ "shield_creator = MiniGridShieldHandler(\"grid.txt\", \"./main\", \"grid.prism\", \"Pmax=? [G !\\\"AgentIsInLavaAndNotDone\\\"]\")\n", "\n",