You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

157 lines
8.4 KiB

  1. #include "storm/exceptions/InvalidArgumentException.h"
  2. #include "storm-pomdp/transformer/ObservationTraceUnfolder.h"
  3. namespace storm {
  4. namespace pomdp {
  5. template<typename ValueType>
  6. ObservationTraceUnfolder<ValueType>::ObservationTraceUnfolder(storm::models::sparse::Pomdp<ValueType> const& model) : model(model) {
  7. statesPerObservation = std::vector<storm::storage::BitVector>(model.getNrObservations(), storm::storage::BitVector(model.getNumberOfStates()));
  8. for (uint64_t state = 0; state < model.getNumberOfStates(); ++state) {
  9. statesPerObservation[model.getObservation(state)].set(state, true);
  10. }
  11. }
  12. template<typename ValueType>
  13. std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<ValueType>::transform(
  14. const std::vector<uint32_t> &observations, std::vector<ValueType> const& risk) {
  15. std::vector<uint32_t> modifiedObservations = observations;
  16. // First observation should be special.
  17. // This just makes the algorithm simpler because we do not treat the first step as a special case later.
  18. modifiedObservations[0] = model.getNrObservations();
  19. storm::storage::BitVector initialStates = model.getInitialStates();
  20. storm::storage::BitVector actualInitialStates = initialStates;
  21. for (uint64_t state : initialStates) {
  22. if (model.getObservation(state) != observations[0]) {
  23. actualInitialStates.set(state, false);
  24. }
  25. }
  26. STORM_LOG_THROW(actualInitialStates.getNumberOfSetBits() == 1, storm::exceptions::InvalidArgumentException, "Must have unique initial state matching the observation");
  27. //
  28. statesPerObservation.resize(model.getNrObservations() + 1);
  29. statesPerObservation[model.getNrObservations()] = actualInitialStates;
  30. std::map<uint64_t,uint64_t> unfoldedToOld;
  31. std::map<uint64_t,uint64_t> unfoldedToOldNextStep;
  32. std::map<uint64_t,uint64_t> oldToUnfolded;
  33. // Add this initial state state:
  34. unfoldedToOldNextStep[0] = actualInitialStates.getNextSetIndex(0);
  35. storm::storage::SparseMatrixBuilder<ValueType> transitionMatrixBuilder(0,0,0,true,true);
  36. uint64_t newStateIndex = 1;
  37. uint64_t newRowGroupStart = 0;
  38. uint64_t newRowCount = 0;
  39. // Notice that we are going to use a special last step
  40. for (uint64_t step = 0; step < observations.size() - 1; ++step) {
  41. std::cout << "step " << step << std::endl;
  42. oldToUnfolded.clear();
  43. unfoldedToOld = unfoldedToOldNextStep;
  44. unfoldedToOldNextStep.clear();
  45. for (auto const& unfoldedToOldEntry : unfoldedToOld) {
  46. transitionMatrixBuilder.newRowGroup(newRowGroupStart);
  47. std::cout << "\tconsider new state " << unfoldedToOldEntry.first << std::endl;
  48. assert(step == 0 || newRowCount == transitionMatrixBuilder.getLastRow() + 1);
  49. uint64_t oldRowIndexStart = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second];
  50. uint64_t oldRowIndexEnd = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second+1];
  51. for (uint64_t oldRowIndex = oldRowIndexStart; oldRowIndex != oldRowIndexEnd; oldRowIndex++) {
  52. std::cout << "\t\tconsider old action " << oldRowIndex << std::endl;
  53. std::cout << "\t\tconsider new row nr " << newRowCount << std::endl;
  54. ValueType resetProb = storm::utility::zero<ValueType>();
  55. // We first find the reset probability
  56. for (auto const &oldRowEntry : model.getTransitionMatrix().getRow(oldRowIndex)) {
  57. if (model.getObservation(oldRowEntry.getColumn()) != observations[step + 1]) {
  58. resetProb += oldRowEntry.getValue();
  59. }
  60. }
  61. std::cout << "\t\t\t add reset" << std::endl;
  62. // Add the resets
  63. if (resetProb != storm::utility::zero<ValueType>()) {
  64. transitionMatrixBuilder.addNextValue(newRowCount, 0, resetProb);
  65. }
  66. std::cout << "\t\t\t add other transitions..." << std::endl;
  67. // Now, we build the outgoing transitions.
  68. for (auto const &oldRowEntry : model.getTransitionMatrix().getRow(oldRowIndex)) {
  69. if (model.getObservation(oldRowEntry.getColumn()) != observations[step + 1]) {
  70. continue;// already handled.
  71. }
  72. uint64_t column = 0;
  73. auto entryIt = oldToUnfolded.find(oldRowEntry.getColumn());
  74. if (entryIt == oldToUnfolded.end()) {
  75. column = newStateIndex;
  76. oldToUnfolded[oldRowEntry.getColumn()] = column;
  77. unfoldedToOldNextStep[column] = oldRowEntry.getColumn();
  78. newStateIndex++;
  79. } else {
  80. column = entryIt->second;
  81. }
  82. std::cout << "\t\t\t\t transition to " << column << std::endl;
  83. transitionMatrixBuilder.addNextValue(newRowCount, column,
  84. oldRowEntry.getValue());
  85. }
  86. newRowCount++;
  87. }
  88. newRowGroupStart = transitionMatrixBuilder.getLastRow() + 1;
  89. }
  90. }
  91. std::cout << "Adding last step..." << std::endl;
  92. // Now, take care of the last step.
  93. uint64_t sinkState = newStateIndex;
  94. uint64_t targetState = newStateIndex + 1;
  95. for (auto const& unfoldedToOldEntry : unfoldedToOldNextStep) {
  96. transitionMatrixBuilder.newRowGroup(newRowGroupStart);
  97. if (!storm::utility::isZero(storm::utility::one<ValueType>() - risk[unfoldedToOldEntry.second])) {
  98. transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState,
  99. storm::utility::one<ValueType>() - risk[unfoldedToOldEntry.second]);
  100. }
  101. if (!storm::utility::isZero(risk[unfoldedToOldEntry.second])) {
  102. transitionMatrixBuilder.addNextValue(newRowGroupStart, targetState,
  103. risk[unfoldedToOldEntry.second]);
  104. }
  105. newRowGroupStart++;
  106. }
  107. // sink state
  108. transitionMatrixBuilder.newRowGroup(newRowGroupStart);
  109. transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState, storm::utility::one<ValueType>());
  110. newRowGroupStart++;
  111. transitionMatrixBuilder.newRowGroup(newRowGroupStart);
  112. // target state
  113. transitionMatrixBuilder.addNextValue(newRowGroupStart, targetState, storm::utility::one<ValueType>());
  114. storm::storage::sparse::ModelComponents<ValueType> components;
  115. components.transitionMatrix = transitionMatrixBuilder.build();
  116. std::cout << components.transitionMatrix << std::endl;
  117. STORM_LOG_ASSERT(components.transitionMatrix.getRowGroupCount() == targetState + 1, "Expect row group count (" << components.transitionMatrix.getRowGroupCount() << ") one more as target state index " << targetState << ")");
  118. storm::models::sparse::StateLabeling labeling(components.transitionMatrix.getRowGroupCount());
  119. labeling.addLabel("_goal");
  120. labeling.addLabelToState("_goal", targetState);
  121. labeling.addLabel("init");
  122. labeling.addLabelToState("init", 0);
  123. components.stateLabeling = labeling;
  124. return std::make_shared<storm::models::sparse::Mdp<ValueType>>(std::move(components));
  125. }
  126. template class ObservationTraceUnfolder<double>;
  127. template class ObservationTraceUnfolder<storm::RationalFunction>;
  128. }
  129. }