You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
275 lines
8.8 KiB
275 lines
8.8 KiB
from typing import Dict, Optional, Union
|
|||||||
from ray.rllib.env.base_env import BaseEnv
|
|||||||
from ray.rllib.evaluation import RolloutWorker
|
|||||||
from ray.rllib.evaluation.episode import Episode
|
|||||||
from ray.rllib.evaluation.episode_v2 import EpisodeV2
|
|||||||
from ray.rllib.policy import Policy
|
|||||||
from ray.rllib.utils.typing import PolicyID
|
|||||||
import stormpy
|
|||||||
import stormpy.core
|
|||||||
import stormpy.simulator
|
|||||||
|
|||||||
from collections import deque
|
|||||||
|
|||||||
import stormpy.shields
|
|||||||
import stormpy.logic
|
|||||||
|
|||||||
import stormpy.examples
|
|||||||
import stormpy.examples.files
|
|||||||
import os
|
|||||||
|
|||||||
import gymnasium as gym
|
|||||||
|
|||||||
import minigrid
|
|||||||
import numpy as np
|
|||||||
|
|||||||
import ray
|
|||||||
from ray.tune import register_env
|
|||||||
from ray.rllib.algorithms.ppo import PPOConfig
|
|||||||
from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator
|
|||||||
from ray import tune, air
|
|||||||
from ray.rllib.algorithms.callbacks import DefaultCallbacks
|
|||||||
from ray.tune.logger import pretty_print
|
|||||||
from ray.rllib.utils.numpy import one_hot
|
|||||||
from ray.rllib.algorithms import ppo
|
|||||||
|
|||||||
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
|
|||||||
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
|
|||||||
from ray.rllib.utils.torch_utils import FLOAT_MIN
|
|||||||
|
|||||||
from ray.rllib.models.preprocessors import get_preprocessor
|
|||||||
|
|||||||
import matplotlib.pyplot as plt
|
|||||||
|
|||||||
import argparse
|
|||||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
|||||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
|||||||
|
|||||||
from Wrapper import OneHotWrapper
|
|||||||
|
|||||||
|
|||||||
torch, nn = try_import_torch()
|
|||||||
|
|||||||
class TorchActionMaskModel(TorchModelV2, nn.Module):
|
|||||||
"""PyTorch version of above ActionMaskingModel."""
|
|||||||
|
|||||||
def __init__(
|
|||||||
self,
|
|||||||
obs_space,
|
|||||||
action_space,
|
|||||||
num_outputs,
|
|||||||
model_config,
|
|||||||
name,
|
|||||||
**kwargs,
|
|||||||
):
|
|||||||
orig_space = getattr(obs_space, "original_space", obs_space)
|
|||||||
assert (
|
|||||||
isinstance(orig_space, Dict)
|
|||||||
and "action_mask" in orig_space.spaces
|
|||||||
and "observations" in orig_space.spaces
|
|||||||
)
|
|||||||
|
|||||||
TorchModelV2.__init__(
|
|||||||
self, obs_space, action_space, num_outputs, model_config, name, **kwargs
|
|||||||
)
|
|||||||
nn.Module.__init__(self)
|
|||||||
|
|||||||
self.internal_model = TorchFC(
|
|||||||
orig_space["observations"],
|
|||||||
action_space,
|
|||||||
num_outputs,
|
|||||||
model_config,
|
|||||||
name + "_internal",
|
|||||||
)
|
|||||||
|
|||||||
# disable action masking --> will likely lead to invalid actions
|
|||||||
self.no_masking = False
|
|||||||
if "no_masking" in model_config["custom_model_config"]:
|
|||||||
self.no_masking = model_config["custom_model_config"]["no_masking"]
|
|||||||
|
|||||||
def forward(self, input_dict, state, seq_lens):
|
|||||||
# Extract the available actions tensor from the observation.
|
|||||||
action_mask = input_dict["obs"]["action_mask"]
|
|||||||
|
|||||||
# Compute the unmasked logits.
|
|||||||
logits, _ = self.internal_model({"obs": input_dict["obs"]["observations"]})
|
|||||||
|
|||||||
# If action masking is disabled, directly return unmasked logits
|
|||||||
if self.no_masking:
|
|||||||
return logits, state
|
|||||||
|
|||||||
# Convert action_mask into a [0.0 || -inf]-type mask.
|
|||||||
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
|
|||||||
masked_logits = logits + inf_mask
|
|||||||
|
|||||||
# Return masked logits.
|
|||||||
return masked_logits, state
|
|||||||
|
|||||||
def value_function(self):
|
|||||||
return self.internal_model.value_function()
|
|||||||
|
|||||||
|
|||||||
|
|||||||
class MyCallbacks(DefaultCallbacks):
|
|||||||
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
|
|||||||
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
|
|||||||
env = base_env.get_sub_environments()[0]
|
|||||||
episode.user_data["count"] = 0
|
|||||||
# print(env.printGrid())
|
|||||||
# print(env.action_space.n)
|
|||||||
# print(env.actions)
|
|||||||
# print(env.mission)
|
|||||||
# print(env.observation_space)
|
|||||||
# img = env.get_frame()
|
|||||||
# plt.imshow(img)
|
|||||||
# plt.show()
|
|||||||
|
|||||||
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
|
|||||||
episode.user_data["count"] = episode.user_data["count"] + 1
|
|||||||
env = base_env.get_sub_environments()[0]
|
|||||||
# print(env.env.env.printGrid())
|
|||||||
|
|||||||
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
|
|||||||
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
|
|||||||
env = base_env.get_sub_environments()[0]
|
|||||||
# print(env.env.env.printGrid())
|
|||||||
# print(episode.user_data["count"])
|
|||||||
|
|||||||
|
|||||||
|
|||||||
|
|||||||
|
|||||||
def parse_arguments(argparse):
|
|||||||
parser = argparse.ArgumentParser()
|
|||||||
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
|
|||||||
parser.add_argument("--env", help="gym environment to load", default="MiniGrid-LavaCrossingS9N1-v0")
|
|||||||
parser.add_argument("--seed", type=int, help="seed for environment", default=1)
|
|||||||
parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32)
|
|||||||
parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees")
|
|||||||
parser.add_argument("--grid_path", default="Grid.txt")
|
|||||||
parser.add_argument("--prism_path", default="Grid.PRISM")
|
|||||||
|
|||||||
args = parser.parse_args()
|
|||||||
|
|||||||
return args
|
|||||||
|
|||||||
|
|||||||
def env_creater(config):
|
|||||||
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
|
|||||||
# name = config.get("name", "MiniGrid-Empty-8x8-v0")
|
|||||||
framestack = config.get("framestack", 4)
|
|||||||
|
|||||||
env = gym.make(name)
|
|||||||
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env)
|
|||||||
env = minigrid.wrappers.ImgObsWrapper(env)
|
|||||||
env = OneHotWrapper(env,
|
|||||||
config.vector_index if hasattr(config, "vector_index") else 0,
|
|||||||
framestack=framestack
|
|||||||
)
|
|||||||
|
|||||||
|
|||||||
return env
|
|||||||
|
|||||||
|
|||||||
def create_shield(grid_file, prism_path):
|
|||||||
os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}")
|
|||||||
|
|||||||
f = open(prism_path, "a")
|
|||||||
f.write("label \"AgentIsInLava\" = AgentIsInLava;")
|
|||||||
f.close()
|
|||||||
|
|||||||
|
|||||||
program = stormpy.parse_prism_program(prism_path)
|
|||||||
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
|
|||||||
|
|||||||
formulas = stormpy.parse_properties_for_prism_program(formula_str, program)
|
|||||||
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
|
|||||||
options.set_build_state_valuations(True)
|
|||||||
options.set_build_choice_labels(True)
|
|||||||
options.set_build_all_labels()
|
|||||||
model = stormpy.build_sparse_model_with_options(program, options)
|
|||||||
|
|||||||
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)
|
|||||||
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
|
|||||||
|
|||||||
assert result.has_scheduler
|
|||||||
assert result.has_shield
|
|||||||
shield = result.shield
|
|||||||
|
|||||||
stormpy.shields.export_shield(model, shield,"Grid.shield")
|
|||||||
|
|||||||
return shield.construct(), model
|
|||||||
|
|||||||
def export_grid_to_text(env, grid_file):
|
|||||||
f = open(grid_file, "w")
|
|||||||
print(env)
|
|||||||
f.write(env.printGrid(init=True))
|
|||||||
# f.write(env.pprint_grid())
|
|||||||
f.close()
|
|||||||
|
|||||||
def create_environment(args):
|
|||||||
env_id= args.env
|
|||||||
env = gym.make(env_id)
|
|||||||
env.reset()
|
|||||||
return env
|
|||||||
|
|||||||
|
|||||||
def main():
|
|||||||
args = parse_arguments(argparse)
|
|||||||
|
|||||||
env = create_environment(args)
|
|||||||
ray.init(num_cpus=3)
|
|||||||
|
|||||||
# print(env.pprint_grid())
|
|||||||
# print(env.printGrid(init=False))
|
|||||||
|
|||||||
grid_file = args.grid_path
|
|||||||
export_grid_to_text(env, grid_file)
|
|||||||
|
|||||||
prism_path = args.prism_path
|
|||||||
shield, model = create_shield(grid_file, prism_path)
|
|||||||
|
|||||||
for state_id in model.states:
|
|||||||
choices = shield.get_choice(state_id)
|
|||||||
print(F"Allowed choices in state {state_id}, are {choices.choice_map} ")
|
|||||||
|
|||||||
env_name = "mini-grid"
|
|||||||
register_env(env_name, env_creater)
|
|||||||
|
|||||||
|
|||||||
algo =(
|
|||||||
PPOConfig()
|
|||||||
.rollouts(num_rollout_workers=1)
|
|||||||
.resources(num_gpus=0)
|
|||||||
.environment(env="mini-grid")
|
|||||||
.framework("torch")
|
|||||||
.callbacks(MyCallbacks)
|
|||||||
.training(model={
|
|||||||
"fcnet_hiddens": [256,256],
|
|||||||
"fcnet_activation": "relu",
|
|||||||
|
|||||||
})
|
|||||||
.build()
|
|||||||
)
|
|||||||
episode_reward = 0
|
|||||||
terminated = truncated = False
|
|||||||
obs, info = env.reset()
|
|||||||
|
|||||||
# while not terminated and not truncated:
|
|||||||
# action = algo.compute_single_action(obs)
|
|||||||
# obs, reward, terminated, truncated = env.step(action)
|
|||||||
|
|||||||
for i in range(30):
|
|||||||
result = algo.train()
|
|||||||
print(pretty_print(result))
|
|||||||
|
|||||||
if i % 5 == 0:
|
|||||||
checkpoint_dir = algo.save()
|
|||||||
print(f"Checkpoint saved in directory {checkpoint_dir}")
|
|||||||
|
|||||||
|
|||||||
|
|||||||
ray.shutdown()
|
|||||||
|
|||||||
if __name__ == '__main__':
|
|||||||
HTTP/1.1 200 OK
Content-Type: text/html; charset=UTF-8
Set-Cookie: i_like_gitea=f85b620387b6ef86; Path=/; HttpOnly; SameSite=Lax
Set-Cookie: _csrf=ahrN0tp_wsPvn1NKjgC3CK9zVVY6MTczMjM5NTU0MTU0NjU5OTc2MA; Path=/; Expires=Sun, 24 Nov 2024 20:59:01 GMT; HttpOnly; SameSite=Lax
Set-Cookie: macaron_flash=; Path=/; Max-Age=0; HttpOnly; SameSite=Lax
X-Frame-Options: SAMEORIGIN
Date: Sat, 23 Nov 2024 20:59:01 GMT
Transfer-Encoding: chunked
6d39
|