Browse Source

changed action handling for probabilities

refactoring
Thomas Knoll 1 year ago
parent
commit
604d2c2b76
  1. 19
      examples/shields/rl/helpers.py
  2. 13
      examples/shields/rl/shieldhandlers.py
  3. 20
      examples/shields/rl/wrappers.py

19
examples/shields/rl/helpers.py

@ -39,29 +39,22 @@ def extract_keys(env):
return keys return keys
def create_log_dir(args): def create_log_dir(args):
return F"{args.log_dir}{args.algorithm}-shielding:{args.shielding}-evaluations:{args.evaluations}-steps:{args.steps}"
return F"{args.log_dir}{args.algorithm}-shielding:{args.shielding}-evaluations:{args.evaluations}-steps:{args.steps}-env:{args.env}"
def get_action_index_mapping(actions): def get_action_index_mapping(actions):
for action_str in actions: for action_str in actions:
if "left" in action_str:
if "move" in action_str:
return Actions.forward
elif "left" in action_str:
return Actions.left return Actions.left
elif "right" in action_str: elif "right" in action_str:
return Actions.right return Actions.right
elif "east" in action_str:
return Actions.forward
elif "south" in action_str:
return Actions.forward
elif "west" in action_str:
return Actions.forward
elif "north" in action_str:
return Actions.forward
elif "pickup" in action_str: elif "pickup" in action_str:
return Actions.pickup return Actions.pickup
elif "done" in action_str: elif "done" in action_str:
return Actions.done return Actions.done
raise ValueError(F"Action string {action_str} not supported") raise ValueError(F"Action string {action_str} not supported")
@ -75,6 +68,10 @@ def parse_arguments(argparse):
choices=[ choices=[
"MiniGrid-LavaCrossingS9N1-v0", "MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-LavaCrossingS9N3-v0", "MiniGrid-LavaCrossingS9N3-v0",
"MiniGrid-LavaSlipperyS12-v0",
"MiniGrid-LavaSlipperyS12-v1",
"MiniGrid-LavaSlipperyS12-v2",
"MiniGrid-LavaSlipperyS12-v3",
# "MiniGrid-DoorKey-8x8-v0", # "MiniGrid-DoorKey-8x8-v0",
# "MiniGrid-LockedRoom-v0", # "MiniGrid-LockedRoom-v0",
# "MiniGrid-FourRooms-v0", # "MiniGrid-FourRooms-v0",

13
examples/shields/rl/shieldhandlers.py

@ -12,6 +12,12 @@ from abc import ABC
import os import os
class Action():
def __init__(self, idx, prob=1, labels=[]) -> None:
self.idx = idx
self.prob = prob
self.labels = labels
class ShieldHandler(ABC): class ShieldHandler(ABC):
def __init__(self) -> None: def __init__(self) -> None:
pass pass
@ -41,6 +47,7 @@ class MiniGridShieldHandler(ShieldHandler):
f.close() f.close()
def __create_shield_dict(self): def __create_shield_dict(self):
print(self.prism_path)
program = stormpy.parse_prism_program(self.prism_path) program = stormpy.parse_prism_program(self.prism_path)
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)
@ -56,6 +63,7 @@ class MiniGridShieldHandler(ShieldHandler):
assert result.has_scheduler assert result.has_scheduler
assert result.has_shield assert result.has_shield
shield = result.shield shield = result.shield
stormpy.shields.export_shield(model, shield, "Grid.shield")
action_dictionary = {} action_dictionary = {}
shield_scheduler = shield.construct() shield_scheduler = shield.construct()
@ -65,12 +73,11 @@ class MiniGridShieldHandler(ShieldHandler):
choices = choice.choice_map choices = choice.choice_map
state_valuation = model.state_valuations.get_string(stateID) state_valuation = model.state_valuations.get_string(stateID)
actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
#actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
actions_to_be_executed = [Action(idx= choice[1], prob=choice[0], labels=model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
action_dictionary[state_valuation] = actions_to_be_executed action_dictionary[state_valuation] = actions_to_be_executed
# stormpy.shields.export_shield(model, shield, "Grid.shield")
return action_dictionary return action_dictionary

20
examples/shields/rl/wrappers.py

@ -1,5 +1,6 @@
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
import random
from minigrid.core.actions import Actions from minigrid.core.actions import Actions
@ -15,9 +16,9 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper):
def __init__(self, env, vector_index, framestack): def __init__(self, env, vector_index, framestack):
super().__init__(env) super().__init__(env)
self.framestack = framestack self.framestack = framestack
# 49=7x7 field of vision; 11=object types; 6=colors; 3=state types.
# 49=7x7 field of vision; 16=object types; 6=colors; 3=state types.
# +4: Direction. # +4: Direction.
self.single_frame_dim = 49 * (11 + 6 + 3) + 4
self.single_frame_dim = 49 * (16 + 6 + 3) + 4
self.init_x = None self.init_x = None
self.init_y = None self.init_y = None
self.x_positions = [] self.x_positions = []
@ -66,8 +67,8 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper):
image = obs["data"] image = obs["data"]
# One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten.
objects = one_hot(image[:, :, 0], depth=11)
# One-hot the last dim into 12, 6, 3 one-hot vectors, then flatten.
objects = one_hot(image[:, :, 0], depth=16)
colors = one_hot(image[:, :, 1], depth=6) colors = one_hot(image[:, :, 1], depth=6)
states = one_hot(image[:, :, 2], depth=3) states = one_hot(image[:, :, 2], depth=3)
@ -117,10 +118,13 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
if cur_pos_str in self.shield and self.shield[cur_pos_str]: if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str] allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions: for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action[1]) # Allowed_action is a set
index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set
if index is None: if index is None:
assert(False) assert(False)
mask[index] = 1.0
allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
mask[index] = allowed
else: else:
for index, x in enumerate(mask): for index, x in enumerate(mask):
mask[index] = 1.0 mask[index] = 1.0
@ -197,11 +201,11 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper):
if cur_pos_str in self.shield and self.shield[cur_pos_str]: if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str] allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions: for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action[1])
index = get_action_index_mapping(allowed_action.labels)
if index is None: if index is None:
assert(False) assert(False)
mask[index] = 1.0
mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
else: else:
for index, x in enumerate(mask): for index, x in enumerate(mask):
mask[index] = 1.0 mask[index] = 1.0

Loading…
Cancel
Save