diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp index 327d374b7..6552071ef 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp @@ -171,11 +171,11 @@ namespace storm { obs = 0; for (auto const& actionVars : actionSelectionVarExpressions) { std::vector actExprs = actionVars; - //actExprs.push_back(followVarExpressions[obs]); + actExprs.push_back(followVarExpressions[obs]); smtSolver->add(storm::expressions::disjunction(actExprs)); - //for (auto const& av : actionVars) { - // smtSolver->add(!followVarExpressions[obs] || !av); - //} + for (auto const& av : actionVars) { + smtSolver->add(!followVarExpressions[obs] || !av); + } ++obs; } @@ -191,7 +191,7 @@ namespace storm { if (targetStates.get(state)) { smtSolver->add(pathVars[state][0]); } else { - smtSolver->add(!pathVars[state][0]); + smtSolver->add(!pathVars[state][0] || followVarExpressions[pomdp.getObservation(state)]); } } } @@ -213,7 +213,11 @@ namespace storm { subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]); for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { - subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn())); + if (pomdp.getObservation(entries.getColumn() == pomdp.getObservation(state))) { + subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn())); + } else { + subexprreachSwitch.push_back(reachVarExpressions.at(entries.getColumn())); + } smtSolver->add(storm::expressions::disjunction(subexprreachSwitch)); subexprreachSwitch.pop_back(); subexprreachNoSwitch.push_back(reachVarExpressions.at(entries.getColumn())); @@ -270,10 +274,14 @@ namespace storm { pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); } pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state))); + pathsubexprs.push_back(followVarExpressions[pomdp.getObservation(state)]); smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); } } } else { + for (uint64_t j = 1; j < k; ++j) { + smtSolver->add(pathVars[state][j]); + } rowindex += pomdp.getNumberOfChoices(state); } } @@ -284,6 +292,7 @@ namespace storm { for(auto const& state : statesForObservation) { if (!targetStates.get(state)) { smtSolver->add(!continuationVars[state] || schedulerVariableExpressions[obs] > 0); + smtSolver->add(!reachVarExpressions[state] || !followVarExpressions[obs] || schedulerVariableExpressions[obs] > 0); } } ++obs; @@ -335,9 +344,10 @@ namespace storm { STORM_LOG_ASSERT(reachVarExpressions.size() > state, "state id " << state << " exceeds number of states (" << reachVarExpressions.size() << ")" ); atLeastOneOfStates.push_back(reachVarExpressions[state]); } - assert(atLeastOneOfStates.size() > 0); // PAPER COMMENT 11 - smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates)); + if (!atLeastOneOfStates.empty()) { + smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates)); + } smtSolver->push(); std::set allOfTheseAssumption; @@ -349,18 +359,54 @@ namespace storm { allOfTheseAssumption.insert(reachVarExpressions[state]); } - for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) { - updateForObservationExpressions.push_back(storm::expressions::disjunction(reachVarExpressionsPerObservation[ob])); - schedulerForObs.push_back(std::vector()); + if (winningRegion.empty()) { + // Keep it simple here to help bughunting if necessary. + for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) { + updateForObservationExpressions.push_back( + storm::expressions::disjunction(reachVarExpressionsPerObservation[ob])); + schedulerForObs.push_back(0); + } + } else { + + uint64_t obs = 0; + for (auto const &statesForObservation : statesPerObservation) { + schedulerForObs.push_back(0); + for (auto const& winningSet : winningRegion.getWinningSetsPerObservation(obs)) { + for (auto const &stateOffset : ~winningSet) { + uint64_t state = statesForObservation[stateOffset]; + assert(obs < schedulerForObs.size()); + ++(schedulerForObs[obs]); + auto constant = expressionManager->integer(schedulerForObs[obs]); + // PAPER COMMENT 14: + smtSolver->add(!(continuationVarExpressions[state] && + (schedulerVariableExpressions[obs] == constant))); + smtSolver->add(!(reachVarExpressions[state] && + followVarExpressions[pomdp.getObservation(state)] && + (schedulerVariableExpressions[obs] == constant))); + + } + } + if (winningRegion.getWinningSetsPerObservation(obs).empty()) { + updateForObservationExpressions.push_back( + storm::expressions::disjunction(reachVarExpressionsPerObservation[obs])); + } else { + updateForObservationExpressions.push_back(winningRegion.extensionExpression(obs, reachVarExpressionsPerObservation[obs])); + } + ++obs; + + } + } + for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { - auto constant = expressionManager->integer(schedulerForObs[obs].size()); + auto constant = expressionManager->integer(schedulerForObs[obs]); smtSolver->add(schedulerVariableExpressions[obs] <= constant); smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs])); } + assert(pomdp.getNrObservations() == schedulerForObs.size()); InternalObservationScheduler scheduler; @@ -549,9 +595,9 @@ namespace storm { storm::analysis::QualitativeAnalysisOnGraphs graphanalysis(pomdp); uint64_t targetStatesBefore = targetStates.getNumberOfSetBits(); STORM_LOG_INFO("Target states before graph based analysis " << targetStates.getNumberOfSetBits()); - targetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates); - uint64_t targetStatesAfter = targetStates.getNumberOfSetBits(); - STORM_LOG_INFO("Target states after graph based analysis " << targetStates.getNumberOfSetBits()); + storm::storage::BitVector newtargetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates); + uint64_t targetStatesAfter = newtargetStates.getNumberOfSetBits(); + STORM_LOG_INFO("Target states after graph based analysis " << newtargetStates.getNumberOfSetBits()); if (targetStatesAfter - targetStatesBefore > 0) { stats.winningRegionUpdatesTimer.start(); @@ -561,7 +607,7 @@ namespace storm { } bool observationIsWinning = true; for (uint64_t state : statesPerObservation[observation]) { - if(!targetStates.get(state)) { + if(!newtargetStates.get(state)) { observationIsWinning = false; break; } @@ -569,6 +615,9 @@ namespace storm { if(observationIsWinning) { stats.incrementGraphBasedWinningObservations(); winningRegion.setObservationIsWinning(observation); + for(auto const& state : statesPerObservation[observation]) { + targetStates.set(state); + } updated.set(observation); } } @@ -608,14 +657,15 @@ namespace storm { if (observations.get(obs) && updated.get(obs)) { STORM_LOG_DEBUG("We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << "."); assert(schedulerForObs.size() > obs); - schedulerForObs[obs].push_back(finalSchedulers.size()); - STORM_LOG_DEBUG("We now have " << schedulerForObs[obs].size() << " policies for states with observation " << obs); + (schedulerForObs[obs])++; + STORM_LOG_DEBUG("We now have " << schedulerForObs[obs] << " policies for states with observation " << obs); for (auto const &state : statesForObservation) { if (!coveredStates.get(state)) { - auto constant = expressionManager->integer(schedulerForObs[obs].size()); + auto constant = expressionManager->integer(schedulerForObs[obs]); // PAPER COMMENT 14: smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); + smtSolver->add(!(reachVarExpressions[state] && followVarExpressions[pomdp.getObservation(state)] && (schedulerVariableExpressions[obs] == constant))); } } } @@ -626,7 +676,7 @@ namespace storm { smtSolver->push(); for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { - auto constant = expressionManager->integer(schedulerForObs[obs].size()); + auto constant = expressionManager->integer(schedulerForObs[obs]); // PAPER COMMENT 13 smtSolver->add(schedulerVariableExpressions[obs] <= constant); // PAPER COMMENT 12 diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h index c3cb08bad..c7fcae2e8 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h @@ -231,7 +231,7 @@ namespace pomdp { std::vector> pathVars; std::vector finalSchedulers; - std::vector> schedulerForObs; + std::vector schedulerForObs; WinningRegion winningRegion; MemlessSearchOptions options;