diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp index 15e19fab9..3788a2ac7 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp @@ -4,6 +4,35 @@ namespace storm { namespace pomdp { + template + MemlessStrategySearchQualitative::MemlessStrategySearchQualitative(storm::models::sparse::Pomdp const& pomdp, + std::set const& targetObservationSet, + storm::storage::BitVector const& targetStates, + storm::storage::BitVector const& surelyReachSinkStates, + std::shared_ptr& smtSolverFactory) : + pomdp(pomdp), + targetStates(targetStates), + surelyReachSinkStates(surelyReachSinkStates), + targetObservations(targetObservationSet) + { + this->expressionManager = std::make_shared(); + smtSolver = smtSolverFactory->create(*expressionManager); + // Initialize states per observation. + for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { + statesPerObservation.push_back(std::vector()); // Consider using bitvectors instead. + } + uint64_t state = 0; + for (auto obs : pomdp.getObservations()) { + statesPerObservation.at(obs).push_back(state++); + } + // Initialize winning region + std::vector nrStatesPerObservation; + for (auto const &states : statesPerObservation) { + nrStatesPerObservation.push_back(states.size()); + } + winningRegion = WinningRegion(nrStatesPerObservation); + } + template void MemlessStrategySearchQualitative::initialize(uint64_t k) { if (maxK == std::numeric_limits::max()) { @@ -12,14 +41,13 @@ namespace storm { for(uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { actionSelectionVars.push_back(std::vector()); actionSelectionVarExpressions.push_back(std::vector()); - statesPerObservation.push_back(std::vector()); // Consider using bitvectors instead. } // Fill the states-per-observation mapping, // declare the reachability variables, // declare the path variables. uint64_t stateId = 0; - for(auto obs : pomdp.getObservations()) { + for(uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { pathVars.push_back(std::vector()); for (uint64_t i = 0; i < k; ++i) { pathVars.back().push_back(expressionManager->declareBooleanVariable("P-"+std::to_string(stateId)+"-"+std::to_string(i)).getExpression()); @@ -28,7 +56,6 @@ namespace storm { reachVarExpressions.push_back(reachVars.back().getExpression()); continuationVars.push_back(expressionManager->declareBooleanVariable("D-" + std::to_string(stateId))); continuationVarExpressions.push_back(continuationVars.back().getExpression()); - statesPerObservation.at(obs).push_back(stateId++); } assert(pathVars.size() == pomdp.getNumberOfStates()); assert(reachVars.size() == pomdp.getNumberOfStates()); @@ -115,14 +142,6 @@ namespace storm { for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { std::vector subexprreach; - -// subexprreach.push_back(!reachVarExpressions.at(state)); -// subexprreach.push_back(!actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action)); -// subexprreach.push_back(!switchVarExpressions[pomdp.getObservation(state)]); -// for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { -// subexprreach.push_back(reachVarExpressions.at(entries.getColumn())); -// } -// smtSolver->add(storm::expressions::disjunction(subexprreach)); for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { for (uint64_t j = 1; j < k; ++j) { pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]); @@ -152,14 +171,7 @@ namespace storm { ++obs; } - // These constraints ensure that the right solver is used. -// obs = 0; -// for(auto const& statesForObservation : statesPerObservation) { -// smtSolver->add(schedulerVariableExpressions[obs] >= schedulerForObs.size()); -// ++obs; -// } - - // TODO updateFoundSchedulers(); + // TODO: Update found schedulers if k is increased. } template @@ -204,7 +216,6 @@ namespace storm { uint64_t iterations = 0; while(true) { scheduler.clear(); - observations.clear(); observationsAfterSwitch.clear(); remainingstates.clear(); @@ -262,10 +273,10 @@ namespace storm { } } + // TODO do not repush everyting to the solver. std::vector schedulerSoFar; uint64_t obs = 0; for (auto const &actionSelectionVarsForObs : actionSelectionVars) { - uint64_t act = 0; scheduler.actions.push_back(std::set()); if (observations.get(obs)) { for (uint64_t act = 0; act < actionSelectionVarsForObs.size(); ++act) { @@ -326,6 +337,18 @@ namespace storm { remainingExpressions.push_back(reachVarExpressions[index]); } + for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { + storm::storage::BitVector update = storm::storage::BitVector(statesPerObservation[observation].size()); + uint64_t i = 0; + for (uint64_t state : statesPerObservation[observation]) { + if (!remainingstates.get(state)) { + update.set(i); + } + } + winningRegion.update(observation, update); + ++i; + } + smtSolver->add(storm::expressions::disjunction(remainingExpressions)); uint64_t obs = 0; diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h index c3a07e6dd..900cdc0d7 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h @@ -5,6 +5,8 @@ #include "storm/utility/solver.h" #include "storm/exceptions/UnexpectedException.h" +#include "storm-pomdp/analysis/WinningRegion.h" + namespace storm { namespace pomdp { @@ -67,25 +69,12 @@ namespace pomdp { class MemlessStrategySearchQualitative { // Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper. - public: MemlessStrategySearchQualitative(storm::models::sparse::Pomdp const& pomdp, std::set const& targetObservationSet, storm::storage::BitVector const& targetStates, storm::storage::BitVector const& surelyReachSinkStates, - std::shared_ptr& smtSolverFactory) : - pomdp(pomdp), - targetStates(targetStates), - surelyReachSinkStates(surelyReachSinkStates), - targetObservations(targetObservationSet) { - this->expressionManager = std::make_shared(); - smtSolver = smtSolverFactory->create(*expressionManager); - - } - - void setSurelyReachSinkStates(storm::storage::BitVector const& surelyReachSink) { - surelyReachSinkStates = surelyReachSink; - } + std::shared_ptr& smtSolverFactory); void analyzeForInitialStates(uint64_t k) { analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates()); @@ -96,8 +85,6 @@ namespace pomdp { std::cout << targetStates << std::endl; std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl; analyze(k, ~surelyReachSinkStates & ~targetStates); - - } bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector()); @@ -137,6 +124,7 @@ namespace pomdp { std::vector finalSchedulers; std::vector> schedulerForObs; + WinningRegion winningRegion; diff --git a/src/storm-pomdp/analysis/WinningRegion.cpp b/src/storm-pomdp/analysis/WinningRegion.cpp new file mode 100644 index 000000000..5b0d4d889 --- /dev/null +++ b/src/storm-pomdp/analysis/WinningRegion.cpp @@ -0,0 +1,86 @@ +#include +#include "storm-pomdp/analysis/WinningRegion.h" + +namespace storm { +namespace pomdp { + WinningRegion::WinningRegion(std::vector const& observationSizes) : observationSizes(observationSizes) + { + for (uint64_t i = 0; i < observationSizes.size(); ++i) { + winningRegion.push_back(std::vector()); + } + } + + void WinningRegion::update(uint64_t observation, storm::storage::BitVector const& winning) { + std::vector newWinningSupport = std::vector(); + bool changed = false; + for (auto const& support : winningRegion[observation]) { + if (winning.isSubsetOf(support)) { + // This new winning support is already covered. + return; + } + if(support.isSubsetOf(winning)) { + // This new winning support extends the previous support, thus the previous support is now spurious + changed = true; + } else { + newWinningSupport.push_back(support); + } + } + + // only if changed. + if (changed) { + newWinningSupport.push_back(winning); + winningRegion[observation] = newWinningSupport; + } else { + winningRegion[observation].push_back(winning); + } + + } + + bool WinningRegion::query(uint64_t observation, storm::storage::BitVector const& currently) const { + for(storm::storage::BitVector winning : winningRegion[observation]) { + if(currently.isSubsetOf(winning)) { + return true; + } + } + return false; + } + + /** + * If we observe this observation, do we surely win? + * @param observation + * @return yes, if all supports for this observation are winning. + */ + bool WinningRegion::observationIsWinning(uint64_t observation) const { + return winningRegion[observation].size() == 1 && winningRegion[observation].front().full(); + } + + void WinningRegion::print() const { + uint64_t observation = 0; + for (auto const& winningSupport : winningRegion) { + std::cout << "***** observation" << observation << std::endl; + for (auto const& support : winningSupport) { + std::cout << " " << support; + } + std::cout << std::endl; + } + } + + /** + * How many different observations are there? + * @return + */ + uint64_t WinningRegion::getNumberOfObservations() const { + return observationSizes.size(); + } + + uint64_t WinningRegion::getStorageSize() const { + uint64_t result = 0; + for (uint64_t i = 0; i < getNumberOfObservations(); ++i) { + result += winningRegion[i].size() * observationSizes[i]; + } + return result; + } + + +} +} \ No newline at end of file diff --git a/src/storm-pomdp/analysis/WinningRegion.h b/src/storm-pomdp/analysis/WinningRegion.h new file mode 100644 index 000000000..0356a9e94 --- /dev/null +++ b/src/storm-pomdp/analysis/WinningRegion.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include "storm/storage/BitVector.h" + +namespace storm { + namespace pomdp { + class WinningRegion { + public: + WinningRegion(std::vector const& observationSizes = {}); + + void update(uint64_t observation, storm::storage::BitVector const& winning); + bool query(uint64_t observation, storm::storage::BitVector const& currently) const; + + bool observationIsWinning(uint64_t observation) const; + + uint64_t getStorageSize() const; + uint64_t getNumberOfObservations() const; + void print() const; + private: + std::vector> winningRegion; + std::vector observationSizes; + }; + } +} \ No newline at end of file