Compare commits
merge into: sp:main
sp:add_velocity_into_framework
sp:main
pull from: sp:add_velocity_into_framework
sp:add_velocity_into_framework
sp:main
19 Commits
main
...
add_veloci
3 changed files with 1191 additions and 128 deletions
@ -0,0 +1,605 @@ |
|||
import sys |
|||
import operator |
|||
from copy import deepcopy |
|||
from os import listdir, system |
|||
import subprocess |
|||
import re |
|||
from collections import defaultdict |
|||
|
|||
from random import randrange |
|||
from ale_py import ALEInterface, SDL_SUPPORT, Action |
|||
from PIL import Image |
|||
from matplotlib import pyplot as plt |
|||
import cv2 |
|||
import pickle |
|||
import queue |
|||
from dataclasses import dataclass, field |
|||
|
|||
from sklearn.cluster import KMeans, DBSCAN |
|||
|
|||
from enum import Enum |
|||
|
|||
from copy import deepcopy |
|||
|
|||
import numpy as np |
|||
|
|||
import logging |
|||
logger = logging.getLogger(__name__) |
|||
|
|||
#import readchar |
|||
|
|||
from sample_factory.algo.utils.tensor_dict import TensorDict |
|||
from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper |
|||
|
|||
import time |
|||
|
|||
tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm" |
|||
rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin" |
|||
|
|||
def tic(): |
|||
import time |
|||
global startTime_for_tictoc |
|||
startTime_for_tictoc = time.time() |
|||
|
|||
def toc(): |
|||
import time |
|||
if 'startTime_for_tictoc' in globals(): |
|||
return time.time() - startTime_for_tictoc |
|||
|
|||
class Verdict(Enum): |
|||
INCONCLUSIVE = 1 |
|||
GOOD = 2 |
|||
BAD = 3 |
|||
|
|||
verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"} |
|||
|
|||
def convert(tuples): |
|||
return dict(tuples) |
|||
|
|||
@dataclass(frozen=True) |
|||
class State: |
|||
x: int |
|||
y: int |
|||
ski_position: int |
|||
velocity: int |
|||
def default_value(): |
|||
return {'action' : None, 'choiceValue' : None} |
|||
|
|||
@dataclass(frozen=True) |
|||
class StateValue: |
|||
ranking: float |
|||
choices: dict = field(default_factory=default_value) |
|||
|
|||
@dataclass(frozen=False) |
|||
class TestResult: |
|||
init_check_pes_min: float |
|||
init_check_pes_max: float |
|||
init_check_pes_avg: float |
|||
init_check_opt_min: float |
|||
init_check_opt_max: float |
|||
init_check_opt_avg: float |
|||
safe_states: int |
|||
unsafe_states: int |
|||
safe_cluster: int |
|||
unsafe_cluster: int |
|||
good_verdicts: int |
|||
bad_verdicts: int |
|||
policy_queries: int |
|||
def __str__(self): |
|||
return f"""Test Result: |
|||
init_check_pes_min: {self.init_check_pes_min} |
|||
init_check_pes_max: {self.init_check_pes_max} |
|||
init_check_pes_avg: {self.init_check_pes_avg} |
|||
init_check_opt_min: {self.init_check_opt_min} |
|||
init_check_opt_max: {self.init_check_opt_max} |
|||
init_check_opt_avg: {self.init_check_opt_avg} |
|||
""" |
|||
@staticmethod |
|||
def csv_header(ws=" "): |
|||
string = f"pesmin{ws}pesmax{ws}pesavg{ws}" |
|||
string += f"optmin{ws}optmax{ws}optavg{ws}" |
|||
string += f"sState{ws}uState{ws}" |
|||
string += f"sClust{ws}uClust{ws}" |
|||
string += f"gVerd{ws}bVerd{ws}queries" |
|||
return string |
|||
|
|||
def csv(self): |
|||
ws = " " |
|||
string = f"{self.init_check_pes_min:0.04f}{ws}{self.init_check_pes_max:0.04f}{ws}{self.init_check_pes_avg:0.04f}{ws}" |
|||
string += f"{self.init_check_opt_min:0.04f}{ws}{self.init_check_opt_max:0.04f}{ws}{self.init_check_opt_avg:0.04f}{ws}" |
|||
ws = "\t" |
|||
string += f"{self.safe_states}{ws}{self.unsafe_states}{ws}" |
|||
string += f"{self.safe_cluster}{ws}{self.unsafe_cluster}{ws}" |
|||
string += f"{self.good_verdicts}{ws}{self.bad_verdicts}{ws}{self.policy_queries}" |
|||
return string |
|||
|
|||
|
|||
def exec(command,verbose=True): |
|||
if verbose: print(f"Executing {command}") |
|||
system(f"echo {command} >> list_of_exec") |
|||
return system(command) |
|||
|
|||
num_tests_per_cluster = 50 |
|||
#factor_tests_per_cluster = 0.2 |
|||
num_ski_positions = 8 |
|||
num_velocities = 5 |
|||
|
|||
def input_to_action(char): |
|||
if char == "0": |
|||
return Action.NOOP |
|||
if char == "1": |
|||
return Action.RIGHT |
|||
if char == "2": |
|||
return Action.LEFT |
|||
if char == "3": |
|||
return "reset" |
|||
if char == "4": |
|||
return "set_x" |
|||
if char == "5": |
|||
return "set_vel" |
|||
if char in ["w", "a", "s", "d"]: |
|||
return char |
|||
|
|||
def saveObservations(observations, verdict, testDir): |
|||
testDir = f"images/testing_{experiment_id}/{verdict.name}_{testDir}_{len(observations)}" |
|||
if len(observations) < 20: |
|||
logger.warn(f"Potentially spurious test case for {testDir}") |
|||
testDir = f"{testDir}_pot_spurious" |
|||
exec(f"mkdir {testDir}", verbose=False) |
|||
for i, obs in enumerate(observations): |
|||
img = Image.fromarray(obs) |
|||
img.save(f"{testDir}/{i:003}.png") |
|||
|
|||
ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) } |
|||
|
|||
def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50): |
|||
#print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="") |
|||
testDir = f"{x}_{y}_{ski_position}_{velocity}" |
|||
try: |
|||
for i, r in enumerate(ramDICT[y]): |
|||
ale.setRAM(i,r) |
|||
ski_position_setting = ski_position_counter[ski_position] |
|||
for i in range(0,ski_position_setting[1]): |
|||
ale.act(ski_position_setting[0]) |
|||
ale.setRAM(14,0) |
|||
ale.setRAM(25,x) |
|||
ale.setRAM(14,180) # TODO |
|||
except Exception as e: |
|||
print(e) |
|||
logger.warn(f"Could not run test for x: {x}, y: {y}, ski_position: {ski_position}, velocity: {velocity}") |
|||
return (Verdict.INCONCLUSIVE, 0) |
|||
|
|||
num_queries = 0 |
|||
all_obs = list() |
|||
speed_list = list() |
|||
resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA) |
|||
for i in range(0,4): |
|||
all_obs.append(resized_obs) |
|||
for i in range(0,duration-4): |
|||
resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA) |
|||
all_obs.append(resized_obs) |
|||
if i % 4 == 0: |
|||
stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])}) |
|||
action = nn_wrapper.query(stack_tensor) |
|||
num_queries += 1 |
|||
ale.act(input_to_action(str(action))) |
|||
else: |
|||
ale.act(input_to_action(str(action))) |
|||
speed_list.append(ale.getRAM()[14]) |
|||
if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0: |
|||
#saveObservations(all_obs, Verdict.BAD, testDir) |
|||
return (Verdict.BAD, num_queries) |
|||
#saveObservations(all_obs, Verdict.GOOD, testDir) |
|||
return (Verdict.GOOD, num_queries) |
|||
|
|||
def skiPositionFormulaList(name): |
|||
formulas = list() |
|||
for i in range(1, num_ski_positions+1): |
|||
formulas.append(f"\"{name}_{i}\"") |
|||
return createBalancedDisjunction(formulas) |
|||
|
|||
|
|||
def computeStateRanking(mdp_file, iteration): |
|||
logger.info("Computing state ranking") |
|||
tic() |
|||
prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" |
|||
prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" |
|||
prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" |
|||
prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" |
|||
prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" |
|||
prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" |
|||
prop += 'Rmax=? [C <= 200]' |
|||
results = list() |
|||
try: |
|||
command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop '{prop}'" |
|||
output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n') |
|||
num_states = 0 |
|||
for line in output: |
|||
#print(line) |
|||
if "States:" in line: |
|||
num_states = int(line.split(" ")[-1]) |
|||
if "Result" in line and not len(results) >= 6: |
|||
range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line) |
|||
if range_value: |
|||
results.append(float(range_value.group(2))) |
|||
results.append(float(range_value.group(3))) |
|||
else: |
|||
value = re.search(r"(.*:)(.*)", line) |
|||
results.append(float(value.group(2))) |
|||
exec(f"mv action_ranking action_ranking_{iteration:03}") |
|||
except subprocess.CalledProcessError as e: |
|||
# todo die gracefully if ranking is uniform |
|||
print(e.output) |
|||
logger.info(f"Computing state ranking - DONE: took {toc()} seconds") |
|||
return TestResult(*tuple(results),0,0,0,0,0,0,0), num_states |
|||
|
|||
def fillStateRanking(file_name, match=""): |
|||
logger.info(f"Parsing state ranking, {file_name}") |
|||
tic() |
|||
state_ranking = dict() |
|||
try: |
|||
with open(file_name, "r") as f: |
|||
file_content = f.readlines() |
|||
for line in file_content: |
|||
if not "move=0" in line: continue |
|||
ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:","")) |
|||
if ranking_value <= 0.1: |
|||
continue |
|||
stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line)) |
|||
choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line)) |
|||
choices = {key:float(value) for (key,value) in choices.items()} |
|||
state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]), int(stateMapping["velocity"])//2) |
|||
value = StateValue(ranking_value, choices) |
|||
state_ranking[state] = value |
|||
logger.info(f"Parsing state ranking - DONE: took {toc()} seconds") |
|||
return state_ranking |
|||
except EnvironmentError: |
|||
print("Ranking file not available. Exiting.") |
|||
toc() |
|||
sys.exit(-1) |
|||
except: |
|||
toc() |
|||
|
|||
def createDisjunction(formulas): |
|||
return " | ".join(formulas) |
|||
|
|||
def statesFormulaTrimmed(states, name): |
|||
#states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster] |
|||
skiPositionGroup = defaultdict(list) |
|||
for item in states: |
|||
skiPositionGroup[item[2]].append(item) |
|||
|
|||
formulas = list() |
|||
for skiPosition, skiPos_group in skiPositionGroup.items(): |
|||
formula = f"formula {name}_{skiPosition} = ( ski_position={skiPosition} & " |
|||
#print(f"{name} ski_pos:{skiPosition}") |
|||
velocityGroup = defaultdict(list) |
|||
velocityFormulas = list() |
|||
for item in skiPos_group: |
|||
velocityGroup[item[3]].append(item) |
|||
for velocity, velocity_group in velocityGroup.items(): |
|||
#print(f"\tvel:{velocity}") |
|||
formulasPerSkiPosition = list() |
|||
yPosGroup = defaultdict(list) |
|||
yFormulas = list() |
|||
for item in velocity_group: |
|||
yPosGroup[item[1]].append(item) |
|||
for y, y_group in yPosGroup.items(): |
|||
#print(f"\t\ty:{y}") |
|||
sorted_y_group = sorted(y_group, key=lambda s: s[0]) |
|||
current_x_min = sorted_y_group[0][0] |
|||
current_x = sorted_y_group[0][0] |
|||
x_ranges = list() |
|||
for state in sorted_y_group[1:-1]: |
|||
if state[0] - current_x == 1: |
|||
current_x = state[0] |
|||
else: |
|||
x_ranges.append(f" ({current_x_min}<=x&x<={current_x})") |
|||
current_x_min = state[0] |
|||
current_x = state[0] |
|||
x_ranges.append(f" {current_x_min}<=x&x<={sorted_y_group[-1][0]}") |
|||
yFormulas.append(f" (y={y} & {createBalancedDisjunction(x_ranges)})") |
|||
#x_ranges.clear() |
|||
|
|||
#velocityFormulas.append(f"(velocity={velocity} & {createBalancedDisjunction(yFormulas)})") |
|||
velocityFormulas.append(f"({createBalancedDisjunction(yFormulas)})") |
|||
#yFormulas.clear() |
|||
formula += createBalancedDisjunction(velocityFormulas) + ");" |
|||
#velocityFormulas.clear() |
|||
formulas.append(formula) |
|||
for i in range(1, num_ski_positions+1): |
|||
if i in skiPositionGroup: |
|||
continue |
|||
formulas.append(f"formula {name}_{i} = false;") |
|||
return "\n".join(formulas) + "\n" |
|||
|
|||
# https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list |
|||
def pairwise(iterable): |
|||
"s -> (s0, s1), (s2, s3), (s4, s5), ..." |
|||
a = iter(iterable) |
|||
return zip(a, a) |
|||
|
|||
def createBalancedDisjunction(formulas): |
|||
if len(formulas) == 0: |
|||
return "false" |
|||
while len(formulas) > 1: |
|||
formulas_tmp = [f"({f} | {g})" for f,g in pairwise(formulas)] |
|||
if len(formulas) % 2 == 1: |
|||
formulas_tmp.append(formulas[-1]) |
|||
formulas = formulas_tmp |
|||
return " ".join(formulas) |
|||
|
|||
def updatePrismFile(newFile, iteration, safeStates, unsafeStates): |
|||
logger.info("Creating next prism file") |
|||
tic() |
|||
initFile = f"{newFile}_no_formulas.prism" |
|||
newFile = f"{newFile}_{iteration:03}.prism" |
|||
exec(f"cp {initFile} {newFile}", verbose=False) |
|||
with open(newFile, "a") as prism: |
|||
prism.write(statesFormulaTrimmed(safeStates, "Safe")) |
|||
prism.write(statesFormulaTrimmed(unsafeStates, "Unsafe")) |
|||
for i in range(1,num_ski_positions+1): |
|||
prism.write(f"label \"Safe_{i}\" = Safe_{i};\n") |
|||
prism.write(f"label \"Unsafe_{i}\" = Unsafe_{i};\n") |
|||
|
|||
logger.info(f"Creating next prism file - DONE: took {toc()} seconds") |
|||
|
|||
|
|||
ale = ALEInterface() |
|||
|
|||
|
|||
#if SDL_SUPPORT: |
|||
# ale.setBool("sound", True) |
|||
# ale.setBool("display_screen", True) |
|||
|
|||
# Load the ROM file |
|||
ale.loadROM(rom_file) |
|||
|
|||
with open('all_positions_v2.pickle', 'rb') as handle: |
|||
ramDICT = pickle.load(handle) |
|||
y_ram_setting = 60 |
|||
x = 70 |
|||
|
|||
|
|||
nn_wrapper = SampleFactoryNNQueryWrapper() |
|||
|
|||
experiment_id = int(time.time()) |
|||
init_mdp = "velocity_safety" |
|||
exec(f"mkdir -p images/testing_{experiment_id}", verbose=False) |
|||
|
|||
|
|||
imagesDir = f"images/testing_{experiment_id}" |
|||
|
|||
def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0, markerSize=1, drawCircle=False): |
|||
#markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)} |
|||
markerList = {(ski_position, velocity):list() for velocity in range(0, num_velocities) for ski_position in range(1,num_ski_positions + 1)} |
|||
images = dict() |
|||
mergedImages = dict() |
|||
for ski_position in range(1, num_ski_positions + 1): |
|||
for velocity in range(0,num_velocities): |
|||
images[(ski_position, velocity)] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png") |
|||
mergedImages[ski_position] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png") |
|||
for state in states: |
|||
s = state[0] |
|||
marker = [color, alpha_factor * state[1].ranking, (s.x-markerSize, s.y-markerSize), (s.x+markerSize, s.y+markerSize)] |
|||
markerList[(s.ski_position, s.velocity)].append(marker) |
|||
for (pos, vel), marker in markerList.items(): |
|||
if len(marker) == 0: continue |
|||
if drawCircle: |
|||
for m in marker: |
|||
images[(pos,vel)] = cv2.circle(images[(pos,vel)], m[2], 1, m[0], thickness=-1) |
|||
mergedImages[pos] = cv2.circle(mergedImages[pos], m[2], 1, m[0], thickness=-1) |
|||
else: |
|||
for m in marker: |
|||
images[(pos,vel)] = cv2.rectangle(images[(pos,vel)], m[2], m[3], m[0], cv2.FILLED) |
|||
mergedImages[pos] = cv2.rectangle(mergedImages[pos], m[2], m[3], m[0], cv2.FILLED) |
|||
for (ski_position, velocity), image in images.items(): |
|||
cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png", image) |
|||
for ski_position, image in mergedImages.items(): |
|||
cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png", image) |
|||
|
|||
|
|||
def concatImages(prefix, iteration): |
|||
logger.info(f"Concatenating images") |
|||
images = [f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png" for vel in range(0,num_velocities) for pos in range(1,num_ski_positions+1)] |
|||
mergedImages = [f"{imagesDir}/{prefix}_{pos:02}_individual.png" for pos in range(1,num_ski_positions+1)] |
|||
for vel in range(0, num_velocities): |
|||
for pos in range(1, num_ski_positions + 1): |
|||
command = f"convert {imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png " |
|||
command += f"-pointsize 10 -gravity NorthEast -annotate +8+0 'p{pos:02}v{vel:02}' " |
|||
command += f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png" |
|||
exec(command, verbose=False) |
|||
exec(f"montage {' '.join(images)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration:03}.png", verbose=False) |
|||
exec(f"montage {' '.join(mergedImages)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration:03}_merged.png", verbose=False) |
|||
#exec(f"sxiv {imagesDir}/{prefix}_{iteration}.png&", verbose=False) |
|||
logger.info(f"Concatenating images - DONE") |
|||
|
|||
def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0): |
|||
""" |
|||
Useful to draw a set of states, e.g. a single cluster |
|||
TODO |
|||
markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()} |
|||
logger.info(f"Drawing {len(states)} states onto {target}") |
|||
tic() |
|||
for state in states: |
|||
s = state[0] |
|||
marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '" |
|||
markerList[s.ski_position].append(marker) |
|||
for pos, marker in markerList.items(): |
|||
command = f"convert {source} {' '.join(marker)} {imagesDir}/{target}_{pos:02}_individual.png" |
|||
exec(command, verbose=False) |
|||
exec(f"montage {imagesDir}/{target}_*_individual.png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False) |
|||
logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds") |
|||
""" |
|||
|
|||
def drawClusters(clusterDict, target, iteration, alpha_factor=1.0): |
|||
logger.info(f"Drawing {len(clusterDict)} clusters") |
|||
tic() |
|||
for _, clusterStates in clusterDict.items(): |
|||
color = (np.random.choice(range(256)), np.random.choice(range(256)), np.random.choice(range(256))) |
|||
color = (int(color[0]), int(color[1]), int(color[2])) |
|||
drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor) |
|||
concatImages(target, iteration) |
|||
logger.info(f"Drawing {len(clusterDict)} clusters - DONE: took {toc()} seconds") |
|||
|
|||
def drawResult(clusterDict, target, iteration, drawnCluster=set()): |
|||
logger.info(f"Drawing {len(clusterDict)} results") |
|||
tic() |
|||
for id, (clusterStates, result) in clusterDict.items(): |
|||
if id in drawnCluster: continue |
|||
# opencv wants BGR |
|||
color = (100,100,100) |
|||
if result == Verdict.GOOD: |
|||
color = (0,200,0) |
|||
elif result == Verdict.BAD: |
|||
color = (0,0,200) |
|||
drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7) |
|||
logger.info(f"Drawing {len(clusterDict)} results - DONE: took {toc()} seconds") |
|||
|
|||
def _init_logger(): |
|||
logger = logging.getLogger('main') |
|||
logger.setLevel(logging.INFO) |
|||
handler = logging.StreamHandler(sys.stdout) |
|||
formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s') |
|||
handler.setFormatter(formatter) |
|||
logger.addHandler(handler) |
|||
|
|||
def clusterImportantStates(ranking, iteration): |
|||
logger.info(f"Starting to cluster {len(ranking)} states into clusters") |
|||
tic() |
|||
states = [[s[0].x,s[0].y, s[0].ski_position * 20, s[0].velocity * 20, s[1].ranking] for s in ranking] |
|||
#states = [[s[0].x,s[0].y, s[0].ski_position * 30, s[1].ranking] for s in ranking] |
|||
kmeans = KMeans(len(states) // 15, random_state=0, n_init="auto").fit(states) |
|||
#dbscan = DBSCAN(eps=5).fit(states) |
|||
#labels = dbscan.labels_ |
|||
labels = kmeans.labels_ |
|||
n_clusters = len(set(labels)) - (1 if -1 in labels else 0) |
|||
logger.info(f"Starting to cluster {len(ranking)} states into clusters - DONE: took {toc()} seconds with {n_clusters} cluster") |
|||
clusterDict = {i : list() for i in range(0,n_clusters)} |
|||
strayStates = list() |
|||
for i, state in enumerate(ranking): |
|||
if labels[i] == -1: |
|||
clusterDict[n_clusters + len(strayStates) + 1] = list() |
|||
clusterDict[n_clusters + len(strayStates) + 1].append(state) |
|||
strayStates.append(state) |
|||
continue |
|||
clusterDict[labels[i]].append(state) |
|||
if len(strayStates) > 0: logger.warning(f"{len(strayStates)} stray states with label -1") |
|||
#drawClusters(clusterDict, f"clusters", iteration) |
|||
return clusterDict |
|||
|
|||
|
|||
def run_experiment(factor_tests_per_cluster): |
|||
logger.info("Starting") |
|||
num_queries = 0 |
|||
|
|||
source = "images/1_full_scaled_down.png" |
|||
for ski_position in range(1, num_ski_positions + 1): |
|||
for velocity in range(0,num_velocities): |
|||
exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_{velocity:02}_individual.png", verbose=False) |
|||
exec(f"cp {source} {imagesDir}/result_{ski_position:02}_{velocity:02}_individual.png", verbose=False) |
|||
exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_individual.png", verbose=False) |
|||
exec(f"cp {source} {imagesDir}/result_{ski_position:02}_individual.png", verbose=False) |
|||
|
|||
goodVerdicts = 0 |
|||
badVerdicts = 0 |
|||
goodVerdictTestCases = list() |
|||
badVerdictTestCases = list() |
|||
safeClusters = 0 |
|||
unsafeClusters = 0 |
|||
safeStates = set() |
|||
unsafeStates = set() |
|||
iteration = 0 |
|||
results = list() |
|||
|
|||
eps = 0.1 |
|||
updatePrismFile(init_mdp, iteration, set(), set()) |
|||
#modelCheckingResult, numStates = TestResult(0,0,0,0,0,0,0,0,0,0,0,0,0), 10 |
|||
modelCheckingResult, numStates = computeStateRanking(f"{init_mdp}_000.prism", iteration) |
|||
results.append(modelCheckingResult) |
|||
ranking = fillStateRanking(f"action_ranking_000") |
|||
|
|||
sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking) |
|||
try: |
|||
clusters = clusterImportantStates(sorted_ranking, iteration) |
|||
except Exception as e: |
|||
print(e) |
|||
sys.exit(-1) |
|||
|
|||
clusterResult = dict() |
|||
logger.info(f"Running tests") |
|||
tic() |
|||
num_cluster_tested = 0 |
|||
iteration = 0 |
|||
drawnCluster = set() |
|||
for id, cluster in clusters.items(): |
|||
num_tests = int(factor_tests_per_cluster * len(cluster)) |
|||
if num_tests == 0: num_tests = 1 |
|||
logger.info(f"Testing {num_tests} states (from {len(cluster)} states) from cluster {id}") |
|||
randomStates = np.random.choice(len(cluster), num_tests, replace=False) |
|||
randomStates = [cluster[i] for i in randomStates] |
|||
|
|||
verdictGood = True |
|||
for state in randomStates: |
|||
x = state[0].x |
|||
y = state[0].y |
|||
ski_pos = state[0].ski_position |
|||
velocity = state[0].velocity |
|||
result, num_queries_this_test_case = run_single_test(ale,nn_wrapper,x,y,ski_pos, velocity, duration=50) |
|||
num_queries += num_queries_this_test_case |
|||
if result == Verdict.BAD: |
|||
clusterResult[id] = (cluster, Verdict.BAD) |
|||
verdictGood = False |
|||
unsafeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster]) |
|||
badVerdicts += 1 |
|||
badVerdictTestCases.append(state) |
|||
|
|||
elif result == Verdict.GOOD: |
|||
goodVerdicts += 1 |
|||
goodVerdictTestCases.append(state) |
|||
if verdictGood: |
|||
clusterResult[id] = (cluster, Verdict.GOOD) |
|||
safeClusters += 1 |
|||
safeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster]) |
|||
else: |
|||
unsafeClusters += 1 |
|||
results[-1].safe_states = len(safeStates) |
|||
results[-1].unsafe_states = len(unsafeStates) |
|||
results[-1].policy_queries = num_queries |
|||
results[-1].safe_cluster = safeClusters |
|||
results[-1].unsafe_cluster = unsafeClusters |
|||
results[-1].good_verdicts = goodVerdicts |
|||
results[-1].bad_verdicts = badVerdicts |
|||
num_cluster_tested += 1 |
|||
if num_cluster_tested % (len(clusters)//20) == 0: |
|||
iteration += 1 |
|||
logger.info(f"Tested Cluster: {num_cluster_tested:03}\tSafe Cluster States : {len(safeStates)}({safeClusters}/{len(clusters)})\tUnsafe Cluster States:{len(unsafeStates)}({unsafeClusters}/{len(clusters)})\tGood Test Cases:{goodVerdicts}\tFailing Test Cases:{badVerdicts}\t{len(safeStates)/len(unsafeStates)} - {goodVerdicts/badVerdicts}") |
|||
drawResult(clusterResult, "result", iteration, drawnCluster) |
|||
drawOntoSkiPosImage(goodVerdictTestCases, (10,255,50), "result", alpha_factor=0.7, markerSize=0, drawCircle=True) |
|||
drawOntoSkiPosImage(badVerdictTestCases, (0,0,0), "result", alpha_factor=0.7, markerSize=0, drawCircle=True) |
|||
concatImages("result", iteration) |
|||
drawnCluster.update(clusterResult.keys()) |
|||
#updatePrismFile(init_mdp, iteration, safeStates, unsafeStates) |
|||
#modelCheckingResult, numStates = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration) |
|||
results.append(deepcopy(modelCheckingResult)) |
|||
logger.info(f"Model Checking Result: {modelCheckingResult}") |
|||
# Account for self-loop states after first iteration |
|||
if iteration > 0: |
|||
results[-1].init_check_pes_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_pes_avg*numStates + 1.0*results[-2].unsafe_states + 0.0*results[-2].safe_states) |
|||
results[-1].init_check_opt_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_opt_avg*numStates + 0.0*results[-2].unsafe_states + 1.0*results[-2].safe_states) |
|||
print(TestResult.csv_header()) |
|||
for result in results[:-1]: |
|||
print(result.csv()) |
|||
|
|||
|
|||
with open(f"data_new_method_{factor_tests_per_cluster}", "w") as f: |
|||
f.write(TestResult.csv_header() + "\n") |
|||
for result in results[:-1]: |
|||
f.write(result.csv() + "\n") |
|||
|
|||
|
|||
_init_logger() |
|||
logger = logging.getLogger('main') |
|||
if __name__ == '__main__': |
|||
for factor_tests_per_cluster in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]: |
|||
run_experiment(factor_tests_per_cluster) |
@ -0,0 +1,137 @@ |
|||
mdp |
|||
|
|||
const int initY = 40; |
|||
const int initX = 80; |
|||
|
|||
const int maxY = 580; |
|||
//const int maxY = 360; |
|||
//const int maxY = 200; |
|||
const int minX = 10; |
|||
const int maxX = 152; |
|||
const int maxVel = 8; |
|||
|
|||
|
|||
formula Gate_1 = (((42<x & x<50) | (74<x & x<82)) & 164<y & y<172); |
|||
formula Gate_2 = (((72<x & x<80) | (104<x & x<112)) & 256<y & y<264); |
|||
formula Gate_3 = (((80<x & x<88) | (112<x & x<120)) & 349<y & y<357); |
|||
formula Gate_4 = (((54<x & x<62) | (88<x & x<96)) & 442<y & y<450); |
|||
formula Gate_5 = (((80<x & x<88) | (112<x & x<120)) & 530<y & y<538); |
|||
|
|||
formula S_Gate_1 = (((32<x & x<60) | (64<x & x<92)) & 124<y & y<172); |
|||
formula S_Gate_2 = (((62<x & x<90) | (94<x & x<132)) & 216<y & y<264); |
|||
formula S_Gate_3 = (((70<x & x<98) | (102<x & x<130)) & 309<y & y<357); |
|||
formula S_Gate_4 = (((44<x & x<72) | (78<x & x<106)) & 402<y & y<450); |
|||
formula S_Gate_5 = (((70<x & x<98) | (102<x & x<130)) & 490<y & y<538); |
|||
|
|||
|
|||
formula Tree_1 = ((x>=124 & x<=142) & (y>=190 & y<=200)); |
|||
formula Tree_2 = ((x>=32 & x<=49) & (y>=284 & y<=295)); |
|||
formula Tree_3 = ((x>=30 & x<=49) & (y>=317 & y<=327)); |
|||
formula Tree_4 = ((x>=12 & x<=30) & (y>=408 & y<=418)); |
|||
formula Tree_5 = ((x>=129 & x<=146) & (y>=468 & y<=480)); |
|||
formula Tree_6 = ((x>=140 & x<=152) & (y>=496 & y<=510)); |
|||
|
|||
formula S_Tree_1 = ((x>=114 & x<=152) & (y>=150 & y<=200)); |
|||
formula S_Tree_2 = ((x>=22 & x<=59) & (y>=244 & y<=295)); |
|||
formula S_Tree_3 = ((x>=20 & x<=59) & (y>=277 & y<=327)); |
|||
formula S_Tree_4 = ((x>=2 & x<=40) & (y>=368 & y<=418)); |
|||
formula S_Tree_5 = ((x>=119 & x<=156) & (y>=438 & y<=480)); |
|||
formula S_Tree_6 = ((x>=130 & x<=162) & (y>=456 & y<=510)); |
|||
|
|||
formula Hit_Tree = Tree_1 | Tree_2 | Tree_3 | Tree_4 | Tree_5 | Tree_6; |
|||
formula Hit_Gate = Gate_1 | Gate_2 | Gate_3 | Gate_4 | Gate_5; |
|||
formula S_Hit_Tree = S_Tree_1 | S_Tree_2 | S_Tree_3 | S_Tree_4 | S_Tree_5 | S_Tree_6; |
|||
formula S_Hit_Gate = S_Gate_1 | S_Gate_2 | S_Gate_3 | S_Gate_4 | S_Gate_5; |
|||
|
|||
formula Safe = ( (Safe_1 | Safe_2) | (Safe_3 | Safe_4) ) | ( (Safe_5 | Safe_6) | (Safe_7 | Safe_8) ); |
|||
formula Unsafe = ( (Unsafe_1 | Unsafe_2) | (Unsafe_3 | Unsafe_4) ) | ( (Unsafe_5 | Unsafe_6) | (Unsafe_7 | Unsafe_8) ); |
|||
|
|||
|
|||
label "Hit_Tree" = Hit_Tree; |
|||
label "Hit_Gate" = Hit_Gate; |
|||
label "S_Hit_Tree" = S_Hit_Tree; |
|||
label "S_Hit_Gate" = S_Hit_Gate; |
|||
|
|||
|
|||
global move : [0..3]; |
|||
|
|||
|
|||
|
|||
module skier |
|||
ski_position : [1..8] init 4; |
|||
reward_given: bool init false; |
|||
//done: bool init false; |
|||
|
|||
|
|||
[left] !reward_given & !Safe & !Unsafe & !Hit_Gate & !Hit_Tree & move=0 & ski_position>1 -> (ski_position'=ski_position-1) & (move'=1); |
|||
[right] !reward_given & !Safe & !Unsafe & !Hit_Gate & !Hit_Tree & move=0 & ski_position<8 -> (ski_position'=ski_position+1) & (move'=1); |
|||
[noop] !reward_given & !Safe & !Unsafe & !Hit_Gate & !Hit_Tree & move=0 -> (move'=1); |
|||
|
|||
|
|||
[done] !reward_given & (Hit_Tree | Hit_Gate | Safe | Unsafe) & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Unsafe_1 & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Unsafe_2 & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Unsafe_3 & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Unsafe_4 & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Unsafe_5 & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Unsafe_6 & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Unsafe_7 & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Unsafe_8 & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Safe_1 & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Safe_2 & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Safe_3 & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Safe_4 & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Safe_5 & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Safe_6 & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Safe_7 & move=0 -> (reward_given'=true); |
|||
//[done] !reward_given & Safe_8 & move=0 -> (reward_given'=true); |
|||
|
|||
|
|||
endmodule |
|||
|
|||
module updateY |
|||
y : [initY..maxY] ; |
|||
|
|||
velocity: [0..16]; |
|||
standstill : [0..8] ; |
|||
[update_y] move=1 & standstill>=5 -> (y'=y) & (move'=2); |
|||
[update_y] move=1 & standstill<5 -> (y'=min(maxY,y+velocity)) & (move'=2); |
|||
|
|||
[update_y] move=2 & (ski_position=1 | ski_position = 8) & standstill>=5 -> (standstill'=min(8,standstill+1)) & (move'=3); |
|||
[update_y] move=2 & (ski_position=1 | ski_position = 8) & standstill<5 -> (velocity'=max(0,velocity-4)) &(move'=3); |
|||
[update_y] move=2 & (ski_position=2 | ski_position = 7) -> (velocity'=max(0 ,velocity-2)) & (standstill'=0) & (move'=3); |
|||
[update_y] move=2 & (ski_position=3 | ski_position = 6) -> (velocity'=min(maxVel,velocity+2)) & (standstill'=0) & (move'=3); |
|||
[update_y] move=2 & (ski_position=4 | ski_position = 5) -> (velocity'=min(maxVel,velocity+4)) & (standstill'=0) & (move'=3); |
|||
endmodule |
|||
|
|||
module updateX |
|||
x : [minX..maxX] init initX; |
|||
|
|||
[update_x] move=3 & standstill>=8 -> (move'=0); |
|||
[update_x] move=3 & standstill<8 & (ski_position=4 | ski_position=5) -> (move'=0); |
|||
|
|||
[update_x] move=3 & standstill<8 & (ski_position=3) -> 0.4: (x'=max(minX,x-0)) & (move'=0) + 0.6: (x'=max(minX,x-1)) & (move'=0); |
|||
[update_x] move=3 & standstill<8 & (ski_position=6) -> 0.4: (x'=min(maxX,x+0)) & (move'=0) + 0.6: (x'=min(maxX,x+1)) & (move'=0); |
|||
|
|||
[update_x] move=3 & standstill<8 & (ski_position=2) -> 0.3: (x'=max(minX,x-1)) & (move'=0) + 0.7: (x'=max(minX,x-2)) & (move'=0); |
|||
[update_x] move=3 & standstill<8 & (ski_position=7) -> 0.3: (x'=min(maxX,x+1)) & (move'=0) + 0.7: (x'=min(maxX,x+2)) & (move'=0); |
|||
|
|||
[update_x] move=3 & standstill<8 & (ski_position=1) -> 0.2: (x'=max(minX,x-2)) & (move'=0) + 0.8: (x'=max(minX,x-3)) & (move'=0); |
|||
[update_x] move=3 & standstill<8 & (ski_position=8) -> 0.2: (x'=min(maxX,x+2)) & (move'=0) + 0.8: (x'=min(maxX,x+3)) & (move'=0); |
|||
endmodule |
|||
|
|||
//rewards |
|||
// [left] !done & !reward_given & Hit_Tree : -100; |
|||
// [left] !done & !reward_given & Hit_Gate : -100; |
|||
// [left] !done & !reward_given & (Unsafe_1 | Unsafe_2 | Unsafe_3 | Unsafe_4 | Unsafe_5 | Unsafe_6 | Unsafe_7 | Unsafe_8) : -100; |
|||
// [right] !done & !reward_given & Hit_Tree : -100; |
|||
// [right] !done & !reward_given & Hit_Gate : -100; |
|||
// [right] !done & !reward_given & (Unsafe_1 | Unsafe_2 | Unsafe_3 | Unsafe_4 | Unsafe_5 | Unsafe_6 | Unsafe_7 | Unsafe_8) : -100; |
|||
// [noop] !done & !reward_given & Hit_Tree : -100; |
|||
// [noop] !done & !reward_given & Hit_Gate : -100; |
|||
// [noop] !done & !reward_given & (Unsafe_1 | Unsafe_2 | Unsafe_3 | Unsafe_4 | Unsafe_5 | Unsafe_6 | Unsafe_7 | Unsafe_8) : -100; |
|||
//endrewards |
|||
|
|||
rewards |
|||
[done] !reward_given & (Hit_Gate | Hit_Tree | Unsafe_1 | Unsafe_2 | Unsafe_3 | Unsafe_4 | Unsafe_5 | Unsafe_6 | Unsafe_7 | Unsafe_8) : -100; |
|||
endrewards |
Write
Preview
Loading…
Cancel
Save
Reference in new issue