Browse Source
first version of query interface that is more accessible than the winning region itself
tempestpy_adaptions
first version of query interface that is more accessible than the winning region itself
tempestpy_adaptions
Sebastian Junges
5 years ago
2 changed files with 92 additions and 0 deletions
-
69src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp
-
23src/storm-pomdp/analysis/WinningRegionQueryInterface.h
@ -0,0 +1,69 @@ |
|||
#include "storm/storage/expressions/Expression.h"
|
|||
#include "storm-pomdp/analysis/WinningRegionQueryInterface.h"
|
|||
|
|||
|
|||
namespace storm { |
|||
namespace pomdp { |
|||
template<typename ValueType> |
|||
WinningRegionQueryInterface<ValueType>::WinningRegionQueryInterface(storm::models::sparse::Pomdp<ValueType> const& pomdp, WinningRegion const& winningRegion) : |
|||
pomdp(pomdp), winningRegion(winningRegion) { |
|||
uint64_t nrObservations = pomdp.getNrObservations(); |
|||
for (uint64_t observation = 0; observation < nrObservations; ++observation) { |
|||
statesPerObservation.push_back(std::vector<uint64_t>()); |
|||
} |
|||
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|||
statesPerObservation[pomdp.getObservation(state)].push_back(state); |
|||
} |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
bool WinningRegionQueryInterface<ValueType>::isInWinningRegion(storm::storage::BitVector const& beliefSupport) const { |
|||
STORM_LOG_ASSERT(beliefSupport.getNumberOfSetBits() > 0, "One cannot think one is literally nowhere"); |
|||
uint64_t observation = pomdp.getObservation(beliefSupport.getNextSetIndex(0)); |
|||
// TODO consider optimizations after testing.
|
|||
storm::storage::BitVector queryVector(statesPerObservation[observation].size()); |
|||
auto stateWithObsIt = statesPerObservation[observation].begin(); |
|||
uint64_t offset = 0; |
|||
for (uint64_t possibleState : beliefSupport) { |
|||
STORM_LOG_ASSERT(pomdp.getObservation(possibleState) == observation, "Support must be observation-consistent"); |
|||
while(possibleState > *stateWithObsIt) { |
|||
stateWithObsIt++; |
|||
offset++; |
|||
} |
|||
if (possibleState == *stateWithObsIt) { |
|||
queryVector.set(offset); |
|||
} |
|||
} |
|||
return winningRegion.query(observation, queryVector); |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
bool WinningRegionQueryInterface<ValueType>::staysInWinningRegion(storm::storage::BitVector const& currentBeliefSupport, uint64_t actionIndex) const { |
|||
std::map<uint32_t, storm::storage::BitVector> successors; |
|||
|
|||
for (uint64_t oldState : currentBeliefSupport) { |
|||
uint64_t row = pomdp.getTransitionMatrix().getRowGroupIndices()[oldState] + actionIndex; |
|||
for (auto const& successor : pomdp.getTransitionMatrix().getRow(row)) { |
|||
assert(!storm::utility::isZero(successor.getValue())); |
|||
uint32_t obs = pomdp.getObservation(successor.getColumn()); |
|||
if (successors.count(obs) == 0) { |
|||
successors[obs] = storm::storage::BitVector(pomdp.getNumberOfStates()); |
|||
} |
|||
successors[obs].set(successor.getColumn(), true); |
|||
|
|||
} |
|||
} |
|||
|
|||
for (auto const& entry : successors) { |
|||
if(!isInWinningRegion(entry.second)) { |
|||
return false; |
|||
} |
|||
} |
|||
return true; |
|||
} |
|||
|
|||
template class WinningRegionQueryInterface<double>; |
|||
} |
|||
|
|||
} |
|||
|
@ -0,0 +1,23 @@ |
|||
#include "storm/models/sparse/Pomdp.h" |
|||
#include "storm-pomdp/analysis/WinningRegion.h" |
|||
|
|||
|
|||
namespace storm { |
|||
namespace pomdp { |
|||
template<typename ValueType> |
|||
class WinningRegionQueryInterface { |
|||
public: |
|||
WinningRegionQueryInterface(storm::models::sparse::Pomdp<ValueType> const& pomdp, WinningRegion const& winningRegion); |
|||
|
|||
bool isInWinningRegion(storm::storage::BitVector const& beliefSupport) const; |
|||
|
|||
bool staysInWinningRegion(storm::storage::BitVector const& beliefSupport, uint64_t actionIndex) const; |
|||
|
|||
private: |
|||
storm::models::sparse::Pomdp<ValueType> const& pomdp; |
|||
WinningRegion const& winningRegion; |
|||
// TODO consider sharing this. |
|||
std::vector<std::vector<uint64_t>> statesPerObservation; |
|||
}; |
|||
} |
|||
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue