|
|
@ -1,3 +1,4 @@ |
|
|
|
#include <storm/exceptions/UnexpectedException.h>
|
|
|
|
#include "storm/storage/expressions/Expression.h"
|
|
|
|
#include "storm-pomdp/analysis/WinningRegionQueryInterface.h"
|
|
|
|
|
|
|
@ -39,6 +40,7 @@ namespace storm { |
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
|
bool WinningRegionQueryInterface<ValueType>::staysInWinningRegion(storm::storage::BitVector const& currentBeliefSupport, uint64_t actionIndex) const { |
|
|
|
STORM_LOG_ASSERT(currentBeliefSupport.getNumberOfSetBits() > 0, "One cannot think one is literally nowhere"); |
|
|
|
std::map<uint32_t, storm::storage::BitVector> successors; |
|
|
|
STORM_LOG_DEBUG("Stays in winning region? (" << currentBeliefSupport << ", " << actionIndex << ")"); |
|
|
|
for (uint64_t oldState : currentBeliefSupport) { |
|
|
@ -57,17 +59,17 @@ namespace storm { |
|
|
|
for (auto const& entry : successors) { |
|
|
|
|
|
|
|
if(!isInWinningRegion(entry.second)) { |
|
|
|
STORM_LOG_DEBUG("Belief support " << entry.second << " is not winning"); |
|
|
|
STORM_LOG_DEBUG("Belief support " << entry.second << " (obs " << entry.first << ") is not winning"); |
|
|
|
return false; |
|
|
|
} else { |
|
|
|
STORM_LOG_DEBUG("Belief support " << entry.second << " is winning"); |
|
|
|
STORM_LOG_DEBUG("Belief support " << entry.second << " (obs " << entry.first << ") is winning"); |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
|
void WinningRegionQueryInterface<ValueType>::validate() const { |
|
|
|
void WinningRegionQueryInterface<ValueType>::validate(storm::storage::BitVector const& badStates) const { |
|
|
|
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { |
|
|
|
for(auto const& winningBelief : winningRegion.getWinningSetsPerObservation(obs)) { |
|
|
|
storm::storage::BitVector states(pomdp.getNumberOfStates()); |
|
|
@ -81,7 +83,63 @@ namespace storm { |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
STORM_LOG_ASSERT(safeActionExists, "Observation " << obs << " with associated states: " << statesPerObservation[obs] << " , support " << states); |
|
|
|
STORM_LOG_THROW(safeActionExists, storm::exceptions::UnexpectedException, "Observation " << obs << " , support " << states); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
|
void WinningRegionQueryInterface<ValueType>::validateIsMaximal(storm::storage::BitVector const& badStates) const { |
|
|
|
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { |
|
|
|
STORM_LOG_DEBUG("Check listed belief supports for observation " << obs << " are maximal"); |
|
|
|
for(auto const& winningBelief : winningRegion.getWinningSetsPerObservation(obs)) { |
|
|
|
storm::storage::BitVector remainders = ~winningBelief; |
|
|
|
for(auto const& additional : remainders) { |
|
|
|
uint64_t addState = statesPerObservation[obs][additional]; |
|
|
|
if (badStates.get(addState)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
storm::storage::BitVector states(pomdp.getNumberOfStates()); |
|
|
|
for (uint64_t offset : winningBelief) { |
|
|
|
states.set(statesPerObservation[obs][offset]); |
|
|
|
} |
|
|
|
states.set(statesPerObservation[obs][additional]); |
|
|
|
assert(states.getNumberOfSetBits() == winningBelief.getNumberOfSetBits() + 1); |
|
|
|
|
|
|
|
bool safeActionExists = false; |
|
|
|
for(uint64_t actionIndex = 0; actionIndex < pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[obs][0]); ++actionIndex) { |
|
|
|
if (staysInWinningRegion(states,actionIndex)) { |
|
|
|
STORM_LOG_DEBUG("Action " << actionIndex << " from " << states << " is safe. "); |
|
|
|
safeActionExists = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
STORM_LOG_THROW(!safeActionExists,storm::exceptions::UnexpectedException, "Observation " << obs << ", support " << states); |
|
|
|
} |
|
|
|
} |
|
|
|
STORM_LOG_DEBUG("All listed belief supports for observation " << obs << " are maximal. Continue with single states."); |
|
|
|
|
|
|
|
for (uint64_t offset = 0; offset < statesPerObservation[obs].size(); ++offset) { |
|
|
|
if(winningRegion.isWinning(obs,offset)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
uint64_t addState = statesPerObservation[obs][offset]; |
|
|
|
if(badStates.get(addState)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
storm::storage::BitVector states(pomdp.getNumberOfStates()); |
|
|
|
states.set(addState); |
|
|
|
bool safeActionExists = false; |
|
|
|
for(uint64_t actionIndex = 0; actionIndex < pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[obs][0]); ++actionIndex) { |
|
|
|
if (staysInWinningRegion(states,actionIndex)) { |
|
|
|
safeActionExists = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
STORM_LOG_THROW(!safeActionExists, storm::exceptions::UnexpectedException, "Observation " << obs << " , support " << states); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|