From cf18349819ff76c49becde9d583e0fdc9f68e905 Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Wed, 23 Aug 2023 10:46:14 +0200 Subject: [PATCH] basic action embedding --- examples/shields/rl/11_minigridrl.py | 255 ++++++++++++++++++ .../shields/{ => rl}/12_basic_training.py | 134 ++++----- examples/shields/rl/MaskEnvironments.py | 91 +++++++ examples/shields/rl/MaskModels.py | 81 ++++++ examples/shields/rl/Wrapper.py | 152 +++++++++++ 5 files changed, 650 insertions(+), 63 deletions(-) create mode 100644 examples/shields/rl/11_minigridrl.py rename examples/shields/{ => rl}/12_basic_training.py (72%) create mode 100644 examples/shields/rl/MaskEnvironments.py create mode 100644 examples/shields/rl/MaskModels.py create mode 100644 examples/shields/rl/Wrapper.py diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py new file mode 100644 index 0000000..b19309d --- /dev/null +++ b/examples/shields/rl/11_minigridrl.py @@ -0,0 +1,255 @@ +from typing import Dict, Optional, Union +from ray.rllib.env.base_env import BaseEnv +from ray.rllib.evaluation import RolloutWorker +from ray.rllib.evaluation.episode import Episode +from ray.rllib.evaluation.episode_v2 import EpisodeV2 +from ray.rllib.policy import Policy +from ray.rllib.utils.typing import PolicyID +import stormpy +import stormpy.core +import stormpy.simulator + + +import stormpy.shields +import stormpy.logic + +import stormpy.examples +import stormpy.examples.files +import os + +import gymnasium as gym + +import minigrid +import numpy as np + +import ray +from ray.tune import register_env +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator +from ray import tune, air +from ray.rllib.algorithms.callbacks import DefaultCallbacks +from ray.tune.logger import pretty_print +from ray.rllib.algorithms import ppo +from ray.rllib.models import ModelCatalog + +from ray.rllib.utils.torch_utils import FLOAT_MIN + +from ray.rllib.models.preprocessors import get_preprocessor +from MaskEnvironments import ParametricActionsMiniGridEnv +from MaskModels import TorchActionMaskModel +from Wrapper import OneHotWrapper, MiniGridEnvWrapper, ImgObsWrapper + +import matplotlib.pyplot as plt + +import argparse + + + +class MyCallbacks(DefaultCallbacks): + def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: + # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") + env = base_env.get_sub_environments()[0] + episode.user_data["count"] = 0 + # print(env.printGrid()) + # print(env.action_space.n) + # print(env.actions) + # print(env.mission) + # print(env.observation_space) + # img = env.get_frame() + # plt.imshow(img) + # plt.show() + + def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: + episode.user_data["count"] = episode.user_data["count"] + 1 + env = base_env.get_sub_environments()[0] + print(env.env.env.printGrid()) + + def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: + # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") + env = base_env.get_sub_environments()[0] + # print(env.env.env.printGrid()) + # print(episode.user_data["count"]) + + + + +def parse_arguments(argparse): + parser = argparse.ArgumentParser() + # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0") + parser.add_argument("--env", help="gym environment to load", default="MiniGrid-LavaCrossingS9N1-v0") + parser.add_argument("--seed", type=int, help="seed for environment", default=1) + parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32) + parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees") + parser.add_argument("--grid_path", default="Grid.txt") + parser.add_argument("--prism_path", default="Grid.PRISM") + + args = parser.parse_args() + + return args + + +def env_creater_custom(config): + # name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") + # # name = config.get("name", "MiniGrid-Empty-8x8-v0") + framestack = config.get("framestack", 4) + + # env = gym.make(name) + # env = ParametricActionsMiniGridEnv(config) + name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") + framestack = config.get("framestack", 4) + + env = gym.make(name) + env = MiniGridEnvWrapper(env) + # env = minigrid.wrappers.ImgObsWrapper(env) + # env = ImgObsWrapper(env) + env = OneHotWrapper(env, + config.vector_index if hasattr(config, "vector_index") else 0, + framestack=framestack + ) + + obs = env.observation_space.sample() + obs2, infos = env.reset(seed=None, options={}) + + print(F"Obs is {obs} before reset. After reset: {obs2}") + # env = minigrid.wrappers.RGBImgPartialObsWrapper(env) + + print(F"Created Custom Minigrid Environment is {env}") + + return env + +def env_creater_cart(config): + return gym.make("CartPole-v1") + +def env_creater(config): + name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") + # name = config.get("name", "MiniGrid-Empty-8x8-v0") + framestack = config.get("framestack", 4) + + env = gym.make(name) + # env = minigrid.wrappers.RGBImgPartialObsWrapper(env) + env = minigrid.wrappers.ImgObsWrapper(env) + env = OneHotWrapper(env, + config.vector_index if hasattr(config, "vector_index") else 0, + framestack=framestack + ) + + print(F"Created Minigrid Environment is {env}") + + return env + + + +def create_shield(grid_file, prism_path): + os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}") + + f = open(prism_path, "a") + f.write("label \"AgentIsInLava\" = AgentIsInLava;") + f.close() + + + program = stormpy.parse_prism_program(prism_path) + formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" + + formulas = stormpy.parse_properties_for_prism_program(formula_str, program) + options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) + options.set_build_state_valuations(True) + options.set_build_choice_labels(True) + options.set_build_all_labels() + model = stormpy.build_sparse_model_with_options(program, options) + + shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) + result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) + + assert result.has_scheduler + assert result.has_shield + shield = result.shield + + stormpy.shields.export_shield(model, shield, "Grid.shield") + + return shield.construct(), model + +def export_grid_to_text(env, grid_file): + f = open(grid_file, "w") + # print(env) + f.write(env.printGrid(init=True)) + # f.write(env.pprint_grid()) + f.close() + +def create_environment(args): + env_id= args.env + env = gym.make(env_id) + env.reset() + return env + + +def main(): + args = parse_arguments(argparse) + + env = create_environment(args) + ray.init(num_cpus=3) + + # print(env.pprint_grid()) + # print(env.printGrid(init=False)) + + grid_file = args.grid_path + export_grid_to_text(env, grid_file) + + prism_path = args.prism_path + shield, model = create_shield(grid_file, prism_path) + shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} + + print(shield_dict) + for state_id in model.states: + choices = shield.get_choice(state_id) + print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") + + env_name = "mini-grid" + register_env(env_name, env_creater_custom) + ModelCatalog.register_custom_model( + "pa_model", + TorchActionMaskModel + ) + + config = (PPOConfig() + .rollouts(num_rollout_workers=1) + .resources(num_gpus=0) + .environment(env="mini-grid") + .framework("torch") + .experimental(_disable_preprocessor_api=False) + .callbacks(MyCallbacks) + .rl_module(_enable_rl_module_api = False) + .training(_enable_learner_api=False ,model={ + "custom_model": "pa_model", + "custom_model_config" : {"shield": shield_dict, "no_masking": True} + # "fcnet_hiddens": [256,256], + # "fcnet_activation": "relu", + + })) + + + algo =( + + config.build() + ) + episode_reward = 0 + terminated = truncated = False + obs, info = env.reset() + + # while not terminated and not truncated: + # action = algo.compute_single_action(obs) + # obs, reward, terminated, truncated = env.step(action) + + for i in range(30): + result = algo.train() + print(pretty_print(result)) + + if i % 5 == 0: + checkpoint_dir = algo.save() + print(f"Checkpoint saved in directory {checkpoint_dir}") + + + + ray.shutdown() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/shields/12_basic_training.py b/examples/shields/rl/12_basic_training.py similarity index 72% rename from examples/shields/12_basic_training.py rename to examples/shields/rl/12_basic_training.py index a1e0623..72bb868 100644 --- a/examples/shields/12_basic_training.py +++ b/examples/shields/rl/12_basic_training.py @@ -33,11 +33,81 @@ from ray.tune.logger import pretty_print from ray.rllib.utils.numpy import one_hot from ray.rllib.algorithms import ppo +from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC +from ray.rllib.models.tf.fcnet import FullyConnectedNetwork +from ray.rllib.utils.torch_utils import FLOAT_MIN + from ray.rllib.models.preprocessors import get_preprocessor import matplotlib.pyplot as plt import argparse +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.utils.framework import try_import_tf, try_import_torch + +from Wrapper import OneHotWrapper + + +torch, nn = try_import_torch() + +class TorchActionMaskModel(TorchModelV2, nn.Module): + """PyTorch version of above ActionMaskingModel.""" + + def __init__( + self, + obs_space, + action_space, + num_outputs, + model_config, + name, + **kwargs, + ): + orig_space = getattr(obs_space, "original_space", obs_space) + assert ( + isinstance(orig_space, Dict) + and "action_mask" in orig_space.spaces + and "observations" in orig_space.spaces + ) + + TorchModelV2.__init__( + self, obs_space, action_space, num_outputs, model_config, name, **kwargs + ) + nn.Module.__init__(self) + + self.internal_model = TorchFC( + orig_space["observations"], + action_space, + num_outputs, + model_config, + name + "_internal", + ) + + # disable action masking --> will likely lead to invalid actions + self.no_masking = False + if "no_masking" in model_config["custom_model_config"]: + self.no_masking = model_config["custom_model_config"]["no_masking"] + + def forward(self, input_dict, state, seq_lens): + # Extract the available actions tensor from the observation. + action_mask = input_dict["obs"]["action_mask"] + + # Compute the unmasked logits. + logits, _ = self.internal_model({"obs": input_dict["obs"]["observations"]}) + + # If action masking is disabled, directly return unmasked logits + if self.no_masking: + return logits, state + + # Convert action_mask into a [0.0 || -inf]-type mask. + inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) + masked_logits = logits + inf_mask + + # Return masked logits. + return masked_logits, state + + def value_function(self): + return self.internal_model.value_function() + class MyCallbacks(DefaultCallbacks): @@ -66,69 +136,7 @@ class MyCallbacks(DefaultCallbacks): # print(episode.user_data["count"]) - - -class OneHotWrapper(gym.core.ObservationWrapper): - def __init__(self, env, vector_index, framestack): - super().__init__(env) - self.framestack = framestack - # 49=7x7 field of vision; 11=object types; 6=colors; 3=state types. - # +4: Direction. - self.single_frame_dim = 49 * (11 + 6 + 3) + 4 - self.init_x = None - self.init_y = None - self.x_positions = [] - self.y_positions = [] - self.x_y_delta_buffer = deque(maxlen=100) - self.vector_index = vector_index - self.frame_buffer = deque(maxlen=self.framestack) - for _ in range(self.framestack): - self.frame_buffer.append(np.zeros((self.single_frame_dim,))) - - self.observation_space = gym.spaces.Box( - 0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32 - ) - - def observation(self, obs): - # Debug output: max-x/y positions to watch exploration progress. - if self.step_count == 0: - for _ in range(self.framestack): - self.frame_buffer.append(np.zeros((self.single_frame_dim,))) - if self.vector_index == 0: - if self.x_positions: - max_diff = max( - np.sqrt( - (np.array(self.x_positions) - self.init_x) ** 2 - + (np.array(self.y_positions) - self.init_y) ** 2 - ) - ) - self.x_y_delta_buffer.append(max_diff) - print( - "100-average dist travelled={}".format( - np.mean(self.x_y_delta_buffer) - ) - ) - self.x_positions = [] - self.y_positions = [] - self.init_x = self.agent_pos[0] - self.init_y = self.agent_pos[1] - - - self.x_positions.append(self.agent_pos[0]) - self.y_positions.append(self.agent_pos[1]) - - # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten. - objects = one_hot(obs[:, :, 0], depth=11) - colors = one_hot(obs[:, :, 1], depth=6) - states = one_hot(obs[:, :, 2], depth=3) - - all_ = np.concatenate([objects, colors, states], -1) - all_flat = np.reshape(all_, (-1,)) - direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32) - single_frame = np.concatenate([all_flat, direction]) - self.frame_buffer.append(single_frame) - return np.concatenate(self.frame_buffer) - + def parse_arguments(argparse): diff --git a/examples/shields/rl/MaskEnvironments.py b/examples/shields/rl/MaskEnvironments.py new file mode 100644 index 0000000..ccd63ba --- /dev/null +++ b/examples/shields/rl/MaskEnvironments.py @@ -0,0 +1,91 @@ +import random +import minigrid + +import gymnasium as gym +import numpy as np +from gymnasium.spaces import Box, Dict, Discrete +from Wrapper import OneHotWrapper + + +class ParametricActionsMiniGridEnv(gym.Env): + """Parametric action version of MiniGrid. + + """ + + def __init__(self, config): + + name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") + self.left_action_embed = np.random.randn(2) + self.right_action_embed = np.random.randn(2) + framestack = config.get("framestack", 4) + + # env = gym.make(name) + # env = minigrid.wrappers.ImgObsWrapper(env) + # env = OneHotWrapper(env, + # config.vector_index if hasattr(config, "vector_index") else 0, + # framestack=framestack + # ) + self.wrapped = gym.make(name) + # self.observation_space = Dict( + # { + # "action_mask": None, + # "avail_actions": None, + # "cart": self.wrapped.observation_space, + # } + # ) + print(F"Wrapped environment is {self.wrapped}") + self.step_count = 0 + self.action_space = self.wrapped.action_space + self.observation_space = self.wrapped.observation_space + + + def update_avail_actions(self): + self.action_assignments = np.array( + [[0.0, 0.0]] * self.action_space.n, dtype=np.float32 + ) + self.action_mask = np.array([0.0] * self.action_space.n, dtype=np.int8) + self.left_idx, self.right_idx = random.sample(range(self.action_space.n), 2) + self.action_assignments[self.left_idx] = self.left_action_embed + self.action_assignments[self.right_idx] = self.right_action_embed + self.action_mask[self.left_idx] = 1 + self.action_mask[self.right_idx] = 1 + + def reset(self, *, seed=None, options=None): + self.update_avail_actions() + obs, infos = self.wrapped.reset() + return obs, infos + return { + "action_mask": self.action_mask, + "avail_actions": self.action_assignments, + "cart": obs, + }, infos + + def step(self, action): + if action == self.left_idx: + actual_action = 0 + elif action == self.right_idx: + actual_action = 1 + else: + actual_action = 0 + # raise ValueError( + # "Chosen action was not one of the non-zero action embeddings", + # action, + # self.action_assignments, + # self.action_mask, + # self.left_idx, + # self.right_idx, + # ) + orig_obs, rew, done, truncated, info = self.wrapped.step(actual_action) + self.update_avail_actions() + self.action_mask = self.action_mask.astype(np.int8) + print(F"Info is {info}") + info["Hello" : "Ich kenn mich nix aus"] + return orig_obs, rew, done, truncated, info + obs = { + "action_mask": self.action_mask, + "avail_actions": self.action_assignments, + "cart": orig_obs, + } + return obs, rew, done, truncated, info + + \ No newline at end of file diff --git a/examples/shields/rl/MaskModels.py b/examples/shields/rl/MaskModels.py new file mode 100644 index 0000000..b017740 --- /dev/null +++ b/examples/shields/rl/MaskModels.py @@ -0,0 +1,81 @@ +from typing import Dict, Optional, Union + +from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel +from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC +from ray.rllib.models.tf.fcnet import FullyConnectedNetwork +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.torch_utils import FLOAT_MIN, FLOAT_MAX + +torch, nn = try_import_torch() + + + +class TorchActionMaskModel(TorchModelV2, nn.Module): + """PyTorch version of above ActionMaskingModel.""" + + def __init__( + self, + obs_space, + action_space, + num_outputs, + model_config, + name, + **kwargs, + ): + orig_space = getattr(obs_space, "original_space", obs_space) + custom_config = model_config['custom_model_config'] + print(F"Original Space is: {orig_space}") + #print(model_config) + print(F"Observation space in model: {obs_space}") + + TorchModelV2.__init__( + self, obs_space, action_space, num_outputs, model_config, name, **kwargs + ) + nn.Module.__init__(self) + + assert("shield" in custom_config) + + self.shield = custom_config["shield"] + + self.internal_model = TorchFC( + orig_space["data"], + action_space, + num_outputs, + model_config, + name + "_internal", + ) + + # disable action masking --> will likely lead to invalid actions + self.no_masking = False + if "no_masking" in model_config["custom_model_config"]: + self.no_masking = model_config["custom_model_config"]["no_masking"] + + def forward(self, input_dict, state, seq_lens): + # Extract the available actions tensor from the observation. + # print(F"Input dict is {input_dict} at obs: {input_dict['obs']}") + # print(F"State is {state}") + + action_mask = [] + + # print(input_dict["env"]) + + # Compute the unmasked logits. + logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]}) + + # If action masking is disabled, directly return unmasked logits + if self.no_masking: + return logits, state + + assert(False) + + return logits, state + # Convert action_mask into a [0.0 || -inf]-type mask. + # inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) + # masked_logits = logits + inf_mask + + # # Return masked logits. + # return masked_logits, state + + def value_function(self): + return self.internal_model.value_function() diff --git a/examples/shields/rl/Wrapper.py b/examples/shields/rl/Wrapper.py new file mode 100644 index 0000000..c1e5e69 --- /dev/null +++ b/examples/shields/rl/Wrapper.py @@ -0,0 +1,152 @@ +import gymnasium as gym +import numpy as np + + +from gymnasium.spaces import Dict, Box +from collections import deque +from ray.rllib.utils.numpy import one_hot + +class OneHotWrapper(gym.core.ObservationWrapper): + def __init__(self, env, vector_index, framestack): + super().__init__(env) + self.framestack = framestack + # 49=7x7 field of vision; 11=object types; 6=colors; 3=state types. + # +4: Direction. + self.single_frame_dim = 49 * (11 + 6 + 3) + 4 + self.init_x = None + self.init_y = None + self.x_positions = [] + self.y_positions = [] + self.x_y_delta_buffer = deque(maxlen=100) + self.vector_index = vector_index + self.frame_buffer = deque(maxlen=self.framestack) + for _ in range(self.framestack): + self.frame_buffer.append(np.zeros((self.single_frame_dim,))) + + self.observation_space = Dict( + { + "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32), + "avail_actions": gym.spaces.Box(0, 10, shape=(10,), dtype=int), + } + ) + + + print(F"Set obersvation space to {self.observation_space}") + + + def observation(self, obs): + # Debug output: max-x/y positions to watch exploration progress. + # print(F"Initial observation in Wrapper {obs}") + if self.step_count == 0: + for _ in range(self.framestack): + self.frame_buffer.append(np.zeros((self.single_frame_dim,))) + if self.vector_index == 0: + if self.x_positions: + max_diff = max( + np.sqrt( + (np.array(self.x_positions) - self.init_x) ** 2 + + (np.array(self.y_positions) - self.init_y) ** 2 + ) + ) + self.x_y_delta_buffer.append(max_diff) + print( + "100-average dist travelled={}".format( + np.mean(self.x_y_delta_buffer) + ) + ) + self.x_positions = [] + self.y_positions = [] + self.init_x = self.agent_pos[0] + self.init_y = self.agent_pos[1] + + + self.x_positions.append(self.agent_pos[0]) + self.y_positions.append(self.agent_pos[1]) + + image = obs["data"] + + # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten. + objects = one_hot(image[:, :, 0], depth=11) + colors = one_hot(image[:, :, 1], depth=6) + states = one_hot(image[:, :, 2], depth=3) + + all_ = np.concatenate([objects, colors, states], -1) + all_flat = np.reshape(all_, (-1,)) + direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32) + single_frame = np.concatenate([all_flat, direction]) + self.frame_buffer.append(single_frame) + + #obs["one-hot"] = np.concatenate(self.frame_buffer) + tmp = {"data": np.concatenate(self.frame_buffer), "avail_actions": obs["avail_actions"] } + return tmp#np.concatenate(self.frame_buffer) + + +class MiniGridEnvWrapper(gym.core.Wrapper): + def __init__(self, env): + super(MiniGridEnvWrapper, self).__init__(env) + self.observation_space = Dict( + { + "data": env.observation_space.spaces["image"], + "avail_actions" : Box(0, 10, shape=(10,), dtype=np.int8), + } + ) + + + def test(self): + print("Testing some stuff") + + def reset(self, *, seed=None, options=None): + obs, infos = self.env.reset() + return { + "data": obs["image"], + "avail_actions": np.array([0.0] * 10, dtype=np.int8) + }, infos + + def step(self, action): + orig_obs, rew, done, truncated, info = self.env.step(action) + + self.test() + #print(F"Original observation is {orig_obs}") + obs = { + "data": orig_obs["image"], + "avail_actions": np.array([0.0] * 10, dtype=np.int8), + } + + #print(F"Info is {info}") + return obs, rew, done, truncated, info + + + + +class ImgObsWrapper(gym.core.ObservationWrapper): + """ + Use the image as the only observation output, no language/mission. + + Example: + >>> import gymnasium as gym + >>> from minigrid.wrappers import ImgObsWrapper + >>> env = gym.make("MiniGrid-Empty-5x5-v0") + >>> obs, _ = env.reset() + >>> obs.keys() + dict_keys(['image', 'direction', 'mission']) + >>> env = ImgObsWrapper(env) + >>> obs, _ = env.reset() + >>> obs.shape + (7, 7, 3) + """ + + def __init__(self, env): + """A wrapper that makes image the only observation. + + Args: + env: The environment to apply the wrapper + """ + super().__init__(env) + self.observation_space = env.observation_space.spaces["image"] + print(F"Set obersvation space to {self.observation_space}") + + def observation(self, obs): + #print(F"obs in img obs wrapper {obs}") + tmp = {"data": obs["image"], "Test": obs["Test"]} + + return tmp