From e22cbdb91bac40495b86b521c3974433035af5fb Mon Sep 17 00:00:00 2001 From: Sebastian Junges Date: Tue, 14 Apr 2020 14:01:48 -0700 Subject: [PATCH] support for computing the winning region or from initial state, some documentation --- .../MemlessStrategySearchQualitative.cpp | 60 ++++++++++++++++--- .../MemlessStrategySearchQualitative.h | 16 ++++- 2 files changed, 65 insertions(+), 11 deletions(-) diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp index b1e7aea9b..ce94c5113 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp @@ -81,7 +81,6 @@ namespace storm { STORM_LOG_INFO("Start intializing solver..."); // TODO fix this bool lookaheadConstraintsRequired = options.lookaheadRequired; - STORM_LOG_WARN("We have hardcoded that we do not need lookahead"); if (maxK == std::numeric_limits::max()) { // not initialized at all. // Create some data structures. @@ -134,12 +133,16 @@ namespace storm { ++obs; } + // PAPER COMMENT: 1 for (auto const& actionVars : actionSelectionVarExpressions) { smtSolver->add(storm::expressions::disjunction(actionVars)); } + // Update at least one observation. + // PAPER COMMENT: 2 smtSolver->add(storm::expressions::disjunction(observationUpdatedExpressions)); + // PAPER COMMENT: 3 if (lookaheadConstraintsRequired) { for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { if (targetStates.get(state)) { @@ -150,6 +153,7 @@ namespace storm { } } + // PAPER COMMENT: 4 uint64_t rowindex = 0; for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { @@ -185,6 +189,7 @@ namespace storm { uint64_t rowindex = 0; for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + // PAPER COMMENT 5 if (surelyReachSinkStates.get(state)) { smtSolver->add(!reachVarExpressions[state]); smtSolver->add(!continuationVarExpressions[state]); @@ -196,8 +201,10 @@ namespace storm { rowindex += pomdp.getNumberOfChoices(state); } else if(!targetStates.get(state)) { if (lookaheadConstraintsRequired) { + // PAPER COMMENT 6 smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back())); + // PAPER COMMENT 7 std::vector>> pathsubsubexprs; for (uint64_t j = 1; j < k; ++j) { pathsubsubexprs.push_back(std::vector>()); @@ -231,6 +238,7 @@ namespace storm { } } + // PAPER COMMENT 8 uint64_t obs = 0; for(auto const& statesForObservation : statesPerObservation) { for(auto const& state : statesForObservation) { @@ -239,9 +247,11 @@ namespace storm { ++obs; } + // PAPER COMMENT 9 for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { smtSolver->add(storm::expressions::implies(switchVarExpressions[obs], storm::expressions::disjunction(reachVarExpressionsPerObservation[obs]))); } + // PAPER COMMENT 10 if (!lookaheadConstraintsRequired) { uint64_t rowIndex = 0; for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { @@ -263,9 +273,6 @@ namespace storm { } else { STORM_LOG_WARN("Some optimization not implemented yet."); } - - - // TODO: Update found schedulers if k is increased. } @@ -287,15 +294,19 @@ namespace storm { atLeastOneOfStates.push_back(reachVarExpressions[state]); } assert(atLeastOneOfStates.size() > 0); + // PAPER COMMENT 11 smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates)); + smtSolver->push(); + + std::set allOfTheseAssumption; + + std::vector updateForObservationExpressions; for (uint64_t state : allOfTheseStates) { assert(reachVarExpressions.size() > state); - smtSolver->add(reachVarExpressions[state]); + allOfTheseAssumption.insert(reachVarExpressions[state]); } - smtSolver->push(); - std::vector updateForObservationExpressions; for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) { updateForObservationExpressions.push_back(storm::expressions::disjunction(reachVarExpressionsPerObservation[ob])); schedulerForObs.push_back(std::vector()); @@ -333,11 +344,19 @@ namespace storm { coveredStates.clear(); coveredStatesAfterSwitch.clear(); observationUpdated.clear(); + if (!allOfTheseAssumption.empty()) { + bool foundResult = this->smtCheck(iterations, allOfTheseAssumption); + if (foundResult) { + // Consider storing the scheduler + return true; + } + } bool newSchedulerDiscovered = false; while (true) { ++iterations; + bool foundScheduler = this->smtCheck(iterations); if (!foundScheduler) { break; @@ -493,6 +512,7 @@ namespace storm { for (auto const &state : statesForObservation) { if (!coveredStates.get(state)) { auto constant = expressionManager->integer(schedulerForObs[obs].size()); + // PAPER COMMENT 14: smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); } } @@ -505,7 +525,9 @@ namespace storm { for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { auto constant = expressionManager->integer(schedulerForObs[obs].size()); + // PAPER COMMENT 13 smtSolver->add(schedulerVariableExpressions[obs] <= constant); + // PAPER COMMENT 12 smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs])); } stats.updateNewStrategySolverTime.stop(); @@ -514,6 +536,21 @@ namespace storm { } winningRegion.print(); + if (!allOfTheseStates.empty()) { + for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { + storm::storage::BitVector check(statesPerObservation[observation].size()); + uint64_t i = 0; + for (uint64_t state : statesPerObservation[observation]) { + if (allOfTheseStates.get(state)) { + check.set(i); + } + ++i; + } + if (!winningRegion.query(observation, check)) { + return false; + } + } + } return true; } @@ -546,7 +583,7 @@ namespace storm { } template - bool MemlessStrategySearchQualitative::smtCheck(uint64_t iteration) { + bool MemlessStrategySearchQualitative::smtCheck(uint64_t iteration, std::set const& assumptions) { if(options.isExportSATSet()) { STORM_LOG_DEBUG("Export SMT Solver Call (" <check(); + if (assumptions.empty()) { + result = smtSolver->check(); + } else { + result = smtSolver->checkWithAssumptions(assumptions); + } stats.smtCheckTimer.stop(); stats.incrementSmtChecks(); diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h index f24e21628..05ed91b8d 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h @@ -132,8 +132,20 @@ namespace pomdp { void analyzeForInitialStates(uint64_t k) { stats.totalTimer.start(); - analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates()); + STORM_LOG_TRACE("Bad states: " << surelyReachSinkStates); + STORM_LOG_TRACE("Target states: " << targetStates); + STORM_LOG_TRACE("Questionmark states: " << (~surelyReachSinkStates & ~targetStates)); + bool result = analyze(k, ~surelyReachSinkStates & ~targetStates, pomdp.getInitialStates()); stats.totalTimer.stop(); + if (result) { + STORM_PRINT_AND_LOG("From initial state, one can almost-surely reach the target."); + } else { + if (k == pomdp.getNumberOfStates()) { + STORM_PRINT_AND_LOG("From initial state, one cannot almost-surely reach the target."); + } else { + STORM_PRINT_AND_LOG("From initial state, one may not almost-surely reach the target."); + } + } } void findNewStrategyForSomeState(uint64_t k) { @@ -157,7 +169,7 @@ namespace pomdp { void initialize(uint64_t k); - bool smtCheck(uint64_t iteration); + bool smtCheck(uint64_t iteration, std::set const& assumptions = {}); std::unique_ptr smtSolver;