5 changed files with 650 additions and 63 deletions
			
			
		- 
					255examples/shields/rl/11_minigridrl.py
- 
					132examples/shields/rl/12_basic_training.py
- 
					91examples/shields/rl/MaskEnvironments.py
- 
					81examples/shields/rl/MaskModels.py
- 
					152examples/shields/rl/Wrapper.py
| @ -0,0 +1,255 @@ | |||
| from typing import Dict, Optional, Union | |||
| from ray.rllib.env.base_env import BaseEnv | |||
| from ray.rllib.evaluation import RolloutWorker | |||
| from ray.rllib.evaluation.episode import Episode | |||
| from ray.rllib.evaluation.episode_v2 import EpisodeV2 | |||
| from ray.rllib.policy import Policy | |||
| from ray.rllib.utils.typing import PolicyID | |||
| import stormpy | |||
| import stormpy.core | |||
| import stormpy.simulator | |||
| 
 | |||
| 
 | |||
| import stormpy.shields | |||
| import stormpy.logic | |||
| 
 | |||
| import stormpy.examples | |||
| import stormpy.examples.files | |||
| import os | |||
| 
 | |||
| import gymnasium as gym | |||
| 
 | |||
| import minigrid | |||
| import numpy as np | |||
| 
 | |||
| import ray | |||
| from ray.tune import register_env | |||
| from ray.rllib.algorithms.ppo import PPOConfig | |||
| from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator | |||
| from ray import tune, air | |||
| from ray.rllib.algorithms.callbacks import DefaultCallbacks | |||
| from ray.tune.logger import pretty_print | |||
| from ray.rllib.algorithms import ppo | |||
| from ray.rllib.models import ModelCatalog | |||
| 
 | |||
| from ray.rllib.utils.torch_utils import FLOAT_MIN | |||
| 
 | |||
| from ray.rllib.models.preprocessors import get_preprocessor | |||
| from MaskEnvironments import ParametricActionsMiniGridEnv | |||
| from MaskModels import TorchActionMaskModel | |||
| from Wrapper import OneHotWrapper, MiniGridEnvWrapper, ImgObsWrapper | |||
| 
 | |||
| import matplotlib.pyplot as plt | |||
| 
 | |||
| import argparse | |||
| 
 | |||
| 
 | |||
| 
 | |||
| class MyCallbacks(DefaultCallbacks): | |||
|     def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: | |||
|         # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") | |||
|         env = base_env.get_sub_environments()[0] | |||
|         episode.user_data["count"] = 0 | |||
|         # print(env.printGrid()) | |||
|         # print(env.action_space.n) | |||
|         # print(env.actions) | |||
|         # print(env.mission) | |||
|         # print(env.observation_space) | |||
|         # img = env.get_frame() | |||
|         # plt.imshow(img) | |||
|         # plt.show() | |||
|         | |||
|     def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: | |||
|          episode.user_data["count"] = episode.user_data["count"] + 1 | |||
|          env = base_env.get_sub_environments()[0] | |||
|          print(env.env.env.printGrid()) | |||
|      | |||
|     def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: | |||
|         # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") | |||
|         env = base_env.get_sub_environments()[0] | |||
|         # print(env.env.env.printGrid()) | |||
|         # print(episode.user_data["count"]) | |||
|          | |||
|      | |||
|         | |||
| 
 | |||
| def parse_arguments(argparse): | |||
|     parser = argparse.ArgumentParser() | |||
|     # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0") | |||
|     parser.add_argument("--env", help="gym environment to load", default="MiniGrid-LavaCrossingS9N1-v0") | |||
|     parser.add_argument("--seed", type=int, help="seed for environment", default=1) | |||
|     parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32) | |||
|     parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees") | |||
|     parser.add_argument("--grid_path", default="Grid.txt") | |||
|     parser.add_argument("--prism_path", default="Grid.PRISM") | |||
|      | |||
|     args = parser.parse_args() | |||
|      | |||
|     return args | |||
| 
 | |||
| 
 | |||
