Browse Source

minor changes

refactoring
Thomas Knoll 12 months ago
parent
commit
80cbbe5a3a
  1. 13
      examples/shields/rl/11_minigridrl.py
  2. 4
      examples/shields/rl/15_train_eval_tune.py
  3. 8
      examples/shields/rl/callbacks.py
  4. 13
      examples/shields/rl/helpers.py
  5. 15
      examples/shields/rl/shieldhandlers.py
  6. 19
      examples/shields/rl/wrappers.py

13
examples/shields/rl/11_minigridrl.py

@ -20,13 +20,15 @@ def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
args.grid_path = F"{args.grid_path}_{config.worker_index}_{args.prism_config}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}_{args.prism_config}.prism"
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula, args.shield_value, args.prism_config)
env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator,
shield_query_creator=create_shield_query,
mask_actions=args.shielding != ShieldingConfig.Disabled,
create_shield_at_reset=args.shield_creation_at_reset)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env)
@ -64,14 +66,17 @@ def ppo(args):
"type": TBXLogger,
"logdir": create_log_dir(args)
})
# .exploration(exploration_config={"exploration_fraction": 0.1})
.training(_enable_learner_api=False ,model={
"custom_model": "shielding_model"
}))
# config.entropy_coeff = 0.05
algo =(
config.build()
)
for i in range(args.evaluations):
result = algo.train()
print(pretty_print(result))

4
examples/shields/rl/15_train_eval_tune.py

@ -33,8 +33,8 @@ def shielding_env_creater(config):
prism_path=args.prism_path,
formula=args.formula)
env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding)
env = gym.make(name, randomize_start=True)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,

8
examples/shields/rl/callbacks.py

@ -45,6 +45,14 @@ class MyCallbacks(DefaultCallbacks):
env = base_env.get_sub_environments()[0]
# print(env.printGrid())
if hasattr(env, "adversaries"):
for adversary in env.adversaries.values():
if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]:
print(F"Adversary ran into agent. Adversary {adversary.cur_pos}, Agent {env.agent_pos}")
# assert False
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]

13
examples/shields/rl/helpers.py

@ -64,7 +64,7 @@ def extract_adversaries(env):
return adv
def create_log_dir(args):
return F"{args.log_dir}sh:{args.shielding}-env:{args.env}"
return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-env:{args.env}-conf:{args.prism_config}"
def test_name(args):
return F"{args.expname}"
@ -99,7 +99,7 @@ def parse_arguments(argparse):
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
parser.add_argument("--env",
help="gym environment to load",
default="MiniGrid-LavaCrossingS9N1-v0",
default="MiniGrid-LavaSlipperyS12-v2",
choices=[
"MiniGrid-Adv-8x8-v0",
"MiniGrid-AdvSimple-8x8-v0",
@ -128,14 +128,17 @@ def parse_arguments(argparse):
parser.add_argument("--prism_path", default="grid")
parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"])
parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--evaluations", type=int, default=10 )
# parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
parser.add_argument("--formula", default="Pmax=? [G !\"AgentRanIntoAdversary\"]")
parser.add_argument("--evaluations", type=int, default=30 )
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# parser.add_argument("--formula", default="<<Agent>> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]")
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
parser.add_argument("--steps", default=20_000, type=int)
parser.add_argument("--expname", default="exp")
parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
parser.add_argument("--prism_config", default=None)
parser.add_argument("--shield_value", default=0.9, type=float)
# parser.add_argument("--random_starts", default=1, type=int)
args = parser.parse_args()
return args

15
examples/shields/rl/shieldhandlers.py

@ -27,11 +27,13 @@ class ShieldHandler(ABC):
pass
class MiniGridShieldHandler(ShieldHandler):
def __init__(self, grid_file, grid_to_prism_path, prism_path, formula) -> None:
def __init__(self, grid_file, grid_to_prism_path, prism_path, formula, shield_value=0.9 ,prism_config=None) -> None:
self.grid_file = grid_file
self.grid_to_prism_path = grid_to_prism_path
self.prism_path = prism_path
self.formula = formula
self.prism_config = prism_config
self.shield_value = shield_value
def __export_grid_to_text(self, env):
f = open(self.grid_file, "w")
@ -40,7 +42,12 @@ class MiniGridShieldHandler(ShieldHandler):
def __create_prism(self):
result = os.system(F"{self.grid_to_prism_path} -v 'Agent,Blue' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
# result = os.system(F"{self.grid_to_prism_path} -v 'Agent,Blue' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
if self.prism_config is None:
result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path}")
else:
result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
# result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
assert result == 0, "Prism file could not be generated"
@ -51,7 +58,7 @@ class MiniGridShieldHandler(ShieldHandler):
def __create_shield_dict(self):
print(self.prism_path)
program = stormpy.parse_prism_program(self.prism_path)
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.9)
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, self.shield_value)
formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
@ -175,7 +182,7 @@ def create_shield_query(env):
if adversaries:
move_text = F"move=0\t& "
agent_position = F"xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}"
agent_position = F"& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}"
query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]"
return query

19
examples/shields/rl/wrappers.py

@ -102,15 +102,20 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
self.shield = shield_creator.create_shield(env=self.env)
self.mask_actions = mask_actions
self.shield_query_creator = shield_query_creator
print(F"Shielding is {self.mask_actions}")
def create_action_mask(self):
# print(F"{self.mask_actions} No shielding")
if not self.mask_actions:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
# print("No shielding")
ret = np.array([1.0] * self.max_available_actions, dtype=np.int8)
# print(ret)
return ret
cur_pos_str = self.shield_query_creator(self.env)
# print(F"Pos string {cur_pos_str}")
# print(F"Shield {list(self.shield.keys())[0]}")
# print(F"Is pos str in shield: {cur_pos_str in self.shield}")
# print(F"Is pos str in shield: {cur_pos_str in self.shield}, Position Str {cur_pos_str}")
# Create the mask
# If shield restricts action mask only valid with 1.0
# else set all actions as valid
@ -119,6 +124,9 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8)
has_allowed_actions = False
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set
if index is None:
@ -129,8 +137,14 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
if allowed_action.prob == 0 and allowed:
assert False
if allowed:
has_allowed_actions = True
mask[index] = allowed
# if not has_allowed_actions:
# print(F"No action allowed for pos string {cur_pos_str}")
# assert(False)
else:
for index, x in enumerate(mask):
mask[index] = 1.0
@ -186,6 +200,7 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper):
self.shield_creator = shield_creator
self.mask_actions = mask_actions
self.shield_query_creator = shield_query_creator
print(F"Shielding is {self.mask_actions}")
def create_action_mask(self):
if not self.mask_actions:

Loading…
Cancel
Save