You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

315 lines
10 KiB

from typing import Dict, Optional, Union
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID
import stormpy
import stormpy.core
import stormpy.simulator
import stormpy.shields
import stormpy.logic
import stormpy.examples
import stormpy.examples.files
import os
import gymnasium as gym
import minigrid
import numpy as np
import ray
from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig
from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator
from ray import tune, air
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.models.preprocessors import get_preprocessor
from MaskEnvironments import ParametricActionsMiniGridEnv
from MaskModels import TorchActionMaskModel
from Wrapper import OneHotWrapper, MiniGridEnvWrapper
import matplotlib.pyplot as plt
import argparse
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
episode.user_data["count"] = 0
# print(env.printGrid())
# print(env.action_space.n)
# print(env.actions)
# print(env.mission)
# print(env.observation_space)
# img = env.get_frame()
# plt.imshow(img)
# plt.show()
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0]
#print(env.env.env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
# print(env.env.env.printGrid())
# print(episode.user_data["count"])
def parse_arguments(argparse):
parser = argparse.ArgumentParser()
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
parser.add_argument("--env", help="gym environment to load", default="MiniGrid-LavaCrossingS9N1-v0")
parser.add_argument("--seed", type=int, help="seed for environment", default=1)
parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32)
parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees")
parser.add_argument("--grid_path", default="Grid.txt")
parser.add_argument("--prism_path", default="Grid.PRISM")
parser.add_argument("--no_masking", default=False)
parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"])
args = parser.parse_args()
return args
def env_creater_custom(config):
# name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
# # name = config.get("name", "MiniGrid-Empty-8x8-v0")
framestack = config.get("framestack", 4)
shield = config.get("shield", {})
# env = gym.make(name)
# env = ParametricActionsMiniGridEnv(config)
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
env = gym.make(name)
env = MiniGridEnvWrapper(env, shield=shield)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env)
env = OneHotWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
# obs = env.observation_space.sample()
# obs2, infos = env.reset(seed=None, options={})
# print(F"Obs is {obs} before reset. After reset: {obs2}")
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env)
# print(F"Created Custom Minigrid Environment is {env}")
return env
def env_creater_cart(config):
return gym.make("CartPole-v1")
def env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
# name = config.get("name", "MiniGrid-Empty-8x8-v0")
framestack = config.get("framestack", 4)
env = gym.make(name)
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env)
env = minigrid.wrappers.ImgObsWrapper(env)
env = OneHotWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
print(F"Created Minigrid Environment is {env}")
return env
def create_shield(grid_file, prism_path):
os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}")
f = open(prism_path, "a")
f.write("label \"AgentIsInLava\" = AgentIsInLava;")
f.close()
program = stormpy.parse_prism_program(prism_path)
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
formulas = stormpy.parse_properties_for_prism_program(formula_str, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
options.set_build_state_valuations(True)
options.set_build_choice_labels(True)
options.set_build_all_labels()
model = stormpy.build_sparse_model_with_options(program, options)
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
assert result.has_scheduler
assert result.has_shield
shield = result.shield
action_dictionary = {}
shield_scheduler = shield.construct()
for stateID in model.states:
choice = shield_scheduler.get_choice(stateID)
choices = choice.choice_map
state_valuation = model.state_valuations.get_string(stateID)
actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
action_dictionary[state_valuation] = actions_to_be_executed
stormpy.shields.export_shield(model, shield, "Grid.shield")
return action_dictionary
def export_grid_to_text(env, grid_file):
f = open(grid_file, "w")
# print(env)
f.write(env.printGrid(init=True))
# f.write(env.pprint_grid())
f.close()
def create_environment(args):
env_id= args.env
env = gym.make(env_id)
env.reset()
return env
def register_custom_minigrid_env(args):
env_name = "mini-grid"
register_env(env_name, env_creater_custom)
ModelCatalog.register_custom_model(
"pa_model",
TorchActionMaskModel
)
def create_shield_dict(args):
env = create_environment(args)
# print(env.pprint_grid())
# print(env.printGrid(init=False))
grid_file = args.grid_path
export_grid_to_text(env, grid_file)
prism_path = args.prism_path
shield_dict = create_shield(grid_file, prism_path)
#shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states}
#print(F"Shield dictionary {shield_dict}")
# for state_id in model.states:
# choices = shield.get_choice(state_id)
# print(F"Allowed choices in state {state_id}, are {choices.choice_map} ")
return shield_dict
def ppo(args):
ray.init(num_cpus=3)
register_custom_minigrid_env(args)
shield_dict = create_shield_dict(args)
config = (PPOConfig()
.rollouts(num_rollout_workers=1)
.resources(num_gpus=0)
.environment(env="mini-grid", env_config={"shield": shield_dict })
.framework("torch")
.callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False)
.training(_enable_learner_api=False ,model={
"custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
# "fcnet_hiddens": [256,256],
# "fcnet_activation": "relu",
}))
algo =(
config.build()
)
# while not terminated and not truncated:
# action = algo.compute_single_action(obs)
# obs, reward, terminated, truncated = env.step(action)
for i in range(30):
result = algo.train()
print(pretty_print(result))
if i % 5 == 0:
checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}")
ray.shutdown()
def dqn(args):
register_custom_minigrid_env(args)
shield_dict = create_shield_dict(args)
config = DQNConfig()
config = config.resources(num_gpus=0)
config = config.rollouts(num_rollout_workers=1)
config = config.environment(env="mini-grid", env_config={"shield": shield_dict })
config = config.framework("torch")
config = config.callbacks(MyCallbacks)
config = config.rl_module(_enable_rl_module_api = False)
config = config.training(hiddens=[], dueling=False, model={
"custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
})
algo = (
config.build()
)
for i in range(30):
result = algo.train()
print(pretty_print(result))
if i % 5 == 0:
checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}")
ray.shutdown()
def main():
args = parse_arguments(argparse)
if args.algorithm == "ppo":
ppo(args)
elif args.algorithm == "dqn":
dqn(args)
if __name__ == '__main__':
main()