| def env_creater_custom(config): | |||
|     # name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") | |||
|     # # name = config.get("name", "MiniGrid-Empty-8x8-v0") | |||
|     framestack = config.get("framestack", 4) | |||
|      | |||
|     # env = gym.make(name) | |||
|     # env = ParametricActionsMiniGridEnv(config) | |||
|     name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") | |||
|     framestack = config.get("framestack", 4) | |||
|      | |||
|     env = gym.make(name) | |||
|     env = MiniGridEnvWrapper(env) | |||
|     # env = minigrid.wrappers.ImgObsWrapper(env) | |||
|     # env = ImgObsWrapper(env) | |||
|     env = OneHotWrapper(env, | |||
|                         config.vector_index if hasattr(config, "vector_index") else 0, | |||
|                         framestack=framestack | |||
|                         ) | |||
|      | |||
|     obs = env.observation_space.sample() | |||
|     obs2, infos = env.reset(seed=None, options={}) | |||
|      | |||
|     print(F"Obs is {obs} before reset. After reset: {obs2}") | |||
|     # env = minigrid.wrappers.RGBImgPartialObsWrapper(env) | |||
|      | |||
|     print(F"Created Custom Minigrid Environment is {env}") | |||
| 
 | |||
|     return env | |||
| 
 | |||
| def env_creater_cart(config): | |||
|     return gym.make("CartPole-v1") | |||
| 
 | |||
| def env_creater(config): | |||
|     name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") | |||
|     # name = config.get("name", "MiniGrid-Empty-8x8-v0") | |||
|     framestack = config.get("framestack", 4) | |||
|      | |||
|     env = gym.make(name) | |||
|     # env = minigrid.wrappers.RGBImgPartialObsWrapper(env) | |||
|     env = minigrid.wrappers.ImgObsWrapper(env) | |||
|     env = OneHotWrapper(env, | |||
|                         config.vector_index if hasattr(config, "vector_index") else 0, | |||
|                         framestack=framestack | |||
|                         ) | |||
|        | |||
|     print(F"Created Minigrid Environment is {env}") | |||
| 
 | |||
|     return env | |||
| 
 | |||
| 
 | |||
| 
 | |||
| def create_shield(grid_file, prism_path): | |||
|     os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}") | |||
|      | |||
|     f = open(prism_path, "a") | |||
|     f.write("label \"AgentIsInLava\" = AgentIsInLava;") | |||
|     f.close() | |||
|      | |||
|      | |||
|     program = stormpy.parse_prism_program(prism_path) | |||
|     formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" | |||
|      | |||
|     formulas = stormpy.parse_properties_for_prism_program(formula_str, program) | |||
|     options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) | |||
|     options.set_build_state_valuations(True) | |||
|     options.set_build_choice_labels(True) | |||
|     options.set_build_all_labels() | |||
|     model = stormpy.build_sparse_model_with_options(program, options) | |||
|      | |||
|     shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)  | |||
|     result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) | |||
|      | |||
|     assert result.has_scheduler | |||
|     assert result.has_shield | |||
|     shield = result.shield | |||
| 
 | |||
|     stormpy.shields.export_shield(model, shield, "Grid.shield") | |||
|      | |||
|     return shield.construct(), model | |||
| 
 | |||
| def export_grid_to_text(env, grid_file): | |||
|     f = open(grid_file, "w") | |||
|     # print(env) | |||
|     f.write(env.printGrid(init=True)) | |||
|     # f.write(env.pprint_grid()) | |||
|     f.close() | |||
| 
 | |||
| def create_environment(args): | |||
|     env_id= args.env | |||
|     env = gym.make(env_id) | |||
|     env.reset() | |||
|     return env | |||
| 
 | |||
| 
 | |||
| def main(): | |||
|     args = parse_arguments(argparse) | |||
| 
 | |||
|     env = create_environment(args) | |||
|     ray.init(num_cpus=3) | |||
| 
 | |||
