From f00a208e9cbccfaf63ce146f47bf4585caef1fa3 Mon Sep 17 00:00:00 2001 From: Sebastian Junges Date: Sun, 10 May 2020 18:24:29 -0700 Subject: [PATCH] validate whether a winning region is maximal --- .../MemlessStrategySearchQualitative.cpp | 12 ++-- src/storm-pomdp/analysis/WinningRegion.h | 5 ++ .../analysis/WinningRegionQueryInterface.cpp | 66 +++++++++++++++++-- .../analysis/WinningRegionQueryInterface.h | 4 +- 4 files changed, 78 insertions(+), 9 deletions(-) diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp index 6552071ef..5326cacf5 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp @@ -75,11 +75,10 @@ namespace storm { nrStatesPerObservation.push_back(states.size()); } winningRegion = WinningRegion(nrStatesPerObservation); - if(options.validateEveryStep) { + if(options.validateResult || options.validateEveryStep) { STORM_LOG_WARN("The validator should only be created when the option is set."); validator = std::make_shared>(pomdp, winningRegion); } - } template @@ -638,14 +637,13 @@ namespace storm { } } - // TODO temporarily switched off due to intiialization issues when restartin. STORM_LOG_ASSERT(!updated.empty(), "The strategy should be new in at least one place"); if(options.computeDebugOutput()) { winningRegion.print(); } if(options.validateEveryStep) { STORM_LOG_WARN("Validating every step, for debug purposes only!"); - validator->validate(); + validator->validate(surelyReachSinkStates); } stats.updateNewStrategySolverTime.start(); for(uint64_t observation : updated) { @@ -686,6 +684,12 @@ namespace storm { STORM_LOG_INFO("... after iteration " << stats.getIterations() << " so far " << stats.getChecks() << " checks." ); } + if(options.validateResult) { + STORM_LOG_WARN("Validating result is a winning region, only for debugging purposes."); + validator->validate(surelyReachSinkStates); + STORM_LOG_WARN("Validating result is a maximal winning region, only for debugging purposes."); + validator->validateIsMaximal(surelyReachSinkStates); + } winningRegion.print(); if (!allOfTheseStates.empty()) { diff --git a/src/storm-pomdp/analysis/WinningRegion.h b/src/storm-pomdp/analysis/WinningRegion.h index 7b95fed85..aaf51d839 100644 --- a/src/storm-pomdp/analysis/WinningRegion.h +++ b/src/storm-pomdp/analysis/WinningRegion.h @@ -11,6 +11,11 @@ namespace storm { bool update(uint64_t observation, storm::storage::BitVector const& winning); bool query(uint64_t observation, storm::storage::BitVector const& currently) const; + bool isWinning(uint64_t observation, uint64_t offset) const { + storm::storage::BitVector currently(observationSizes[observation]); + currently.set(offset); + return query(observation,currently); + } std::vector const& getWinningSetsPerObservation(uint64_t observation) const; diff --git a/src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp b/src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp index 167a09e26..79db31a74 100644 --- a/src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp +++ b/src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp @@ -1,3 +1,4 @@ +#include #include "storm/storage/expressions/Expression.h" #include "storm-pomdp/analysis/WinningRegionQueryInterface.h" @@ -39,6 +40,7 @@ namespace storm { template bool WinningRegionQueryInterface::staysInWinningRegion(storm::storage::BitVector const& currentBeliefSupport, uint64_t actionIndex) const { + STORM_LOG_ASSERT(currentBeliefSupport.getNumberOfSetBits() > 0, "One cannot think one is literally nowhere"); std::map successors; STORM_LOG_DEBUG("Stays in winning region? (" << currentBeliefSupport << ", " << actionIndex << ")"); for (uint64_t oldState : currentBeliefSupport) { @@ -57,17 +59,17 @@ namespace storm { for (auto const& entry : successors) { if(!isInWinningRegion(entry.second)) { - STORM_LOG_DEBUG("Belief support " << entry.second << " is not winning"); + STORM_LOG_DEBUG("Belief support " << entry.second << " (obs " << entry.first << ") is not winning"); return false; } else { - STORM_LOG_DEBUG("Belief support " << entry.second << " is winning"); + STORM_LOG_DEBUG("Belief support " << entry.second << " (obs " << entry.first << ") is winning"); } } return true; } template - void WinningRegionQueryInterface::validate() const { + void WinningRegionQueryInterface::validate(storm::storage::BitVector const& badStates) const { for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { for(auto const& winningBelief : winningRegion.getWinningSetsPerObservation(obs)) { storm::storage::BitVector states(pomdp.getNumberOfStates()); @@ -81,7 +83,63 @@ namespace storm { break; } } - STORM_LOG_ASSERT(safeActionExists, "Observation " << obs << " with associated states: " << statesPerObservation[obs] << " , support " << states); + STORM_LOG_THROW(safeActionExists, storm::exceptions::UnexpectedException, "Observation " << obs << " , support " << states); + } + } + } + + template + void WinningRegionQueryInterface::validateIsMaximal(storm::storage::BitVector const& badStates) const { + for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { + STORM_LOG_DEBUG("Check listed belief supports for observation " << obs << " are maximal"); + for(auto const& winningBelief : winningRegion.getWinningSetsPerObservation(obs)) { + storm::storage::BitVector remainders = ~winningBelief; + for(auto const& additional : remainders) { + uint64_t addState = statesPerObservation[obs][additional]; + if (badStates.get(addState)) { + continue; + } + + storm::storage::BitVector states(pomdp.getNumberOfStates()); + for (uint64_t offset : winningBelief) { + states.set(statesPerObservation[obs][offset]); + } + states.set(statesPerObservation[obs][additional]); + assert(states.getNumberOfSetBits() == winningBelief.getNumberOfSetBits() + 1); + + bool safeActionExists = false; + for(uint64_t actionIndex = 0; actionIndex < pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[obs][0]); ++actionIndex) { + if (staysInWinningRegion(states,actionIndex)) { + STORM_LOG_DEBUG("Action " << actionIndex << " from " << states << " is safe. "); + safeActionExists = true; + break; + } + } + + STORM_LOG_THROW(!safeActionExists,storm::exceptions::UnexpectedException, "Observation " << obs << ", support " << states); + } + } + STORM_LOG_DEBUG("All listed belief supports for observation " << obs << " are maximal. Continue with single states."); + + for (uint64_t offset = 0; offset < statesPerObservation[obs].size(); ++offset) { + if(winningRegion.isWinning(obs,offset)) { + continue; + } + uint64_t addState = statesPerObservation[obs][offset]; + if(badStates.get(addState)) { + continue; + } + storm::storage::BitVector states(pomdp.getNumberOfStates()); + states.set(addState); + bool safeActionExists = false; + for(uint64_t actionIndex = 0; actionIndex < pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[obs][0]); ++actionIndex) { + if (staysInWinningRegion(states,actionIndex)) { + safeActionExists = true; + break; + } + } + + STORM_LOG_THROW(!safeActionExists, storm::exceptions::UnexpectedException, "Observation " << obs << " , support " << states); } } } diff --git a/src/storm-pomdp/analysis/WinningRegionQueryInterface.h b/src/storm-pomdp/analysis/WinningRegionQueryInterface.h index 44b9cc53e..3ccf45478 100644 --- a/src/storm-pomdp/analysis/WinningRegionQueryInterface.h +++ b/src/storm-pomdp/analysis/WinningRegionQueryInterface.h @@ -14,7 +14,9 @@ namespace storm { bool staysInWinningRegion(storm::storage::BitVector const& beliefSupport, uint64_t actionIndex) const; - void validate() const; + void validate(storm::storage::BitVector const& badStates) const; + + void validateIsMaximal(storm::storage::BitVector const& badStates) const; private: storm::models::sparse::Pomdp const& pomdp; WinningRegion const& winningRegion;