from typing import Dict, Optional, Union from ray.rllib.env.base_env import BaseEnv from ray.rllib.evaluation import RolloutWorker from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.episode_v2 import EpisodeV2 from ray.rllib.policy import Policy from ray.rllib.utils.typing import PolicyID import stormpy import stormpy.core import stormpy.simulator from collections import deque import stormpy.shields import stormpy.logic import stormpy.examples import stormpy.examples.files import os import gymnasium as gym import minigrid import numpy as np import ray from ray.tune import register_env from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator from ray import tune, air from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.tune.logger import pretty_print from ray.rllib.utils.numpy import one_hot from ray.rllib.algorithms import ppo from ray.rllib.models.preprocessors import get_preprocessor import matplotlib.pyplot as plt import argparse class MyCallbacks(DefaultCallbacks): def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") env = base_env.get_sub_environments()[0] episode.user_data["count"] = 0 # print(env.printGrid()) # print(env.action_space.n) # print(env.actions) # print(env.mission) # print(env.observation_space) # img = env.get_frame() # plt.imshow(img) # plt.show() def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: episode.user_data["count"] = episode.user_data["count"] + 1 env = base_env.get_sub_environments()[0] # print(env.env.env.printGrid()) def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") env = base_env.get_sub_environments()[0] # print(env.env.env.printGrid()) # print(episode.user_data["count"]) class OneHotWrapper(gym.core.ObservationWrapper): def __init__(self, env, vector_index, framestack): super().__init__(env) self.framestack = framestack # 49=7x7 field of vision; 11=object types; 6=colors; 3=state types. # +4: Direction. self.single_frame_dim = 49 * (11 + 6 + 3) + 4 self.init_x = None self.init_y = None self.x_positions = [] self.y_positions = [] self.x_y_delta_buffer = deque(maxlen=100) self.vector_index = vector_index self.frame_buffer = deque(maxlen=self.framestack) for _ in range(self.framestack): self.frame_buffer.append(np.zeros((self.single_frame_dim,))) self.observation_space = gym.spaces.Box( 0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32 ) def observation(self, obs): # Debug output: max-x/y positions to watch exploration progress. if self.step_count == 0: for _ in range(self.framestack): self.frame_buffer.append(np.zeros((self.single_frame_dim,))) if self.vector_index == 0: if self.x_positions: max_diff = max( np.sqrt( (np.array(self.x_positions) - self.init_x) ** 2 + (np.array(self.y_positions) - self.init_y) ** 2 ) ) self.x_y_delta_buffer.append(max_diff) print( "100-average dist travelled={}".format( np.mean(self.x_y_delta_buffer) ) ) self.x_positions = [] self.y_positions = [] self.init_x = self.agent_pos[0] self.init_y = self.agent_pos[1] self.x_positions.append(self.agent_pos[0]) self.y_positions.append(self.agent_pos[1]) # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten. objects = one_hot(obs[:, :, 0], depth=11) colors = one_hot(obs[:, :, 1], depth=6) states = one_hot(obs[:, :, 2], depth=3) all_ = np.concatenate([objects, colors, states], -1) all_flat = np.reshape(all_, (-1,)) direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32) single_frame = np.concatenate([all_flat, direction]) self.frame_buffer.append(single_frame) return np.concatenate(self.frame_buffer) def parse_arguments(argparse): parser = argparse.ArgumentParser() # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0") parser.add_argument("--env", help="gym environment to load", default="MiniGrid-LavaCrossingS9N1-v0") parser.add_argument("--seed", type=int, help="seed for environment", default=1) parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32) parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees") parser.add_argument("--grid_path", default="Grid.txt") parser.add_argument("--prism_path", default="Grid.PRISM") args = parser.parse_args() return args def env_creater(config): name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") # name = config.get("name", "MiniGrid-Empty-8x8-v0") framestack = config.get("framestack", 4) env = gym.make(name) # env = minigrid.wrappers.RGBImgPartialObsWrapper(env) env = minigrid.wrappers.ImgObsWrapper(env) env = OneHotWrapper(env, config.vector_index if hasattr(config, "vector_index") else 0, framestack=framestack ) return env def create_shield(grid_file, prism_path): os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}") f = open(prism_path, "a") f.write("label \"AgentIsInLava\" = AgentIsInLava;") f.close() program = stormpy.parse_prism_program(prism_path) formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" formulas = stormpy.parse_properties_for_prism_program(formula_str, program) options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) options.set_build_state_valuations(True) options.set_build_choice_labels(True) options.set_build_all_labels() model = stormpy.build_sparse_model_with_options(program, options) shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) assert result.has_scheduler assert result.has_shield shield = result.shield stormpy.shields.export_shield(model, shield,"Grid.shield") return shield.construct(), model def export_grid_to_text(env, grid_file): f = open(grid_file, "w") print(env) f.write(env.printGrid(init=True)) # f.write(env.pprint_grid()) f.close() def create_environment(args): env_id= args.env env = gym.make(env_id) env.reset() return env def main(): args = parse_arguments(argparse) env = create_environment(args) ray.init(num_cpus=3) # print(env.pprint_grid()) # print(env.printGrid(init=False)) grid_file = args.grid_path export_grid_to_text(env, grid_file) prism_path = args.prism_path shield, model = create_shield(grid_file, prism_path) for state_id in model.states: choices = shield.get_choice(state_id) print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") env_name = "mini-grid" register_env(env_name, env_creater) algo =( PPOConfig() .rollouts(num_rollout_workers=1) .resources(num_gpus=0) .environment(env="mini-grid") .framework("torch") .callbacks(MyCallbacks) .training(model={ "fcnet_hiddens": [256,256], "fcnet_activation": "relu", }) .build() ) episode_reward = 0 terminated = truncated = False obs, info = env.reset() # while not terminated and not truncated: # action = algo.compute_single_action(obs) # obs, reward, terminated, truncated = env.step(action) for i in range(30): result = algo.train() print(pretty_print(result)) if i % 5 == 0: checkpoint_dir = algo.save() print(f"Checkpoint saved in directory {checkpoint_dir}") ray.shutdown() if __name__ == '__main__': main()