|     # print(env.pprint_grid()) | |||
|     # print(env.printGrid(init=False)) | |||
|      | |||
|     grid_file = args.grid_path | |||
|     export_grid_to_text(env, grid_file) | |||
|      | |||
|     prism_path = args.prism_path | |||
|     shield, model = create_shield(grid_file, prism_path) | |||
|     shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} | |||
|     | |||
|     print(shield_dict) | |||
|     for state_id in model.states: | |||
|         choices = shield.get_choice(state_id) | |||
|         print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") | |||
|          | |||
|     env_name = "mini-grid" | |||
|     register_env(env_name, env_creater_custom) | |||
|     ModelCatalog.register_custom_model( | |||
|         "pa_model",  | |||
|         TorchActionMaskModel | |||
|     ) | |||
|      | |||
|     config = (PPOConfig() | |||
|         .rollouts(num_rollout_workers=1) | |||
|         .resources(num_gpus=0) | |||
|         .environment(env="mini-grid") | |||
|         .framework("torch")        | |||
|         .experimental(_disable_preprocessor_api=False) | |||
|         .callbacks(MyCallbacks) | |||
|         .rl_module(_enable_rl_module_api = False) | |||
|         .training(_enable_learner_api=False ,model={ | |||
|             "custom_model": "pa_model", | |||
|             "custom_model_config" : {"shield": shield_dict, "no_masking": True} | |||
|             # "fcnet_hiddens": [256,256], | |||
|             # "fcnet_activation": "relu", | |||
|              | |||
|         })) | |||
| 
 | |||
|      | |||
|     algo =( | |||
|          | |||
|         config.build() | |||
|     ) | |||
|     episode_reward = 0 | |||
|     terminated = truncated = False | |||
|     obs, info = env.reset() | |||
|      | |||
|     # while not terminated and not truncated: | |||
|     #     action = algo.compute_single_action(obs) | |||
|     #     obs, reward, terminated, truncated = env.step(action) | |||
|      | |||
|     for i in range(30): | |||
|         result = algo.train() | |||
|         print(pretty_print(result)) | |||
| 
 | |||
|         if i % 5 == 0: | |||
|             checkpoint_dir = algo.save() | |||
|             print(f"Checkpoint saved in directory {checkpoint_dir}") | |||
| 
 | |||
|     | |||
| 
 | |||
|     ray.shutdown() | |||
| 
 | |||
| if __name__ == '__main__': | |||
|     main() | |||
| @ -0,0 +1,91 @@ | |||
| import random | |||
| import minigrid | |||
| 
 | |||
| import gymnasium as gym | |||
| import numpy as np | |||
| from gymnasium.spaces import Box, Dict, Discrete | |||
| from Wrapper import OneHotWrapper | |||
| 
 | |||
| 
 | |||
| class ParametricActionsMiniGridEnv(gym.Env): | |||
|     """Parametric action version of MiniGrid. | |||
| 
 | |||
|     """ | |||
| 
 | |||
|     def __init__(self, config): | |||
|         | |||
|         name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") | |||
|         self.left_action_embed = np.random.randn(2) | |||
|         self.right_action_embed = np.random.randn(2) | |||
|         framestack = config.get("framestack", 4) | |||
|          | |||
|         # env = gym.make(name) | |||
|         # env = minigrid.wrappers.ImgObsWrapper(env) | |||
|         # env = OneHotWrapper(env, | |||
|         #                 config.vector_index if hasattr(config, "vector_index") else 0, | |||
|         #                 framestack=framestack | |||
|         #                 ) | |||
|         self.wrapped = gym.make(name) | |||
|         # self.observation_space = Dict( | |||
|         #     { | |||
|         #          "action_mask": None, | |||
|         #          "avail_actions": None, | |||
|         #         "cart": self.wrapped.observation_space, | |||
|         #     } | |||
|         # ) | |||
|         print(F"Wrapped environment is {self.wrapped}") | |||
|         self.step_count = 0 | |||
|         self.action_space = self.wrapped.action_space | |||
|         self.observation_space = self.wrapped.observation_space | |||
|          | |||
|          | |||
|     def update_avail_actions(self): | |||
|         self.action_assignments = np.array( | |||
|             [[0.0, 0.0]] * self.action_space.n, dtype=np.float32 | |||
|         ) | |||
|         self.action_mask = np.array([0.0] * self.action_space.n, dtype=np.int8) | |||
|         self.left_idx, self.right_idx = random.sample(range(self.action_space.n), 2) | |||
|         self.action_assignments[self.left_idx] = self.left_action_embed | |||
|         self.action_assignments[self.right_idx] = self.right_action_embed | |||
|         self.action_mask[self.left_idx] = 1 | |||
|         self.action_mask[self.right_idx] = 1 | |||
| 
 | |||
