diff --git a/src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp b/src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp new file mode 100644 index 000000000..cdb3afa88 --- /dev/null +++ b/src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp @@ -0,0 +1,69 @@ +#include "storm/storage/expressions/Expression.h" +#include "storm-pomdp/analysis/WinningRegionQueryInterface.h" + + +namespace storm { + namespace pomdp { + template + WinningRegionQueryInterface::WinningRegionQueryInterface(storm::models::sparse::Pomdp const& pomdp, WinningRegion const& winningRegion) : + pomdp(pomdp), winningRegion(winningRegion) { + uint64_t nrObservations = pomdp.getNrObservations(); + for (uint64_t observation = 0; observation < nrObservations; ++observation) { + statesPerObservation.push_back(std::vector()); + } + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + statesPerObservation[pomdp.getObservation(state)].push_back(state); + } + } + + template + bool WinningRegionQueryInterface::isInWinningRegion(storm::storage::BitVector const& beliefSupport) const { + STORM_LOG_ASSERT(beliefSupport.getNumberOfSetBits() > 0, "One cannot think one is literally nowhere"); + uint64_t observation = pomdp.getObservation(beliefSupport.getNextSetIndex(0)); + // TODO consider optimizations after testing. + storm::storage::BitVector queryVector(statesPerObservation[observation].size()); + auto stateWithObsIt = statesPerObservation[observation].begin(); + uint64_t offset = 0; + for (uint64_t possibleState : beliefSupport) { + STORM_LOG_ASSERT(pomdp.getObservation(possibleState) == observation, "Support must be observation-consistent"); + while(possibleState > *stateWithObsIt) { + stateWithObsIt++; + offset++; + } + if (possibleState == *stateWithObsIt) { + queryVector.set(offset); + } + } + return winningRegion.query(observation, queryVector); + } + + template + bool WinningRegionQueryInterface::staysInWinningRegion(storm::storage::BitVector const& currentBeliefSupport, uint64_t actionIndex) const { + std::map successors; + + for (uint64_t oldState : currentBeliefSupport) { + uint64_t row = pomdp.getTransitionMatrix().getRowGroupIndices()[oldState] + actionIndex; + for (auto const& successor : pomdp.getTransitionMatrix().getRow(row)) { + assert(!storm::utility::isZero(successor.getValue())); + uint32_t obs = pomdp.getObservation(successor.getColumn()); + if (successors.count(obs) == 0) { + successors[obs] = storm::storage::BitVector(pomdp.getNumberOfStates()); + } + successors[obs].set(successor.getColumn(), true); + + } + } + + for (auto const& entry : successors) { + if(!isInWinningRegion(entry.second)) { + return false; + } + } + return true; + } + + template class WinningRegionQueryInterface; + } + +} + diff --git a/src/storm-pomdp/analysis/WinningRegionQueryInterface.h b/src/storm-pomdp/analysis/WinningRegionQueryInterface.h new file mode 100644 index 000000000..4dad3193d --- /dev/null +++ b/src/storm-pomdp/analysis/WinningRegionQueryInterface.h @@ -0,0 +1,23 @@ +#include "storm/models/sparse/Pomdp.h" +#include "storm-pomdp/analysis/WinningRegion.h" + + +namespace storm { + namespace pomdp { + template + class WinningRegionQueryInterface { + public: + WinningRegionQueryInterface(storm::models::sparse::Pomdp const& pomdp, WinningRegion const& winningRegion); + + bool isInWinningRegion(storm::storage::BitVector const& beliefSupport) const; + + bool staysInWinningRegion(storm::storage::BitVector const& beliefSupport, uint64_t actionIndex) const; + + private: + storm::models::sparse::Pomdp const& pomdp; + WinningRegion const& winningRegion; + // TODO consider sharing this. + std::vector> statesPerObservation; + }; + } +} \ No newline at end of file