commit
						8f7efe48db
					
				 3 changed files with 227 additions and 0 deletions
			
			
		- 
					71first_try.prism
- 
					78test.py
- 
					78test_model.py
| @ -0,0 +1,71 @@ | |||||
|  | mdp | ||||
|  | 
 | ||||
|  | const int initY = 40; | ||||
|  | const int initX = 80; | ||||
|  | 
 | ||||
|  | const int maxY = 240; | ||||
|  | const int minX = 12; | ||||
|  | const int maxX = 147; | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | formula HitTree = (122<x & x<139) & (220<y & y=232); | ||||
|  | formula HitGate = ((x=51 | x=79) & y=164); | ||||
|  | 
 | ||||
|  | formula PassedGates = (158<y & y<165 & 51<x & x<79); | ||||
|  | 
 | ||||
|  | global move : [0..3] init 0; | ||||
|  | 
 | ||||
|  | module skier | ||||
|  |   ski_position : [1..14] init 8; | ||||
|  |   done : bool init false; | ||||
|  | 
 | ||||
|  |   [left]  !done & move=0 & ski_position>1 -> (ski_position'=ski_position-1) & (move'=1) & (done'=(PassedGates|HitTree|HitGate)); | ||||
|  |   [right] !done & move=0 & ski_position<14 -> (ski_position'=ski_position+1) & (move'=1) & (done'=(PassedGates|HitTree|HitGate)); | ||||
|  |   [noop]  !done & move=0 -> (move'=1) & (done'=(PassedGates|HitTree|HitGate)); | ||||
|  | 
 | ||||
|  |   [left]  done & move=0 & ski_position>1 -> (ski_position'=ski_position-1) & (move'=1); | ||||
|  |   [right] done & move=0 & ski_position<14 -> (ski_position'=ski_position+1) & (move'=1); | ||||
|  |   [noop]  done & move=0 -> (move'=1); | ||||
|  | 
 | ||||
|  | endmodule | ||||
|  | 
 | ||||
|  | module updateY | ||||
|  |   y : [initY..maxY] init initY; | ||||
|  |   standstill : [0..8] init 0; | ||||
|  |   [update_y] move=1 & (ski_position=1 | ski_position = 14) & standstill>=5 -> (y'=y)  & (standstill'=min(8,standstill+1)) & (move'=2); | ||||
|  |   [update_y] move=1 & (ski_position=1 | ski_position = 14) & standstill<5  -> (y'=min(maxY,y+4)) & (standstill'=min(8,standstill+1)) & (move'=2); | ||||
|  |   [update_y] move=1 & (ski_position=2 | ski_position = 3 | ski_position = 12 | ski_position = 13) -> (y'=min(maxY,y+8))  & (standstill'=0) & (move'=2); | ||||
|  |   [update_y] move=1 & (ski_position=4 | ski_position = 5 | ski_position = 10 | ski_position = 11) -> (y'=min(maxY,y+12)) & (standstill'=0) & (move'=2); | ||||
|  |   [update_y] move=1 & (ski_position=6 | ski_position = 7 | ski_position =  8 | ski_position =  9) -> (y'=min(maxY,y+16)) & (standstill'=0) & (move'=2); | ||||
|  | endmodule | ||||
|  | 
 | ||||
|  | module updateX | ||||
|  |   x : [minX..maxX] init initX; | ||||
|  | 
 | ||||
|  |   [update_x]  move=2 & standstill>=8                                       -> (move'=0); | ||||
|  |   [update_x]  move=2 & standstill<8 & (ski_position=7 | ski_position=8)    -> (move'=0); | ||||
|  | 
 | ||||
|  |   [update_x]  move=2 & standstill<8 &  ski_position=6                      -> 0.1: (x'=max(minX,x-3)) + 0.7: (x'=max(minX,x-4)) + 0.2: (x'=max(minX,x-5)) & (move'=0); | ||||
|  |   [update_x]  move=2 & standstill<8 &  ski_position=9                      -> 0.1: (x'=min(maxX,x+3)) + 0.7: (x'=min(maxX,x+4)) + 0.2: (x'=min(maxX,x+5)) & (move'=0); | ||||
|  | 
 | ||||
|  |   [update_x]  move=2 & standstill<8 & (ski_position=4 | ski_position=5)    -> 0.1: (x'=max(minX,x-5)) + 0.7: (x'=max(minX,x-6)) + 0.2: (x'=max(minX,x-7)) & (move'=0); | ||||
|  |   [update_x]  move=2 & standstill<8 & (ski_position=10 | ski_position=11)  -> 0.1: (x'=min(maxX,x+5)) + 0.7: (x'=min(maxX,x+6)) + 0.2: (x'=min(maxX,x+7)) & (move'=0); | ||||
|  | 
 | ||||
|  |   [update_x]  move=2 & standstill<8 & (ski_position=2 | ski_position=3)    -> 0.1: (x'=max(minX,x-5)) + 0.7: (x'=max(minX,x-6)) + 0.2: (x'=max(minX,x-7)) & (move'=0); | ||||
|  |   [update_x]  move=2 & standstill<8 & (ski_position=12 | ski_position=13)  -> 0.1: (x'=min(maxX,x+5)) + 0.7: (x'=min(maxX,x+6)) + 0.2: (x'=min(maxX,x+7)) & (move'=0); | ||||
|  | 
 | ||||
|  |   [update_x]  move=2 & standstill<8 & (ski_position=1)                     -> 0.1: (x'=max(minX,x-0)) + 0.7: (x'=max(minX,x-2)) + 0.2: (x'=max(minX,x-3)) & (move'=0); | ||||
|  |   [update_x]  move=2 & standstill<8 & (ski_position=14)                    -> 0.1: (x'=min(maxX,x+0)) + 0.7: (x'=min(maxX,x+2)) + 0.2: (x'=min(maxX,x+3)) & (move'=0); | ||||
|  | endmodule | ||||
|  | 
 | ||||
|  | rewards | ||||
|  |   [left]  !done & PassedGates : 100; | ||||
|  |   [right] !done & PassedGates : 100; | ||||
|  |   [noop]  !done & PassedGates : 100; | ||||
|  |   [left]  !done & (HitTree) : -200; | ||||
|  |   [right] !done & (HitTree) : -200; | ||||
|  |   [noop]  !done & (HitTree) : -200; | ||||
|  |   [left]  !done & (HitGate) : -150; | ||||
|  |   [right] !done & (HitGate) : -150; | ||||
|  |   [noop]  !done & (HitGate) : -150; | ||||
|  | endrewards | ||||
| @ -0,0 +1,78 @@ | |||||
|  | import gym | ||||
|  | from PIL import Image | ||||
|  | from copy import deepcopy | ||||
|  | import numpy as np | ||||
|  | 
 | ||||
|  | from matplotlib import pyplot as plt | ||||
|  | import readchar | ||||
|  | 
 | ||||
|  | import queue | ||||
|  | 
 | ||||
|  | ski_position_queue = queue.Queue() | ||||
|  | 
 | ||||
|  | env = gym.make("ALE/Skiing-v5", render_mode="human") | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | observation, info = env.reset() | ||||
|  | y = 40 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | standstillcounter = 0 | ||||
|  | def update_y(y, ski_position): | ||||
|  |     global standstillcounter | ||||
|  |     if ski_position in [6,7, 8,9]: | ||||
|  |         standstillcounter = 0 | ||||
|  |         y_update = 16 | ||||
|  |     elif ski_position in [4,5, 10,11]: | ||||
|  |         standstillcounter = 0 | ||||
|  |         y_update = 12 | ||||
|  |     elif ski_position in [2,3, 12,13]: | ||||
|  |         standstillcounter = 0 | ||||
|  |         y_update = 8 | ||||
|  |     elif ski_position in [1, 14] and standstillcounter >= 5: | ||||
|  |         if standstillcounter >= 8: | ||||
|  |             print("!!!!!!!!!! no more x updates!!!!!!!!!!!") | ||||
|  |         y_update = 0 | ||||
|  |     elif ski_position in [1, 14]: | ||||
|  |         y_update = 4 | ||||
|  | 
 | ||||
|  |     if ski_position in [1, 14]: | ||||
|  |         standstillcounter += 1 | ||||
|  |     return y_update | ||||
|  | 
 | ||||
|  | def update_ski_position(ski_position, action): | ||||
|  |     if action == 0: | ||||
|  |         return ski_position | ||||
|  |     elif action == 1: | ||||
|  |         return min(ski_position+1, 14) | ||||
|  |     elif action == 2: | ||||
|  |         return max(ski_position-1, 1) | ||||
|  | 
 | ||||
|  | approx_x_coordinate = 80 | ||||
|  | ski_position = 8 | ||||
|  | for _ in range(1000000): | ||||
|  |     action = env.action_space.sample()  # agent policy that uses the observation and info | ||||
|  |     action = int(repr(readchar.readchar())[1]) | ||||
|  |     ski_position = update_ski_position(ski_position, action) | ||||
|  |     y_update = update_y(y, ski_position) | ||||
|  |     y += y_update if y_update else 0 | ||||
|  | 
 | ||||
|  |     old_x = deepcopy(approx_x_coordinate) | ||||
|  |     approx_x_coordinate = int(np.mean(np.where(observation[:,:,1] == 92)[1])) | ||||
|  |     print(f"Action: {action},\tski position: {ski_position},\ty_update: {y_update},\ty: {y},\tx: {approx_x_coordinate},\tx_update:{approx_x_coordinate - old_x}") | ||||
|  |     observation, reward, terminated, truncated, info = env.step(action) | ||||
|  |     if terminated or truncated: | ||||
|  |         observation, info = env.reset() | ||||
|  | 
 | ||||
|  | 
 | ||||
|  |     observation, reward, terminated, truncated, info = env.step(0) | ||||
|  |     observation, reward, terminated, truncated, info = env.step(0) | ||||
|  |     observation, reward, terminated, truncated, info = env.step(0) | ||||
|  |     observation, reward, terminated, truncated, info = env.step(0) | ||||
|  | 
 | ||||
|  |     #plt.imshow(observation) | ||||
|  |     #plt.show() | ||||
|  |     #im = Image.fromarray(observation) | ||||
|  |     #im.save("init.png") | ||||
|  | 
 | ||||
|  | env.close() | ||||
| @ -0,0 +1,78 @@ | |||||
|  | import time, re, sys, csv, os | ||||
|  | 
 | ||||
|  | from subprocess import call | ||||
|  | from os import listdir, system | ||||
|  | from os.path import isfile, join, getctime | ||||
|  | from dataclasses import dataclass, field | ||||
|  | 
 | ||||
|  | ski_position_to_rgb_map = {1: "230, 138, 0", | ||||
|  |                            2: "255,153,0", | ||||
|  |                            3: "255, 184, 77", | ||||
|  |                            4: "255, 194, 100", | ||||
|  |                            5: "255, 194, 120", | ||||
|  |                            6: "255, 210, 130", | ||||
|  |                            7: "255, 210, 140", | ||||
|  |                            8: "230, 204, 255", | ||||
|  |                            9: "204, 153, 255", | ||||
|  |                            10: "255, 102, 179", | ||||
|  |                            11: "255, 51, 153", | ||||
|  |                            12: "255, 0, 255", | ||||
|  |                            13: "179, 0, 179", | ||||
|  |                            14: "102, 0, 102"} | ||||
|  | 
 | ||||
|  | def convert(tuples): | ||||
|  |     return dict(tuples) | ||||
|  | @dataclass(frozen=True) | ||||
|  | class State: | ||||
|  |     x: int | ||||
|  |     y: int | ||||
|  |     ski_position: int | ||||
|  | 
 | ||||
|  | def default_value(): | ||||
|  |     return {'action' : None, 'choiceValue' : None} | ||||
|  | 
 | ||||
|  | @dataclass(frozen=True) | ||||
|  | class StateValue: | ||||
|  |     ranking: float | ||||
|  |     choices: dict = field(default_factory=default_value) | ||||
|  | def exec(command): | ||||
|  |     print(f"Executing {command}") | ||||
|  |     system(f"echo {command} >> list_of_exec") | ||||
|  |     return system(command) | ||||
|  | 
 | ||||
|  | def fillStateRanking(file_name, match=""): | ||||
|  |     state_ranking = dict() | ||||
|  |     try: | ||||
|  |         with open(file_name, "r") as f: | ||||
|  |             file_content = f.readlines() | ||||
|  |         for line in file_content: | ||||
|  |             if match and skip_line.match(line): continue | ||||
|  |             stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line)) | ||||
|  |             #print("stateMapping", stateMapping) | ||||
|  |             choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line)) | ||||
|  |             #print("choices", choices) | ||||
|  |             ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:","")) | ||||
|  |             #print("ranking_value", ranking_value) | ||||
|  |             state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"])) | ||||
|  |             value = StateValue(ranking_value, choices) | ||||
|  |             state_ranking[state] = value | ||||
|  |         return state_ranking | ||||
|  | 
 | ||||
|  |     except EnvironmentError: | ||||
|  |         print("TODO file not available. Exiting.") | ||||
|  |         sys.exit(1) | ||||
|  | 
 | ||||
|  | ranking = fillStateRanking("action_ranking") | ||||
|  | sorted_ranking = sorted(ranking.items(), key=lambda x: x[1].ranking) | ||||
|  | 
 | ||||
|  | draw_commands = list() | ||||
|  | 
 | ||||
|  | for state in sorted_ranking[-20:-1]: | ||||
|  |     print(state) | ||||
|  |     x = state[0].x | ||||
|  |     y = state[0].y | ||||
|  |     markerSize = 2 | ||||
|  |     print(state[0].ski_position) | ||||
|  |     draw_commands.append(f"-fill 'rgba({ski_position_to_rgb_map[state[0].ski_position]},0.7)' -draw 'rectangle {x-markerSize},{y-markerSize} {x+markerSize},{y+markerSize} '") | ||||
|  | command = f"convert init.png  {' '.join(draw_commands)} first_try.png" | ||||
|  | exec(command) | ||||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue