Browse Source

initial layout for rl test

refactoring
Thomas Knoll 1 year ago
parent
commit
dd9dd43036
  1. 267
      examples/shields/12_basic_training.py

267
examples/shields/12_basic_training.py

@ -0,0 +1,267 @@
from typing import Dict, Optional, Union
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID
import stormpy
import stormpy.core
import stormpy.simulator
from collections import deque
import stormpy.shields
import stormpy.logic
import stormpy.examples
import stormpy.examples.files
import os
import gymnasium as gym
import minigrid
import numpy as np
import ray
from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator
from ray import tune, air
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print
from ray.rllib.utils.numpy import one_hot
from ray.rllib.algorithms import ppo
from ray.rllib.models.preprocessors import get_preprocessor
import matplotlib.pyplot as plt
import argparse
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
episode.user_data["count"] = 0
# print(env.printGrid())
# print(env.action_space.n)
# print(env.actions)
# print(env.mission)
# print(env.observation_space)
# img = env.get_frame()
# plt.imshow(img)
# plt.show()
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0]
# print(env.env.env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
# print(env.env.env.printGrid())
# print(episode.user_data["count"])
class OneHotWrapper(gym.core.ObservationWrapper):
def __init__(self, env, vector_index, framestack):
super().__init__(env)
self.framestack = framestack
# 49=7x7 field of vision; 11=object types; 6=colors; 3=state types.
# +4: Direction.
self.single_frame_dim = 49 * (11 + 6 + 3) + 4
self.init_x = None
self.init_y = None
self.x_positions = []
self.y_positions = []
self.x_y_delta_buffer = deque(maxlen=100)
self.vector_index = vector_index
self.frame_buffer = deque(maxlen=self.framestack)
for _ in range(self.framestack):
self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
self.observation_space = gym.spaces.Box(
0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32
)
def observation(self, obs):
# Debug output: max-x/y positions to watch exploration progress.
if self.step_count == 0:
for _ in range(self.framestack):
self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
if self.vector_index == 0:
if self.x_positions:
max_diff = max(
np.sqrt(
(np.array(self.x_positions) - self.init_x) ** 2
+ (np.array(self.y_positions) - self.init_y) ** 2
)
)
self.x_y_delta_buffer.append(max_diff)
print(
"100-average dist travelled={}".format(
np.mean(self.x_y_delta_buffer)
)
)
self.x_positions = []
self.y_positions = []
self.init_x = self.agent_pos[0]
self.init_y = self.agent_pos[1]
self.x_positions.append(self.agent_pos[0])
self.y_positions.append(self.agent_pos[1])
# One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten.
objects = one_hot(obs[:, :, 0], depth=11)
colors = one_hot(obs[:, :, 1], depth=6)
states = one_hot(obs[:, :, 2], depth=3)
all_ = np.concatenate([objects, colors, states], -1)
all_flat = np.reshape(all_, (-1,))
direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)
single_frame = np.concatenate([all_flat, direction])
self.frame_buffer.append(single_frame)
return np.concatenate(self.frame_buffer)
def parse_arguments(argparse):
parser = argparse.ArgumentParser()
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
parser.add_argument("--env", help="gym environment to load", default="MiniGrid-LavaCrossingS9N1-v0")
parser.add_argument("--seed", type=int, help="seed for environment", default=1)
parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32)
parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees")
parser.add_argument("--grid_path", default="Grid.txt")
parser.add_argument("--prism_path", default="Grid.PRISM")
args = parser.parse_args()
return args
def env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
# name = config.get("name", "MiniGrid-Empty-8x8-v0")
framestack = config.get("framestack", 4)
env = gym.make(name)
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env)
env = minigrid.wrappers.ImgObsWrapper(env)
env = OneHotWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
def create_shield(grid_file, prism_path):
os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}")
f = open(prism_path, "a")
f.write("label \"AgentIsInLava\" = AgentIsInLava;")
f.close()
program = stormpy.parse_prism_program(prism_path)
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
formulas = stormpy.parse_properties_for_prism_program(formula_str, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
options.set_build_state_valuations(True)
options.set_build_choice_labels(True)
options.set_build_all_labels()
model = stormpy.build_sparse_model_with_options(program, options)
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
assert result.has_scheduler
assert result.has_shield
shield = result.shield
stormpy.shields.export_shield(model, shield,"Grid.shield")
return shield.construct(), model
def export_grid_to_text(env, grid_file):
f = open(grid_file, "w")
print(env)
f.write(env.printGrid(init=True))
# f.write(env.pprint_grid())
f.close()
def create_environment(args):
env_id= args.env
env = gym.make(env_id)
env.reset()
return env
def main():
args = parse_arguments(argparse)
env = create_environment(args)
ray.init(num_cpus=3)
# print(env.pprint_grid())
# print(env.printGrid(init=False))
grid_file = args.grid_path
export_grid_to_text(env, grid_file)
prism_path = args.prism_path
shield, model = create_shield(grid_file, prism_path)
for state_id in model.states:
choices = shield.get_choice(state_id)
print(F"Allowed choices in state {state_id}, are {choices.choice_map} ")
env_name = "mini-grid"
register_env(env_name, env_creater)
algo =(
PPOConfig()
.rollouts(num_rollout_workers=1)
.resources(num_gpus=0)
.environment(env="mini-grid")
.framework("torch")
.callbacks(MyCallbacks)
.training(model={
"fcnet_hiddens": [256,256],
"fcnet_activation": "relu",
})
.build()
)
episode_reward = 0
terminated = truncated = False
obs, info = env.reset()
# while not terminated and not truncated:
# action = algo.compute_single_action(obs)
# obs, reward, terminated, truncated = env.step(action)
for i in range(30):
result = algo.train()
print(pretty_print(result))
if i % 5 == 0:
checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}")
ray.shutdown()
if __name__ == '__main__':
main()
Loading…
Cancel
Save