sp
1 year ago
commit
8f7efe48db
3 changed files with 227 additions and 0 deletions
-
71first_try.prism
-
78test.py
-
78test_model.py
@ -0,0 +1,71 @@ |
|||||
|
mdp |
||||
|
|
||||
|
const int initY = 40; |
||||
|
const int initX = 80; |
||||
|
|
||||
|
const int maxY = 240; |
||||
|
const int minX = 12; |
||||
|
const int maxX = 147; |
||||
|
|
||||
|
|
||||
|
formula HitTree = (122<x & x<139) & (220<y & y=232); |
||||
|
formula HitGate = ((x=51 | x=79) & y=164); |
||||
|
|
||||
|
formula PassedGates = (158<y & y<165 & 51<x & x<79); |
||||
|
|
||||
|
global move : [0..3] init 0; |
||||
|
|
||||
|
module skier |
||||
|
ski_position : [1..14] init 8; |
||||
|
done : bool init false; |
||||
|
|
||||
|
[left] !done & move=0 & ski_position>1 -> (ski_position'=ski_position-1) & (move'=1) & (done'=(PassedGates|HitTree|HitGate)); |
||||
|
[right] !done & move=0 & ski_position<14 -> (ski_position'=ski_position+1) & (move'=1) & (done'=(PassedGates|HitTree|HitGate)); |
||||
|
[noop] !done & move=0 -> (move'=1) & (done'=(PassedGates|HitTree|HitGate)); |
||||
|
|
||||
|
[left] done & move=0 & ski_position>1 -> (ski_position'=ski_position-1) & (move'=1); |
||||
|
[right] done & move=0 & ski_position<14 -> (ski_position'=ski_position+1) & (move'=1); |
||||
|
[noop] done & move=0 -> (move'=1); |
||||
|
|
||||
|
endmodule |
||||
|
|
||||
|
module updateY |
||||
|
y : [initY..maxY] init initY; |
||||
|
standstill : [0..8] init 0; |
||||
|
[update_y] move=1 & (ski_position=1 | ski_position = 14) & standstill>=5 -> (y'=y) & (standstill'=min(8,standstill+1)) & (move'=2); |
||||
|
[update_y] move=1 & (ski_position=1 | ski_position = 14) & standstill<5 -> (y'=min(maxY,y+4)) & (standstill'=min(8,standstill+1)) & (move'=2); |
||||
|
[update_y] move=1 & (ski_position=2 | ski_position = 3 | ski_position = 12 | ski_position = 13) -> (y'=min(maxY,y+8)) & (standstill'=0) & (move'=2); |
||||
|
[update_y] move=1 & (ski_position=4 | ski_position = 5 | ski_position = 10 | ski_position = 11) -> (y'=min(maxY,y+12)) & (standstill'=0) & (move'=2); |
||||
|
[update_y] move=1 & (ski_position=6 | ski_position = 7 | ski_position = 8 | ski_position = 9) -> (y'=min(maxY,y+16)) & (standstill'=0) & (move'=2); |
||||
|
endmodule |
||||
|
|
||||
|
module updateX |
||||
|
x : [minX..maxX] init initX; |
||||
|
|
||||
|
[update_x] move=2 & standstill>=8 -> (move'=0); |
||||
|
[update_x] move=2 & standstill<8 & (ski_position=7 | ski_position=8) -> (move'=0); |
||||
|
|
||||
|
[update_x] move=2 & standstill<8 & ski_position=6 -> 0.1: (x'=max(minX,x-3)) + 0.7: (x'=max(minX,x-4)) + 0.2: (x'=max(minX,x-5)) & (move'=0); |
||||
|
[update_x] move=2 & standstill<8 & ski_position=9 -> 0.1: (x'=min(maxX,x+3)) + 0.7: (x'=min(maxX,x+4)) + 0.2: (x'=min(maxX,x+5)) & (move'=0); |
||||
|
|
||||
|
[update_x] move=2 & standstill<8 & (ski_position=4 | ski_position=5) -> 0.1: (x'=max(minX,x-5)) + 0.7: (x'=max(minX,x-6)) + 0.2: (x'=max(minX,x-7)) & (move'=0); |
||||
|
[update_x] move=2 & standstill<8 & (ski_position=10 | ski_position=11) -> 0.1: (x'=min(maxX,x+5)) + 0.7: (x'=min(maxX,x+6)) + 0.2: (x'=min(maxX,x+7)) & (move'=0); |
||||
|
|
||||
|
[update_x] move=2 & standstill<8 & (ski_position=2 | ski_position=3) -> 0.1: (x'=max(minX,x-5)) + 0.7: (x'=max(minX,x-6)) + 0.2: (x'=max(minX,x-7)) & (move'=0); |
||||
|
[update_x] move=2 & standstill<8 & (ski_position=12 | ski_position=13) -> 0.1: (x'=min(maxX,x+5)) + 0.7: (x'=min(maxX,x+6)) + 0.2: (x'=min(maxX,x+7)) & (move'=0); |
||||
|
|
||||
|
[update_x] move=2 & standstill<8 & (ski_position=1) -> 0.1: (x'=max(minX,x-0)) + 0.7: (x'=max(minX,x-2)) + 0.2: (x'=max(minX,x-3)) & (move'=0); |
||||
|
[update_x] move=2 & standstill<8 & (ski_position=14) -> 0.1: (x'=min(maxX,x+0)) + 0.7: (x'=min(maxX,x+2)) + 0.2: (x'=min(maxX,x+3)) & (move'=0); |
||||
|
endmodule |
||||
|
|
||||
|
rewards |
||||
|
[left] !done & PassedGates : 100; |
||||
|
[right] !done & PassedGates : 100; |
||||
|
[noop] !done & PassedGates : 100; |
||||
|
[left] !done & (HitTree) : -200; |
||||
|
[right] !done & (HitTree) : -200; |
||||
|
[noop] !done & (HitTree) : -200; |
||||
|
[left] !done & (HitGate) : -150; |
||||
|
[right] !done & (HitGate) : -150; |
||||
|
[noop] !done & (HitGate) : -150; |
||||
|
endrewards |
@ -0,0 +1,78 @@ |
|||||
|
import gym |
||||
|
from PIL import Image |
||||
|
from copy import deepcopy |
||||
|
import numpy as np |
||||
|
|
||||
|
from matplotlib import pyplot as plt |
||||
|
import readchar |
||||
|
|
||||
|
import queue |
||||
|
|
||||
|
ski_position_queue = queue.Queue() |
||||
|
|
||||
|
env = gym.make("ALE/Skiing-v5", render_mode="human") |
||||
|
|
||||
|
|
||||
|
observation, info = env.reset() |
||||
|
y = 40 |
||||
|
|
||||
|
|
||||
|
standstillcounter = 0 |
||||
|
def update_y(y, ski_position): |
||||
|
global standstillcounter |
||||
|
if ski_position in [6,7, 8,9]: |
||||
|
standstillcounter = 0 |
||||
|
y_update = 16 |
||||
|
elif ski_position in [4,5, 10,11]: |
||||
|
standstillcounter = 0 |
||||
|
y_update = 12 |
||||
|
elif ski_position in [2,3, 12,13]: |
||||
|
standstillcounter = 0 |
||||
|
y_update = 8 |
||||
|
elif ski_position in [1, 14] and standstillcounter >= 5: |
||||
|
if standstillcounter >= 8: |
||||
|
print("!!!!!!!!!! no more x updates!!!!!!!!!!!") |
||||
|
y_update = 0 |
||||
|
elif ski_position in [1, 14]: |
||||
|
y_update = 4 |
||||
|
|
||||
|
if ski_position in [1, 14]: |
||||
|
standstillcounter += 1 |
||||
|
return y_update |
||||
|
|
||||
|
def update_ski_position(ski_position, action): |
||||
|
if action == 0: |
||||
|
return ski_position |
||||
|
elif action == 1: |
||||
|
return min(ski_position+1, 14) |
||||
|
elif action == 2: |
||||
|
return max(ski_position-1, 1) |
||||
|
|
||||
|
approx_x_coordinate = 80 |
||||
|
ski_position = 8 |
||||
|
for _ in range(1000000): |
||||
|
action = env.action_space.sample() # agent policy that uses the observation and info |
||||
|
action = int(repr(readchar.readchar())[1]) |
||||
|
ski_position = update_ski_position(ski_position, action) |
||||
|
y_update = update_y(y, ski_position) |
||||
|
y += y_update if y_update else 0 |
||||
|
|
||||
|
old_x = deepcopy(approx_x_coordinate) |
||||
|
approx_x_coordinate = int(np.mean(np.where(observation[:,:,1] == 92)[1])) |
||||
|
print(f"Action: {action},\tski position: {ski_position},\ty_update: {y_update},\ty: {y},\tx: {approx_x_coordinate},\tx_update:{approx_x_coordinate - old_x}") |
||||
|
observation, reward, terminated, truncated, info = env.step(action) |
||||
|
if terminated or truncated: |
||||
|
observation, info = env.reset() |
||||
|
|
||||
|
|
||||
|
observation, reward, terminated, truncated, info = env.step(0) |
||||
|
observation, reward, terminated, truncated, info = env.step(0) |
||||
|
observation, reward, terminated, truncated, info = env.step(0) |
||||
|
observation, reward, terminated, truncated, info = env.step(0) |
||||
|
|
||||
|
#plt.imshow(observation) |
||||
|
#plt.show() |
||||
|
#im = Image.fromarray(observation) |
||||
|
#im.save("init.png") |
||||
|
|
||||
|
env.close() |
@ -0,0 +1,78 @@ |
|||||
|
import time, re, sys, csv, os |
||||
|
|
||||
|
from subprocess import call |
||||
|
from os import listdir, system |
||||
|
from os.path import isfile, join, getctime |
||||
|
from dataclasses import dataclass, field |
||||
|
|
||||
|
ski_position_to_rgb_map = {1: "230, 138, 0", |
||||
|
2: "255,153,0", |
||||
|
3: "255, 184, 77", |
||||
|
4: "255, 194, 100", |
||||
|
5: "255, 194, 120", |
||||
|
6: "255, 210, 130", |
||||
|
7: "255, 210, 140", |
||||
|
8: "230, 204, 255", |
||||
|
9: "204, 153, 255", |
||||
|
10: "255, 102, 179", |
||||
|
11: "255, 51, 153", |
||||
|
12: "255, 0, 255", |
||||
|
13: "179, 0, 179", |
||||
|
14: "102, 0, 102"} |
||||
|
|
||||
|
def convert(tuples): |
||||
|
return dict(tuples) |
||||
|
@dataclass(frozen=True) |
||||
|
class State: |
||||
|
x: int |
||||
|
y: int |
||||
|
ski_position: int |
||||
|
|
||||
|
def default_value(): |
||||
|
return {'action' : None, 'choiceValue' : None} |
||||
|
|
||||
|
@dataclass(frozen=True) |
||||
|
class StateValue: |
||||
|
ranking: float |
||||
|
choices: dict = field(default_factory=default_value) |
||||
|
def exec(command): |
||||
|
print(f"Executing {command}") |
||||
|
system(f"echo {command} >> list_of_exec") |
||||
|
return system(command) |
||||
|
|
||||
|
def fillStateRanking(file_name, match=""): |
||||
|
state_ranking = dict() |
||||
|
try: |
||||
|
with open(file_name, "r") as f: |
||||
|
file_content = f.readlines() |
||||
|
for line in file_content: |
||||
|
if match and skip_line.match(line): continue |
||||
|
stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line)) |
||||
|
#print("stateMapping", stateMapping) |
||||
|
choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line)) |
||||
|
#print("choices", choices) |
||||
|
ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:","")) |
||||
|
#print("ranking_value", ranking_value) |
||||
|
state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"])) |
||||
|
value = StateValue(ranking_value, choices) |
||||
|
state_ranking[state] = value |
||||
|
return state_ranking |
||||
|
|
||||
|
except EnvironmentError: |
||||
|
print("TODO file not available. Exiting.") |
||||
|
sys.exit(1) |
||||
|
|
||||
|
ranking = fillStateRanking("action_ranking") |
||||
|
sorted_ranking = sorted(ranking.items(), key=lambda x: x[1].ranking) |
||||
|
|
||||
|
draw_commands = list() |
||||
|
|
||||
|
for state in sorted_ranking[-20:-1]: |
||||
|
print(state) |
||||
|
x = state[0].x |
||||
|
y = state[0].y |
||||
|
markerSize = 2 |
||||
|
print(state[0].ski_position) |
||||
|
draw_commands.append(f"-fill 'rgba({ski_position_to_rgb_map[state[0].ski_position]},0.7)' -draw 'rectangle {x-markerSize},{y-markerSize} {x+markerSize},{y+markerSize} '") |
||||
|
command = f"convert init.png {' '.join(draw_commands)} first_try.png" |
||||
|
exec(command) |
Write
Preview
Loading…
Cancel
Save
Reference in new issue