Browse Source

validate whether a winning region is maximal

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
f00a208e9c
  1. 12
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  2. 5
      src/storm-pomdp/analysis/WinningRegion.h
  3. 66
      src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp
  4. 4
      src/storm-pomdp/analysis/WinningRegionQueryInterface.h

12
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -75,11 +75,10 @@ namespace storm {
nrStatesPerObservation.push_back(states.size()); nrStatesPerObservation.push_back(states.size());
} }
winningRegion = WinningRegion(nrStatesPerObservation); winningRegion = WinningRegion(nrStatesPerObservation);
if(options.validateEveryStep) {
if(options.validateResult || options.validateEveryStep) {
STORM_LOG_WARN("The validator should only be created when the option is set."); STORM_LOG_WARN("The validator should only be created when the option is set.");
validator = std::make_shared<WinningRegionQueryInterface<ValueType>>(pomdp, winningRegion); validator = std::make_shared<WinningRegionQueryInterface<ValueType>>(pomdp, winningRegion);
} }
} }
template <typename ValueType> template <typename ValueType>
@ -638,14 +637,13 @@ namespace storm {
} }
} }
// TODO temporarily switched off due to intiialization issues when restartin.
STORM_LOG_ASSERT(!updated.empty(), "The strategy should be new in at least one place"); STORM_LOG_ASSERT(!updated.empty(), "The strategy should be new in at least one place");
if(options.computeDebugOutput()) { if(options.computeDebugOutput()) {
winningRegion.print(); winningRegion.print();
} }
if(options.validateEveryStep) { if(options.validateEveryStep) {
STORM_LOG_WARN("Validating every step, for debug purposes only!"); STORM_LOG_WARN("Validating every step, for debug purposes only!");
validator->validate();
validator->validate(surelyReachSinkStates);
} }
stats.updateNewStrategySolverTime.start(); stats.updateNewStrategySolverTime.start();
for(uint64_t observation : updated) { for(uint64_t observation : updated) {
@ -686,6 +684,12 @@ namespace storm {
STORM_LOG_INFO("... after iteration " << stats.getIterations() << " so far " << stats.getChecks() << " checks." ); STORM_LOG_INFO("... after iteration " << stats.getIterations() << " so far " << stats.getChecks() << " checks." );
} }
if(options.validateResult) {
STORM_LOG_WARN("Validating result is a winning region, only for debugging purposes.");
validator->validate(surelyReachSinkStates);
STORM_LOG_WARN("Validating result is a maximal winning region, only for debugging purposes.");
validator->validateIsMaximal(surelyReachSinkStates);
}
winningRegion.print(); winningRegion.print();
if (!allOfTheseStates.empty()) { if (!allOfTheseStates.empty()) {

5
src/storm-pomdp/analysis/WinningRegion.h

@ -11,6 +11,11 @@ namespace storm {
bool update(uint64_t observation, storm::storage::BitVector const& winning); bool update(uint64_t observation, storm::storage::BitVector const& winning);
bool query(uint64_t observation, storm::storage::BitVector const& currently) const; bool query(uint64_t observation, storm::storage::BitVector const& currently) const;
bool isWinning(uint64_t observation, uint64_t offset) const {
storm::storage::BitVector currently(observationSizes[observation]);
currently.set(offset);
return query(observation,currently);
}
std::vector<storm::storage::BitVector> const& getWinningSetsPerObservation(uint64_t observation) const; std::vector<storm::storage::BitVector> const& getWinningSetsPerObservation(uint64_t observation) const;

66
src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp

@ -1,3 +1,4 @@
#include <storm/exceptions/UnexpectedException.h>
#include "storm/storage/expressions/Expression.h" #include "storm/storage/expressions/Expression.h"
#include "storm-pomdp/analysis/WinningRegionQueryInterface.h" #include "storm-pomdp/analysis/WinningRegionQueryInterface.h"
@ -39,6 +40,7 @@ namespace storm {
template<typename ValueType> template<typename ValueType>
bool WinningRegionQueryInterface<ValueType>::staysInWinningRegion(storm::storage::BitVector const& currentBeliefSupport, uint64_t actionIndex) const { bool WinningRegionQueryInterface<ValueType>::staysInWinningRegion(storm::storage::BitVector const& currentBeliefSupport, uint64_t actionIndex) const {
STORM_LOG_ASSERT(currentBeliefSupport.getNumberOfSetBits() > 0, "One cannot think one is literally nowhere");
std::map<uint32_t, storm::storage::BitVector> successors; std::map<uint32_t, storm::storage::BitVector> successors;
STORM_LOG_DEBUG("Stays in winning region? (" << currentBeliefSupport << ", " << actionIndex << ")"); STORM_LOG_DEBUG("Stays in winning region? (" << currentBeliefSupport << ", " << actionIndex << ")");
for (uint64_t oldState : currentBeliefSupport) { for (uint64_t oldState : currentBeliefSupport) {
@ -57,17 +59,17 @@ namespace storm {
for (auto const& entry : successors) { for (auto const& entry : successors) {
if(!isInWinningRegion(entry.second)) { if(!isInWinningRegion(entry.second)) {
STORM_LOG_DEBUG("Belief support " << entry.second << " is not winning");
STORM_LOG_DEBUG("Belief support " << entry.second << " (obs " << entry.first << ") is not winning");
return false; return false;
} else { } else {
STORM_LOG_DEBUG("Belief support " << entry.second << " is winning");
STORM_LOG_DEBUG("Belief support " << entry.second << " (obs " << entry.first << ") is winning");
} }
} }
return true; return true;
} }
template<typename ValueType> template<typename ValueType>
void WinningRegionQueryInterface<ValueType>::validate() const {
void WinningRegionQueryInterface<ValueType>::validate(storm::storage::BitVector const& badStates) const {
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
for(auto const& winningBelief : winningRegion.getWinningSetsPerObservation(obs)) { for(auto const& winningBelief : winningRegion.getWinningSetsPerObservation(obs)) {
storm::storage::BitVector states(pomdp.getNumberOfStates()); storm::storage::BitVector states(pomdp.getNumberOfStates());
@ -81,7 +83,63 @@ namespace storm {
break; break;
} }
} }
STORM_LOG_ASSERT(safeActionExists, "Observation " << obs << " with associated states: " << statesPerObservation[obs] << " , support " << states);
STORM_LOG_THROW(safeActionExists, storm::exceptions::UnexpectedException, "Observation " << obs << " , support " << states);
}
}
}
template<typename ValueType>
void WinningRegionQueryInterface<ValueType>::validateIsMaximal(storm::storage::BitVector const& badStates) const {
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
STORM_LOG_DEBUG("Check listed belief supports for observation " << obs << " are maximal");
for(auto const& winningBelief : winningRegion.getWinningSetsPerObservation(obs)) {
storm::storage::BitVector remainders = ~winningBelief;
for(auto const& additional : remainders) {
uint64_t addState = statesPerObservation[obs][additional];
if (badStates.get(addState)) {
continue;
}
storm::storage::BitVector states(pomdp.getNumberOfStates());
for (uint64_t offset : winningBelief) {
states.set(statesPerObservation[obs][offset]);
}
states.set(statesPerObservation[obs][additional]);
assert(states.getNumberOfSetBits() == winningBelief.getNumberOfSetBits() + 1);
bool safeActionExists = false;
for(uint64_t actionIndex = 0; actionIndex < pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[obs][0]); ++actionIndex) {
if (staysInWinningRegion(states,actionIndex)) {
STORM_LOG_DEBUG("Action " << actionIndex << " from " << states << " is safe. ");
safeActionExists = true;
break;
}
}
STORM_LOG_THROW(!safeActionExists,storm::exceptions::UnexpectedException, "Observation " << obs << ", support " << states);
}
}
STORM_LOG_DEBUG("All listed belief supports for observation " << obs << " are maximal. Continue with single states.");
for (uint64_t offset = 0; offset < statesPerObservation[obs].size(); ++offset) {
if(winningRegion.isWinning(obs,offset)) {
continue;
}
uint64_t addState = statesPerObservation[obs][offset];
if(badStates.get(addState)) {
continue;
}
storm::storage::BitVector states(pomdp.getNumberOfStates());
states.set(addState);
bool safeActionExists = false;
for(uint64_t actionIndex = 0; actionIndex < pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[obs][0]); ++actionIndex) {
if (staysInWinningRegion(states,actionIndex)) {
safeActionExists = true;
break;
}
}
STORM_LOG_THROW(!safeActionExists, storm::exceptions::UnexpectedException, "Observation " << obs << " , support " << states);
} }
} }
} }

4
src/storm-pomdp/analysis/WinningRegionQueryInterface.h

@ -14,7 +14,9 @@ namespace storm {
bool staysInWinningRegion(storm::storage::BitVector const& beliefSupport, uint64_t actionIndex) const; bool staysInWinningRegion(storm::storage::BitVector const& beliefSupport, uint64_t actionIndex) const;
void validate() const;
void validate(storm::storage::BitVector const& badStates) const;
void validateIsMaximal(storm::storage::BitVector const& badStates) const;
private: private:
storm::models::sparse::Pomdp<ValueType> const& pomdp; storm::models::sparse::Pomdp<ValueType> const& pomdp;
WinningRegion const& winningRegion; WinningRegion const& winningRegion;

Loading…
Cancel
Save