diff --git a/src/storm-pomdp-cli/storm-pomdp.cpp b/src/storm-pomdp-cli/storm-pomdp.cpp index ce4ddd951..41b07c351 100644 --- a/src/storm-pomdp-cli/storm-pomdp.cpp +++ b/src/storm-pomdp-cli/storm-pomdp.cpp @@ -108,6 +108,7 @@ namespace storm { } else if(storm::utility::getLogLevel() == l3pp::LogLevel::TRACE) { loglevel = 3; + options.validateEveryStep = true; } options.setDebugLevel(loglevel); diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp index d09410545..744da8a95 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp @@ -75,6 +75,10 @@ namespace storm { nrStatesPerObservation.push_back(states.size()); } winningRegion = WinningRegion(nrStatesPerObservation); + if(options.validateEveryStep) { + STORM_LOG_WARN("The validator should only be created when the option is set."); + validator = std::make_shared>(pomdp, winningRegion); + } } @@ -586,14 +590,13 @@ namespace storm { } // TODO temporarily switched off due to intiialization issues when restartin. STORM_LOG_ASSERT(!updated.empty(), "The strategy should be new in at least one place"); - - - - if(options.computeDebugOutput()) { winningRegion.print(); } - + if(options.validateEveryStep) { + STORM_LOG_WARN("Validating every step, for debug purposes only!"); + validator->validate(); + } stats.updateNewStrategySolverTime.start(); for(uint64_t observation : updated) { updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]); diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h index acdb9ac6e..c3cb08bad 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h @@ -8,6 +8,7 @@ #include "storm/exceptions/UnexpectedException.h" #include "storm-pomdp/analysis/WinningRegion.h" +#include "storm-pomdp/analysis/WinningRegionQueryInterface.h" namespace storm { namespace pomdp { @@ -45,6 +46,7 @@ namespace pomdp { bool onlyDeterministicStrategies = false; bool forceLookahead = true; + bool validateEveryStep = false; private: std::string exportSATcalls = ""; @@ -236,6 +238,7 @@ namespace pomdp { Statistics stats; std::shared_ptr& smtSolverFactory; + std::shared_ptr> validator; }; diff --git a/src/storm-pomdp/analysis/WinningRegion.cpp b/src/storm-pomdp/analysis/WinningRegion.cpp index 1fd3ae950..c34ec262b 100644 --- a/src/storm-pomdp/analysis/WinningRegion.cpp +++ b/src/storm-pomdp/analysis/WinningRegion.cpp @@ -130,9 +130,25 @@ namespace pomdp { * @return */ uint64_t WinningRegion::getNumberOfObservations() const { + assert(winningRegion.size() == observationSizes.size()); return observationSizes.size(); } + bool WinningRegion::empty() const { + for (auto const& ob : winningRegion) { + if (!ob.empty()) { + return false; + } + } + return true; + } + + std::vector const& WinningRegion::getWinningSetsPerObservation(uint64_t observation) const { + + assert(observation < getNumberOfObservations()); + return winningRegion[observation]; + } + uint64_t WinningRegion::getStorageSize() const { uint64_t result = 0; for (uint64_t i = 0; i < getNumberOfObservations(); ++i) { diff --git a/src/storm-pomdp/analysis/WinningRegion.h b/src/storm-pomdp/analysis/WinningRegion.h index 29e3f511f..23c294a3c 100644 --- a/src/storm-pomdp/analysis/WinningRegion.h +++ b/src/storm-pomdp/analysis/WinningRegion.h @@ -12,6 +12,8 @@ namespace storm { bool update(uint64_t observation, storm::storage::BitVector const& winning); bool query(uint64_t observation, storm::storage::BitVector const& currently) const; + std::vector const& getWinningSetsPerObservation(uint64_t observation) const; + void setObservationIsWinning(uint64_t observation); bool observationIsWinning(uint64_t observation) const; @@ -20,6 +22,7 @@ namespace storm { uint64_t getStorageSize() const; uint64_t getNumberOfObservations() const; + bool empty() const; void print() const; private: std::vector> winningRegion; diff --git a/src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp b/src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp index cdb3afa88..167a09e26 100644 --- a/src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp +++ b/src/storm-pomdp/analysis/WinningRegionQueryInterface.cpp @@ -40,7 +40,7 @@ namespace storm { template bool WinningRegionQueryInterface::staysInWinningRegion(storm::storage::BitVector const& currentBeliefSupport, uint64_t actionIndex) const { std::map successors; - + STORM_LOG_DEBUG("Stays in winning region? (" << currentBeliefSupport << ", " << actionIndex << ")"); for (uint64_t oldState : currentBeliefSupport) { uint64_t row = pomdp.getTransitionMatrix().getRowGroupIndices()[oldState] + actionIndex; for (auto const& successor : pomdp.getTransitionMatrix().getRow(row)) { @@ -55,14 +55,39 @@ namespace storm { } for (auto const& entry : successors) { + if(!isInWinningRegion(entry.second)) { + STORM_LOG_DEBUG("Belief support " << entry.second << " is not winning"); return false; + } else { + STORM_LOG_DEBUG("Belief support " << entry.second << " is winning"); } } return true; } + template + void WinningRegionQueryInterface::validate() const { + for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { + for(auto const& winningBelief : winningRegion.getWinningSetsPerObservation(obs)) { + storm::storage::BitVector states(pomdp.getNumberOfStates()); + for (uint64_t offset : winningBelief) { + states.set(statesPerObservation[obs][offset]); + } + bool safeActionExists = false; + for(uint64_t actionIndex = 0; actionIndex < pomdp.getTransitionMatrix().getRowGroupSize(statesPerObservation[obs][0]); ++actionIndex) { + if (staysInWinningRegion(states,actionIndex)) { + safeActionExists = true; + break; + } + } + STORM_LOG_ASSERT(safeActionExists, "Observation " << obs << " with associated states: " << statesPerObservation[obs] << " , support " << states); + } + } + } + template class WinningRegionQueryInterface; + template class WinningRegionQueryInterface; } } diff --git a/src/storm-pomdp/analysis/WinningRegionQueryInterface.h b/src/storm-pomdp/analysis/WinningRegionQueryInterface.h index 4dad3193d..44b9cc53e 100644 --- a/src/storm-pomdp/analysis/WinningRegionQueryInterface.h +++ b/src/storm-pomdp/analysis/WinningRegionQueryInterface.h @@ -1,3 +1,4 @@ +#pragma once #include "storm/models/sparse/Pomdp.h" #include "storm-pomdp/analysis/WinningRegion.h" @@ -13,6 +14,7 @@ namespace storm { bool staysInWinningRegion(storm::storage::BitVector const& beliefSupport, uint64_t actionIndex) const; + void validate() const; private: storm::models::sparse::Pomdp const& pomdp; WinningRegion const& winningRegion;