From 556a884e74565c7e2bcf8838e91c3536f8704361 Mon Sep 17 00:00:00 2001 From: Sebastian Junges Date: Sun, 10 May 2020 18:27:14 -0700 Subject: [PATCH] use target state to initialise winning region, better timers and slight improvements in partial scheduler extension --- .../MemlessStrategySearchQualitative.cpp | 226 +++++++++++++----- .../MemlessStrategySearchQualitative.h | 11 +- src/storm-pomdp/analysis/WinningRegion.cpp | 31 ++- src/storm-pomdp/analysis/WinningRegion.h | 1 + 4 files changed, 187 insertions(+), 82 deletions(-) diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp index 5326cacf5..9130da993 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp @@ -41,9 +41,11 @@ namespace storm { STORM_PRINT_AND_LOG("SAT Calls time: " << smtCheckTimer << std::endl); STORM_PRINT_AND_LOG("Outer iterations: " << outerIterations << std::endl); STORM_PRINT_AND_LOG("Solver initialization time: " << initializeSolverTimer << std::endl); - STORM_PRINT_AND_LOG("Extend partial scheduler time: " << updateExtensionSolverTime << std::endl); + STORM_PRINT_AND_LOG("Obtain partial scheduler time: " << evaluateExtensionSolverTime << std::endl); + STORM_PRINT_AND_LOG("Update solver to extend partial scheduler time: " << encodeExtensionSolverTime << std::endl); STORM_PRINT_AND_LOG("Update solver with new scheduler time: " << updateNewStrategySolverTime << std::endl); STORM_PRINT_AND_LOG("Winning regions update time: " << winningRegionUpdatesTimer << std::endl); + STORM_PRINT_AND_LOG("Graph search time: " << graphSearchTime << std::endl); } template @@ -155,15 +157,21 @@ namespace storm { uint64_t obs = 0; - if (options.onlyDeterministicStrategies) { - for(auto const& statesForObservation : statesPerObservation) { - for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front())-1; ++a) { - for (uint64_t b = a+1; b < pomdp.getNumberOfChoices(statesForObservation.front()); ++b) { - smtSolver->add(!actionSelectionVarExpressions[obs][a] || !actionSelectionVarExpressions[obs][b]); + + for(auto const& statesForObservation : statesPerObservation) { + if ( pomdp.getNumberOfChoices(statesForObservation.front()) == 1) { + ++obs; + continue; + } + if (options.onlyDeterministicStrategies || statesForObservation.size() == 1) { + for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()) - 1; ++a) { + for (uint64_t b = a + 1; b < pomdp.getNumberOfChoices(statesForObservation.front()); ++b) { + smtSolver->add(!(actionSelectionVarExpressions[obs][a]) || + !(actionSelectionVarExpressions[obs][b])); } } - ++obs; } + ++obs; } // PAPER COMMENT: 1 @@ -302,38 +310,107 @@ namespace storm { smtSolver->add(storm::expressions::implies(switchVarExpressions[obs], storm::expressions::disjunction(reachVarExpressionsPerObservation[obs]))); } // PAPER COMMENT 10 - if (!lookaheadConstraintsRequired) { - uint64_t rowIndex = 0; - for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { - uint64_t enabledActions = pomdp.getNumberOfChoices(state); - if (!surelyReachSinkStates.get(state)) { - std::vector successorVars; - for (uint64_t act = 0; act < enabledActions; ++act) { - for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowIndex)) { - successorVars.push_back(reachVarExpressions[entries.getColumn()]); - } - rowIndex++; - } - successorVars.push_back(!switchVars[pomdp.getObservation(state)]); - smtSolver->add(storm::expressions::implies(storm::expressions::conjunction(successorVars), reachVarExpressions[state])); - } else { - rowIndex += enabledActions; - } +// if (!lookaheadConstraintsRequired) { +// uint64_t rowIndex = 0; +// for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { +// uint64_t enabledActions = pomdp.getNumberOfChoices(state); +// if (!surelyReachSinkStates.get(state)) { +// std::vector successorVars; +// for (uint64_t act = 0; act < enabledActions; ++act) { +// for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowIndex)) { +// successorVars.push_back(reachVarExpressions[entries.getColumn()]); +// } +// rowIndex++; +// } +// successorVars.push_back(!switchVars[pomdp.getObservation(state)]); +// smtSolver->add(storm::expressions::implies(storm::expressions::conjunction(successorVars), reachVarExpressions[state])); +// } else { +// rowIndex += enabledActions; +// } +// } +// } else { +// STORM_LOG_WARN("Some optimization not implemented yet."); +// } + // TODO: Update found schedulers if k is increased. + } + + template + uint64_t MemlessStrategySearchQualitative::getOffsetFromObservation(uint64_t state, uint64_t observation) const { + if(!useFindOffset) { + STORM_LOG_WARN("This code is slow and should only be used for debugging."); + useFindOffset = true; + } + uint64_t offset = 0; + for(uint64_t s : statesPerObservation[observation]) { + if (s == state) { + return offset; } - } else { - STORM_LOG_WARN("Some optimization not implemented yet."); + ++offset; } - // TODO: Update found schedulers if k is increased. + assert(false); // State should have occured. + return 0; } template bool MemlessStrategySearchQualitative::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) { + std::cout << "Surely reach sink states: " << surelyReachSinkStates << std::endl; + std::cout << "Target states " << targetStates << std::endl; + std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl; stats.initializeSolverTimer.start(); // TODO: When do we need to reinitialize? When the solver has been reset. initialize(k); maxK = k; + stats.winningRegionUpdatesTimer.start(); + storm::storage::BitVector updated(pomdp.getNrObservations()); + // TODO CODE DUPLICATION WITH UPDATE, PUT IN PROCEDURE + storm::storage::BitVector potentialWinner(pomdp.getNrObservations()); + storm::storage::BitVector observationsWithPartialWinners(pomdp.getNrObservations()); + for(uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { + if (winningRegion.observationIsWinning(observation)) { + continue; + } + bool observationIsWinning = true; + for (uint64_t state : statesPerObservation[observation]) { + if(!targetStates.get(state)) { + observationIsWinning = false; + observationsWithPartialWinners.set(observation); + } else { + potentialWinner.set(observation); + } + } + if(observationIsWinning) { + STORM_LOG_TRACE("Observation " << observation << " is winning."); + stats.incrementGraphBasedWinningObservations(); + winningRegion.setObservationIsWinning(observation); + updated.set(observation); + } + } + STORM_LOG_INFO("Graph based winning obs: " << stats.getGraphBasedwinningObservations()); + observationsWithPartialWinners &= potentialWinner; + for (auto const& observation : observationsWithPartialWinners) { + + uint64_t nrStatesForObs = statesPerObservation[observation].size(); + storm::storage::BitVector update(nrStatesForObs); + for (uint64_t i = 0; i < nrStatesForObs; ++i ) { + uint64_t state = statesPerObservation[observation][i]; + if(targetStates.get(state)) { + update.set(i); + } + } + assert(!update.empty()); + STORM_LOG_TRACE("Extend winning region for observation " << observation << " with target states/offsets" << update); + winningRegion.addTargetStates(observation, update); + assert(winningRegion.query(observation,update));// "Cannot continue: No scheduler known for state " << i << " (observation " << obs << ")."); + + updated.set(observation); + } + for (auto const& state : targetStates) { + STORM_LOG_ASSERT(winningRegion.isWinning(pomdp.getObservation(state),getOffsetFromObservation(state,pomdp.getObservation(state))), "Target state " << state << " , observation " << pomdp.getObservation(state) << " is not reflected as winning."); + } + stats.winningRegionUpdatesTimer.stop(); + uint64_t maximalNrActions = 8; STORM_LOG_WARN("We have hardcoded (an upper bound on) the number of actions"); @@ -415,6 +492,7 @@ namespace storm { storm::storage::BitVector observations(pomdp.getNrObservations()); storm::storage::BitVector observationsAfterSwitch(pomdp.getNrObservations()); storm::storage::BitVector observationUpdated(pomdp.getNrObservations()); + storm::storage::BitVector uncoveredStates(pomdp.getNumberOfStates()); storm::storage::BitVector coveredStates(pomdp.getNumberOfStates()); storm::storage::BitVector coveredStatesAfterSwitch(pomdp.getNumberOfStates()); @@ -449,15 +527,14 @@ namespace storm { break; } newSchedulerDiscovered = true; - stats.updateExtensionSolverTime.start(); - auto model = smtSolver->getModel(); + stats.evaluateExtensionSolverTime.start(); + auto const& model = smtSolver->getModel(); newObservationsAfterSwitch.clear(); newObservations.clear(); uint64_t obs = 0; for (auto const& ov : observationUpdatedVariables) { - if (!observationUpdated.get(obs) && model->getBooleanValue(ov)) { STORM_LOG_TRACE("New observation updated: " << obs); observationUpdated.set(obs); @@ -465,32 +542,43 @@ namespace storm { obs++; } - uint64_t i = 0; - for (auto const& rv : reachVars) { - if (!coveredStates.get(i) && model->getBooleanValue(rv)) { +// for(uint64_t i : targetStates) { +// assert(model->getBooleanValue(reachVars[i])); +// } + + uncoveredStates = ~coveredStates; + for (uint64_t i : uncoveredStates) { + auto const& rv =reachVars[i]; + auto const& rvExpr =reachVarExpressions[i]; + if (model->getBooleanValue(rv)) { STORM_LOG_TRACE("New state: " << i); - smtSolver->add(rv.getExpression()); + smtSolver->add(rvExpr); assert(!surelyReachSinkStates.get(i)); newObservations.set(pomdp.getObservation(i)); coveredStates.set(i); } - ++i; } - i = 0; - for (auto const& rv : continuationVars) { - if (!coveredStatesAfterSwitch.get(i) && model->getBooleanValue(rv) ) { - smtSolver->add(rv.getExpression()); - if (!observationsAfterSwitch.get(pomdp.getObservation(i))) { - newObservationsAfterSwitch.set(pomdp.getObservation(i)); + + storm::storage::BitVector uncoveredStatesAfterSwitch(~coveredStatesAfterSwitch); + for (uint64_t i : uncoveredStatesAfterSwitch) { + auto const& cv = continuationVars[i]; + if (model->getBooleanValue(cv)) { + uint64_t obs = pomdp.getObservation(i); + STORM_LOG_ASSERT(winningRegion.isWinning(obs,getOffsetFromObservation(i,obs)), "Cannot continue: No scheduler known for state " << i << " (observation " << obs << ")."); + auto const& cvExpr =continuationVarExpressions[i]; + smtSolver->add(cvExpr); + if (!observationsAfterSwitch.get(obs)) { + newObservationsAfterSwitch.set(obs); } - ++i; + } } + stats.evaluateExtensionSolverTime.stop(); if (options.computeTraceOutput()) { detail::printRelevantInfoFromModel(model, reachVars, continuationVars); } - + stats.encodeExtensionSolverTime.start(); for (auto obs : newObservations) { auto const &actionSelectionVarsForObs = actionSelectionVars[obs]; observations.set(obs); @@ -534,16 +622,11 @@ namespace storm { if (remainingExpressions.empty()) { - stats.updateExtensionSolverTime.stop(); + stats.encodeExtensionSolverTime.stop(); break; } - // Add scheduler - - //std::cout << storm::expressions::disjunction(remainingExpressions) << std::endl; - smtSolver->add(storm::expressions::disjunction(remainingExpressions)); - stats.updateExtensionSolverTime.stop(); - + stats.encodeExtensionSolverTime.stop(); } if (!newSchedulerDiscovered) { break; @@ -591,45 +674,58 @@ namespace storm { } stats.winningRegionUpdatesTimer.stop(); if (newTargetObservations>0) { + stats.graphSearchTime.start(); storm::analysis::QualitativeAnalysisOnGraphs graphanalysis(pomdp); uint64_t targetStatesBefore = targetStates.getNumberOfSetBits(); STORM_LOG_INFO("Target states before graph based analysis " << targetStates.getNumberOfSetBits()); - storm::storage::BitVector newtargetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates); - uint64_t targetStatesAfter = newtargetStates.getNumberOfSetBits(); - STORM_LOG_INFO("Target states after graph based analysis " << newtargetStates.getNumberOfSetBits()); + targetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates); + uint64_t targetStatesAfter = targetStates.getNumberOfSetBits(); + STORM_LOG_INFO("Target states after graph based analysis " << targetStates.getNumberOfSetBits()); + stats.graphSearchTime.stop(); if (targetStatesAfter - targetStatesBefore > 0) { stats.winningRegionUpdatesTimer.start(); - + // TODO CODE DUPLICATION WITH INIT, PUT IN PROCEDURE + storm::storage::BitVector potentialWinner(pomdp.getNrObservations()); + storm::storage::BitVector observationsWithPartialWinners(pomdp.getNrObservations()); for(uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { if (winningRegion.observationIsWinning(observation)) { continue; } bool observationIsWinning = true; for (uint64_t state : statesPerObservation[observation]) { - if(!newtargetStates.get(state)) { + if(!targetStates.get(state)) { observationIsWinning = false; - break; + observationsWithPartialWinners.set(observation); + } else { + potentialWinner.set(observation); } } if(observationIsWinning) { stats.incrementGraphBasedWinningObservations(); winningRegion.setObservationIsWinning(observation); - for(auto const& state : statesPerObservation[observation]) { - targetStates.set(state); - } updated.set(observation); } } STORM_LOG_INFO("Graph based winning obs: " << stats.getGraphBasedwinningObservations()); - uint64_t nonWinObTargetStates =0; - for (uint64_t state : targetStates) { - if (!winningRegion.observationIsWinning(pomdp.getObservation(state))) { - nonWinObTargetStates++; + observationsWithPartialWinners &= potentialWinner; + for (auto const& observation : observationsWithPartialWinners) { + uint64_t nrStatesForObs = statesPerObservation[observation].size(); + storm::storage::BitVector update(nrStatesForObs); + for (uint64_t i = 0; i < nrStatesForObs; ++i ) { + uint64_t state = statesPerObservation[observation][i]; + if(targetStates.get(state)) { + update.set(i); + } } + assert(!update.empty()); + STORM_LOG_TRACE("Extend winning region for observation " << observation << " with target states/offsets" << update); + winningRegion.addTargetStates(observation, update); + assert(winningRegion.query(observation,update));// + updated.set(observation); } stats.winningRegionUpdatesTimer.stop(); - if (nonWinObTargetStates > 0) { - std::cout << "Non winning target states " << nonWinObTargetStates << std::endl; + + if (observationsWithPartialWinners.getNumberOfSetBits() > 0) { STORM_LOG_WARN("This case has been barely tested and likely contains bug"); reset(); return analyze(k, ~targetStates & ~surelyReachSinkStates); diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h index 6106844e1..47124663d 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h @@ -108,8 +108,10 @@ namespace pomdp { storm::utility::Stopwatch totalTimer; storm::utility::Stopwatch smtCheckTimer; storm::utility::Stopwatch initializeSolverTimer; - storm::utility::Stopwatch updateExtensionSolverTime; + storm::utility::Stopwatch evaluateExtensionSolverTime; + storm::utility::Stopwatch encodeExtensionSolverTime; storm::utility::Stopwatch updateNewStrategySolverTime; + storm::utility::Stopwatch graphSearchTime; storm::utility::Stopwatch winningRegionUpdatesTimer; @@ -168,9 +170,6 @@ namespace pomdp { } void computeWinningRegion(uint64_t k) { - std::cout << surelyReachSinkStates << std::endl; - std::cout << targetStates << std::endl; - std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl; stats.totalTimer.start(); analyze(k, ~surelyReachSinkStates & ~targetStates); stats.totalTimer.stop(); @@ -180,6 +179,8 @@ namespace pomdp { return winningRegion; } + uint64_t getOffsetFromObservation(uint64_t state, uint64_t observation) const; + bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector()); Statistics const& getStatistics() const; @@ -241,6 +242,8 @@ namespace pomdp { std::shared_ptr& smtSolverFactory; std::shared_ptr> validator; + mutable bool useFindOffset = false; + }; } diff --git a/src/storm-pomdp/analysis/WinningRegion.cpp b/src/storm-pomdp/analysis/WinningRegion.cpp index 26e2c333e..c2408f20e 100644 --- a/src/storm-pomdp/analysis/WinningRegion.cpp +++ b/src/storm-pomdp/analysis/WinningRegion.cpp @@ -18,19 +18,24 @@ namespace pomdp { winningRegion[observation] = { storm::storage::BitVector(observationSizes[observation], true) }; } -// void WinningRegion::addTargetState(uint64_t observation, uint64_t offset) { -// std::vector newWinningSupport = std::vector(); -// bool changed = true; -// for (auto const& support : winningRegion[observation]) { -// newWinningSupport.push_back(storm::storage::BitVector(support)); -// if(!support.get(offset)) { -// changed = true; -// newWinningSupport.back().set(offset); -// } -// } -// -// -// } + void WinningRegion::addTargetStates(uint64_t observation, storm::storage::BitVector const& offsets) { + assert(!offsets.empty()); + if(winningRegion[observation].empty()) { + winningRegion[observation].push_back(offsets); + return; + } + std::vector newWinningSupport = std::vector(); + + for (auto const& support : winningRegion[observation]) { + newWinningSupport.push_back(support | offsets); + } + // TODO it may be worthwhile to check whether something changed. If nothing changed, there is no need for the next routine. + // TODO the following code is bit naive. + winningRegion[observation].clear(); // This prevents some overhead. + for (auto const& newWinning : newWinningSupport) { + update(observation, newWinning); + } + } bool WinningRegion::update(uint64_t observation, storm::storage::BitVector const& winning) { std::vector newWinningSupport = std::vector(); diff --git a/src/storm-pomdp/analysis/WinningRegion.h b/src/storm-pomdp/analysis/WinningRegion.h index aaf51d839..bc81cf907 100644 --- a/src/storm-pomdp/analysis/WinningRegion.h +++ b/src/storm-pomdp/analysis/WinningRegion.h @@ -19,6 +19,7 @@ namespace storm { std::vector const& getWinningSetsPerObservation(uint64_t observation) const; + void addTargetStates(uint64_t observation, storm::storage::BitVector const& offsets); void setObservationIsWinning(uint64_t observation); bool observationIsWinning(uint64_t observation) const;