diff --git a/src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.cpp b/src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.cpp index bda4ac15e..6c9a62d18 100644 --- a/src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.cpp +++ b/src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.cpp @@ -15,11 +15,13 @@ namespace storm { const std::string QualitativePOMDPAnalysisSettings::moduleName = "pomdpQualitative"; const std::string exportSATCallsOption = "exportSATCallsPath"; const std::string lookaheadHorizonOption = "lookaheadHorizon"; + const std::string onlyDeterministicOption = "onlyDeterministic"; QualitativePOMDPAnalysisSettings::QualitativePOMDPAnalysisSettings() : ModuleSettings(moduleName) { this->addOption(storm::settings::OptionBuilder(moduleName, exportSATCallsOption, false, "Export the SAT calls?.").addArgument(storm::settings::ArgumentBuilder::createStringArgument("path", "The name of the file to which to write the model.").build()).build()); this->addOption(storm::settings::OptionBuilder(moduleName, lookaheadHorizonOption, false, "In reachability in combination with a discrete ranking function, a lookahead is necessary.").addArgument(storm::settings::ArgumentBuilder::createUnsignedIntegerArgument("bound", "The lookahead. Use 0 for the number of states.").setDefaultValueUnsignedInteger(0).build()).build()); + this->addOption(storm::settings::OptionBuilder(moduleName, onlyDeterministicOption, false, "Search only for deterministic schedulers").build()); } uint64_t QualitativePOMDPAnalysisSettings::getLookahead() const { @@ -33,6 +35,9 @@ namespace storm { return this->getOption(exportSATCallsOption).getArgumentByName("path").getValueAsString(); } + bool QualitativePOMDPAnalysisSettings::isOnlyDeterministicSet() const { + return this->getOption(onlyDeterministicOption).getHasOptionBeenSet(); + } void QualitativePOMDPAnalysisSettings::finalize() { diff --git a/src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.h b/src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.h index 6a9830e8b..30682e625 100644 --- a/src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.h +++ b/src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.h @@ -23,6 +23,7 @@ namespace storm { uint64_t getLookahead() const; bool isExportSATCallsSet() const; std::string getExportSATCallsPath() const; + bool isOnlyDeterministicSet() const; virtual ~QualitativePOMDPAnalysisSettings() = default; diff --git a/src/storm-pomdp-cli/storm-pomdp.cpp b/src/storm-pomdp-cli/storm-pomdp.cpp index e67e716e6..7ea13d615 100644 --- a/src/storm-pomdp-cli/storm-pomdp.cpp +++ b/src/storm-pomdp-cli/storm-pomdp.cpp @@ -174,6 +174,15 @@ bool extractTargetAndSinkObservationSets(std::shared_ptr +std::set extractObservations(storm::models::sparse::Pomdp const& pomdp, storm::storage::BitVector const& states) { + std::set observations; + for(auto state : states) { + observations.insert(pomdp.getObservation(state)); + } + return observations; +} + /*! * Entry point for the pomdp backend. * @@ -195,6 +204,7 @@ int main(const int argc, const char** argv) { auto const& coreSettings = storm::settings::getModule(); auto const& pomdpSettings = storm::settings::getModule(); + auto const& ioSettings = storm::settings::getModule(); auto const &general = storm::settings::getModule(); auto const &debug = storm::settings::getModule(); auto const& pomdpQualSettings = storm::settings::getModule(); @@ -208,6 +218,9 @@ int main(const int argc, const char** argv) { if (debug.isTraceSet()) { storm::utility::setLogLevel(l3pp::LogLevel::TRACE); } + if (debug.isLogfileSet()) { + storm::utility::initializeFileLogging(); + } // For several engines, no model building step is performed, but the verification is started right away. storm::settings::modules::CoreSettings::Engine engine = coreSettings.getEngine(); @@ -237,17 +250,6 @@ int main(const int argc, const char** argv) { if (formula->isProbabilityOperatorFormula()) { - - std::set targetObservationSet; - storm::storage::BitVector targetStates(pomdp->getNumberOfStates()); - storm::storage::BitVector badStates(pomdp->getNumberOfStates()); - - bool validFormula = extractTargetAndSinkObservationSets(pomdp, subformula1, targetObservationSet, targetStates, badStates); - STORM_LOG_THROW(validFormula, storm::exceptions::InvalidPropertyException, - "The formula is not supported by the grid approximation"); - STORM_LOG_ASSERT(!targetObservationSet.empty(), "The set of target observations is empty!"); - - boost::optional prob1States; boost::optional prob0States; if (pomdpSettings.isSelfloopReductionSet() && !storm::solver::minimize(formula->asProbabilityOperatorFormula().getOptimalityType())) { @@ -271,8 +273,29 @@ int main(const int argc, const char** argv) { storm::pomdp::transformer::KnownProbabilityTransformer kpt = storm::pomdp::transformer::KnownProbabilityTransformer(); pomdp = kpt.transform(*pomdp, *prob0States, *prob1States); } + + if (ioSettings.isExportDotSet()) { + std::shared_ptr> sparseModel = pomdp; + storm::api::exportSparseModelAsDot(sparseModel, ioSettings.getExportDotFilename(), ioSettings.getExportDotMaxWidth()); + } + if (ioSettings.isExportExplicitSet()) { + std::shared_ptr> sparseModel = pomdp; + storm::api::exportSparseModelAsDrn(sparseModel, ioSettings.getExportExplicitFilename()); + } + + + if (pomdpSettings.isGridApproximationSet()) { + std::set targetObservationSet; + storm::storage::BitVector targetStates(pomdp->getNumberOfStates()); + storm::storage::BitVector badStates(pomdp->getNumberOfStates()); + + bool validFormula = extractTargetAndSinkObservationSets(pomdp, subformula1, targetObservationSet, targetStates, badStates); + STORM_LOG_THROW(validFormula, storm::exceptions::InvalidPropertyException, + "The formula is not supported by the grid approximation"); + STORM_LOG_ASSERT(!targetObservationSet.empty(), "The set of target observations is empty!"); + storm::pomdp::modelchecker::ApproximatePOMDPModelchecker checker = storm::pomdp::modelchecker::ApproximatePOMDPModelchecker(); double overRes = storm::utility::one(); double underRes = storm::utility::zero(); @@ -291,10 +314,12 @@ int main(const int argc, const char** argv) { } } if (pomdpSettings.isMemlessSearchSet()) { -// std::cout << std::endl; -// pomdp->writeDotToStream(std::cout); -// std::cout << std::endl; -// std::cout << std::endl; + storm::analysis::QualitativeAnalysis qualitativeAnalysis(*pomdp); + // After preprocessing, this might be done cheaper. + storm::storage::BitVector targetStates = qualitativeAnalysis.analyseProb1(formula->asProbabilityOperatorFormula()); + storm::storage::BitVector surelyNotAlmostSurelyReachTarget = qualitativeAnalysis.analyseProbSmaller1(formula->asProbabilityOperatorFormula()); + std::set targetObservationSet = extractObservations(*pomdp, targetStates); + storm::expressions::ExpressionManager expressionManager; std::shared_ptr smtSolverFactory = std::make_shared(); @@ -302,9 +327,9 @@ int main(const int argc, const char** argv) { if (lookahead == 0) { lookahead = pomdp->getNumberOfStates(); } - storm::pomdp::MemlessSearchOptions options; + options.onlyDeterministicStrategies = pomdpQualSettings.isOnlyDeterministicSet(); uint64_t loglevel = 0; // TODO a big ugly, but we have our own loglevels. if(storm::utility::getLogLevel() == l3pp::LogLevel::INFO) { @@ -322,13 +347,23 @@ int main(const int argc, const char** argv) { options.setExportSATCalls(pomdpQualSettings.getExportSATCallsPath()); } + if (storm::utility::graph::checkIfECWithChoiceExists(pomdp->getTransitionMatrix(), pomdp->getBackwardTransitions(), ~targetStates & ~surelyNotAlmostSurelyReachTarget, storm::storage::BitVector(pomdp->getNumberOfChoices(), true))) { + options.lookaheadRequired = true; + STORM_LOG_DEBUG("Lookahead required."); + } else { + options.lookaheadRequired = false; + STORM_LOG_DEBUG("No lookahead required."); + } + if (pomdpSettings.getMemlessSearchMethod() == "ccd16memless") { - storm::pomdp::QualitativeStrategySearchNaive memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory); + storm::pomdp::QualitativeStrategySearchNaive memlessSearch(*pomdp, targetObservationSet, targetStates, surelyNotAlmostSurelyReachTarget, smtSolverFactory); memlessSearch.findNewStrategyForSomeState(lookahead); } else if (pomdpSettings.getMemlessSearchMethod() == "iterative") { - storm::pomdp::MemlessStrategySearchQualitative memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory, options); + storm::pomdp::MemlessStrategySearchQualitative memlessSearch(*pomdp, targetObservationSet, targetStates, surelyNotAlmostSurelyReachTarget, smtSolverFactory, options); memlessSearch.findNewStrategyForSomeState(lookahead); + memlessSearch.finalizeStatistics(); + memlessSearch.getStatistics().print(); } else { STORM_LOG_ERROR("This method is not implemented."); } diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp index d1a67d132..b1e7aea9b 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp @@ -19,18 +19,30 @@ namespace storm { STORM_LOG_TRACE(ss.str()); i = 0; STORM_LOG_TRACE("states from which we continue: "); - ss.clear(); + std::stringstream ss2; for (auto rv : continuationVars) { if (model->getBooleanValue(rv)) { - ss << " " << i; + ss2 << " " << i; } ++i; } - STORM_LOG_TRACE(ss.str()); + STORM_LOG_TRACE(ss2.str()); } } + template + void MemlessStrategySearchQualitative::Statistics::print() const { + STORM_PRINT_AND_LOG("Total time: " << totalTimer << std::endl); + STORM_PRINT_AND_LOG("SAT Calls " << satCalls << std::endl); + STORM_PRINT_AND_LOG("SAT Calls time: " << smtCheckTimer << std::endl); + STORM_PRINT_AND_LOG("Outer iterations: " << outerIterations << std::endl); + STORM_PRINT_AND_LOG("Solver initialization time: " << initializeSolverTimer << std::endl); + STORM_PRINT_AND_LOG("Extend partial scheduler time: " << updateExtensionSolverTime << std::endl); + STORM_PRINT_AND_LOG("Update solver with new scheduler time: " << updateNewStrategySolverTime << std::endl); + STORM_PRINT_AND_LOG("Winning regions update time: " << winningRegionUpdatesTimer << std::endl); + } + template MemlessStrategySearchQualitative::MemlessStrategySearchQualitative(storm::models::sparse::Pomdp const& pomdp, std::set const& targetObservationSet, @@ -39,9 +51,9 @@ namespace storm { std::shared_ptr& smtSolverFactory, MemlessSearchOptions const& options) : pomdp(pomdp), - targetStates(targetStates), surelyReachSinkStates(surelyReachSinkStates), targetObservations(targetObservationSet), + targetStates(targetStates), options(options) { this->expressionManager = std::make_shared(); @@ -49,6 +61,7 @@ namespace storm { // Initialize states per observation. for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { statesPerObservation.push_back(std::vector()); // Consider using bitvectors instead. + reachVarExpressionsPerObservation.push_back(std::vector()); } uint64_t state = 0; for (auto obs : pomdp.getObservations()) { @@ -60,10 +73,15 @@ namespace storm { nrStatesPerObservation.push_back(states.size()); } winningRegion = WinningRegion(nrStatesPerObservation); + } template void MemlessStrategySearchQualitative::initialize(uint64_t k) { + STORM_LOG_INFO("Start intializing solver..."); + // TODO fix this + bool lookaheadConstraintsRequired = options.lookaheadRequired; + STORM_LOG_WARN("We have hardcoded that we do not need lookahead"); if (maxK == std::numeric_limits::max()) { // not initialized at all. // Create some data structures. @@ -76,16 +94,19 @@ namespace storm { // declare the reachability variables, // declare the path variables. for(uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { - pathVars.push_back(std::vector()); - for (uint64_t i = 0; i < k; ++i) { - pathVars.back().push_back(expressionManager->declareBooleanVariable("P-"+std::to_string(stateId)+"-"+std::to_string(i)).getExpression()); + if(lookaheadConstraintsRequired) { + pathVars.push_back(std::vector()); + for (uint64_t i = 0; i < k; ++i) { + pathVars.back().push_back(expressionManager->declareBooleanVariable("P-" + std::to_string(stateId) + "-" + std::to_string(i)).getExpression()); + } } reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId))); reachVarExpressions.push_back(reachVars.back().getExpression()); + reachVarExpressionsPerObservation[pomdp.getObservation(stateId)].push_back(reachVarExpressions.back()); continuationVars.push_back(expressionManager->declareBooleanVariable("D-" + std::to_string(stateId))); continuationVarExpressions.push_back(continuationVars.back().getExpression()); } - assert(pathVars.size() == pomdp.getNumberOfStates()); + assert(!lookaheadConstraintsRequired || pathVars.size() == pomdp.getNumberOfStates()); assert(reachVars.size() == pomdp.getNumberOfStates()); assert(reachVarExpressions.size() == pomdp.getNumberOfStates()); @@ -101,7 +122,15 @@ namespace storm { schedulerVariableExpressions.push_back(schedulerVariables.back()); switchVars.push_back(expressionManager->declareBooleanVariable("S-" + std::to_string(obs))); switchVarExpressions.push_back(switchVars.back().getExpression()); - + observationUpdatedVariables.push_back(expressionManager->declareBooleanVariable("U-" + std::to_string(obs))); + observationUpdatedExpressions.push_back(observationUpdatedVariables.back().getExpression()); + if (options.onlyDeterministicStrategies) { + for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front())-1; ++a) { + for (uint64_t b = a+1; b < pomdp.getNumberOfChoices(statesForObservation.front()); ++b) { + smtSolver->add(!actionSelectionVarExpressions[obs][a] || !actionSelectionVarExpressions[obs][b]); + } + } + } ++obs; } @@ -109,11 +138,15 @@ namespace storm { smtSolver->add(storm::expressions::disjunction(actionVars)); } - for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { - if (targetStates.get(state)) { - smtSolver->add(pathVars[state][0]); - } else { - smtSolver->add(!pathVars[state][0]); + smtSolver->add(storm::expressions::disjunction(observationUpdatedExpressions)); + + if (lookaheadConstraintsRequired) { + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + if (targetStates.get(state)) { + smtSolver->add(pathVars[state][0]); + } else { + smtSolver->add(!pathVars[state][0]); + } } } @@ -152,42 +185,49 @@ namespace storm { uint64_t rowindex = 0; for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { - if (surelyReachSinkStates.get(state)) { smtSolver->add(!reachVarExpressions[state]); - for (uint64_t j = 1; j < k; ++j) { - smtSolver->add(!pathVars[state][j]); - } smtSolver->add(!continuationVarExpressions[state]); - } else if(!targetStates.get(state)) { - std::vector>> pathsubsubexprs; - for (uint64_t j = 1; j < k; ++j) { - pathsubsubexprs.push_back(std::vector>()); - for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { - pathsubsubexprs.back().push_back(std::vector()); + if (lookaheadConstraintsRequired) { + for (uint64_t j = 1; j < k; ++j) { + smtSolver->add(!pathVars[state][j]); } } + rowindex += pomdp.getNumberOfChoices(state); + } else if(!targetStates.get(state)) { + if (lookaheadConstraintsRequired) { + smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back())); + + std::vector>> pathsubsubexprs; + for (uint64_t j = 1; j < k; ++j) { + pathsubsubexprs.push_back(std::vector>()); + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { + pathsubsubexprs.back().push_back(std::vector()); + } + } - for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { - std::vector subexprreach; - for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { - for (uint64_t j = 1; j < k; ++j) { - pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]); + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { + std::vector subexprreach; + for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { + for (uint64_t j = 1; j < k; ++j) { + pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]); + } } + rowindex++; } - rowindex++; - } - smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back())); - for (uint64_t j = 1; j < k; ++j) { - std::vector pathsubexprs; + for (uint64_t j = 1; j < k; ++j) { + std::vector pathsubexprs; - for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { - pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { + pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); + } + pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state))); + smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); } - pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state))); - smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); } + } else { + rowindex += pomdp.getNumberOfChoices(state); } } @@ -199,17 +239,47 @@ namespace storm { ++obs; } + for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { + smtSolver->add(storm::expressions::implies(switchVarExpressions[obs], storm::expressions::disjunction(reachVarExpressionsPerObservation[obs]))); + } + if (!lookaheadConstraintsRequired) { + uint64_t rowIndex = 0; + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + uint64_t enabledActions = pomdp.getNumberOfChoices(state); + if (!surelyReachSinkStates.get(state)) { + std::vector successorVars; + for (uint64_t act = 0; act < enabledActions; ++act) { + for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowIndex)) { + successorVars.push_back(reachVarExpressions[entries.getColumn()]); + } + rowIndex++; + } + successorVars.push_back(!switchVars[pomdp.getObservation(state)]); + smtSolver->add(storm::expressions::implies(storm::expressions::conjunction(successorVars), reachVarExpressions[state])); + } else { + rowIndex += enabledActions; + } + } + } else { + STORM_LOG_WARN("Some optimization not implemented yet."); + } + + + // TODO: Update found schedulers if k is increased. } template bool MemlessStrategySearchQualitative::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) { + stats.initializeSolverTimer.start(); if (k < maxK) { initialize(k); maxK = k; } + uint64_t maximalNrActions = 8; + STORM_LOG_WARN("We have hardcoded (an upper bound on) the number of actions"); std::vector atLeastOneOfStates; for (uint64_t state : oneOfTheseStates) { @@ -225,112 +295,117 @@ namespace storm { } smtSolver->push(); - uint64_t obs = 0; - for(auto const& statesForObservation : statesPerObservation) { - smtSolver->add(schedulerVariableExpressions[obs] <= schedulerForObs.size()); - ++obs; - } + std::vector updateForObservationExpressions; for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) { + updateForObservationExpressions.push_back(storm::expressions::disjunction(reachVarExpressionsPerObservation[ob])); schedulerForObs.push_back(std::vector()); } + for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { + auto constant = expressionManager->integer(schedulerForObs[obs].size()); + smtSolver->add(schedulerVariableExpressions[obs] <= constant); + smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs])); + } + + + InternalObservationScheduler scheduler; scheduler.switchObservations = storm::storage::BitVector(pomdp.getNrObservations()); + storm::storage::BitVector newObservations(pomdp.getNrObservations()); + storm::storage::BitVector newObservationsAfterSwitch(pomdp.getNrObservations()); storm::storage::BitVector observations(pomdp.getNrObservations()); storm::storage::BitVector observationsAfterSwitch(pomdp.getNrObservations()); - storm::storage::BitVector remainingstates(pomdp.getNumberOfStates()); + storm::storage::BitVector observationUpdated(pomdp.getNrObservations()); + storm::storage::BitVector coveredStates(pomdp.getNumberOfStates()); + storm::storage::BitVector coveredStatesAfterSwitch(pomdp.getNumberOfStates()); + + stats.initializeSolverTimer.stop(); + STORM_LOG_INFO("Start iterative solver..."); uint64_t iterations = 0; while(true) { - scheduler.clear(); + stats.incrementOuterIterations(); + + scheduler.reset(pomdp.getNrObservations(), maximalNrActions); observations.clear(); observationsAfterSwitch.clear(); - remainingstates.clear(); + coveredStates.clear(); + coveredStatesAfterSwitch.clear(); + observationUpdated.clear(); + bool newSchedulerDiscovered = false; while (true) { ++iterations; - if(options.isExportSATSet()) { - STORM_LOG_DEBUG("Export SMT Solver Call (" <getSmtLibString() << std::endl; - storm::utility::closeFile(filestream); - } - - STORM_LOG_DEBUG("Call to SMT Solver (" <check(); - uint64_t i = 0; - - if (result == storm::solver::SmtSolver::CheckResult::Unknown) { - STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); - } else if (result == storm::solver::SmtSolver::CheckResult::Unsat) { - STORM_LOG_DEBUG("Unsatisfiable!"); + bool foundScheduler = this->smtCheck(iterations); + if (!foundScheduler) { break; } - - STORM_LOG_DEBUG("Satisfying assignment: "); - STORM_LOG_TRACE(smtSolver->getModelAsValuation().toString(true)); + newSchedulerDiscovered = true; + stats.updateExtensionSolverTime.start(); auto model = smtSolver->getModel(); + newObservationsAfterSwitch.clear(); + newObservations.clear(); - observations.clear(); - observationsAfterSwitch.clear(); - remainingstates.clear(); - scheduler.clear(); + uint64_t obs = 0; + for (auto ov : observationUpdatedVariables) { + if (!observationUpdated.get(obs) && model->getBooleanValue(ov)) { + STORM_LOG_TRACE("New observation updated: " << obs); + observationUpdated.set(obs); + } + obs++; + } + + uint64_t i = 0; for (auto rv : reachVars) { - if (model->getBooleanValue(rv)) { + if (!coveredStates.get(i) && model->getBooleanValue(rv)) { + STORM_LOG_TRACE("New state: " << i); smtSolver->add(rv.getExpression()); - observations.set(pomdp.getObservation(i)); - } else { - remainingstates.set(i); + newObservations.set(pomdp.getObservation(i)); + coveredStates.set(i); } ++i; } i = 0; for (auto rv : continuationVars) { - if (model->getBooleanValue(rv)) { + if (!coveredStatesAfterSwitch.get(i) && model->getBooleanValue(rv) ) { smtSolver->add(rv.getExpression()); - observationsAfterSwitch.set(pomdp.getObservation(i)); + if (!observationsAfterSwitch.get(pomdp.getObservation(i))) { + newObservationsAfterSwitch.set(pomdp.getObservation(i)); + } + ++i; } - ++i; } if (options.computeTraceOutput()) { detail::printRelevantInfoFromModel(model, reachVars, continuationVars); } - // TODO do not repush everyting to the solver. - std::vector schedulerSoFar; - uint64_t obs = 0; - for (auto const &actionSelectionVarsForObs : actionSelectionVars) { - scheduler.actions.push_back(std::set()); - if (observations.get(obs)) { - for (uint64_t act = 0; act < actionSelectionVarsForObs.size(); ++act) { - auto const& asv = actionSelectionVarsForObs[act]; - if (model->getBooleanValue(asv)) { - scheduler.actions.back().insert(act); - schedulerSoFar.push_back(actionSelectionVarExpressions[obs][act]); - } - } - if (model->getBooleanValue(switchVars[obs])) { - scheduler.switchObservations.set(obs); - schedulerSoFar.push_back(switchVarExpressions[obs]); + for (auto obs : newObservations) { + auto const &actionSelectionVarsForObs = actionSelectionVars[obs]; + observations.set(obs); + for (uint64_t act = 0; act < actionSelectionVarsForObs.size(); ++act) { + if (model->getBooleanValue(actionSelectionVarsForObs[act])) { + scheduler.actions[obs].set(act); + smtSolver->add(actionSelectionVarExpressions[obs][act]); } else { - schedulerSoFar.push_back(!switchVarExpressions[obs]); + smtSolver->add(!actionSelectionVarExpressions[obs][act]); } } - - if (observationsAfterSwitch.get(obs)) { - scheduler.schedulerRef.push_back(model->getIntegerValue(schedulerVariables[obs])); - schedulerSoFar.push_back(schedulerVariableExpressions[obs] == expressionManager->integer(scheduler.schedulerRef.back())); + if (model->getBooleanValue(switchVars[obs])) { + scheduler.switchObservations.set(obs); + smtSolver->add(switchVarExpressions[obs]); } else { - scheduler.schedulerRef.push_back(0); + smtSolver->add(!switchVarExpressions[obs]); } - obs++; + } + for (auto obs : newObservationsAfterSwitch) { + observationsAfterSwitch.set(obs); + scheduler.schedulerRef[obs] = model->getIntegerValue(schedulerVariables[obs]); + smtSolver->add(schedulerVariableExpressions[obs] == expressionManager->integer(scheduler.schedulerRef.back())); } if(options.computeTraceOutput()) { @@ -341,56 +416,82 @@ namespace storm { } std::vector remainingExpressions; - for (auto index : remainingstates) { - remainingExpressions.push_back(reachVarExpressions[index]); + for (auto index : ~coveredStates) { + if (observationUpdated.get(pomdp.getObservation(index))) { + remainingExpressions.push_back(reachVarExpressions[index]); + } + } + for (auto index : ~observationUpdated) { + remainingExpressions.push_back(observationUpdatedExpressions[index]); + } + + + if (remainingExpressions.empty()) { + stats.updateExtensionSolverTime.stop(); + break; } // Add scheduler - smtSolver->add(storm::expressions::conjunction(schedulerSoFar)); + + //std::cout << storm::expressions::disjunction(remainingExpressions) << std::endl; + smtSolver->add(storm::expressions::disjunction(remainingExpressions)); + stats.updateExtensionSolverTime.stop(); } - if (scheduler.empty()) { + if (!newSchedulerDiscovered) { break; } smtSolver->pop(); if(options.computeDebugOutput()) { - printCoveredStates(remainingstates); + printCoveredStates(~coveredStates); // generates info output, but here we only want it for debug level. // For consistency, all output on info level. STORM_LOG_DEBUG("the scheduler: "); scheduler.printForObservations(observations,observationsAfterSwitch); } - std::vector remainingExpressions; - for (auto index : remainingstates) { - remainingExpressions.push_back(reachVarExpressions[index]); - } + stats.winningRegionUpdatesTimer.start(); + storm::storage::BitVector updated(observations.size()); for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { - storm::storage::BitVector update = storm::storage::BitVector(statesPerObservation[observation].size()); + STORM_LOG_TRACE("consider observation " << observation); + storm::storage::BitVector update(statesPerObservation[observation].size()); uint64_t i = 0; for (uint64_t state : statesPerObservation[observation]) { - if (!remainingstates.get(state)) { + if (coveredStates.get(state)) { update.set(i); } + ++i; + } + if(!update.empty()) { + STORM_LOG_TRACE("Update Winning Region: Observation " << observation << " with update " << update); + bool updateResult = winningRegion.update(observation, update); + STORM_LOG_TRACE("Region changed:" << updateResult); + if (updateResult) { + updated.set(observation); + updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]); + } } - winningRegion.update(observation, update); - ++i; } + STORM_LOG_ASSERT(!updated.empty(), "The strategy should be new in at least one place"); + stats.winningRegionUpdatesTimer.stop(); - smtSolver->add(storm::expressions::disjunction(remainingExpressions)); + if(options.computeDebugOutput()) { + winningRegion.print(); + } + stats.updateNewStrategySolverTime.start(); uint64_t obs = 0; for (auto const &statesForObservation : statesPerObservation) { - if (observations.get(obs)) { + if (observations.get(obs) && updated.get(obs)) { STORM_LOG_DEBUG("We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << "."); assert(schedulerForObs.size() > obs); schedulerForObs[obs].push_back(finalSchedulers.size()); STORM_LOG_DEBUG("We now have " << schedulerForObs[obs].size() << " policies for states with observation " << obs); for (auto const &state : statesForObservation) { - if (remainingstates.get(state)) { + if (!coveredStates.get(state)) { auto constant = expressionManager->integer(schedulerForObs[obs].size()); smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); } @@ -399,14 +500,20 @@ namespace storm { ++obs; } finalSchedulers.push_back(scheduler); + smtSolver->push(); for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { auto constant = expressionManager->integer(schedulerForObs[obs].size()); smtSolver->add(schedulerVariableExpressions[obs] <= constant); + smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs])); } + stats.updateNewStrategySolverTime.stop(); + } + winningRegion.print(); + return true; } @@ -423,21 +530,53 @@ namespace storm { } + template + void MemlessStrategySearchQualitative::printScheduler(std::vector const& ) { + } template - void MemlessStrategySearchQualitative::printScheduler(std::vector const& ) { + void MemlessStrategySearchQualitative::finalizeStatistics() { } + template + typename MemlessStrategySearchQualitative::Statistics const& MemlessStrategySearchQualitative::getStatistics() const{ + return stats; + } template - storm::expressions::Expression const& MemlessStrategySearchQualitative::getDoneActionExpression(uint64_t obs) const { - return actionSelectionVarExpressions[obs].back(); - } + bool MemlessStrategySearchQualitative::smtCheck(uint64_t iteration) { + if(options.isExportSATSet()) { + STORM_LOG_DEBUG("Export SMT Solver Call (" <getSmtLibString() << std::endl; + storm::utility::closeFile(filestream); + } + STORM_LOG_DEBUG("Call to SMT Solver (" <check(); + stats.smtCheckTimer.stop(); + stats.incrementSmtChecks(); + + if (result == storm::solver::SmtSolver::CheckResult::Unknown) { + STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); + } else if (result == storm::solver::SmtSolver::CheckResult::Unsat) { + STORM_LOG_DEBUG("Unsatisfiable!"); + return false; + } + + STORM_LOG_DEBUG("Satisfying assignment: "); + STORM_LOG_TRACE(smtSolver->getModelAsValuation().toString(true)); + return true; + } template class MemlessStrategySearchQualitative; template class MemlessStrategySearchQualitative; + + } } diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h index 674f41b0c..c12e1dd57 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h @@ -4,6 +4,7 @@ #include "storm/solver/SmtSolver.h" #include "storm/models/sparse/Pomdp.h" #include "storm/utility/solver.h" +#include "storm/utility/Stopwatch.h" #include "storm/exceptions/UnexpectedException.h" #include "storm-pomdp/analysis/WinningRegion.h" @@ -42,20 +43,24 @@ namespace pomdp { return debugLevel > 2; } + bool onlyDeterministicStrategies = false; + bool lookaheadRequired = true; + private: std::string exportSATcalls = ""; uint64_t debugLevel = 0; + }; struct InternalObservationScheduler { - std::vector> actions; + std::vector actions; std::vector schedulerRef; storm::storage::BitVector switchObservations; - void clear() { - actions.clear(); - schedulerRef.clear(); + void reset(uint64_t nrObservations, uint64_t nrActions) { + actions = std::vector(nrObservations, storm::storage::BitVector(nrActions)); + schedulerRef = std::vector(nrObservations, 0); switchObservations.clear(); } @@ -65,8 +70,10 @@ namespace pomdp { void printForObservations(storm::storage::BitVector const& observations, storm::storage::BitVector const& observationsAfterSwitch) const { for (uint64_t obs = 0; obs < observations.size(); ++obs) { - if (observations.get(obs)) { + if (observations.get(obs) || observationsAfterSwitch.get(obs)) { STORM_LOG_INFO("For observation: " << obs); + } + if (observations.get(obs)) { std::stringstream ss; ss << "actions:"; for (auto act : actions[obs]) { @@ -90,6 +97,31 @@ namespace pomdp { // Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper. public: + class Statistics { + public: + Statistics() = default; + void print() const; + + storm::utility::Stopwatch totalTimer; + storm::utility::Stopwatch smtCheckTimer; + storm::utility::Stopwatch initializeSolverTimer; + storm::utility::Stopwatch updateExtensionSolverTime; + storm::utility::Stopwatch updateNewStrategySolverTime; + + storm::utility::Stopwatch winningRegionUpdatesTimer; + + void incrementOuterIterations() { + outerIterations++; + } + + void incrementSmtChecks() { + satCalls++; + } + private: + uint64_t satCalls = 0; + uint64_t outerIterations = 0; + }; + MemlessStrategySearchQualitative(storm::models::sparse::Pomdp const& pomdp, std::set const& targetObservationSet, storm::storage::BitVector const& targetStates, @@ -98,19 +130,24 @@ namespace pomdp { MemlessSearchOptions const& options); void analyzeForInitialStates(uint64_t k) { + stats.totalTimer.start(); analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates()); + stats.totalTimer.stop(); } void findNewStrategyForSomeState(uint64_t k) { std::cout << surelyReachSinkStates << std::endl; std::cout << targetStates << std::endl; std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl; + stats.totalTimer.start(); analyze(k, ~surelyReachSinkStates & ~targetStates); + stats.totalTimer.stop(); } bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector()); - + Statistics const& getStatistics() const; + void finalizeStatistics(); private: storm::expressions::Expression const& getDoneActionExpression(uint64_t obs) const; @@ -119,24 +156,30 @@ namespace pomdp { void initialize(uint64_t k); + bool smtCheck(uint64_t iteration); + std::unique_ptr smtSolver; storm::models::sparse::Pomdp const& pomdp; std::shared_ptr expressionManager; uint64_t maxK = std::numeric_limits::max(); + storm::storage::BitVector surelyReachSinkStates; std::set targetObservations; storm::storage::BitVector targetStates; - storm::storage::BitVector surelyReachSinkStates; + std::vector> statesPerObservation; std::vector schedulerVariables; std::vector schedulerVariableExpressions; - std::vector> statesPerObservation; std::vector> actionSelectionVarExpressions; // A_{z,a} - std::vector> actionSelectionVars; + std::vector> actionSelectionVars; // A_{z,a} std::vector reachVars; std::vector reachVarExpressions; + std::vector> reachVarExpressionsPerObservation; + + std::vector observationUpdatedVariables; + std::vector observationUpdatedExpressions; std::vector switchVars; std::vector switchVarExpressions; @@ -149,7 +192,7 @@ namespace pomdp { WinningRegion winningRegion; MemlessSearchOptions options; - + Statistics stats; }; diff --git a/src/storm-pomdp/analysis/QualitativeAnalysis.cpp b/src/storm-pomdp/analysis/QualitativeAnalysis.cpp index 9bcec71a5..0feea0c95 100644 --- a/src/storm-pomdp/analysis/QualitativeAnalysis.cpp +++ b/src/storm-pomdp/analysis/QualitativeAnalysis.cpp @@ -26,6 +26,23 @@ namespace storm { storm::storage::BitVector QualitativeAnalysis::analyseProb1(storm::logic::ProbabilityOperatorFormula const& formula) const { return analyseProb0or1(formula, false); } + + template + storm::storage::BitVector QualitativeAnalysis::analyseProbSmaller1(storm::logic::ProbabilityOperatorFormula const &formula) const { + STORM_LOG_THROW(formula.hasOptimalityType() || formula.hasBound(), storm::exceptions::InvalidPropertyException, "The formula " << formula << " does not specify whether to minimize or maximize."); + bool minimizes = (formula.hasOptimalityType() && storm::solver::minimize(formula.getOptimalityType())) || (formula.hasBound() && storm::logic::isLowerBound(formula.getBound().comparisonType)); + STORM_LOG_THROW(!minimizes,storm::exceptions::NotImplementedException, "This operation is only supported when maximizing"); + std::shared_ptr subformula = formula.getSubformula().asSharedPointer(); + std::shared_ptr untilSubformula; + // If necessary, convert the subformula to a more general case + if (subformula->isEventuallyFormula()) { + untilSubformula = std::make_shared(storm::logic::Formula::getTrueFormula(), subformula->asEventuallyFormula().getSubformula().asSharedPointer()); + } else if(subformula->isUntilFormula()) { + untilSubformula = std::make_shared(subformula->asUntilFormula()); + } + // The vector is sound, but not necessarily complete! + return ~storm::utility::graph::performProb1E(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), checkPropositionalFormula(untilSubformula->getLeftSubformula()), checkPropositionalFormula(untilSubformula->getRightSubformula())); + } template storm::storage::BitVector QualitativeAnalysis::analyseProb0or1(storm::logic::ProbabilityOperatorFormula const& formula, bool prob0) const { diff --git a/src/storm-pomdp/analysis/QualitativeAnalysis.h b/src/storm-pomdp/analysis/QualitativeAnalysis.h index ac692e866..e15f6161d 100644 --- a/src/storm-pomdp/analysis/QualitativeAnalysis.h +++ b/src/storm-pomdp/analysis/QualitativeAnalysis.h @@ -10,7 +10,7 @@ namespace storm { QualitativeAnalysis(storm::models::sparse::Pomdp const& pomdp); storm::storage::BitVector analyseProb0(storm::logic::ProbabilityOperatorFormula const& formula) const; storm::storage::BitVector analyseProb1(storm::logic::ProbabilityOperatorFormula const& formula) const; - + storm::storage::BitVector analyseProbSmaller1(storm::logic::ProbabilityOperatorFormula const& formula) const; private: storm::storage::BitVector analyseProb0or1(storm::logic::ProbabilityOperatorFormula const& formula, bool prob0) const; storm::storage::BitVector analyseProb0Max(storm::logic::UntilFormula const& formula) const; diff --git a/src/storm-pomdp/analysis/WinningRegion.cpp b/src/storm-pomdp/analysis/WinningRegion.cpp index 5b0d4d889..596bb0a5d 100644 --- a/src/storm-pomdp/analysis/WinningRegion.cpp +++ b/src/storm-pomdp/analysis/WinningRegion.cpp @@ -1,4 +1,6 @@ #include +#include "storm/storage/expressions/Expression.h" +#include "storm/storage/expressions/ExpressionManager.h" #include "storm-pomdp/analysis/WinningRegion.h" namespace storm { @@ -10,13 +12,13 @@ namespace pomdp { } } - void WinningRegion::update(uint64_t observation, storm::storage::BitVector const& winning) { + bool WinningRegion::update(uint64_t observation, storm::storage::BitVector const& winning) { std::vector newWinningSupport = std::vector(); bool changed = false; for (auto const& support : winningRegion[observation]) { if (winning.isSubsetOf(support)) { // This new winning support is already covered. - return; + return false; } if(support.isSubsetOf(winning)) { // This new winning support extends the previous support, thus the previous support is now spurious @@ -33,6 +35,7 @@ namespace pomdp { } else { winningRegion[observation].push_back(winning); } + return true; } @@ -45,6 +48,34 @@ namespace pomdp { return false; } + storm::expressions::Expression WinningRegion::extensionExpression(uint64_t observation, std::vector& varsForStates) const { + std::vector expressionForEntry; + + for(auto const& winningForObservation : winningRegion[observation]) { + if (winningForObservation.full()) { + assert(winningRegion[observation].size() == 1); + return varsForStates.front().getManager().boolean(false); + } + std::vector subexpr; + std::vector leftHandSides; + assert(varsForStates.size() == winningForObservation.size()); + for(uint64_t i = 0; i < varsForStates.size(); ++i) { + if (winningForObservation.get(i)) { + leftHandSides.push_back(varsForStates[i]); + } else { + subexpr.push_back(varsForStates[i]); + } + } + storm::expressions::Expression rightHandSide = storm::expressions::disjunction(subexpr); + for(auto const& lhs : leftHandSides) { + expressionForEntry.push_back(storm::expressions::implies(lhs,rightHandSide)); + } + expressionForEntry.push_back(storm::expressions::disjunction(varsForStates)); + + } + return storm::expressions::conjunction(expressionForEntry); + } + /** * If we observe this observation, do we surely win? * @param observation @@ -62,6 +93,7 @@ namespace pomdp { std::cout << " " << support; } std::cout << std::endl; + observation++; } } diff --git a/src/storm-pomdp/analysis/WinningRegion.h b/src/storm-pomdp/analysis/WinningRegion.h index 0356a9e94..e23dabf70 100644 --- a/src/storm-pomdp/analysis/WinningRegion.h +++ b/src/storm-pomdp/analysis/WinningRegion.h @@ -9,12 +9,14 @@ namespace storm { public: WinningRegion(std::vector const& observationSizes = {}); - void update(uint64_t observation, storm::storage::BitVector const& winning); + bool update(uint64_t observation, storm::storage::BitVector const& winning); bool query(uint64_t observation, storm::storage::BitVector const& currently) const; bool observationIsWinning(uint64_t observation) const; + storm::expressions::Expression extensionExpression(uint64_t observation, std::vector& varsForStates) const; - uint64_t getStorageSize() const; + + uint64_t getStorageSize() const; uint64_t getNumberOfObservations() const; void print() const; private: