You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

417 lines
16 KiB

import sys
import operator
from os import listdir, system
import re
from collections import defaultdict
from random import randrange
from ale_py import ALEInterface, SDL_SUPPORT, Action
#from colors import *
from PIL import Image
from matplotlib import pyplot as plt
import cv2
import pickle
import queue
from dataclasses import dataclass, field
from sklearn.cluster import KMeans
from enum import Enum
from copy import deepcopy
import numpy as np
import logging
logger = logging.getLogger(__name__)
#import readchar
from sample_factory.algo.utils.tensor_dict import TensorDict
from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
import time
tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm"
rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin"
def tic():
import time
global startTime_for_tictoc
startTime_for_tictoc = time.time()
def toc():
import time
if 'startTime_for_tictoc' in globals():
return time.time() - startTime_for_tictoc
class Verdict(Enum):
INCONCLUSIVE = 1
GOOD = 2
BAD = 3
verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"}
def convert(tuples):
return dict(tuples)
@dataclass(frozen=True)
class State:
x: int
y: int
ski_position: int
def default_value():
return {'action' : None, 'choiceValue' : None}
@dataclass(frozen=True)
class StateValue:
ranking: float
choices: dict = field(default_factory=default_value)
def exec(command,verbose=True):
if verbose: print(f"Executing {command}")
system(f"echo {command} >> list_of_exec")
return system(command)
num_tests_per_cluster = 50
factor_tests_per_cluster = 0.2
num_ski_positions = 8
def input_to_action(char):
if char == "0":
return Action.NOOP
if char == "1":
return Action.RIGHT
if char == "2":
return Action.LEFT
if char == "3":
return "reset"
if char == "4":
return "set_x"
if char == "5":
return "set_vel"
if char in ["w", "a", "s", "d"]:
return char
def drawImportantStates(important_states):
draw_commands = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list(), 9:list(), 10:list(), 11:list(), 12:list(), 13:list(), 14:list()}
for state in important_states:
x = state[0].x
y = state[0].y
markerSize = 2
ski_position = state[0].ski_position
draw_commands[ski_position].append(f"-fill 'rgba(255,204,0,{state[1].ranking})' -draw 'rectangle {x-markerSize},{y-markerSize} {x+markerSize},{y+markerSize} '")
for i in range(1,15):
command = f"convert images/1_full_scaled_down.png {' '.join(draw_commands[i])} first_try_{i:02}.png"
exec(command)
def saveObservations(observations, verdict, testDir):
testDir = f"images/testing_{experiment_id}/{verdict.name}_{testDir}_{len(observations)}"
if len(observations) < 20:
logger.warn(f"Potentially spurious test case for {testDir}")
testDir = f"{testDir}_pot_spurious"
exec(f"mkdir {testDir}", verbose=False)
for i, obs in enumerate(observations):
img = Image.fromarray(obs)
img.save(f"{testDir}/{i:003}.png")
ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=200):
#print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="")
testDir = f"{x}_{y}_{ski_position}"
for i, r in enumerate(ramDICT[y]):
ale.setRAM(i,r)
ski_position_setting = ski_position_counter[ski_position]
for i in range(0,ski_position_setting[1]):
ale.act(ski_position_setting[0])
ale.setRAM(14,0)
ale.setRAM(25,x)
ale.setRAM(14,180)
all_obs = list()
speed_list = list()
first_action_set = False
first_action = 0
for i in range(0,duration):
resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
for i in range(0,4):
all_obs.append(resized_obs)
if len(all_obs) >= 4:
stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
action = nn_wrapper.query(stack_tensor)
if not first_action_set:
first_action_set = True
first_action = input_to_action(str(action))
ale.act(input_to_action(str(action)))
else:
ale.act(Action.NOOP)
speed_list.append(ale.getRAM()[14])
if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0:
saveObservations(all_obs, Verdict.BAD, testDir)
return Verdict.BAD
saveObservations(all_obs, Verdict.GOOD, testDir)
return Verdict.GOOD
def optimalAction(choices):
return max(choices.items(), key=operator.itemgetter(1))[0]
def computeStateRanking(mdp_file):
logger.info("Computing state ranking")
tic()
command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --prop 'Rmax=? [C <= 1000]'"
exec(command)
logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
def fillStateRanking(file_name, match=""):
logger.info("Parsing state ranking")
tic()
state_ranking = dict()
try:
with open(file_name, "r") as f:
file_content = f.readlines()
for line in file_content:
if not "move=0" in line: continue
ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:",""))
if ranking_value <= 0.1:
continue
stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
#print("stateMapping", stateMapping)
choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line))
choices = {key:float(value) for (key,value) in choices.items()}
#print("choices", choices)
#print("ranking_value", ranking_value)
state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))
value = StateValue(ranking_value, choices)
state_ranking[state] = value
logger.info(f"Parsing state ranking - DONE: took {toc()} seconds")
return state_ranking
except EnvironmentError:
print("Ranking file not available. Exiting.")
toc()
sys.exit(1)
except:
toc()
def createDisjunction(formulas):
return " | ".join(formulas)
def clusterFormula(cluster):
formula = ""
states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster]
skiPositionGroup = defaultdict(list)
for item in states:
skiPositionGroup[item[2]].append(item)
first = True
for skiPosition, group in skiPositionGroup.items():
formula += f"ski_position={skiPosition} & ("
yPosGroup = defaultdict(list)
for item in group:
yPosGroup[item[1]].append(item)
for y, y_group in yPosGroup.items():
if first:
first = False
else:
formula += " | "
sorted_y_group = sorted(y_group, key=lambda s: s[0])
formula += f"( y={y} & ("
current_x_min = sorted_y_group[0][0]
current_x = sorted_y_group[0][0]
x_ranges = list()
for state in sorted_y_group[1:-1]:
if state[0] - current_x == 1:
current_x = state[0]
else:
x_ranges.append(f" ({current_x_min}<= x & x<={current_x})")
current_x_min = state[0]
current_x = state[0]
x_ranges.append(f" ({current_x_min}<= x & x<={sorted_y_group[-1][0]})")
formula += " | ".join(x_ranges)
formula += ") )"
formula += ")"
return formula
def createUnsafeFormula(clusters):
formulas = ""
disjunction = "formula Unsafe = false"
for i, cluster in enumerate(clusters):
formulas += f"formula Unsafe_{i} = {clusterFormula(cluster)};\n"
clusterFormula(cluster)
disjunction += f" | Unsafe_{i}"
disjunction += ";\n"
return formulas + "\n" + disjunction
def createSafeFormula(clusters):
formulas = ""
disjunction = "formula Safe = false"
for i, cluster in enumerate(clusters):
formulas += f"formula Safe_{i} = {clusterFormula(cluster)};\n"
disjunction += f" | Safe_{i}"
disjunction += ";\n"
return formulas + "\n" + disjunction
def updatePrismFile(newFile, iteration, safeStates, unsafeStates):
logger.info("Creating next prism file")
tic()
initFile = f"{newFile}_no_formulas.prism"
newFile = f"{newFile}_{iteration:03}.prism"
exec(f"cp {initFile} {newFile}", verbose=False)
with open(newFile, "a") as prism:
prism.write(createSafeFormula(safeStates))
prism.write(createUnsafeFormula(unsafeStates))
logger.info(f"Creating next prism file - DONE: took {toc()} seconds")
ale = ALEInterface()
#if SDL_SUPPORT:
# ale.setBool("sound", True)
# ale.setBool("display_screen", True)
# Load the ROM file
ale.loadROM(rom_file)
with open('all_positions_v2.pickle', 'rb') as handle:
ramDICT = pickle.load(handle)
y_ram_setting = 60
x = 70
nn_wrapper = SampleFactoryNNQueryWrapper()
iteration = 0
experiment_id = int(time.time())
init_mdp = "velocity_safety"
exec(f"mkdir -p images/testing_{experiment_id}", verbose=False)
#exec(f"cp 1_full_scaled_down.png images/testing_{experiment_id}/testing_0000.png")
#exec(f"cp {init_mdp}.prism {init_mdp}_000.prism")
markerSize = 1
#markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
imagesDir = f"images/testing_{experiment_id}"
def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0):
markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)}
for state in states:
s = state[0]
marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
markerList[s.ski_position].append(marker)
for pos, marker in markerList.items():
command = f"convert {imagesDir}/{target_prefix}_{pos:02}.png {' '.join(marker)} {imagesDir}/{target_prefix}_{pos:02}.png"
exec(command, verbose=False)
def concatImages(prefix, iteration):
exec(f"montage {imagesDir}/{prefix}_*png -geometry +0+0 -tile x1 {imagesDir}/{prefix}_{iteration}.png", verbose=False)
#exec(f"sxiv {imagesDir}/{prefix}.png&", verbose=False)
def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0):
"""
Useful to draw a set of states, e.g. a single cluster
"""
markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
logger.info(f"Drawing {len(states)} states onto {target}")
tic()
for state in states:
s = state[0]
marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
markerList[s.ski_position].append(marker)
for pos, marker in markerList.items():
command = f"convert {source} {' '.join(marker)} {imagesDir}/{target}_{pos:02}.png"
exec(command, verbose=False)
exec(f"montage {imagesDir}/{target}_*png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False)
logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds")
def drawClusters(clusterDict, target, iteration, alpha_factor=1.0):
for ski_position in range(1, num_ski_positions + 1):
source = "images/1_full_scaled_down.png"
exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png")
for _, clusterStates in clusterDict.items():
color = f"{np.random.choice(range(256))}, {np.random.choice(range(256))}, {np.random.choice(range(256))}"
drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor)
concatImages(target, iteration)
def drawResult(clusterDict, target, iteration):
for ski_position in range(1, num_ski_positions + 1):
source = "images/1_full_scaled_down.png"
exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png")
for _, (clusterStates, result) in clusterDict.items():
color = "100,100,100"
if result == Verdict.GOOD:
color = "0,200,0"
elif result == Verdict.BAD:
color = "200,0,0"
drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7)
concatImages(target, iteration)
def _init_logger():
logger = logging.getLogger('main')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
def clusterImportantStates(ranking, iteration, n_clusters=40):
logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster")
tic()
states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[1].ranking] for s in ranking]
kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states)
logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster - DONE: took {toc()} seconds")
clusterDict = {i : list() for i in range(0,n_clusters)}
for i, state in enumerate(ranking):
clusterDict[kmeans.labels_[i]].append(state)
drawClusters(clusterDict, f"clusters", iteration)
return clusterDict
if __name__ == '__main__':
_init_logger()
logger = logging.getLogger('main')
logger.info("Starting")
n_clusters = 2
testAll = False
safeStates = list()
unsafeStates = list()
iteration = 0
while True:
updatePrismFile(init_mdp, iteration, safeStates, unsafeStates)
#computeStateRanking(f"{init_mdp}_{iteration:03}.prism")
ranking = fillStateRanking("action_ranking")
sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
clusters = clusterImportantStates(sorted_ranking, iteration, n_clusters)
if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)}
clusterResult = dict()
for id, cluster in clusters.items():
num_tests = int(factor_tests_per_cluster * len(cluster))
num_tests = 1
logger.info(f"Testing {num_tests} states (from {len(cluster)} states) from cluster {id}")
randomStates = np.random.choice(len(cluster), num_tests, replace=False)
randomStates = [cluster[i] for i in randomStates]
verdictGood = True
for state in randomStates:
x = state[0].x
y = state[0].y
ski_pos = state[0].ski_position
result = run_single_test(ale,nn_wrapper,x,y,ski_pos, duration=50)
if result == Verdict.BAD:
if testAll:
failingPerCluster[id].append(state)
else:
clusterResult[id] = (cluster, Verdict.BAD)
verdictGood = False
unsafeStates.append(cluster)
break
if verdictGood:
clusterResult[id] = (cluster, Verdict.GOOD)
safeStates.append(cluster)
if testAll: drawClusters(failingPerCluster, f"failing")
drawResult(clusterResult, "result", iteration)
iteration += 1