|     def reset(self, *, seed=None, options=None): | |||
|         self.update_avail_actions() | |||
|         obs, infos = self.wrapped.reset() | |||
|         return obs, infos | |||
|         return { | |||
|             "action_mask": self.action_mask, | |||
|             "avail_actions": self.action_assignments, | |||
|             "cart": obs, | |||
|         }, infos | |||
| 
 | |||
|     def step(self, action): | |||
|         if action == self.left_idx: | |||
|             actual_action = 0 | |||
|         elif action == self.right_idx: | |||
|             actual_action = 1 | |||
|         else: | |||
|             actual_action = 0 | |||
|             # raise ValueError( | |||
|             #     "Chosen action was not one of the non-zero action embeddings", | |||
|             #     action, | |||
|             #     self.action_assignments, | |||
|             #     self.action_mask, | |||
|             #     self.left_idx, | |||
|             #     self.right_idx, | |||
|             # ) | |||
|         orig_obs, rew, done, truncated, info = self.wrapped.step(actual_action) | |||
|         self.update_avail_actions() | |||
|         self.action_mask = self.action_mask.astype(np.int8) | |||
|         print(F"Info is {info}") | |||
|         info["Hello" : "Ich kenn mich nix aus"] | |||
|         return orig_obs, rew, done, truncated, info | |||
|         obs = { | |||
|             "action_mask": self.action_mask, | |||
|             "avail_actions": self.action_assignments, | |||
|             "cart": orig_obs, | |||
|         } | |||
|         return obs, rew, done, truncated, info | |||
| 
 | |||
|     | |||
| @ -0,0 +1,81 @@ | |||
| from typing import Dict, Optional, Union | |||
| 
 | |||
| from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel | |||
| from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC | |||
| from ray.rllib.models.tf.fcnet import FullyConnectedNetwork | |||
| from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 | |||
| from ray.rllib.utils.framework import try_import_tf, try_import_torch | |||
| from ray.rllib.utils.torch_utils import FLOAT_MIN, FLOAT_MAX | |||
| 
 | |||
| torch, nn = try_import_torch() | |||
| 
 | |||
| 
 | |||
| 
 | |||
| class TorchActionMaskModel(TorchModelV2, nn.Module): | |||
|     """PyTorch version of above ActionMaskingModel.""" | |||
| 
 | |||
|     def __init__( | |||
|         self, | |||
|         obs_space, | |||
|         action_space, | |||
|         num_outputs, | |||
|         model_config, | |||
|         name, | |||
|         **kwargs, | |||
|     ): | |||
|         orig_space = getattr(obs_space, "original_space", obs_space) | |||
|         custom_config = model_config['custom_model_config'] | |||
|         print(F"Original Space is: {orig_space}") | |||
|         #print(model_config) | |||
|         print(F"Observation space in model: {obs_space}") | |||
|          | |||
|         TorchModelV2.__init__( | |||
|             self, obs_space, action_space, num_outputs, model_config, name, **kwargs | |||
|         ) | |||
|         nn.Module.__init__(self) | |||
|          | |||
|         assert("shield" in custom_config) | |||
|          | |||
|         self.shield = custom_config["shield"] | |||
| 
 | |||
|         self.internal_model = TorchFC( | |||
|             orig_space["data"], | |||
|             action_space, | |||
|             num_outputs, | |||
|             model_config, | |||
|             name + "_internal", | |||
|         ) | |||
| 
 | |||
