diff --git a/src/storm-pomdp-cli/settings/modules/POMDPSettings.cpp b/src/storm-pomdp-cli/settings/modules/POMDPSettings.cpp index 2a637d436..5a26cda46 100644 --- a/src/storm-pomdp-cli/settings/modules/POMDPSettings.cpp +++ b/src/storm-pomdp-cli/settings/modules/POMDPSettings.cpp @@ -27,6 +27,7 @@ namespace storm { const std::string transformBinaryOption = "transformbinary"; const std::string transformSimpleOption = "transformsimple"; const std::string memlessSearchOption = "memlesssearch"; + std::vector memlessSearchMethods = {"none", "ccdmemless", "ccdmemory", "iterative"}; POMDPSettings::POMDPSettings() : ModuleSettings(moduleName) { this->addOption(storm::settings::OptionBuilder(moduleName, exportAsParametricModelOption, false, "Export the parametric file.").addArgument(storm::settings::ArgumentBuilder::createStringArgument("filename", "The name of the file to which to write the model.").build()).build()); @@ -46,7 +47,9 @@ namespace storm { 10).addValidatorUnsignedInteger( storm::settings::ArgumentValidatorFactory::createUnsignedGreaterValidator( 0)).build()).build()); - this->addOption(storm::settings::OptionBuilder(moduleName, memlessSearchOption, false, "Search for a qualitative memoryless scheuler").build()); + + this->addOption(storm::settings::OptionBuilder(moduleName, memlessSearchOption, false, "Search for a qualitative memoryless scheuler").addArgument(storm::settings::ArgumentBuilder::createStringArgument("method", "method name").addValidatorString(ArgumentValidatorFactory::createMultipleChoiceValidator(memlessSearchMethods)).setDefaultValueString("none").build()).build()); + } bool POMDPSettings::isExportToParametricSet() const { @@ -86,6 +89,10 @@ namespace storm { return this->getOption(memlessSearchOption).getHasOptionBeenSet(); } + std::string POMDPSettings::getMemlessSearchMethod() const { + return this->getOption(memlessSearchOption).getArgumentByName("method").getValueAsString(); + } + uint64_t POMDPSettings::getMemoryBound() const { return this->getOption(memoryBoundOption).getArgumentByName("bound").getValueAsUnsignedInteger(); } diff --git a/src/storm-pomdp-cli/settings/modules/POMDPSettings.h b/src/storm-pomdp-cli/settings/modules/POMDPSettings.h index 9f7332774..bd3b13fa8 100644 --- a/src/storm-pomdp-cli/settings/modules/POMDPSettings.h +++ b/src/storm-pomdp-cli/settings/modules/POMDPSettings.h @@ -33,6 +33,7 @@ namespace storm { bool isTransformSimpleSet() const; bool isTransformBinarySet() const; bool isMemlessSearchSet() const; + std::string getMemlessSearchMethod() const; std::string getFscApplicationTypeString() const; uint64_t getMemoryBound() const; diff --git a/src/storm-pomdp-cli/storm-pomdp.cpp b/src/storm-pomdp-cli/storm-pomdp.cpp index 3927d0728..3fa2eb9d4 100644 --- a/src/storm-pomdp-cli/storm-pomdp.cpp +++ b/src/storm-pomdp-cli/storm-pomdp.cpp @@ -26,9 +26,10 @@ #include "storm/settings/modules/TopologicalEquationSolverSettings.h" #include "storm/settings/modules/ModelCheckerSettings.h" #include "storm/settings/modules/MultiplierSettings.h" + +#include "storm/settings/modules/TransformationSettings.h" #include "storm/settings/modules/MultiObjectiveSettings.h" #include "storm-pomdp-cli/settings/modules/POMDPSettings.h" - #include "storm/analysis/GraphConditions.h" #include "storm-cli-utilities/cli.h" @@ -44,6 +45,7 @@ #include "storm-pomdp/analysis/QualitativeAnalysis.h" #include "storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.h" #include "storm-pomdp/analysis/MemlessStrategySearchQualitative.h" +#include "storm-pomdp/analysis/QualitativeStrategySearchNaive.h" #include "storm/api/storm.h" #include @@ -59,6 +61,8 @@ void initializeSettings() { storm::settings::addModule(); storm::settings::addModule(); storm::settings::addModule(); + + storm::settings::addModule(); storm::settings::addModule(); storm::settings::addModule(); storm::settings::addModule(); @@ -79,9 +83,9 @@ void initializeSettings() { } template -bool extractTargetAndSinkObservationSets(std::shared_ptr> const& pomdp, storm::logic::Formula const& subformula, std::set& targetObservationSet, storm::storage::BitVector& badStates) { +bool extractTargetAndSinkObservationSets(std::shared_ptr> const& pomdp, storm::logic::Formula const& subformula, std::set& targetObservationSet, storm::storage::BitVector& targetStates, storm::storage::BitVector& badStates) { //TODO refactor (use model checker to determine the states, then transform into observations). - + //TODO rename into appropriate function name. bool validFormula = false; if (subformula.isEventuallyFormula()) { storm::logic::EventuallyFormula const &eventuallyFormula = subformula.asEventuallyFormula(); @@ -94,6 +98,7 @@ bool extractTargetAndSinkObservationSets(std::shared_ptrgetNumberOfStates(); ++state) { if (labeling.getStateHasLabel(targetLabel, state)) { targetObservationSet.insert(pomdp->getObservation(state)); + targetStates.set(state); } } } else if (subformula2.isAtomicExpressionFormula()) { @@ -106,18 +111,19 @@ bool extractTargetAndSinkObservationSets(std::shared_ptrgetNumberOfStates(); ++state) { if (labeling.getStateHasLabel(targetLabel, state)) { targetObservationSet.insert(pomdp->getObservation(state)); + targetStates.set(state); } } } } else if (subformula.isUntilFormula()) { - storm::logic::UntilFormula const &eventuallyFormula = subformula.asUntilFormula(); - storm::logic::Formula const &subformula1 = eventuallyFormula.getLeftSubformula(); + storm::logic::UntilFormula const &untilFormula = subformula.asUntilFormula(); + storm::logic::Formula const &subformula1 = untilFormula.getLeftSubformula(); if (subformula1.isAtomicLabelFormula()) { storm::logic::AtomicLabelFormula const &alFormula = subformula1.asAtomicLabelFormula(); std::string targetLabel = alFormula.getLabel(); auto labeling = pomdp->getStateLabeling(); for (size_t state = 0; state < pomdp->getNumberOfStates(); ++state) { - if (labeling.getStateHasLabel(targetLabel, state)) { + if (!labeling.getStateHasLabel(targetLabel, state)) { badStates.set(state); } } @@ -128,14 +134,14 @@ bool extractTargetAndSinkObservationSets(std::shared_ptrgetStateLabeling(); for (size_t state = 0; state < pomdp->getNumberOfStates(); ++state) { - if (labeling.getStateHasLabel(targetLabel, state)) { + if (!labeling.getStateHasLabel(targetLabel, state)) { badStates.set(state); } } } else { return false; } - storm::logic::Formula const &subformula2 = eventuallyFormula.getRightSubformula(); + storm::logic::Formula const &subformula2 = untilFormula.getRightSubformula(); if (subformula2.isAtomicLabelFormula()) { storm::logic::AtomicLabelFormula const &alFormula = subformula2.asAtomicLabelFormula(); validFormula = true; @@ -144,7 +150,9 @@ bool extractTargetAndSinkObservationSets(std::shared_ptrgetNumberOfStates(); ++state) { if (labeling.getStateHasLabel(targetLabel, state)) { targetObservationSet.insert(pomdp->getObservation(state)); + targetStates.set(state); } + } } else if (subformula2.isAtomicExpressionFormula()) { validFormula = true; @@ -156,7 +164,9 @@ bool extractTargetAndSinkObservationSets(std::shared_ptrgetNumberOfStates(); ++state) { if (labeling.getStateHasLabel(targetLabel, state)) { targetObservationSet.insert(pomdp->getObservation(state)); + targetStates.set(state); } + } } } @@ -227,9 +237,10 @@ int main(const int argc, const char** argv) { if (formula->isProbabilityOperatorFormula()) { std::set targetObservationSet; - std::set badObservationSet; + storm::storage::BitVector targetStates(pomdp->getNumberOfStates()); + storm::storage::BitVector badStates(pomdp->getNumberOfStates()); - bool validFormula = extractTargetAndSinkObservationSets(pomdp, subformula1, targetObservationSet, badObservationSet); + bool validFormula = extractTargetAndSinkObservationSets(pomdp, subformula1, targetObservationSet, targetStates, badStates); STORM_LOG_THROW(validFormula, storm::exceptions::InvalidPropertyException, "The formula is not supported by the grid approximation"); STORM_LOG_ASSERT(!targetObservationSet.empty(), "The set of target observations is empty!"); @@ -278,11 +289,22 @@ int main(const int argc, const char** argv) { } } if (pomdpSettings.isMemlessSearchSet()) { +// std::cout << std::endl; +// pomdp->writeDotToStream(std::cout); +// std::cout << std::endl; +// std::cout << std::endl; storm::expressions::ExpressionManager expressionManager; std::shared_ptr smtSolverFactory = std::make_shared(); + if (pomdpSettings.getMemlessSearchMethod() == "ccd16memless") { + storm::pomdp::QualitativeStrategySearchNaive memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory); + memlessSearch.findNewStrategyForSomeState(5); + } else if (pomdpSettings.getMemlessSearchMethod() == "iterative") { + storm::pomdp::MemlessStrategySearchQualitative memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory); + memlessSearch.findNewStrategyForSomeState(5); + } else { + STORM_LOG_ERROR("This method is not implemented."); + } - storm::pomdp::MemlessStrategySearchQualitative memlessSearch(*pomdp, targetObservationSet, smtSolverFactory); - memlessSearch.analyze(5); } } else if (formula->isRewardOperatorFormula()) { diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp index 2ed1230e0..675c91aab 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp @@ -6,12 +6,12 @@ namespace storm { template void MemlessStrategySearchQualitative::initialize(uint64_t k) { - if (maxK == -1) { + if (maxK == std::numeric_limits::max()) { // not initialized at all. - // Create some data structures. for(uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { - actionSelectionVars.push_back(std::vector()); + actionSelectionVars.push_back(std::vector()); + actionSelectionVarExpressions.push_back(std::vector()); statesPerObservation.push_back(std::vector()); // Consider using bitvectors instead. } @@ -24,91 +24,182 @@ namespace storm { for (uint64_t i = 0; i < k; ++i) { pathVars.back().push_back(expressionManager->declareBooleanVariable("P-"+std::to_string(stateId)+"-"+std::to_string(i)).getExpression()); } - reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId)).getExpression()); - + reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId))); + reachVarExpressions.push_back(reachVars.back().getExpression()); statesPerObservation.at(obs).push_back(stateId++); - } assert(pathVars.size() == pomdp.getNumberOfStates()); + assert(reachVars.size() == pomdp.getNumberOfStates()); + assert(reachVarExpressions.size() == pomdp.getNumberOfStates()); // Create the action selection variables. uint64_t obs = 0; for(auto const& statesForObservation : statesPerObservation) { for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()); ++a) { std::string varName = "A-" + std::to_string(obs) + "-" + std::to_string(a); - actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName).getExpression()); + actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName)); + actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression()); } ++obs; } - - } else { - assert(false); - } uint64_t rowindex = 0; for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { - std::vector> pathsubsubexprs; - for (uint64_t j = 1; j < k; ++j) { - pathsubsubexprs.push_back(std::vector()); - } - - if (targetObservations.count(pomdp.getObservation(state)) > 0) { + if (targetStates.get(state)) { smtSolver->add(pathVars[state][0]); } else { smtSolver->add(!pathVars[state][0]); } if (surelyReachSinkStates.get(state)) { - smtSolver->add(!reachVars[state]); - } - else { + smtSolver->add(!reachVarExpressions[state]); + } else if(!targetStates.get(state)) { + std::vector>> pathsubsubexprs; + for (uint64_t j = 1; j < k; ++j) { + pathsubsubexprs.push_back(std::vector>()); + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { + pathsubsubexprs.back().push_back(std::vector()); + } + } + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { std::vector subexprreach; - subexprreach.push_back(!reachVars.at(state)); - subexprreach.push_back(!actionSelectionVars.at(pomdp.getObservation(state)).at(action)); + subexprreach.push_back(!reachVarExpressions.at(state)); + subexprreach.push_back(!actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action)); for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { - subexprreach.push_back(reachVars.at(entries.getColumn())); + subexprreach.push_back(reachVarExpressions.at(entries.getColumn())); } smtSolver->add(storm::expressions::disjunction(subexprreach)); for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { for (uint64_t j = 1; j < k; ++j) { - pathsubsubexprs[j - 1].push_back(pathVars[entries.getColumn()][j - 1]); + pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]); } } rowindex++; } - smtSolver->add(storm::expressions::implies(reachVars.at(state), pathVars.at(state).back())); - + smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back())); for (uint64_t j = 1; j < k; ++j) { std::vector pathsubexprs; for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { - pathsubexprs.push_back(actionSelectionVars.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1])); + pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); } smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); } } - } - for (auto const& actionVars : actionSelectionVars) { + for (auto const& actionVars : actionSelectionVarExpressions) { smtSolver->add(storm::expressions::disjunction(actionVars)); } + } + + template + bool MemlessStrategySearchQualitative::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) { + if (k < maxK) { + initialize(k); + } + + std::vector atLeastOneOfStates; + for (uint64_t state : oneOfTheseStates) { + STORM_LOG_ASSERT(reachVarExpressions.size() > state, "state id " << state << " exceeds number of states (" << reachVarExpressions.size() << ")" ); + atLeastOneOfStates.push_back(reachVarExpressions[state]); + } + assert(atLeastOneOfStates.size() > 0); + smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates)); + for (uint64_t state : allOfTheseStates) { + assert(reachVarExpressions.size() > state); + smtSolver->add(reachVarExpressions[state]); + } + std::cout << smtSolver->getSmtLibString() << std::endl; - //for (auto const& ) - } + std::vector> scheduler; + while (true) { + + auto result = smtSolver->check(); + uint64_t i = 0; + + if (result == storm::solver::SmtSolver::CheckResult::Unknown) { + STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); + } else if (result == storm::solver::SmtSolver::CheckResult::Unsat) { + std::cout << std::endl << "Unsatisfiable!" << std::endl; + return false; + } + + std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl; + auto model = smtSolver->getModel(); + std::cout << "states that are okay" << std::endl; + + + storm::storage::BitVector observations(pomdp.getNrObservations()); + storm::storage::BitVector remainingstates(pomdp.getNumberOfStates()); + for (auto rv : reachVars) { + if (model->getBooleanValue(rv)) { + std::cout << i << " " << std::endl; + observations.set(pomdp.getObservation(i)); + } else { + remainingstates.set(i); + } + //std::cout << i << ": " << model->getBooleanValue(rv) << ", "; + ++i; + } + + scheduler.clear(); + + std::vector schedulerSoFar; + uint64_t obs = 0; + for (auto const &actionSelectionVarsForObs : actionSelectionVars) { + uint64_t act = 0; + scheduler.push_back(std::set()); + for (auto const &asv : actionSelectionVarsForObs) { + if (model->getBooleanValue(asv)) { + scheduler.back().insert(act); + schedulerSoFar.push_back(actionSelectionVarExpressions[obs][act]); + } + act++; + } + obs++; + } + + std::cout << "the scheduler: " << std::endl; + for (uint64_t obs = 0; obs < scheduler.size(); ++obs) { + if (observations.get(obs)) { + std::cout << "observation: " << obs << std::endl; + std::cout << "actions:"; + for (auto act : scheduler[obs]) { + std::cout << " " << act; + } + std::cout << std::endl; + } + } + + + std::vector remainingExpressions; + for (auto index : remainingstates) { + remainingExpressions.push_back(reachVarExpressions[index]); + } + + smtSolver->push(); + // Add scheduler + smtSolver->add(storm::expressions::conjunction(schedulerSoFar)); + smtSolver->add(storm::expressions::disjunction(remainingExpressions)); + + } + + } template class MemlessStrategySearchQualitative; + template class MemlessStrategySearchQualitative; } } diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h index ca1a84aa6..5aa69b3ce 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h @@ -3,6 +3,7 @@ #include "storm/solver/SmtSolver.h" #include "storm/models/sparse/Pomdp.h" #include "storm/utility/solver.h" +#include "storm/exceptions/UnexpectedException.h" namespace storm { namespace pomdp { @@ -15,8 +16,12 @@ namespace pomdp { public: MemlessStrategySearchQualitative(storm::models::sparse::Pomdp const& pomdp, std::set const& targetObservationSet, + storm::storage::BitVector const& targetStates, + storm::storage::BitVector const& surelyReachSinkStates, std::shared_ptr& smtSolverFactory) : pomdp(pomdp), + targetStates(targetStates), + surelyReachSinkStates(surelyReachSinkStates), targetObservations(targetObservationSet) { this->expressionManager = std::make_shared(); smtSolver = smtSolverFactory->create(*expressionManager); @@ -27,49 +32,40 @@ namespace pomdp { surelyReachSinkStates = surelyReachSink; } - void analyze(uint64_t k) { - if (k < maxK) { - initialize(k); - } - std::cout << smtSolver->getSmtLibString() << std::endl; - for (uint64_t state : pomdp.getInitialStates()) { - smtSolver->add(reachVars[state]); - } - auto result = smtSolver->check(); - switch(result) { - case storm::solver::SmtSolver::CheckResult::Sat: - std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl; - - case storm::solver::SmtSolver::CheckResult::Unsat: - // std::cout << std::endl << "Unsatisfiability core: {" << std::endl; - // for (auto const& expr : solver->getUnsatCore()) { - // std::cout << "\t " << expr << std::endl; - // } - // std::cout << "}" << std::endl; - - default: - std::cout<< "oops." << std::endl; - // STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); - } - //std::cout << "get model:" << std::endl; - //std::cout << smtSolver->getModel().toString() << std::endl; + void analyzeForInitialStates(uint64_t k) { + analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates()); } + void findNewStrategyForSomeState(uint64_t k) { + std::cout << surelyReachSinkStates << std::endl; + std::cout << targetStates << std::endl; + std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl; + analyze(k, ~surelyReachSinkStates & ~targetStates); + + + } + + bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector()); + private: void initialize(uint64_t k); + std::unique_ptr smtSolver; storm::models::sparse::Pomdp const& pomdp; std::shared_ptr expressionManager; - uint64_t maxK = -1; + uint64_t maxK = std::numeric_limits::max(); std::set targetObservations; + storm::storage::BitVector targetStates; storm::storage::BitVector surelyReachSinkStates; std::vector> statesPerObservation; - std::vector> actionSelectionVars; // A_{z,a} - std::vector reachVars; + std::vector> actionSelectionVarExpressions; // A_{z,a} + std::vector> actionSelectionVars; + std::vector reachVars; + std::vector reachVarExpressions; std::vector> pathVars; diff --git a/src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp b/src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp new file mode 100644 index 000000000..f1ec0b8e8 --- /dev/null +++ b/src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp @@ -0,0 +1,186 @@ + + +#include "storm-pomdp/analysis/QualitativeStrategySearchNaive.h" + + +namespace storm { + namespace pomdp { + + template + void QualitativeStrategySearchNaive::initialize(uint64_t k) { + if (maxK == std::numeric_limits::max()) { + // not initialized at all. + // Create some data structures. + for(uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { + actionSelectionVars.push_back(std::vector()); + actionSelectionVarExpressions.push_back(std::vector()); + statesPerObservation.push_back(std::vector()); // Consider using bitvectors instead. + } + + // Fill the states-per-observation mapping, + // declare the reachability variables, + // declare the path variables. + uint64_t stateId = 0; + for(auto obs : pomdp.getObservations()) { + pathVars.push_back(std::vector()); + for (uint64_t i = 0; i < k; ++i) { + pathVars.back().push_back(expressionManager->declareBooleanVariable("P-"+std::to_string(stateId)+"-"+std::to_string(i)).getExpression()); + } + reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId))); + reachVarExpressions.push_back(reachVars.back().getExpression()); + statesPerObservation.at(obs).push_back(stateId++); + } + assert(pathVars.size() == pomdp.getNumberOfStates()); + + // Create the action selection variables. + uint64_t obs = 0; + for(auto const& statesForObservation : statesPerObservation) { + for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()); ++a) { + std::string varName = "A-" + std::to_string(obs) + "-" + std::to_string(a); + actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName)); + actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression()); + } + ++obs; + } + } else { + assert(false); + } + + uint64_t rowindex = 0; + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + if (targetStates.get(state)) { + smtSolver->add(pathVars[state][0]); + } else { + smtSolver->add(!pathVars[state][0]); + } + + if (surelyReachSinkStates.get(state)) { + smtSolver->add(!reachVarExpressions[state]); + } else if(!targetStates.get(state)) { + std::vector>> pathsubsubexprs; + for (uint64_t j = 1; j < k; ++j) { + pathsubsubexprs.push_back(std::vector>()); + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { + pathsubsubexprs.back().push_back(std::vector()); + } + } + + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { + std::vector subexprreach; + + subexprreach.push_back(!reachVarExpressions.at(state)); + subexprreach.push_back(!actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action)); + for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { + subexprreach.push_back(reachVarExpressions.at(entries.getColumn())); + } + smtSolver->add(storm::expressions::disjunction(subexprreach)); + for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { + for (uint64_t j = 1; j < k; ++j) { + pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]); + } + } + rowindex++; + } + smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back())); + + for (uint64_t j = 1; j < k; ++j) { + std::vector pathsubexprs; + + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { + pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); + } + smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); + } + } + } + + for (auto const& actionVars : actionSelectionVarExpressions) { + smtSolver->add(storm::expressions::disjunction(actionVars)); + } + } + + template + bool QualitativeStrategySearchNaive::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) { + if (k < maxK) { + initialize(k); + } + + std::vector atLeastOneOfStates; + + for(uint64_t state : oneOfTheseStates) { + atLeastOneOfStates.push_back(reachVarExpressions[state]); + } + assert(atLeastOneOfStates.size() > 0); + smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates)); + + for(uint64_t state : allOfTheseStates) { + smtSolver->add(reachVarExpressions[state]); + } + + + + std::cout << smtSolver->getSmtLibString() << std::endl; + + auto result = smtSolver->check(); + uint64_t i = 0; + smtSolver->push(); + + + + if (result == storm::solver::SmtSolver::CheckResult::Unknown) { + STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); + } else if(result == storm::solver::SmtSolver::CheckResult::Unsat) { + std::cout << std::endl << "Unsatisfiable!" << std::endl; + return false; + } else { + + std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl; + auto model = smtSolver->getModel(); + std::cout << "states that are okay" << std::endl; + storm::storage::BitVector observations(pomdp.getNrObservations()); + storm::storage::BitVector remainingstates(pomdp.getNumberOfStates()); + for (auto rv : reachVars) { + if (model->getBooleanValue(rv)) { + std::cout << i << " " << std::endl; + observations.set(pomdp.getObservation(i)); + } else { + remainingstates.set(i); + } + //std::cout << i << ": " << model->getBooleanValue(rv) << ", "; + ++i; + } + std::vector > scheduler; + for (auto const &actionSelectionVarsForObs : actionSelectionVars) { + uint64_t act = 0; + scheduler.push_back(std::set()); + for (auto const &asv : actionSelectionVarsForObs) { + if (model->getBooleanValue(asv)) { + scheduler.back().insert(act); + } + act++; + } + } + std::cout << "the scheduler: " << std::endl; + for (uint64_t obs = 0; obs < scheduler.size(); ++obs) { + if (observations.get(obs)) { + std::cout << "observation: " << obs << std::endl; + std::cout << "actions:"; + for (auto act : scheduler[obs]) { + std::cout << " " << act; + } + std::cout << std::endl; + } + } + + + return true; + } + + + + } + + template class QualitativeStrategySearchNaive; + template class QualitativeStrategySearchNaive; + } +} diff --git a/src/storm-pomdp/analysis/QualitativeStrategySearchNaive.h b/src/storm-pomdp/analysis/QualitativeStrategySearchNaive.h new file mode 100644 index 000000000..5020fc9ab --- /dev/null +++ b/src/storm-pomdp/analysis/QualitativeStrategySearchNaive.h @@ -0,0 +1,74 @@ +#include +#include "storm/storage/expressions/Expressions.h" +#include "storm/solver/SmtSolver.h" +#include "storm/models/sparse/Pomdp.h" +#include "storm/utility/solver.h" +#include "storm/exceptions/UnexpectedException.h" + +namespace storm { + namespace pomdp { + + template + class QualitativeStrategySearchNaive { + // Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper. + + + public: + QualitativeStrategySearchNaive(storm::models::sparse::Pomdp const& pomdp, + std::set const& targetObservationSet, + storm::storage::BitVector const& targetStates, + storm::storage::BitVector const& surelyReachSinkStates, + std::shared_ptr& smtSolverFactory) : + pomdp(pomdp), + targetStates(targetStates), + surelyReachSinkStates(surelyReachSinkStates), + targetObservations(targetObservationSet) { + this->expressionManager = std::make_shared(); + smtSolver = smtSolverFactory->create(*expressionManager); + + } + + void setSurelyReachSinkStates(storm::storage::BitVector const& surelyReachSink) { + surelyReachSinkStates = surelyReachSink; + } + + void analyzeForInitialStates(uint64_t k) { + analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates()); + } + + void findNewStrategyForSomeState(uint64_t k) { + std::cout << surelyReachSinkStates << std::endl; + std::cout << targetStates << std::endl; + std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl; + analyze(k, ~surelyReachSinkStates & ~targetStates); + + + } + + bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector()); + + private: + void initialize(uint64_t k); + + + std::unique_ptr smtSolver; + storm::models::sparse::Pomdp const& pomdp; + std::shared_ptr expressionManager; + uint64_t maxK = std::numeric_limits::max(); + + std::set targetObservations; + storm::storage::BitVector targetStates; + storm::storage::BitVector surelyReachSinkStates; + + std::vector> statesPerObservation; + std::vector> actionSelectionVarExpressions; // A_{z,a} + std::vector> actionSelectionVars; + std::vector reachVars; + std::vector reachVarExpressions; + std::vector> pathVars; + + + + }; + } +}