sp
1 year ago
commit
8f7efe48db
3 changed files with 227 additions and 0 deletions
-
71first_try.prism
-
78test.py
-
78test_model.py
@ -0,0 +1,71 @@ |
|||
mdp |
|||
|
|||
const int initY = 40; |
|||
const int initX = 80; |
|||
|
|||
const int maxY = 240; |
|||
const int minX = 12; |
|||
const int maxX = 147; |
|||
|
|||
|
|||
formula HitTree = (122<x & x<139) & (220<y & y=232); |
|||
formula HitGate = ((x=51 | x=79) & y=164); |
|||
|
|||
formula PassedGates = (158<y & y<165 & 51<x & x<79); |
|||
|
|||
global move : [0..3] init 0; |
|||
|
|||
module skier |
|||
ski_position : [1..14] init 8; |
|||
done : bool init false; |
|||
|
|||
[left] !done & move=0 & ski_position>1 -> (ski_position'=ski_position-1) & (move'=1) & (done'=(PassedGates|HitTree|HitGate)); |
|||
[right] !done & move=0 & ski_position<14 -> (ski_position'=ski_position+1) & (move'=1) & (done'=(PassedGates|HitTree|HitGate)); |
|||
[noop] !done & move=0 -> (move'=1) & (done'=(PassedGates|HitTree|HitGate)); |
|||
|
|||
[left] done & move=0 & ski_position>1 -> (ski_position'=ski_position-1) & (move'=1); |
|||
[right] done & move=0 & ski_position<14 -> (ski_position'=ski_position+1) & (move'=1); |
|||
[noop] done & move=0 -> (move'=1); |
|||
|
|||
endmodule |
|||
|
|||
module updateY |
|||
y : [initY..maxY] init initY; |
|||
standstill : [0..8] init 0; |
|||
[update_y] move=1 & (ski_position=1 | ski_position = 14) & standstill>=5 -> (y'=y) & (standstill'=min(8,standstill+1)) & (move'=2); |
|||
[update_y] move=1 & (ski_position=1 | ski_position = 14) & standstill<5 -> (y'=min(maxY,y+4)) & (standstill'=min(8,standstill+1)) & (move'=2); |
|||
[update_y] move=1 & (ski_position=2 | ski_position = 3 | ski_position = 12 | ski_position = 13) -> (y'=min(maxY,y+8)) & (standstill'=0) & (move'=2); |
|||
[update_y] move=1 & (ski_position=4 | ski_position = 5 | ski_position = 10 | ski_position = 11) -> (y'=min(maxY,y+12)) & (standstill'=0) & (move'=2); |
|||
[update_y] move=1 & (ski_position=6 | ski_position = 7 | ski_position = 8 | ski_position = 9) -> (y'=min(maxY,y+16)) & (standstill'=0) & (move'=2); |
|||
endmodule |
|||
|
|||
module updateX |
|||
x : [minX..maxX] init initX; |
|||
|
|||
[update_x] move=2 & standstill>=8 -> (move'=0); |
|||
[update_x] move=2 & standstill<8 & (ski_position=7 | ski_position=8) -> (move'=0); |
|||
|
|||
[update_x] move=2 & standstill<8 & ski_position=6 -> 0.1: (x'=max(minX,x-3)) + 0.7: (x'=max(minX,x-4)) + 0.2: (x'=max(minX,x-5)) & (move'=0); |
|||
[update_x] move=2 & standstill<8 & ski_position=9 -> 0.1: (x'=min(maxX,x+3)) + 0.7: (x'=min(maxX,x+4)) + 0.2: (x'=min(maxX,x+5)) & (move'=0); |
|||
|
|||
[update_x] move=2 & standstill<8 & (ski_position=4 | ski_position=5) -> 0.1: (x'=max(minX,x-5)) + 0.7: (x'=max(minX,x-6)) + 0.2: (x'=max(minX,x-7)) & (move'=0); |
|||
[update_x] move=2 & standstill<8 & (ski_position=10 | ski_position=11) -> 0.1: (x'=min(maxX,x+5)) + 0.7: (x'=min(maxX,x+6)) + 0.2: (x'=min(maxX,x+7)) & (move'=0); |
|||
|
|||
[update_x] move=2 & standstill<8 & (ski_position=2 | ski_position=3) -> 0.1: (x'=max(minX,x-5)) + 0.7: (x'=max(minX,x-6)) + 0.2: (x'=max(minX,x-7)) & (move'=0); |
|||
[update_x] move=2 & standstill<8 & (ski_position=12 | ski_position=13) -> 0.1: (x'=min(maxX,x+5)) + 0.7: (x'=min(maxX,x+6)) + 0.2: (x'=min(maxX,x+7)) & (move'=0); |
|||
|
|||
[update_x] move=2 & standstill<8 & (ski_position=1) -> 0.1: (x'=max(minX,x-0)) + 0.7: (x'=max(minX,x-2)) + 0.2: (x'=max(minX,x-3)) & (move'=0); |
|||
[update_x] move=2 & standstill<8 & (ski_position=14) -> 0.1: (x'=min(maxX,x+0)) + 0.7: (x'=min(maxX,x+2)) + 0.2: (x'=min(maxX,x+3)) & (move'=0); |
|||
endmodule |
|||
|
|||
rewards |
|||
[left] !done & PassedGates : 100; |
|||
[right] !done & PassedGates : 100; |
|||
[noop] !done & PassedGates : 100; |
|||
[left] !done & (HitTree) : -200; |
|||
[right] !done & (HitTree) : -200; |
|||
[noop] !done & (HitTree) : -200; |
|||
[left] !done & (HitGate) : -150; |
|||
[right] !done & (HitGate) : -150; |
|||
[noop] !done & (HitGate) : -150; |
|||
endrewards |
@ -0,0 +1,78 @@ |
|||
import gym |
|||
from PIL import Image |
|||
from copy import deepcopy |
|||
import numpy as np |
|||
|
|||
from matplotlib import pyplot as plt |
|||
import readchar |
|||
|
|||
import queue |
|||
|
|||
ski_position_queue = queue.Queue() |
|||
|
|||
env = gym.make("ALE/Skiing-v5", render_mode="human") |
|||
|
|||
|
|||
observation, info = env.reset() |
|||
y = 40 |
|||
|
|||
|
|||
standstillcounter = 0 |
|||
def update_y(y, ski_position): |
|||
global standstillcounter |
|||
if ski_position in [6,7, 8,9]: |
|||
standstillcounter = 0 |
|||
y_update = 16 |
|||
elif ski_position in [4,5, 10,11]: |
|||
standstillcounter = 0 |
|||
y_update = 12 |
|||
elif ski_position in [2,3, 12,13]: |
|||
standstillcounter = 0 |
|||
y_update = 8 |
|||
elif ski_position in [1, 14] and standstillcounter >= 5: |
|||
if standstillcounter >= 8: |
|||
print("!!!!!!!!!! no more x updates!!!!!!!!!!!") |
|||
y_update = 0 |
|||
elif ski_position in [1, 14]: |
|||
y_update = 4 |
|||
|
|||
if ski_position in [1, 14]: |
|||
standstillcounter += 1 |
|||
return y_update |
|||
|
|||
def update_ski_position(ski_position, action): |
|||
if action == 0: |
|||
return ski_position |
|||
elif action == 1: |
|||
return min(ski_position+1, 14) |
|||
elif action == 2: |
|||
return max(ski_position-1, 1) |
|||
|
|||
approx_x_coordinate = 80 |
|||
ski_position = 8 |
|||
for _ in range(1000000): |
|||
action = env.action_space.sample() # agent policy that uses the observation and info |
|||
action = int(repr(readchar.readchar())[1]) |
|||
ski_position = update_ski_position(ski_position, action) |
|||
y_update = update_y(y, ski_position) |
|||
y += y_update if y_update else 0 |
|||
|
|||
old_x = deepcopy(approx_x_coordinate) |
|||
approx_x_coordinate = int(np.mean(np.where(observation[:,:,1] == 92)[1])) |
|||
print(f"Action: {action},\tski position: {ski_position},\ty_update: {y_update},\ty: {y},\tx: {approx_x_coordinate},\tx_update:{approx_x_coordinate - old_x}") |
|||
observation, reward, terminated, truncated, info = env.step(action) |
|||
if terminated or truncated: |
|||
observation, info = env.reset() |
|||
|
|||
|
|||
observation, reward, terminated, truncated, info = env.step(0) |
|||
observation, reward, terminated, truncated, info = env.step(0) |
|||
observation, reward, terminated, truncated, info = env.step(0) |
|||
observation, reward, terminated, truncated, info = env.step(0) |
|||
|
|||
#plt.imshow(observation) |
|||
#plt.show() |
|||
#im = Image.fromarray(observation) |
|||
#im.save("init.png") |
|||
|
|||
env.close() |
@ -0,0 +1,78 @@ |
|||
import time, re, sys, csv, os |
|||
|
|||
from subprocess import call |
|||
from os import listdir, system |
|||
from os.path import isfile, join, getctime |
|||
from dataclasses import dataclass, field |
|||
|
|||
ski_position_to_rgb_map = {1: "230, 138, 0", |
|||
2: "255,153,0", |
|||
3: "255, 184, 77", |
|||
4: "255, 194, 100", |
|||
5: "255, 194, 120", |
|||
6: "255, 210, 130", |
|||
7: "255, 210, 140", |
|||
8: "230, 204, 255", |
|||
9: "204, 153, 255", |
|||
10: "255, 102, 179", |
|||
11: "255, 51, 153", |
|||
12: "255, 0, 255", |
|||
13: "179, 0, 179", |
|||
14: "102, 0, 102"} |
|||
|
|||
def convert(tuples): |
|||
return dict(tuples) |
|||
@dataclass(frozen=True) |
|||
class State: |
|||
x: int |
|||
y: int |
|||
ski_position: int |
|||
|
|||
def default_value(): |
|||
return {'action' : None, 'choiceValue' : None} |
|||
|
|||
@dataclass(frozen=True) |
|||
class StateValue: |
|||
ranking: float |
|||
choices: dict = field(default_factory=default_value) |
|||
def exec(command): |
|||
print(f"Executing {command}") |
|||
system(f"echo {command} >> list_of_exec") |
|||
return system(command) |
|||
|
|||
def fillStateRanking(file_name, match=""): |
|||
state_ranking = dict() |
|||
try: |
|||
with open(file_name, "r") as f: |
|||
file_content = f.readlines() |
|||
for line in file_content: |
|||
if match and skip_line.match(line): continue |
|||
stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line)) |
|||
#print("stateMapping", stateMapping) |
|||
choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line)) |
|||
#print("choices", choices) |
|||
ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:","")) |
|||
#print("ranking_value", ranking_value) |
|||
state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"])) |
|||
value = StateValue(ranking_value, choices) |
|||
state_ranking[state] = value |
|||
return state_ranking |
|||
|
|||
except EnvironmentError: |
|||
print("TODO file not available. Exiting.") |
|||
sys.exit(1) |
|||
|
|||
ranking = fillStateRanking("action_ranking") |
|||
sorted_ranking = sorted(ranking.items(), key=lambda x: x[1].ranking) |
|||
|
|||
draw_commands = list() |
|||
|
|||
for state in sorted_ranking[-20:-1]: |
|||
print(state) |
|||
x = state[0].x |
|||
y = state[0].y |
|||
markerSize = 2 |
|||
print(state[0].ski_position) |
|||
draw_commands.append(f"-fill 'rgba({ski_position_to_rgb_map[state[0].ski_position]},0.7)' -draw 'rectangle {x-markerSize},{y-markerSize} {x+markerSize},{y+markerSize} '") |
|||
command = f"convert init.png {' '.join(draw_commands)} first_try.png" |
|||
exec(command) |
Reference in new issue
xxxxxxxxxx