Thomas Knoll
1 year ago
5 changed files with 650 additions and 63 deletions
-
255examples/shields/rl/11_minigridrl.py
-
134examples/shields/rl/12_basic_training.py
-
91examples/shields/rl/MaskEnvironments.py
-
81examples/shields/rl/MaskModels.py
-
152examples/shields/rl/Wrapper.py
@ -0,0 +1,255 @@ |
|||||
|
from typing import Dict, Optional, Union |
||||
|
from ray.rllib.env.base_env import BaseEnv |
||||
|
from ray.rllib.evaluation import RolloutWorker |
||||
|
from ray.rllib.evaluation.episode import Episode |
||||
|
from ray.rllib.evaluation.episode_v2 import EpisodeV2 |
||||
|
from ray.rllib.policy import Policy |
||||
|
from ray.rllib.utils.typing import PolicyID |
||||
|
import stormpy |
||||
|
import stormpy.core |
||||
|
import stormpy.simulator |
||||
|
|
||||
|
|
||||
|
import stormpy.shields |
||||
|
import stormpy.logic |
||||
|
|
||||
|
import stormpy.examples |
||||
|
import stormpy.examples.files |
||||
|
import os |
||||
|
|
||||
|
import gymnasium as gym |
||||
|
|
||||
|
import minigrid |
||||
|
import numpy as np |
||||
|
|
||||
|
import ray |
||||
|
from ray.tune import register_env |
||||
|
from ray.rllib.algorithms.ppo import PPOConfig |
||||
|
from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator |
||||
|
from ray import tune, air |
||||
|
from ray.rllib.algorithms.callbacks import DefaultCallbacks |
||||
|
from ray.tune.logger import pretty_print |
||||
|
from ray.rllib.algorithms import ppo |
||||
|
from ray.rllib.models import ModelCatalog |
||||
|
|
||||
|
from ray.rllib.utils.torch_utils import FLOAT_MIN |
||||
|
|
||||
|
from ray.rllib.models.preprocessors import get_preprocessor |
||||
|
from MaskEnvironments import ParametricActionsMiniGridEnv |
||||
|
from MaskModels import TorchActionMaskModel |
||||
|
from Wrapper import OneHotWrapper, MiniGridEnvWrapper, ImgObsWrapper |
||||
|
|
||||
|
import matplotlib.pyplot as plt |
||||
|
|
||||
|
import argparse |
||||
|
|
||||
|
|
||||
|
|
||||
|
class MyCallbacks(DefaultCallbacks): |
||||
|
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: |
||||
|
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}") |
||||
|
env = base_env.get_sub_environments()[0] |
||||
|
episode.user_data["count"] = 0 |
||||
|
# print(env.printGrid()) |
||||
|
# print(env.action_space.n) |
||||
|
# print(env.actions) |
||||
|
# print(env.mission) |
||||
|
# print(env.observation_space) |
||||
|
# img = env.get_frame() |
||||
|
# plt.imshow(img) |
||||
|
# plt.show() |
||||
|
|
||||
|
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: |
||||
|
episode.user_data["count"] = episode.user_data["count"] + 1 |
||||
|
env = base_env.get_sub_environments()[0] |
||||
|
print(env.env.env.printGrid()) |
||||
|
|
||||
|
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: |
||||
|
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}") |
||||
|
env = base_env.get_sub_environments()[0] |
||||
|
# print(env.env.env.printGrid()) |
||||
|
# print(episode.user_data["count"]) |
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
def parse_arguments(argparse): |
||||
|
parser = argparse.ArgumentParser() |
||||
|
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0") |
||||
|
parser.add_argument("--env", help="gym environment to load", default="MiniGrid-LavaCrossingS9N1-v0") |
||||
|
parser.add_argument("--seed", type=int, help="seed for environment", default=1) |
||||
|
parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32) |
||||
|
parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees") |
||||
|
parser.add_argument("--grid_path", default="Grid.txt") |
||||
|
parser.add_argument("--prism_path", default="Grid.PRISM") |
||||
|
|
||||
|
args = parser.parse_args() |
||||
|
|
||||
|
return args |
||||
|
|
||||
|
|
||||
|
def env_creater_custom(config): |
||||
|
# name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") |
||||
|
# # name = config.get("name", "MiniGrid-Empty-8x8-v0") |
||||
|
framestack = config.get("framestack", 4) |
||||
|
|
||||
|
# env = gym.make(name) |
||||
|
# env = ParametricActionsMiniGridEnv(config) |
||||
|
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") |
||||
|
framestack = config.get("framestack", 4) |
||||
|
|
||||
|
env = gym.make(name) |
||||
|
env = MiniGridEnvWrapper(env) |
||||
|
# env = minigrid.wrappers.ImgObsWrapper(env) |
||||
|
# env = ImgObsWrapper(env) |
||||
|
env = OneHotWrapper(env, |
||||
|
config.vector_index if hasattr(config, "vector_index") else 0, |
||||
|
framestack=framestack |
||||
|
) |
||||
|
|
||||
|
obs = env.observation_space.sample() |
||||
|
obs2, infos = env.reset(seed=None, options={}) |
||||
|
|
||||
|
print(F"Obs is {obs} before reset. After reset: {obs2}") |
||||
|
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env) |
||||
|
|
||||
|
print(F"Created Custom Minigrid Environment is {env}") |
||||
|
|
||||
|
return env |
||||
|
|
||||
|
def env_creater_cart(config): |
||||
|
return gym.make("CartPole-v1") |
||||
|
|
||||
|
def env_creater(config): |
||||
|
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") |
||||
|
# name = config.get("name", "MiniGrid-Empty-8x8-v0") |
||||
|
framestack = config.get("framestack", 4) |
||||
|
|
||||
|
env = gym.make(name) |
||||
|
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env) |
||||
|
env = minigrid.wrappers.ImgObsWrapper(env) |
||||
|
env = OneHotWrapper(env, |
||||
|
config.vector_index if hasattr(config, "vector_index") else 0, |
||||
|
framestack=framestack |
||||
|
) |
||||
|
|
||||
|
print(F"Created Minigrid Environment is {env}") |
||||
|
|
||||
|
return env |
||||
|
|
||||
|
|
||||
|
|
||||
|
def create_shield(grid_file, prism_path): |
||||
|
os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}") |
||||
|
|
||||
|
f = open(prism_path, "a") |
||||
|
f.write("label \"AgentIsInLava\" = AgentIsInLava;") |
||||
|
f.close() |
||||
|
|
||||
|
|
||||
|
program = stormpy.parse_prism_program(prism_path) |
||||
|
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" |
||||
|
|
||||
|
formulas = stormpy.parse_properties_for_prism_program(formula_str, program) |
||||
|
options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) |
||||
|
options.set_build_state_valuations(True) |
||||
|
options.set_build_choice_labels(True) |
||||
|
options.set_build_all_labels() |
||||
|
model = stormpy.build_sparse_model_with_options(program, options) |
||||
|
|
||||
|
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) |
||||
|
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) |
||||
|
|
||||
|
assert result.has_scheduler |
||||
|
assert result.has_shield |
||||
|
shield = result.shield |
||||
|
|
||||
|
stormpy.shields.export_shield(model, shield, "Grid.shield") |
||||
|
|
||||
|
return shield.construct(), model |
||||
|
|
||||
|
def export_grid_to_text(env, grid_file): |
||||
|
f = open(grid_file, "w") |
||||
|
# print(env) |
||||
|
f.write(env.printGrid(init=True)) |
||||
|
# f.write(env.pprint_grid()) |
||||
|
f.close() |
||||
|
|
||||
|
def create_environment(args): |
||||
|
env_id= args.env |
||||
|
env = gym.make(env_id) |
||||
|
env.reset() |
||||
|
return env |
||||
|
|
||||
|
|
||||
|
def main(): |
||||
|
args = parse_arguments(argparse) |
||||
|
|
||||
|
env = create_environment(args) |
||||
|
ray.init(num_cpus=3) |
||||
|
|
||||
|
# print(env.pprint_grid()) |
||||
|
# print(env.printGrid(init=False)) |
||||
|
|
||||
|
grid_file = args.grid_path |
||||
|
export_grid_to_text(env, grid_file) |
||||
|
|
||||
|
prism_path = args.prism_path |
||||
|
shield, model = create_shield(grid_file, prism_path) |
||||
|
shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} |
||||
|
|
||||
|
print(shield_dict) |
||||
|
for state_id in model.states: |
||||
|
choices = shield.get_choice(state_id) |
||||
|
print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") |
||||
|
|
||||
|
env_name = "mini-grid" |
||||
|
register_env(env_name, env_creater_custom) |
||||
|
ModelCatalog.register_custom_model( |
||||
|
"pa_model", |
||||
|
TorchActionMaskModel |
||||
|
) |
||||
|
|
||||
|
config = (PPOConfig() |
||||
|
.rollouts(num_rollout_workers=1) |
||||
|
.resources(num_gpus=0) |
||||
|
.environment(env="mini-grid") |
||||
|
.framework("torch") |
||||
|
.experimental(_disable_preprocessor_api=False) |
||||
|
.callbacks(MyCallbacks) |
||||
|
.rl_module(_enable_rl_module_api = False) |
||||
|
.training(_enable_learner_api=False ,model={ |
||||
|
"custom_model": "pa_model", |
||||
|
"custom_model_config" : {"shield": shield_dict, "no_masking": True} |
||||
|
# "fcnet_hiddens": [256,256], |
||||
|
# "fcnet_activation": "relu", |
||||
|
|
||||
|
})) |
||||
|
|
||||
|
|
||||
|
algo =( |
||||
|
|
||||
|
config.build() |
||||
|
) |
||||
|
episode_reward = 0 |
||||
|
terminated = truncated = False |
||||
|
obs, info = env.reset() |
||||
|
|
||||
|
# while not terminated and not truncated: |
||||
|
# action = algo.compute_single_action(obs) |
||||
|
# obs, reward, terminated, truncated = env.step(action) |
||||
|
|
||||
|
for i in range(30): |
||||
|
result = algo.train() |
||||
|
print(pretty_print(result)) |
||||
|
|
||||
|
if i % 5 == 0: |
||||
|
checkpoint_dir = algo.save() |
||||
|
print(f"Checkpoint saved in directory {checkpoint_dir}") |
||||
|
|
||||
|
|
||||
|
|
||||
|
ray.shutdown() |
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
main() |
@ -0,0 +1,91 @@ |
|||||
|
import random |
||||
|
import minigrid |
||||
|
|
||||
|
import gymnasium as gym |
||||
|
import numpy as np |
||||
|
from gymnasium.spaces import Box, Dict, Discrete |
||||
|
from Wrapper import OneHotWrapper |
||||
|
|
||||
|
|
||||
|
class ParametricActionsMiniGridEnv(gym.Env): |
||||
|
"""Parametric action version of MiniGrid. |
||||
|
|
||||
|
""" |
||||
|
|
||||
|
def __init__(self, config): |
||||
|
|
||||
|
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") |
||||
|
self.left_action_embed = np.random.randn(2) |
||||
|
self.right_action_embed = np.random.randn(2) |
||||
|
framestack = config.get("framestack", 4) |
||||
|
|
||||
|
# env = gym.make(name) |
||||
|
# env = minigrid.wrappers.ImgObsWrapper(env) |
||||
|
# env = OneHotWrapper(env, |
||||
|
# config.vector_index if hasattr(config, "vector_index") else 0, |
||||
|
# framestack=framestack |
||||
|
# ) |
||||
|
self.wrapped = gym.make(name) |
||||
|
# self.observation_space = Dict( |
||||
|
# { |
||||
|
# "action_mask": None, |
||||
|
# "avail_actions": None, |
||||
|
# "cart": self.wrapped.observation_space, |
||||
|
# } |
||||
|
# ) |
||||
|
print(F"Wrapped environment is {self.wrapped}") |
||||
|
self.step_count = 0 |
||||
|
self.action_space = self.wrapped.action_space |
||||
|
self.observation_space = self.wrapped.observation_space |
||||
|
|
||||
|
|
||||
|
def update_avail_actions(self): |
||||
|
self.action_assignments = np.array( |
||||
|
[[0.0, 0.0]] * self.action_space.n, dtype=np.float32 |
||||
|
) |
||||
|
self.action_mask = np.array([0.0] * self.action_space.n, dtype=np.int8) |
||||
|
self.left_idx, self.right_idx = random.sample(range(self.action_space.n), 2) |
||||
|
self.action_assignments[self.left_idx] = self.left_action_embed |
||||
|
self.action_assignments[self.right_idx] = self.right_action_embed |
||||
|
self.action_mask[self.left_idx] = 1 |
||||
|
self.action_mask[self.right_idx] = 1 |
||||
|
|
||||
|
def reset(self, *, seed=None, options=None): |
||||
|
self.update_avail_actions() |
||||
|
obs, infos = self.wrapped.reset() |
||||
|
return obs, infos |
||||
|
return { |
||||
|
"action_mask": self.action_mask, |
||||
|
"avail_actions": self.action_assignments, |
||||
|
"cart": obs, |
||||
|
}, infos |
||||
|
|
||||
|
def step(self, action): |
||||
|
if action == self.left_idx: |
||||
|
actual_action = 0 |
||||
|
elif action == self.right_idx: |
||||
|
actual_action = 1 |
||||
|
else: |
||||
|
actual_action = 0 |
||||
|
# raise ValueError( |
||||
|
# "Chosen action was not one of the non-zero action embeddings", |
||||
|
# action, |
||||
|
# self.action_assignments, |
||||
|
# self.action_mask, |
||||
|
# self.left_idx, |
||||
|
# self.right_idx, |
||||
|
# ) |
||||
|
orig_obs, rew, done, truncated, info = self.wrapped.step(actual_action) |
||||
|
self.update_avail_actions() |
||||
|
self.action_mask = self.action_mask.astype(np.int8) |
||||
|
print(F"Info is {info}") |
||||
|
info["Hello" : "Ich kenn mich nix aus"] |
||||
|
return orig_obs, rew, done, truncated, info |
||||
|
obs = { |
||||
|
"action_mask": self.action_mask, |
||||
|
"avail_actions": self.action_assignments, |
||||
|
"cart": orig_obs, |
||||
|
} |
||||
|
return obs, rew, done, truncated, info |
||||
|
|
||||
|
|
@ -0,0 +1,81 @@ |
|||||
|
from typing import Dict, Optional, Union |
||||
|
|
||||
|
from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel |
||||
|
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC |
||||
|
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork |
||||
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 |
||||
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch |
||||
|
from ray.rllib.utils.torch_utils import FLOAT_MIN, FLOAT_MAX |
||||
|
|
||||
|
torch, nn = try_import_torch() |
||||
|
|
||||
|
|
||||
|
|
||||
|
class TorchActionMaskModel(TorchModelV2, nn.Module): |
||||
|
"""PyTorch version of above ActionMaskingModel.""" |
||||
|
|
||||
|
def __init__( |
||||
|
self, |
||||
|
obs_space, |
||||
|
action_space, |
||||
|
num_outputs, |
||||
|
model_config, |
||||
|
name, |
||||
|
**kwargs, |
||||
|
): |
||||
|
orig_space = getattr(obs_space, "original_space", obs_space) |
||||
|
custom_config = model_config['custom_model_config'] |
||||
|
print(F"Original Space is: {orig_space}") |
||||
|
#print(model_config) |
||||
|
print(F"Observation space in model: {obs_space}") |
||||
|
|
||||
|
TorchModelV2.__init__( |
||||
|
self, obs_space, action_space, num_outputs, model_config, name, **kwargs |
||||
|
) |
||||
|
nn.Module.__init__(self) |
||||
|
|
||||
|
assert("shield" in custom_config) |
||||
|
|
||||
|
self.shield = custom_config["shield"] |
||||
|
|
||||
|
self.internal_model = TorchFC( |
||||
|
orig_space["data"], |
||||
|
action_space, |
||||
|
num_outputs, |
||||
|
model_config, |
||||
|
name + "_internal", |
||||
|
) |
||||
|
|
||||
|
# disable action masking --> will likely lead to invalid actions |
||||
|
self.no_masking = False |
||||
|
if "no_masking" in model_config["custom_model_config"]: |
||||
|
self.no_masking = model_config["custom_model_config"]["no_masking"] |
||||
|
|
||||
|
def forward(self, input_dict, state, seq_lens): |
||||
|
# Extract the available actions tensor from the observation. |
||||
|
# print(F"Input dict is {input_dict} at obs: {input_dict['obs']}") |
||||
|
# print(F"State is {state}") |
||||
|
|
||||
|
action_mask = [] |
||||
|
|
||||
|
# print(input_dict["env"]) |
||||
|
|
||||
|
# Compute the unmasked logits. |
||||
|
logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]}) |
||||
|
|
||||
|
# If action masking is disabled, directly return unmasked logits |
||||
|
if self.no_masking: |
||||
|
return logits, state |
||||
|
|
||||
|
assert(False) |
||||
|
|
||||
|
return logits, state |
||||
|
# Convert action_mask into a [0.0 || -inf]-type mask. |
||||
|
# inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) |
||||
|
# masked_logits = logits + inf_mask |
||||
|
|
||||
|
# # Return masked logits. |
||||
|
# return masked_logits, state |
||||
|
|
||||
|
def value_function(self): |
||||
|
return self.internal_model.value_function() |
@ -0,0 +1,152 @@ |
|||||
|
import gymnasium as gym |
||||
|
import numpy as np |
||||
|
|
||||
|
|
||||
|
from gymnasium.spaces import Dict, Box |
||||
|
from collections import deque |
||||
|
from ray.rllib.utils.numpy import one_hot |
||||
|
|
||||
|
class OneHotWrapper(gym.core.ObservationWrapper): |
||||
|
def __init__(self, env, vector_index, framestack): |
||||
|
super().__init__(env) |
||||
|
self.framestack = framestack |
||||
|
# 49=7x7 field of vision; 11=object types; 6=colors; 3=state types. |
||||
|
# +4: Direction. |
||||
|
self.single_frame_dim = 49 * (11 + 6 + 3) + 4 |
||||
|
self.init_x = None |
||||
|
self.init_y = None |
||||
|
self.x_positions = [] |
||||
|
self.y_positions = [] |
||||
|
self.x_y_delta_buffer = deque(maxlen=100) |
||||
|
self.vector_index = vector_index |
||||
|
self.frame_buffer = deque(maxlen=self.framestack) |
||||
|
for _ in range(self.framestack): |
||||
|
self.frame_buffer.append(np.zeros((self.single_frame_dim,))) |
||||
|
|
||||
|
self.observation_space = Dict( |
||||
|
{ |
||||
|
"data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32), |
||||
|
"avail_actions": gym.spaces.Box(0, 10, shape=(10,), dtype=int), |
||||
|
} |
||||
|
) |
||||
|
|
||||
|
|
||||
|
print(F"Set obersvation space to {self.observation_space}") |
||||
|
|
||||
|
|
||||
|
def observation(self, obs): |
||||
|
# Debug output: max-x/y positions to watch exploration progress. |
||||
|
# print(F"Initial observation in Wrapper {obs}") |
||||
|
if self.step_count == 0: |
||||
|
for _ in range(self.framestack): |
||||
|
self.frame_buffer.append(np.zeros((self.single_frame_dim,))) |
||||
|
if self.vector_index == 0: |
||||
|
if self.x_positions: |
||||
|
max_diff = max( |
||||
|
np.sqrt( |
||||
|
(np.array(self.x_positions) - self.init_x) ** 2 |
||||
|
+ (np.array(self.y_positions) - self.init_y) ** 2 |
||||
|
) |
||||
|
) |
||||
|
self.x_y_delta_buffer.append(max_diff) |
||||
|
print( |
||||
|
"100-average dist travelled={}".format( |
||||
|
np.mean(self.x_y_delta_buffer) |
||||
|
) |
||||
|
) |
||||
|
self.x_positions = [] |
||||
|
self.y_positions = [] |
||||
|
self.init_x = self.agent_pos[0] |
||||
|
self.init_y = self.agent_pos[1] |
||||
|
|
||||
|
|
||||
|
self.x_positions.append(self.agent_pos[0]) |
||||
|
self.y_positions.append(self.agent_pos[1]) |
||||
|
|
||||
|
image = obs["data"] |
||||
|
|
||||
|
# One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten. |
||||
|
objects = one_hot(image[:, :, 0], depth=11) |
||||
|
colors = one_hot(image[:, :, 1], depth=6) |
||||
|
states = one_hot(image[:, :, 2], depth=3) |
||||
|
|
||||
|
all_ = np.concatenate([objects, colors, states], -1) |
||||
|
all_flat = np.reshape(all_, (-1,)) |
||||
|
direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32) |
||||
|
single_frame = np.concatenate([all_flat, direction]) |
||||
|
self.frame_buffer.append(single_frame) |
||||
|
|
||||
|
#obs["one-hot"] = np.concatenate(self.frame_buffer) |
||||
|
tmp = {"data": np.concatenate(self.frame_buffer), "avail_actions": obs["avail_actions"] } |
||||
|
return tmp#np.concatenate(self.frame_buffer) |
||||
|
|
||||
|
|
||||
|
class MiniGridEnvWrapper(gym.core.Wrapper): |
||||
|
def __init__(self, env): |
||||
|
super(MiniGridEnvWrapper, self).__init__(env) |
||||
|
self.observation_space = Dict( |
||||
|
{ |
||||
|
"data": env.observation_space.spaces["image"], |
||||
|
"avail_actions" : Box(0, 10, shape=(10,), dtype=np.int8), |
||||
|
} |
||||
|
) |
||||
|
|
||||
|
|
||||
|
def test(self): |
||||
|
print("Testing some stuff") |
||||
|
|
||||
|
def reset(self, *, seed=None, options=None): |
||||
|
obs, infos = self.env.reset() |
||||
|
return { |
||||
|
"data": obs["image"], |
||||
|
"avail_actions": np.array([0.0] * 10, dtype=np.int8) |
||||
|
}, infos |
||||
|
|
||||
|
def step(self, action): |
||||
|
orig_obs, rew, done, truncated, info = self.env.step(action) |
||||
|
|
||||
|
self.test() |
||||
|
#print(F"Original observation is {orig_obs}") |
||||
|
obs = { |
||||
|
"data": orig_obs["image"], |
||||
|
"avail_actions": np.array([0.0] * 10, dtype=np.int8), |
||||
|
} |
||||
|
|
||||
|
#print(F"Info is {info}") |
||||
|
return obs, rew, done, truncated, info |
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
class ImgObsWrapper(gym.core.ObservationWrapper): |
||||
|
""" |
||||
|
Use the image as the only observation output, no language/mission. |
||||
|
|
||||
|
Example: |
||||
|
>>> import gymnasium as gym |
||||
|
>>> from minigrid.wrappers import ImgObsWrapper |
||||
|
>>> env = gym.make("MiniGrid-Empty-5x5-v0") |
||||
|
>>> obs, _ = env.reset() |
||||
|
>>> obs.keys() |
||||
|
dict_keys(['image', 'direction', 'mission']) |
||||
|
>>> env = ImgObsWrapper(env) |
||||
|
>>> obs, _ = env.reset() |
||||
|
>>> obs.shape |
||||
|
(7, 7, 3) |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, env): |
||||
|
"""A wrapper that makes image the only observation. |
||||
|
|
||||
|
Args: |
||||
|
env: The environment to apply the wrapper |
||||
|
""" |
||||
|
super().__init__(env) |
||||
|
self.observation_space = env.observation_space.spaces["image"] |
||||
|
print(F"Set obersvation space to {self.observation_space}") |
||||
|
|
||||
|
def observation(self, obs): |
||||
|
#print(F"obs in img obs wrapper {obs}") |
||||
|
tmp = {"data": obs["image"], "Test": obs["Test"]} |
||||
|
|
||||
|
return tmp |
Write
Preview
Loading…
Cancel
Save
Reference in new issue