You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

118 lines
3.7 KiB

  1. import time, re, sys, csv, os
  2. import gym
  3. from PIL import Image
  4. from copy import deepcopy
  5. from dataclasses import dataclass, field
  6. import numpy as np
  7. from matplotlib import pyplot as plt
  8. import readchar
  9. def string_to_action(action):
  10. if action == "left":
  11. return 2
  12. if action == "right":
  13. return 1
  14. if action == "noop":
  15. return 0
  16. return 0
  17. scheduler_file = "x80_y128_pos8.sched"
  18. def convert(tuples):
  19. return dict(tuples)
  20. @dataclass(frozen=True)
  21. class State:
  22. x: int
  23. y: int
  24. ski_position: int
  25. def parse_scheduler(scheduler_file):
  26. scheduler = dict()
  27. try:
  28. with open(scheduler_file, "r") as f:
  29. file_content = f.readlines()
  30. for line in file_content:
  31. if not "move=0" in line: continue
  32. stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
  33. #print("stateMapping", stateMapping)
  34. choice = re.findall(r"{(left|right|noop)}", line)
  35. if choice: choice = choice[0]
  36. #print("choice", choice)
  37. state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))
  38. scheduler[state] = choice
  39. return scheduler
  40. except EnvironmentError:
  41. print("TODO file not available. Exiting.")
  42. sys.exit(1)
  43. env = gym.make("ALE/Skiing-v5")#, render_mode="human")
  44. #env = gym.wrappers.ResizeObservation(env, (84, 84))
  45. #env = gym.wrappers.GrayScaleObservation(env)
  46. observation, info = env.reset()
  47. y = 40
  48. standstillcounter = 0
  49. def update_y(y, ski_position):
  50. y_update = 0
  51. global standstillcounter
  52. if ski_position in [6,7, 8,9]:
  53. standstillcounter = 0
  54. y_update = 16
  55. elif ski_position in [4,5, 10,11]:
  56. standstillcounter = 0
  57. y_update = 12
  58. elif ski_position in [2,3, 12,13]:
  59. standstillcounter = 0
  60. y_update = 8
  61. elif ski_position in [1, 14] and standstillcounter >= 5:
  62. if standstillcounter >= 8:
  63. print("!!!!!!!!!! no more x updates!!!!!!!!!!!")
  64. y_update = 0
  65. elif ski_position in [1, 14]:
  66. y_update = 4
  67. if ski_position in [1, 14]:
  68. standstillcounter += 1
  69. return y_update
  70. def update_ski_position(ski_position, action):
  71. if action == 0:
  72. return ski_position
  73. elif action == 1:
  74. return min(ski_position+1, 14)
  75. elif action == 2:
  76. return max(ski_position-1, 1)
  77. approx_x_coordinate = 80
  78. ski_position = 8
  79. #scheduler = parse_scheduler(scheduler_file)
  80. j = 0
  81. for _ in range(1000000):
  82. j += 1
  83. #action = env.action_space.sample() # agent policy that uses the observation and info
  84. #action = int(repr(readchar.readchar())[1])
  85. #action = string_to_action(scheduler.get(State(approx_x_coordinate, y, ski_position), "noop"))
  86. action = 0
  87. #ski_position = update_ski_position(ski_position, action)
  88. #y_update = update_y(y, ski_position)
  89. #y += y_update if y_update else 0
  90. #old_x = deepcopy(approx_x_coordinate)
  91. #approx_x_coordinate = int(np.mean(np.where(observation[:,:,1] == 92)[1]))
  92. #print(f"Action: {action},\tski position: {ski_position},\ty_update: {y_update},\ty: {y},\tx: {approx_x_coordinate},\tx_update:{approx_x_coordinate - old_x}")
  93. observation, reward, terminated, truncated, info = env.step(action)
  94. if terminated or truncated:
  95. observation, info = env.reset()
  96. break
  97. img = Image.fromarray(observation)
  98. img.save(f"images/{j:05}.png")
  99. #observation, reward, terminated, truncated, info = env.step(0)
  100. #observation, reward, terminated, truncated, info = env.step(0)
  101. #observation, reward, terminated, truncated, info = env.step(0)
  102. #observation, reward, terminated, truncated, info = env.step(0)
  103. env.close()