|         # disable action masking --> will likely lead to invalid actions | |||
|         self.no_masking = False | |||
|         if "no_masking" in model_config["custom_model_config"]: | |||
|             self.no_masking = model_config["custom_model_config"]["no_masking"] | |||
| 
 | |||
|     def forward(self, input_dict, state, seq_lens): | |||
|         # Extract the available actions tensor from the observation. | |||
|         # print(F"Input dict is {input_dict} at obs: {input_dict['obs']}") | |||
|         # print(F"State is {state}") | |||
|          | |||
|         action_mask = [] | |||
|          | |||
|         # print(input_dict["env"]) | |||
| 
 | |||
|         # Compute the unmasked logits. | |||
|         logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]}) | |||
| 
 | |||
|         # If action masking is disabled, directly return unmasked logits | |||
|         if self.no_masking: | |||
|             return logits, state | |||
| 
 | |||
|         assert(False) | |||
| 
 | |||
|         return logits, state | |||
|         # Convert action_mask into a [0.0 || -inf]-type mask. | |||
|         # inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) | |||
|         # masked_logits = logits + inf_mask | |||
| 
 | |||
|         # # Return masked logits. | |||
|         # return masked_logits, state | |||
| 
 | |||
|     def value_function(self): | |||
|         return self.internal_model.value_function() | |||
| @ -0,0 +1,152 @@ | |||
| import gymnasium as gym | |||
| import numpy as np | |||
| 
 | |||
| 
 | |||
| from gymnasium.spaces import Dict, Box | |||
| from collections import deque | |||
| from ray.rllib.utils.numpy import one_hot | |||
| 
 | |||
| class OneHotWrapper(gym.core.ObservationWrapper): | |||
|     def __init__(self, env, vector_index, framestack): | |||
|         super().__init__(env) | |||
|         self.framestack = framestack | |||
|         # 49=7x7 field of vision; 11=object types; 6=colors; 3=state types. | |||
|         # +4: Direction. | |||
|         self.single_frame_dim = 49 * (11 + 6 + 3) + 4 | |||
|         self.init_x = None | |||
|         self.init_y = None | |||
|         self.x_positions = [] | |||
|         self.y_positions = [] | |||
|         self.x_y_delta_buffer = deque(maxlen=100) | |||
|         self.vector_index = vector_index | |||
|         self.frame_buffer = deque(maxlen=self.framestack) | |||
|         for _ in range(self.framestack): | |||
|             self.frame_buffer.append(np.zeros((self.single_frame_dim,))) | |||
| 
 | |||
|         self.observation_space = Dict( | |||
|             { | |||
|                 "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32), | |||
|                 "avail_actions": gym.spaces.Box(0, 10, shape=(10,), dtype=int), | |||
|             } | |||
|             )  | |||
|          | |||
|          | |||
|         print(F"Set obersvation space to {self.observation_space}") | |||
|          | |||
| 
 | |||
|     def observation(self, obs): | |||
|         # Debug output: max-x/y positions to watch exploration progress. | |||
|         # print(F"Initial observation in Wrapper {obs}") | |||
|         if self.step_count == 0: | |||
|             for _ in range(self.framestack): | |||
|                 self.frame_buffer.append(np.zeros((self.single_frame_dim,))) | |||
|             if self.vector_index == 0: | |||
|                 if self.x_positions: | |||
|                     max_diff = max( | |||
|                         np.sqrt( | |||
|                             (np.array(self.x_positions) - self.init_x) ** 2 | |||
|                             + (np.array(self.y_positions) - self.init_y) ** 2 | |||
|                         ) | |||
|                     ) | |||
|                     self.x_y_delta_buffer.append(max_diff) | |||
|                     print( | |||
|                         "100-average dist travelled={}".format( | |||
|                             np.mean(self.x_y_delta_buffer) | |||
|                         ) | |||
|                     ) | |||
|                     self.x_positions = [] | |||
|                     self.y_positions = [] | |||
|                 self.init_x = self.agent_pos[0] | |||
|                 self.init_y = self.agent_pos[1] | |||
| 
 | |||
