From 53800c2145329443ec7300323fa2a480cf9abf67 Mon Sep 17 00:00:00 2001 From: Sebastian Junges Date: Sun, 24 May 2020 11:18:42 -0700 Subject: [PATCH] major improvements by introducing real-valued ranking and various related fixes --- .../MemlessStrategySearchQualitative.cpp | 239 ++++++++++++------ .../MemlessStrategySearchQualitative.h | 25 +- 2 files changed, 178 insertions(+), 86 deletions(-) diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp index ce7bd71d1..eb05be797 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp @@ -83,8 +83,9 @@ namespace storm { } template - void MemlessStrategySearchQualitative::initialize(uint64_t k) { + bool MemlessStrategySearchQualitative::initialize(uint64_t k) { STORM_LOG_INFO("Start intializing solver..."); + bool lookaheadConstraintsRequired; if (options.forceLookahead) { lookaheadConstraintsRequired = true; @@ -131,8 +132,10 @@ namespace storm { } for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { - pathVars.push_back(std::vector()); + pathVars.push_back(std::vector()); + pathVarExpressions.push_back(std::vector()); } + } uint64_t initK = 0; @@ -142,15 +145,25 @@ namespace storm { if (initK < k) { for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { if (lookaheadConstraintsRequired) { - for (uint64_t i = initK; i < k; ++i) { - pathVars[stateId].push_back(expressionManager->declareBooleanVariable( - "P-" + std::to_string(stateId) + "-" + std::to_string(i)).getExpression()); + if(options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) { + for (uint64_t i = initK; i < k; ++i) { + pathVars[stateId].push_back(expressionManager->declareBooleanVariable( + "P-" + std::to_string(stateId) + "-" + std::to_string(i))); + pathVarExpressions[stateId].push_back(pathVars[stateId].back().getExpression()); + } + } else if (options.pathVariableType == MemlessSearchPathVariables::IntegerRanking) { + pathVars[stateId].push_back(expressionManager->declareIntegerVariable("P-" + std::to_string(stateId))); + pathVarExpressions[stateId].push_back(pathVars[stateId].back().getExpression()); + } else { + assert(options.pathVariableType == MemlessSearchPathVariables::RealRanking); + pathVars[stateId].push_back(expressionManager->declareRationalVariable("P-" + std::to_string(stateId))); + pathVarExpressions[stateId].push_back(pathVars[stateId].back().getExpression()); } } } } - assert(!lookaheadConstraintsRequired || pathVars.size() == pomdp.getNumberOfStates()); + assert(!lookaheadConstraintsRequired || pathVarExpressions.size() == pomdp.getNumberOfStates()); assert(reachVars.size() == pomdp.getNumberOfStates()); assert(reachVarExpressions.size() == pomdp.getNumberOfStates()); @@ -193,13 +206,23 @@ namespace storm { // PAPER COMMENT: 3 if (lookaheadConstraintsRequired) { - for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { - if (targetStates.get(state)) { - smtSolver->add(pathVars[state][0]); - } else { - smtSolver->add(!pathVars[state][0] || followVarExpressions[pomdp.getObservation(state)]); + if (options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) { + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + if (targetStates.get(state)) { + smtSolver->add(pathVarExpressions[state][0]); + } else { + smtSolver->add(!pathVarExpressions[state][0] || followVarExpressions[pomdp.getObservation(state)]); + } } + } else { + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + smtSolver->add(pathVarExpressions[state][0] <= expressionManager->integer(k)); + smtSolver->add(pathVarExpressions[state][0] >= expressionManager->integer(0)); + + } + //assert(false); } + } // PAPER COMMENT: 4 @@ -218,6 +241,8 @@ namespace storm { subexprreachNoSwitch.push_back(!reachVarExpressions[state]); subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]); + subexprreachSwitch.push_back(followVarExpressions[pomdp.getObservation(state)]); + subexprreachNoSwitch.push_back(followVarExpressions[pomdp.getObservation(state)]); for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { if (pomdp.getObservation(entries.getColumn() == pomdp.getObservation(state))) { subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn())); @@ -244,50 +269,85 @@ namespace storm { smtSolver->add(!reachVarExpressions[state]); smtSolver->add(!continuationVarExpressions[state]); if (lookaheadConstraintsRequired) { - for (uint64_t j = 1; j < k; ++j) { - smtSolver->add(!pathVars[state][j]); + if(options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) { + for (uint64_t j = 1; j < k; ++j) { + smtSolver->add(!pathVarExpressions[state][j]); + } + } else { + smtSolver->add(pathVarExpressions[state][0] == expressionManager->integer(k)); } } rowindex += pomdp.getNumberOfChoices(state); } else if(!targetStates.get(state)) { if (lookaheadConstraintsRequired) { - // PAPER COMMENT 6 - smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back())); - // PAPER COMMENT 7 - std::vector>> pathsubsubexprs; - for (uint64_t j = 1; j < k; ++j) { - pathsubsubexprs.push_back(std::vector>()); - for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { - pathsubsubexprs.back().push_back(std::vector()); + + if(options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) { + // PAPER COMMENT 6 + smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), + pathVarExpressions.at(state).back())); + // PAPER COMMENT 7 + + std::vector>> pathsubsubexprs; + for (uint64_t j = 1; j < k; ++j) { + pathsubsubexprs.push_back(std::vector>()); + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { + pathsubsubexprs.back().push_back(std::vector()); + } } - } - for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { - std::vector subexprreach; - for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { - for (uint64_t j = 1; j < k; ++j) { - pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]); + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { + std::vector subexprreach; + for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { + for (uint64_t j = 1; j < k; ++j) { + pathsubsubexprs[j - 1][action].push_back(pathVarExpressions[entries.getColumn()][j - 1]); + } } + rowindex++; } - rowindex++; - } - for (uint64_t j = 1; j < k; ++j) { - std::vector pathsubexprs; + for (uint64_t j = 1; j < k; ++j) { + std::vector pathsubexprs; + + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { + pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); + } + pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state))); + pathsubexprs.push_back(followVarExpressions[pomdp.getObservation(state)]); + smtSolver->add(storm::expressions::iff(pathVarExpressions[state][j], + storm::expressions::disjunction(pathsubexprs))); + + } + } else { + std::vector actPathDisjunction; for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { - pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); + std::vector pathDisjunction; + for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { + pathDisjunction.push_back(pathVarExpressions[entries.getColumn()][0] < pathVarExpressions[state][0]); + } + actPathDisjunction.push_back(storm::expressions::disjunction(pathDisjunction) && actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action)); + rowindex++; } - pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state))); - pathsubexprs.push_back(followVarExpressions[pomdp.getObservation(state)]); - smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); + // TODO reconsider if this next add is sound + actPathDisjunction.push_back(switchVarExpressions.at(pomdp.getObservation(state))); + actPathDisjunction.push_back(followVarExpressions[pomdp.getObservation(state)]); + actPathDisjunction.push_back(!reachVarExpressions[state]); + smtSolver->add(storm::expressions::disjunction(actPathDisjunction)); } } } else { - for (uint64_t j = 1; j < k; ++j) { - smtSolver->add(pathVars[state][j]); + if (lookaheadConstraintsRequired) { + if (options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) { + for (uint64_t j = 1; j < k; ++j) { + smtSolver->add(pathVarExpressions[state][j]); + } + } else { + smtSolver->add(pathVarExpressions[state][0] == expressionManager->integer(0)); + //assert(false); + } } + smtSolver->add(reachVars[state]); rowindex += pomdp.getNumberOfChoices(state); } } @@ -308,29 +368,7 @@ namespace storm { for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { smtSolver->add(storm::expressions::implies(switchVarExpressions[obs], storm::expressions::disjunction(reachVarExpressionsPerObservation[obs]))); } - // PAPER COMMENT 10 -// if (!lookaheadConstraintsRequired) { -// uint64_t rowIndex = 0; -// for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { -// uint64_t enabledActions = pomdp.getNumberOfChoices(state); -// if (!surelyReachSinkStates.get(state)) { -// std::vector successorVars; -// for (uint64_t act = 0; act < enabledActions; ++act) { -// for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowIndex)) { -// successorVars.push_back(reachVarExpressions[entries.getColumn()]); -// } -// rowIndex++; -// } -// successorVars.push_back(!switchVars[pomdp.getObservation(state)]); -// smtSolver->add(storm::expressions::implies(storm::expressions::conjunction(successorVars), reachVarExpressions[state])); -// } else { -// rowIndex += enabledActions; -// } -// } -// } else { -// STORM_LOG_WARN("Some optimization not implemented yet."); -// } - // TODO: Update found schedulers if k is increased. + return lookaheadConstraintsRequired; } template @@ -357,9 +395,10 @@ namespace storm { STORM_LOG_DEBUG("Questionmark states " << (~surelyReachSinkStates & ~targetStates)); stats.initializeSolverTimer.start(); // TODO: When do we need to reinitialize? When the solver has been reset. - initialize(k); - maxK = k; - + bool lookaheadConstraintsRequired = initialize(k); + if(lookaheadConstraintsRequired) { + maxK = k; + } stats.winningRegionUpdatesTimer.start(); storm::storage::BitVector updated(pomdp.getNrObservations()); @@ -500,11 +539,11 @@ namespace storm { uint64_t iterations = 0; while(true) { stats.incrementOuterIterations(); - + // TODO consider what we really want to store about the schedulers. scheduler.reset(pomdp.getNrObservations(), maximalNrActions); observations.clear(); observationsAfterSwitch.clear(); - coveredStates.clear(); + coveredStates = targetStates; coveredStatesAfterSwitch.clear(); observationUpdated.clear(); if (!allOfTheseAssumption.empty()) { @@ -540,20 +579,23 @@ namespace storm { obs++; } -// for(uint64_t i : targetStates) { -// assert(model->getBooleanValue(reachVars[i])); -// } - uncoveredStates = ~coveredStates; for (uint64_t i : uncoveredStates) { auto const& rv =reachVars[i]; auto const& rvExpr =reachVarExpressions[i]; - if (model->getBooleanValue(rv)) { + if (observationUpdated.get(pomdp.getObservation(i)) && model->getBooleanValue(rv)) { STORM_LOG_TRACE("New state: " << i); smtSolver->add(rvExpr); assert(!surelyReachSinkStates.get(i)); newObservations.set(pomdp.getObservation(i)); coveredStates.set(i); + if(lookaheadConstraintsRequired) { + if (options.pathVariableType == MemlessSearchPathVariables::IntegerRanking) { + smtSolver->add(pathVarExpressions[i][0] == expressionManager->integer(model->getIntegerValue(pathVars[i][0]))); + } else if(options.pathVariableType == MemlessSearchPathVariables::RealRanking) { + smtSolver->add(pathVarExpressions[i][0] == expressionManager->rational(model->getRationalValue(pathVars[i][0]))); + } + } } } @@ -594,6 +636,11 @@ namespace storm { } else { smtSolver->add(!switchVarExpressions[obs]); } + if (model->getBooleanValue(followVars[obs])) { + smtSolver->add(followVarExpressions[obs]); + } else { + smtSolver->add(!followVarExpressions[obs]); + } } for (auto obs : newObservationsAfterSwitch) { observationsAfterSwitch.set(obs); @@ -724,7 +771,7 @@ namespace storm { if (observationsWithPartialWinners.getNumberOfSetBits() > 0) { reset(); - return analyze(k, ~targetStates & ~surelyReachSinkStates); + return analyze(k, ~targetStates & ~surelyReachSinkStates, allOfTheseStates); } } @@ -737,6 +784,10 @@ namespace storm { STORM_LOG_WARN("Validating every step, for debug purposes only!"); validator->validate(surelyReachSinkStates); } + if (stats.getIterations() % options.restartAfterNIterations == options.restartAfterNIterations-1) { + reset(); + return analyze(k, ~targetStates & ~surelyReachSinkStates, allOfTheseStates); + } stats.updateNewStrategySolverTime.start(); for(uint64_t observation : updated) { updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]); @@ -749,15 +800,27 @@ namespace storm { assert(schedulerForObs.size() > obs); (schedulerForObs[obs])++; STORM_LOG_DEBUG("We now have " << schedulerForObs[obs] << " policies for states with observation " << obs); - - for (auto const &state : statesForObservation) { - if (!coveredStates.get(state)) { - auto constant = expressionManager->integer(schedulerForObs[obs]); - // PAPER COMMENT 14: - smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); - smtSolver->add(!(reachVarExpressions[state] && followVarExpressions[pomdp.getObservation(state)] && (schedulerVariableExpressions[obs] == constant))); + if (winningRegion.observationIsWinning(obs)) { + for (auto const &state : statesForObservation) { + smtSolver->add(reachVarExpressions[state]); + } + auto constant = expressionManager->integer(schedulerForObs[obs]); + smtSolver->add(schedulerVariableExpressions[obs] == constant); + } else { + auto constant = expressionManager->integer(schedulerForObs[obs]); + for (auto const &state : statesForObservation) { + if (!coveredStates.get(state)) { + + // PAPER COMMENT 14: + smtSolver->add(!(continuationVarExpressions[state] && + (schedulerVariableExpressions[obs] == constant))); + smtSolver->add(!(reachVarExpressions[state] && + followVarExpressions[pomdp.getObservation(state)] && + (schedulerVariableExpressions[obs] == constant))); + } } } + } ++obs; } @@ -766,11 +829,21 @@ namespace storm { smtSolver->push(); for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { - auto constant = expressionManager->integer(schedulerForObs[obs]); - // PAPER COMMENT 13 - smtSolver->add(schedulerVariableExpressions[obs] <= constant); - // PAPER COMMENT 12 - smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs])); + if(winningRegion.observationIsWinning(obs)) { + auto constant = expressionManager->integer(schedulerForObs[obs]); + // PAPER COMMENT 13 + // Scheduler variable is already fixed. + // PAPER COMMENT 12 + // Observation will not be updated. + smtSolver->add(!observationUpdatedExpressions[obs]); + } else { + auto constant = expressionManager->integer(schedulerForObs[obs]); + // PAPER COMMENT 13 + smtSolver->add(schedulerVariableExpressions[obs] <= constant); + // PAPER COMMENT 12 + smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], + updateForObservationExpressions[obs])); + } } stats.updateNewStrategySolverTime.stop(); diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h index 47124663d..c60117602 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h @@ -13,6 +13,20 @@ namespace storm { namespace pomdp { + enum class MemlessSearchPathVariables { + BooleanRanking, IntegerRanking, RealRanking + }; + MemlessSearchPathVariables pathVariableTypeFromString(std::string const& in) { + if(in == "int") { + return MemlessSearchPathVariables::IntegerRanking; + } else if (in == "real") { + return MemlessSearchPathVariables::RealRanking; + } else { + assert(in == "bool"); + return MemlessSearchPathVariables::BooleanRanking; + } + } + class MemlessSearchOptions { public: @@ -45,9 +59,13 @@ namespace pomdp { } bool onlyDeterministicStrategies = false; - bool forceLookahead = true; + bool forceLookahead = false; bool validateEveryStep = false; bool validateResult = false; + MemlessSearchPathVariables pathVariableType = MemlessSearchPathVariables::RealRanking; + uint64_t restartAfterNIterations = 250; + uint64_t extensionCallTimeout = 0u; + uint64_t localIterationMaximum = 600; private: std::string exportSATcalls = ""; @@ -198,7 +216,7 @@ namespace pomdp { void printScheduler(std::vector const& ); void printCoveredStates(storm::storage::BitVector const& remaining) const; - void initialize(uint64_t k); + bool initialize(uint64_t k); bool smtCheck(uint64_t iteration, std::set const& assumptions = {}); @@ -230,7 +248,8 @@ namespace pomdp { std::vector followVarExpressions; std::vector continuationVars; std::vector continuationVarExpressions; - std::vector> pathVars; + std::vector> pathVars; + std::vector> pathVarExpressions; std::vector finalSchedulers; std::vector schedulerForObs;