Thomas Knoll
1 year ago
2 changed files with 6 additions and 278 deletions
@ -1,275 +0,0 @@ |
|||||
from typing import Dict, Optional, Union |
|
||||
from ray.rllib.env.base_env import BaseEnv |
|
||||
from ray.rllib.evaluation import RolloutWorker |
|
||||
from ray.rllib.evaluation.episode import Episode |
|
||||
from ray.rllib.evaluation.episode_v2 import EpisodeV2 |
|
||||
from ray.rllib.policy import Policy |
|
||||
from ray.rllib.utils.typing import PolicyID |
|
||||
import stormpy |
|
||||
import stormpy.core |
|
||||
import stormpy.simulator |
|
||||
|
|
||||
from collections import deque |
|
||||
|
|
||||
import stormpy.shields |
|
||||
import stormpy.logic |
|
||||
|
|
||||
import stormpy.examples |
|
||||
import stormpy.examples.files |
|
||||
import os |
|
||||
|
|
||||
import gymnasium as gym |
|
||||
|
|
||||
import minigrid |
|
||||
import numpy as np |
|
||||
|
|
||||
import ray |
|
||||
from ray.tune import register_env |
|
||||
from ray.rllib.algorithms.ppo import PPOConfig |
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator |
|
||||
from ray import tune, air |
|
||||
from ray.rllib.algorithms.callbacks import DefaultCallbacks |
|
||||
from ray.tune.logger import pretty_print |
|
||||
from ray.rllib.utils.numpy import one_hot |
|
||||
from ray.rllib.algorithms import ppo |
|
||||
|
|
||||
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC |
|
||||
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork |
|
||||
from ray.rllib.utils.torch_utils import FLOAT_MIN |
|
||||
|
|
||||
from ray.rllib.models.preprocessors import get_preprocessor |
|
||||
|
|
||||
import matplotlib.pyplot as plt |
|
||||
|
|
||||
import argparse |
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 |
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch |
|
||||
|
|
||||
from examples.shields.rl.Wrappers import OneHotShieldingWrapper |
|
||||
|
|
||||
|
|
||||
torch, nn = try_import_torch() |
|
||||
|
|
||||
class TorchActionMaskModel(TorchModelV2, nn.Module): |
|
||||
"""PyTorch version of above ActionMaskingModel.""" |
|
||||
|
|
||||
def __init__( |
|
||||
self, |
|
||||
obs_space, |
|
||||
action_space, |
|
||||
num_outputs, |
|
||||
model_config, |
|
||||
name, |
|
||||
**kwargs, |
|
||||
): |
|
||||
orig_space = getattr(obs_space, "original_space", obs_space) |
|
||||
assert ( |
|
||||
isinstance(orig_space, Dict) |
|
||||
and "action_mask" in orig_space.spaces |
|
||||
and "observations" in orig_space.spaces |
|
||||
) |
|
||||
|
|
||||
TorchModelV2.__init__( |
|
||||
self, obs_space, action_space, num_outputs, model_config, name, **kwargs |
|
||||
) |
|
||||
nn.Module.__init__(self) |
|
||||
|
|
||||
self.internal_model = TorchFC( |
|
||||
orig_space["observations"], |
|
||||
action_space, |
|
||||
num_outputs, |
|
||||
model_config, |
|
||||
name + "_internal", |
|
||||
) |
|
||||
|
|
||||
# disable action masking --> will likely lead to invalid actions |
|
||||
self.no_masking = False |
|
||||
if "no_masking" in model_config["custom_model_config"]: |
|
||||
self.no_masking = model_config["custom_model_config"]["no_masking"] |
|
||||
|
|
||||
def forward(self, input_dict, state, seq_lens): |
|
||||
# Extract the available actions tensor from the observation. |
|
||||
action_mask = input_dict["obs"]["action_mask"] |
|
||||
|
|
||||
# Compute the unmasked logits. |
|
||||
logits, _ = self.internal_model({"obs": input_dict["obs"]["observations"]}) |
|
||||
|
|
||||
# If action masking is disabled, directly return unmasked logits |
|
||||
if self.no_masking: |
|
||||
return logits, state |
|
||||
|
|
||||
# Convert action_mask into a [0.0 || -inf]-type mask. |
|
||||
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) |
|
||||
masked_logits = logits + inf_mask |
|
||||
|
|
||||
# Return masked logits. |
|
||||
return masked_logits, state |
|
||||
|
|
||||
def value_function(self): |
|
||||
return self.internal_model.value_function() |
|
||||
|
|
||||
|
|
||||
|
|
||||
class MyCallbacks(DefaultCallbacks): |
|
||||
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: |
|
||||
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}") |
|
||||
env = base_env.get_sub_environments()[0] |
|
||||
episode.user_data["count"] = 0 |
|
||||
# print(env.printGrid()) |
|
||||
# print(env.action_space.n) |
|
||||
# print(env.actions) |
|
||||
# print(env.mission) |
|
||||
# print(env.observation_space) |
|
||||
# img = env.get_frame() |
|
||||
# plt.imshow(img) |
|
||||
# plt.show() |
|
||||
|
|
||||
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: |
|
||||
episode.user_data["count"] = episode.user_data["count"] + 1 |
|
||||
env = base_env.get_sub_environments()[0] |
|
||||
# print(env.env.env.printGrid()) |
|
||||
|
|
||||
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: |
|
||||
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}") |
|
||||
env = base_env.get_sub_environments()[0] |
|
||||
# print(env.env.env.printGrid()) |
|
||||
# print(episode.user_data["count"]) |
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
def parse_arguments(argparse): |
|
||||
parser = argparse.ArgumentParser() |
|
||||
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0") |
|
||||
parser.add_argument("--env", help="gym environment to load", default="MiniGrid-LavaCrossingS9N1-v0") |
|
||||
parser.add_argument("--seed", type=int, help="seed for environment", default=1) |
|
||||
parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32) |
|
||||
parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees") |
|
||||
parser.add_argument("--grid_path", default="Grid.txt") |
|
||||
parser.add_argument("--prism_path", default="Grid.PRISM") |
|
||||
|
|
||||
args = parser.parse_args() |
|
||||
|
|
||||
return args |
|
||||
|
|
||||
|
|
||||
def env_creater(config): |
|
||||
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") |
|
||||
# name = config.get("name", "MiniGrid-Empty-8x8-v0") |
|
||||
framestack = config.get("framestack", 4) |
|
||||
|
|
||||
env = gym.make(name) |
|
||||
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env) |
|
||||
env = minigrid.wrappers.ImgObsWrapper(env) |
|
||||
env = OneHotShieldingWrapper(env, |
|
||||
config.vector_index if hasattr(config, "vector_index") else 0, |
|
||||
framestack=framestack |
|
||||
) |
|
||||
|
|
||||
|
|
||||
return env |
|
||||
|
|
||||
|
|
||||
def create_shield(grid_file, prism_path): |
|
||||
os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}") |
|
||||
|
|
||||
f = open(prism_path, "a") |
|
||||
f.write("label \"AgentIsInLava\" = AgentIsInLava;") |
|
||||
f.close() |
|
||||
|
|
||||
|
|
||||
program = stormpy.parse_prism_program(prism_path) |
|
||||
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" |
|
||||
|
|
||||
formulas = stormpy.parse_properties_for_prism_program(formula_str, program) |
|
||||
options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) |
|
||||
options.set_build_state_valuations(True) |
|
||||
options.set_build_choice_labels(True) |
|
||||
options.set_build_all_labels() |
|
||||
model = stormpy.build_sparse_model_with_options(program, options) |
|
||||
|
|
||||
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) |
|
||||
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) |
|
||||
|
|
||||
assert result.has_scheduler |
|
||||
assert result.has_shield |
|
||||
shield = result.shield |
|
||||
|
|
||||
stormpy.shields.export_shield(model, shield,"Grid.shield") |
|
||||
|
|
||||
return shield.construct(), model |
|
||||
|
|
||||
def export_grid_to_text(env, grid_file): |
|
||||
f = open(grid_file, "w") |
|
||||
print(env) |
|
||||
f.write(env.printGrid(init=True)) |
|
||||
# f.write(env.pprint_grid()) |
|
||||
f.close() |
|
||||
|
|
||||
def create_environment(args): |
|
||||
env_id= args.env |
|
||||
env = gym.make(env_id) |
|
||||
env.reset() |
|
||||
return env |
|
||||
|
|
||||
|
|
||||
def main(): |
|
||||
args = parse_arguments(argparse) |
|
||||
|
|
||||
env = create_environment(args) |
|
||||
ray.init(num_cpus=3) |
|
||||
|
|
||||
# print(env.pprint_grid()) |
|
||||
# print(env.printGrid(init=False)) |
|
||||
|
|
||||
grid_file = args.grid_path |
|
||||
export_grid_to_text(env, grid_file) |
|
||||
|
|
||||
prism_path = args.prism_path |
|
||||
shield, model = create_shield(grid_file, prism_path) |
|
||||
|
|
||||
for state_id in model.states: |
|
||||
choices = shield.get_choice(state_id) |
|
||||
print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") |
|
||||
|
|
||||
env_name = "mini-grid" |
|
||||
register_env(env_name, env_creater) |
|
||||
|
|
||||
|
|
||||
algo =( |
|
||||
PPOConfig() |
|
||||
.rollouts(num_rollout_workers=1) |
|
||||
.resources(num_gpus=0) |
|
||||
.environment(env="mini-grid") |
|
||||
.framework("torch") |
|
||||
.callbacks(MyCallbacks) |
|
||||
.training(model={ |
|
||||
"fcnet_hiddens": [256,256], |
|
||||
"fcnet_activation": "relu", |
|
||||
|
|
||||
}) |
|
||||
.build() |
|
||||
) |
|
||||
episode_reward = 0 |
|
||||
terminated = truncated = False |
|
||||
obs, info = env.reset() |
|
||||
|
|
||||
# while not terminated and not truncated: |
|
||||
# action = algo.compute_single_action(obs) |
|
||||
# obs, reward, terminated, truncated = env.step(action) |
|
||||
|
|
||||
for i in range(30): |
|
||||
result = algo.train() |
|
||||
print(pretty_print(result)) |
|
||||
|
|
||||
if i % 5 == 0: |
|
||||
checkpoint_dir = algo.save() |
|
||||
print(f"Checkpoint saved in directory {checkpoint_dir}") |
|
||||
|
|
||||
|
|
||||
|
|
||||
ray.shutdown() |
|
||||
|
|
||||
if __name__ == '__main__': |
|
||||
main() |
|
Write
Preview
Loading…
Cancel
Save
Reference in new issue