You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

152 lines
8.1 KiB

  1. #include <storm/exceptions/UnexpectedException.h>
  2. #include "storm/storage/expressions/Expression.h"
  3. #include "storm-pomdp/analysis/WinningRegionQueryInterface.h"
  4. namespace storm {
  5. namespace pomdp {
  6. template<typename ValueType>
  7. WinningRegionQueryInterface<ValueType>::WinningRegionQueryInterface(storm::models::sparse::Pomdp<ValueType> const& pomdp, WinningRegion const& winningRegion) :
  8. pomdp(pomdp), winningRegion(winningRegion) {
  9. uint64_t nrObservations = pomdp.getNrObservations();
  10. for (uint64_t observation = 0; observation < nrObservations; ++observation) {
  11. statesPerObservation.push_back(std::vector<uint64_t>());
  12. }
  13. for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
  14. statesPerObservation[pomdp.getObservation(state)].push_back(state);
  15. }
  16. }
  17. template<typename ValueType>
  18. bool WinningRegionQueryInterface<ValueType>::isInWinningRegion(storm::storage::BitVector const& beliefSupport) const {
  19. STORM_LOG_ASSERT(beliefSupport.getNumberOfSetBits() > 0, "One cannot think one is literally nowhere");
  20. uint64_t observation = pomdp.getObservation(beliefSupport.getNextSetIndex(0));
  21. // TODO consider optimizations after testing.
  22. storm::storage::BitVector queryVector(statesPerObservation[observation].size());
  23. auto stateWithObsIt = statesPerObservation[observation].begin();
  24. uint64_t offset = 0;
  25. for (uint64_t possibleState : beliefSupport) {
  26. STORM_LOG_ASSERT(pomdp.getObservation(possibleState) == observation, "Support must be observation-consistent");
  27. while(possibleState > *stateWithObsIt) {
  28. stateWithObsIt++;
  29. offset++;
  30. }
  31. if (possibleState == *stateWithObsIt) {
  32. queryVector.set(offset);
  33. }
  34. }
  35. return winningRegion.query(observation, queryVector);
  36. }
  37. template<typename ValueType>
  38. bool WinningRegionQueryInterface<ValueType>::staysInWinningRegion(storm::storage::BitVector const& currentBeliefSupport, uint64_t actionIndex) const {
  39. STORM_LOG_ASSERT(currentBeliefSupport.getNumberOfSetBits() > 0, "One cannot think one is literally nowhere");
  40. std::map<uint32_t, storm::storage::BitVector> successors;
  41. STORM_LOG_DEBUG("Stays in winning region? (" << currentBeliefSupport << ", " << actionIndex << ")");
  42. for (uint64_t oldState : currentBeliefSupport) {
  43. uint64_t row = pomdp.getTransitionMatrix().getRowGroupIndices()[oldState] + actionIndex;
  44. for (auto const& successor : pomdp.getTransitionMatrix().getRow(row)) {
  45. assert(!storm::utility::isZero(successor.getValue()));
  46. uint32_t obs = pomdp.getObservation(successor.getColumn());
  47. if (successors.count(obs) == 0) {
  48. successors[obs] = storm::storage::BitVector(pomdp.getNumberOfStates());
  49. }
  50. successors[obs].set(successor.getColumn(), true);
  51. }
  52. }
  53. for (auto const& entry : successors) {
  54. if(!isInWinningRegion(entry.second)) {
  55. STORM_LOG_DEBUG("Belief support " << entry.second << " (obs " << entry.first << ") is not winning");
  56. return false;
  57. } else {
  58. STORM_LOG_DEBUG("Belief support " << entry.second << " (obs " << entry.first << ") is winning");
  59. }
  60. }
  61. return true;
  62. }
  63. template<typename ValueType>
  64. void WinningRegionQueryInterface<ValueType>::validate(storm::storage::BitVector const& badStates) const {
  65. for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
  66. for(auto const& winningBelief : winningRegion.getWinningSetsPerObservation(obs)) {
  67. storm::storage::BitVector states(pomdp.getNumberOfStates());
  68. for (uint64_t offset : winningBelief) {
  69. states.set(statesPerObservation[obs][offset]);
  70. }
  71. bool safeActionExists = false;
  72. for(uint64_t actionIndex = 0; actionIndex < pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[obs][0]); ++actionIndex) {
  73. if (staysInWinningRegion(states,actionIndex)) {
  74. safeActionExists = true;
  75. break;
  76. }
  77. }
  78. STORM_LOG_THROW(safeActionExists, storm::exceptions::UnexpectedException, "Observation " << obs << " , support " << states);
  79. }
  80. }
  81. }
  82. template<typename ValueType>
  83. void WinningRegionQueryInterface<ValueType>::validateIsMaximal(storm::storage::BitVector const& badStates) const {
  84. for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
  85. STORM_LOG_DEBUG("Check listed belief supports for observation " << obs << " are maximal");
  86. for(auto const& winningBelief : winningRegion.getWinningSetsPerObservation(obs)) {
  87. storm::storage::BitVector remainders = ~winningBelief;
  88. for(auto const& additional : remainders) {
  89. uint64_t addState = statesPerObservation[obs][additional];
  90. if (badStates.get(addState)) {
  91. continue;
  92. }
  93. storm::storage::BitVector states(pomdp.getNumberOfStates());
  94. for (uint64_t offset : winningBelief) {
  95. states.set(statesPerObservation[obs][offset]);
  96. }
  97. states.set(statesPerObservation[obs][additional]);
  98. assert(states.getNumberOfSetBits() == winningBelief.getNumberOfSetBits() + 1);
  99. bool safeActionExists = false;
  100. for(uint64_t actionIndex = 0; actionIndex < pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[obs][0]); ++actionIndex) {
  101. if (staysInWinningRegion(states,actionIndex)) {
  102. STORM_LOG_DEBUG("Action " << actionIndex << " from " << states << " is safe. ");
  103. safeActionExists = true;
  104. break;
  105. }
  106. }
  107. STORM_LOG_THROW(!safeActionExists,storm::exceptions::UnexpectedException, "Observation " << obs << ", support " << states);
  108. }
  109. }
  110. STORM_LOG_DEBUG("All listed belief supports for observation " << obs << " are maximal. Continue with single states.");
  111. for (uint64_t offset = 0; offset < statesPerObservation[obs].size(); ++offset) {
  112. if(winningRegion.isWinning(obs,offset)) {
  113. continue;
  114. }
  115. uint64_t addState = statesPerObservation[obs][offset];
  116. if(badStates.get(addState)) {
  117. continue;
  118. }
  119. storm::storage::BitVector states(pomdp.getNumberOfStates());
  120. states.set(addState);
  121. bool safeActionExists = false;
  122. for(uint64_t actionIndex = 0; actionIndex < pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[obs][0]); ++actionIndex) {
  123. if (staysInWinningRegion(states,actionIndex)) {
  124. safeActionExists = true;
  125. break;
  126. }
  127. }
  128. STORM_LOG_THROW(!safeActionExists, storm::exceptions::UnexpectedException, "Observation " << obs << " , support " << states);
  129. }
  130. }
  131. }
  132. template class WinningRegionQueryInterface<double>;
  133. template class WinningRegionQueryInterface<storm::RationalNumber>;
  134. }
  135. }