From 4e182a8e5b1a6d916cd813058e9d6dc0b6e9d390 Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Thu, 7 Sep 2023 15:58:18 +0200 Subject: [PATCH] removed basic training --- examples/shields/rl/12_basic_training.py | 275 ----------------------- examples/shields/rl/14_train_eval.py | 9 +- 2 files changed, 6 insertions(+), 278 deletions(-) delete mode 100644 examples/shields/rl/12_basic_training.py diff --git a/examples/shields/rl/12_basic_training.py b/examples/shields/rl/12_basic_training.py deleted file mode 100644 index 40759b2..0000000 --- a/examples/shields/rl/12_basic_training.py +++ /dev/null @@ -1,275 +0,0 @@ -from typing import Dict, Optional, Union -from ray.rllib.env.base_env import BaseEnv -from ray.rllib.evaluation import RolloutWorker -from ray.rllib.evaluation.episode import Episode -from ray.rllib.evaluation.episode_v2 import EpisodeV2 -from ray.rllib.policy import Policy -from ray.rllib.utils.typing import PolicyID -import stormpy -import stormpy.core -import stormpy.simulator - -from collections import deque - -import stormpy.shields -import stormpy.logic - -import stormpy.examples -import stormpy.examples.files -import os - -import gymnasium as gym - -import minigrid -import numpy as np - -import ray -from ray.tune import register_env -from ray.rllib.algorithms.ppo import PPOConfig -from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator -from ray import tune, air -from ray.rllib.algorithms.callbacks import DefaultCallbacks -from ray.tune.logger import pretty_print -from ray.rllib.utils.numpy import one_hot -from ray.rllib.algorithms import ppo - -from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC -from ray.rllib.models.tf.fcnet import FullyConnectedNetwork -from ray.rllib.utils.torch_utils import FLOAT_MIN - -from ray.rllib.models.preprocessors import get_preprocessor - -import matplotlib.pyplot as plt - -import argparse -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.utils.framework import try_import_tf, try_import_torch - -from examples.shields.rl.Wrappers import OneHotShieldingWrapper - - -torch, nn = try_import_torch() - -class TorchActionMaskModel(TorchModelV2, nn.Module): - """PyTorch version of above ActionMaskingModel.""" - - def __init__( - self, - obs_space, - action_space, - num_outputs, - model_config, - name, - **kwargs, - ): - orig_space = getattr(obs_space, "original_space", obs_space) - assert ( - isinstance(orig_space, Dict) - and "action_mask" in orig_space.spaces - and "observations" in orig_space.spaces - ) - - TorchModelV2.__init__( - self, obs_space, action_space, num_outputs, model_config, name, **kwargs - ) - nn.Module.__init__(self) - - self.internal_model = TorchFC( - orig_space["observations"], - action_space, - num_outputs, - model_config, - name + "_internal", - ) - - # disable action masking --> will likely lead to invalid actions - self.no_masking = False - if "no_masking" in model_config["custom_model_config"]: - self.no_masking = model_config["custom_model_config"]["no_masking"] - - def forward(self, input_dict, state, seq_lens): - # Extract the available actions tensor from the observation. - action_mask = input_dict["obs"]["action_mask"] - - # Compute the unmasked logits. - logits, _ = self.internal_model({"obs": input_dict["obs"]["observations"]}) - - # If action masking is disabled, directly return unmasked logits - if self.no_masking: - return logits, state - - # Convert action_mask into a [0.0 || -inf]-type mask. - inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) - masked_logits = logits + inf_mask - - # Return masked logits. - return masked_logits, state - - def value_function(self): - return self.internal_model.value_function() - - - -class MyCallbacks(DefaultCallbacks): - def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: - # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") - env = base_env.get_sub_environments()[0] - episode.user_data["count"] = 0 - # print(env.printGrid()) - # print(env.action_space.n) - # print(env.actions) - # print(env.mission) - # print(env.observation_space) - # img = env.get_frame() - # plt.imshow(img) - # plt.show() - - def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: - episode.user_data["count"] = episode.user_data["count"] + 1 - env = base_env.get_sub_environments()[0] - # print(env.env.env.printGrid()) - - def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: - # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") - env = base_env.get_sub_environments()[0] - # print(env.env.env.printGrid()) - # print(episode.user_data["count"]) - - - - - -def parse_arguments(argparse): - parser = argparse.ArgumentParser() - # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0") - parser.add_argument("--env", help="gym environment to load", default="MiniGrid-LavaCrossingS9N1-v0") - parser.add_argument("--seed", type=int, help="seed for environment", default=1) - parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32) - parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees") - parser.add_argument("--grid_path", default="Grid.txt") - parser.add_argument("--prism_path", default="Grid.PRISM") - - args = parser.parse_args() - - return args - - -def env_creater(config): - name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") - # name = config.get("name", "MiniGrid-Empty-8x8-v0") - framestack = config.get("framestack", 4) - - env = gym.make(name) - # env = minigrid.wrappers.RGBImgPartialObsWrapper(env) - env = minigrid.wrappers.ImgObsWrapper(env) - env = OneHotShieldingWrapper(env, - config.vector_index if hasattr(config, "vector_index") else 0, - framestack=framestack - ) - - - return env - - -def create_shield(grid_file, prism_path): - os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}") - - f = open(prism_path, "a") - f.write("label \"AgentIsInLava\" = AgentIsInLava;") - f.close() - - - program = stormpy.parse_prism_program(prism_path) - formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" - - formulas = stormpy.parse_properties_for_prism_program(formula_str, program) - options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) - options.set_build_state_valuations(True) - options.set_build_choice_labels(True) - options.set_build_all_labels() - model = stormpy.build_sparse_model_with_options(program, options) - - shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) - result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) - - assert result.has_scheduler - assert result.has_shield - shield = result.shield - - stormpy.shields.export_shield(model, shield,"Grid.shield") - - return shield.construct(), model - -def export_grid_to_text(env, grid_file): - f = open(grid_file, "w") - print(env) - f.write(env.printGrid(init=True)) - # f.write(env.pprint_grid()) - f.close() - -def create_environment(args): - env_id= args.env - env = gym.make(env_id) - env.reset() - return env - - -def main(): - args = parse_arguments(argparse) - - env = create_environment(args) - ray.init(num_cpus=3) - - # print(env.pprint_grid()) - # print(env.printGrid(init=False)) - - grid_file = args.grid_path - export_grid_to_text(env, grid_file) - - prism_path = args.prism_path - shield, model = create_shield(grid_file, prism_path) - - for state_id in model.states: - choices = shield.get_choice(state_id) - print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") - - env_name = "mini-grid" - register_env(env_name, env_creater) - - - algo =( - PPOConfig() - .rollouts(num_rollout_workers=1) - .resources(num_gpus=0) - .environment(env="mini-grid") - .framework("torch") - .callbacks(MyCallbacks) - .training(model={ - "fcnet_hiddens": [256,256], - "fcnet_activation": "relu", - - }) - .build() - ) - episode_reward = 0 - terminated = truncated = False - obs, info = env.reset() - - # while not terminated and not truncated: - # action = algo.compute_single_action(obs) - # obs, reward, terminated, truncated = env.step(action) - - for i in range(30): - result = algo.train() - print(pretty_print(result)) - - if i % 5 == 0: - checkpoint_dir = algo.save() - print(f"Checkpoint saved in directory {checkpoint_dir}") - - - - ray.shutdown() - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/examples/shields/rl/14_train_eval.py b/examples/shields/rl/14_train_eval.py index 30cf831..bba2db1 100644 --- a/examples/shields/rl/14_train_eval.py +++ b/examples/shields/rl/14_train_eval.py @@ -102,8 +102,10 @@ def ppo(args): if i % 5 == 0: algo.save() - writer = SummaryWriter(log_dir=F"{create_log_dir(args)}-eval") - csv_logger = CSVLogger() + eval_log_dir = F"{create_log_dir(args)}-eval" + + writer = SummaryWriter(log_dir=eval_log_dir) + csv_logger = CSVLogger(config=config, logdir=eval_log_dir) for i in range(iterations): eval_result = algo.evaluate() @@ -111,6 +113,7 @@ def ppo(args): print(eval_result) # logger.on_result(eval_result) + csv_logger.on_result(eval_result) evaluation = eval_result['evaluation'] epsiode_reward_mean = evaluation['episode_reward_mean'] @@ -118,7 +121,7 @@ def ppo(args): print(epsiode_reward_mean) writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i) writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i) - + writer.close()