|        | |||
|         self.x_positions.append(self.agent_pos[0]) | |||
|         self.y_positions.append(self.agent_pos[1]) | |||
|          | |||
|         image = obs["data"] | |||
| 
 | |||
|         # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten. | |||
|         objects = one_hot(image[:, :, 0], depth=11) | |||
|         colors = one_hot(image[:, :, 1], depth=6) | |||
|         states = one_hot(image[:, :, 2], depth=3) | |||
|        | |||
|         all_ = np.concatenate([objects, colors, states], -1) | |||
|         all_flat = np.reshape(all_, (-1,)) | |||
|         direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32) | |||
|         single_frame = np.concatenate([all_flat, direction]) | |||
|         self.frame_buffer.append(single_frame) | |||
|          | |||
|         #obs["one-hot"] = np.concatenate(self.frame_buffer) | |||
|         tmp = {"data": np.concatenate(self.frame_buffer), "avail_actions": obs["avail_actions"] } | |||
|         return tmp#np.concatenate(self.frame_buffer) | |||
| 
 | |||
| 
 | |||
| class MiniGridEnvWrapper(gym.core.Wrapper): | |||
|     def __init__(self, env): | |||
|         super(MiniGridEnvWrapper, self).__init__(env) | |||
|         self.observation_space = Dict( | |||
|             { | |||
|                 "data": env.observation_space.spaces["image"], | |||
|                 "avail_actions" : Box(0, 10, shape=(10,), dtype=np.int8), | |||
|             } | |||
|         ) | |||
|          | |||
|          | |||
|     def test(self): | |||
|         print("Testing some stuff") | |||
|      | |||
|     def reset(self, *, seed=None, options=None): | |||
|         obs, infos = self.env.reset() | |||
|         return { | |||
|             "data": obs["image"], | |||
|             "avail_actions": np.array([0.0] * 10, dtype=np.int8) | |||
|         }, infos | |||
|      | |||
|     def step(self, action): | |||
|         orig_obs, rew, done, truncated, info = self.env.step(action) | |||
|          | |||
|         self.test() | |||
|         #print(F"Original observation is {orig_obs}") | |||
|         obs = { | |||
|             "data": orig_obs["image"], | |||
|             "avail_actions":  np.array([0.0] * 10, dtype=np.int8), | |||
|         } | |||
|          | |||
|         #print(F"Info is {info}") | |||
|         return obs, rew, done, truncated, info | |||
|      | |||
|      | |||
| 
 | |||
| 
 | |||
| class ImgObsWrapper(gym.core.ObservationWrapper): | |||
|     """ | |||
|     Use the image as the only observation output, no language/mission. | |||
| 
 | |||
|     Example: | |||
|         >>> import gymnasium as gym | |||
|         >>> from minigrid.wrappers import ImgObsWrapper | |||
|         >>> env = gym.make("MiniGrid-Empty-5x5-v0") | |||
|         >>> obs, _ = env.reset() | |||
|         >>> obs.keys() | |||
|         dict_keys(['image', 'direction', 'mission']) | |||
|         >>> env = ImgObsWrapper(env) | |||
|         >>> obs, _ = env.reset() | |||
|         >>> obs.shape | |||
|         (7, 7, 3) | |||
|     """ | |||
| 
 | |||
|     def __init__(self, env): | |||
|         """A wrapper that makes image the only observation. | |||
| 
 | |||
|         Args: | |||
|             env: The environment to apply the wrapper | |||
|         """ | |||
|         super().__init__(env) | |||
|         self.observation_space = env.observation_space.spaces["image"] | |||
|         print(F"Set obersvation space to {self.observation_space}") | |||
| 
 | |||
|     def observation(self, obs): | |||
|         #print(F"obs in img obs wrapper {obs}") | |||
|         tmp = {"data": obs["image"], "Test": obs["Test"]} | |||
|          | |||
|         return tmp | |||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue