You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

68 lines
2.3 KiB

  1. import sys
  2. from enum import Flag, auto
  3. import numpy as np
  4. class Verdict(Flag):
  5. INCONCLUSIVE = auto()
  6. PASS = auto()
  7. FAIL = auto()
  8. class Simulator():
  9. def __init__(self, allStateActionPairs, strategy, deadlockStates, reachedStates, bound=3, numSimulations=1):
  10. self.allStateActionPairs = { ( pair.state_id, pair.action_id ) : pair.next_state_probabilities for pair in allStateActionPairs }
  11. self.strategy = strategy
  12. self.deadlockStates = deadlockStates
  13. self.reachedStates = reachedStates
  14. #print(f"Deadlock: {self.deadlockStates}")
  15. #print(f"GoalStates: {self.reachedStates}")
  16. self.bound = bound
  17. self.numSimulations = numSimulations
  18. allStates = set([state.state_id for state in allStateActionPairs])
  19. allStates = allStates.difference(set(deadlockStates))
  20. allStates = allStates.difference(set(reachedStates))
  21. self.allStates = np.array(list(allStates))
  22. def _pickRandomTestCase(self):
  23. testCase = np.random.choice(self.allStates, 1)[0]
  24. #self.allStates = np.delete(self.allStates, testCase)
  25. return testCase
  26. def _simulate(self, initialStateId):
  27. i = 0
  28. actionId = self.strategy[initialStateId]
  29. nextStatePair = (initialStateId, actionId)
  30. while i < self.bound:
  31. i += 1
  32. nextStateProbabilities = self.allStateActionPairs[nextStatePair]
  33. weights = list()
  34. nextStateIds = list()
  35. for nextStateId, probability in nextStateProbabilities.items():
  36. weights.append(probability)
  37. nextStateIds.append(nextStateId)
  38. nextStateId = np.random.choice(nextStateIds, 1, p=weights)[0]
  39. if nextStateId in self.deadlockStates:
  40. return Verdict.FAIL, i
  41. if nextStateId in self.reachedStates:
  42. return Verdict.PASS, i
  43. nextStatePair = (nextStateId, self.strategy[nextStateId])
  44. return Verdict.INCONCLUSIVE, i
  45. def runTest(self):
  46. testCase = self._pickRandomTestCase()
  47. histogram = [0,0,0]
  48. for i in range(self.numSimulations):
  49. result, numQueries = self._simulate(testCase)
  50. if result == Verdict.FAIL:
  51. return testCase, Verdict.FAIL, numQueries
  52. return testCase, Verdict.INCONCLUSIVE, numQueries