From 9a5b01b6f75da8ce8ba94f531958388504da5266 Mon Sep 17 00:00:00 2001 From: Sebastian Junges Date: Sat, 21 Dec 2019 17:32:21 +0100 Subject: [PATCH] a new encoding for almost sure reachability --- .../MemlessStrategySearchQualitative.cpp | 298 ++++++++++++++---- .../MemlessStrategySearchQualitative.h | 51 +++ 2 files changed, 286 insertions(+), 63 deletions(-) diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp index 675c91aab..15e19fab9 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp @@ -26,6 +26,8 @@ namespace storm { } reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId))); reachVarExpressions.push_back(reachVars.back().getExpression()); + continuationVars.push_back(expressionManager->declareBooleanVariable("D-" + std::to_string(stateId))); + continuationVarExpressions.push_back(continuationVars.back().getExpression()); statesPerObservation.at(obs).push_back(stateId++); } assert(pathVars.size() == pomdp.getNumberOfStates()); @@ -40,22 +42,68 @@ namespace storm { actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName)); actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression()); } + schedulerVariables.push_back(expressionManager->declareBitVectorVariable("scheduler-obs-" + std::to_string(obs), statesPerObservation.size())); + schedulerVariableExpressions.push_back(schedulerVariables.back()); + switchVars.push_back(expressionManager->declareBooleanVariable("S-" + std::to_string(obs))); + switchVarExpressions.push_back(switchVars.back().getExpression()); + ++obs; } + + for (auto const& actionVars : actionSelectionVarExpressions) { + smtSolver->add(storm::expressions::disjunction(actionVars)); + } + + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + if (targetStates.get(state)) { + smtSolver->add(pathVars[state][0]); + } else { + smtSolver->add(!pathVars[state][0]); + } + } + + uint64_t rowindex = 0; + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { + std::vector subexprreachSwitch; + std::vector subexprreachNoSwitch; + subexprreachSwitch.push_back(!reachVarExpressions[state]); + subexprreachSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); + subexprreachSwitch.push_back(!switchVarExpressions[pomdp.getObservation(state)]); + subexprreachNoSwitch.push_back(!reachVarExpressions[state]); + subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); + subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]); + for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { + subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn())); + smtSolver->add(storm::expressions::disjunction(subexprreachSwitch)); + subexprreachSwitch.pop_back(); + subexprreachNoSwitch.push_back(reachVarExpressions.at(entries.getColumn())); + smtSolver->add(storm::expressions::disjunction(subexprreachNoSwitch)); + subexprreachNoSwitch.pop_back(); + } + + rowindex++; + } + } + + smtSolver->push(); } else { + smtSolver->pop(); + smtSolver->pop(); + smtSolver->push(); assert(false); } uint64_t rowindex = 0; for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { - if (targetStates.get(state)) { - smtSolver->add(pathVars[state][0]); - } else { - smtSolver->add(!pathVars[state][0]); - } if (surelyReachSinkStates.get(state)) { smtSolver->add(!reachVarExpressions[state]); + for (uint64_t j = 1; j < k; ++j) { + smtSolver->add(!pathVars[state][j]); + } + smtSolver->add(!continuationVarExpressions[state]); } else if(!targetStates.get(state)) { std::vector>> pathsubsubexprs; for (uint64_t j = 1; j < k; ++j) { @@ -68,12 +116,13 @@ namespace storm { for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { std::vector subexprreach; - subexprreach.push_back(!reachVarExpressions.at(state)); - subexprreach.push_back(!actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action)); - for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { - subexprreach.push_back(reachVarExpressions.at(entries.getColumn())); - } - smtSolver->add(storm::expressions::disjunction(subexprreach)); +// subexprreach.push_back(!reachVarExpressions.at(state)); +// subexprreach.push_back(!actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action)); +// subexprreach.push_back(!switchVarExpressions[pomdp.getObservation(state)]); +// for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { +// subexprreach.push_back(reachVarExpressions.at(entries.getColumn())); +// } +// smtSolver->add(storm::expressions::disjunction(subexprreach)); for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { for (uint64_t j = 1; j < k; ++j) { pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]); @@ -89,22 +138,38 @@ namespace storm { for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); } + pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state))); smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); } } } - for (auto const& actionVars : actionSelectionVarExpressions) { - smtSolver->add(storm::expressions::disjunction(actionVars)); + uint64_t obs = 0; + for(auto const& statesForObservation : statesPerObservation) { + for(auto const& state : statesForObservation) { + smtSolver->add(!continuationVars[state] || schedulerVariableExpressions[obs] > 0); + } + ++obs; } + + // These constraints ensure that the right solver is used. +// obs = 0; +// for(auto const& statesForObservation : statesPerObservation) { +// smtSolver->add(schedulerVariableExpressions[obs] >= schedulerForObs.size()); +// ++obs; +// } + + // TODO updateFoundSchedulers(); } template bool MemlessStrategySearchQualitative::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) { if (k < maxK) { initialize(k); + maxK = k; } + std::vector atLeastOneOfStates; for (uint64_t state : oneOfTheseStates) { @@ -119,83 +184,190 @@ namespace storm { smtSolver->add(reachVarExpressions[state]); } - std::cout << smtSolver->getSmtLibString() << std::endl; + smtSolver->push(); + uint64_t obs = 0; + for(auto const& statesForObservation : statesPerObservation) { + smtSolver->add(schedulerVariableExpressions[obs] <= schedulerForObs.size()); + ++obs; + } + for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) { + schedulerForObs.push_back(std::vector()); + } - std::vector> scheduler; + InternalObservationScheduler scheduler; + scheduler.switchObservations = storm::storage::BitVector(pomdp.getNrObservations()); + storm::storage::BitVector observations(pomdp.getNrObservations()); + storm::storage::BitVector observationsAfterSwitch(pomdp.getNrObservations()); + storm::storage::BitVector remainingstates(pomdp.getNumberOfStates()); - while (true) { + uint64_t iterations = 0; + while(true) { + scheduler.clear(); - auto result = smtSolver->check(); - uint64_t i = 0; + observations.clear(); + observationsAfterSwitch.clear(); + remainingstates.clear(); - if (result == storm::solver::SmtSolver::CheckResult::Unknown) { - STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); - } else if (result == storm::solver::SmtSolver::CheckResult::Unsat) { - std::cout << std::endl << "Unsatisfiable!" << std::endl; - return false; - } + while (true) { + ++iterations; + std::cout << "Call to SMT Solver (" <getSmtLibString() << std::endl; - std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl; - auto model = smtSolver->getModel(); - std::cout << "states that are okay" << std::endl; + auto result = smtSolver->check(); + uint64_t i = 0; + if (result == storm::solver::SmtSolver::CheckResult::Unknown) { + STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); + } else if (result == storm::solver::SmtSolver::CheckResult::Unsat) { + std::cout << std::endl << "Unsatisfiable!" << std::endl; + break; + } - storm::storage::BitVector observations(pomdp.getNrObservations()); - storm::storage::BitVector remainingstates(pomdp.getNumberOfStates()); - for (auto rv : reachVars) { - if (model->getBooleanValue(rv)) { - std::cout << i << " " << std::endl; - observations.set(pomdp.getObservation(i)); - } else { - remainingstates.set(i); + std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl; + auto model = smtSolver->getModel(); + + + observations.clear(); + observationsAfterSwitch.clear(); + remainingstates.clear(); + scheduler.clear(); + + for (auto rv : reachVars) { + if (model->getBooleanValue(rv)) { + smtSolver->add(rv.getExpression()); + observations.set(pomdp.getObservation(i)); + } else { + remainingstates.set(i); + } + ++i; } - //std::cout << i << ": " << model->getBooleanValue(rv) << ", "; - ++i; - } - scheduler.clear(); + i = 0; + std::cout << "states from which we continue" << std::endl; + for (auto rv : continuationVars) { + if (model->getBooleanValue(rv)) { + smtSolver->add(rv.getExpression()); + observationsAfterSwitch.set(pomdp.getObservation(i)); + std::cout << " " << i; + } + ++i; + } + std::cout << std::endl; - std::vector schedulerSoFar; - uint64_t obs = 0; - for (auto const &actionSelectionVarsForObs : actionSelectionVars) { - uint64_t act = 0; - scheduler.push_back(std::set()); - for (auto const &asv : actionSelectionVarsForObs) { - if (model->getBooleanValue(asv)) { - scheduler.back().insert(act); - schedulerSoFar.push_back(actionSelectionVarExpressions[obs][act]); + std::cout << "states that are okay" << std::endl; + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + if (!remainingstates.get(state)) { + std::cout << " " << state; } - act++; } - obs++; - } - std::cout << "the scheduler: " << std::endl; - for (uint64_t obs = 0; obs < scheduler.size(); ++obs) { - if (observations.get(obs)) { - std::cout << "observation: " << obs << std::endl; - std::cout << "actions:"; - for (auto act : scheduler[obs]) { - std::cout << " " << act; + std::vector schedulerSoFar; + uint64_t obs = 0; + for (auto const &actionSelectionVarsForObs : actionSelectionVars) { + uint64_t act = 0; + scheduler.actions.push_back(std::set()); + if (observations.get(obs)) { + for (uint64_t act = 0; act < actionSelectionVarsForObs.size(); ++act) { + auto const& asv = actionSelectionVarsForObs[act]; + if (model->getBooleanValue(asv)) { + scheduler.actions.back().insert(act); + schedulerSoFar.push_back(actionSelectionVarExpressions[obs][act]); + } + } + if (model->getBooleanValue(switchVars[obs])) { + scheduler.switchObservations.set(obs); + schedulerSoFar.push_back(switchVarExpressions[obs]); + } else { + schedulerSoFar.push_back(!switchVarExpressions[obs]); + } } - std::cout << std::endl; + + if (observationsAfterSwitch.get(obs)) { + scheduler.schedulerRef.push_back(model->getIntegerValue(schedulerVariables[obs])); + schedulerSoFar.push_back(schedulerVariableExpressions[obs] == expressionManager->integer(scheduler.schedulerRef.back())); + } else { + scheduler.schedulerRef.push_back(0); + } + obs++; } - } + std::cout << "the scheduler so far: " << std::endl; + scheduler.printForObservations(observations,observationsAfterSwitch); + + + + + std::vector remainingExpressions; + for (auto index : remainingstates) { + remainingExpressions.push_back(reachVarExpressions[index]); + } + // Add scheduler + smtSolver->add(storm::expressions::conjunction(schedulerSoFar)); + smtSolver->add(storm::expressions::disjunction(remainingExpressions)); + + } + if (scheduler.empty()) { + break; + } + smtSolver->pop(); + std::cout << "states that are okay" << std::endl; + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + if (!remainingstates.get(state)) { + std::cout << " " << state; + } + } + std::cout << std::endl; + std::cout << "the scheduler: " << std::endl; + scheduler.printForObservations(observations,observationsAfterSwitch); std::vector remainingExpressions; for (auto index : remainingstates) { remainingExpressions.push_back(reachVarExpressions[index]); } - smtSolver->push(); - // Add scheduler - smtSolver->add(storm::expressions::conjunction(schedulerSoFar)); smtSolver->add(storm::expressions::disjunction(remainingExpressions)); + uint64_t obs = 0; + for (auto const &statesForObservation : statesPerObservation) { + + if (observations.get(obs)) { + std::cout << "We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << "." << std::endl; + assert(schedulerForObs.size() > obs); + schedulerForObs[obs].push_back(finalSchedulers.size()); + std::cout << "We now have " << schedulerForObs[obs].size() << " policies for states with observation " << obs << std::endl; + + for (auto const &state : statesForObservation) { + if (remainingstates.get(state)) { + auto constant = expressionManager->integer(schedulerForObs[obs].size()); + smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); + } + } + + } + ++obs; + } + finalSchedulers.push_back(scheduler); + smtSolver->push(); + + for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { + auto constant = expressionManager->integer(schedulerForObs[obs].size()); + smtSolver->add(schedulerVariableExpressions[obs] <= constant); + } + } + return true; + } + template + void MemlessStrategySearchQualitative::printScheduler(std::vector const& ) { + + } + + + template + storm::expressions::Expression const& MemlessStrategySearchQualitative::getDoneActionExpression(uint64_t obs) const { + return actionSelectionVarExpressions[obs].back(); } diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h index 5aa69b3ce..d1388d8e2 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h @@ -8,6 +8,42 @@ namespace storm { namespace pomdp { + struct InternalObservationScheduler { + std::vector> actions; + std::vector schedulerRef; + storm::storage::BitVector switchObservations; + + void clear() { + actions.clear(); + schedulerRef.clear(); + switchObservations.clear(); + } + + bool empty() const { + return actions.empty(); + } + + void printForObservations(storm::storage::BitVector const& observations, storm::storage::BitVector const& observationsAfterSwitch) const { + for (uint64_t obs = 0; obs < observations.size(); ++obs) { + if (observations.get(obs)) { + std::cout << "observation: " << obs << std::endl; + std::cout << "actions:"; + for (auto act : actions[obs]) { + std::cout << " " << act; + } + if (switchObservations.get(obs)) { + std::cout << " and switch."; + } + std::cout << std::endl; + } + if (observationsAfterSwitch.get(obs)) { + std::cout << "scheduler ref: " << schedulerRef[obs] << std::endl; + } + + } + } + }; + template class MemlessStrategySearchQualitative { // Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper. @@ -49,6 +85,10 @@ namespace pomdp { private: + storm::expressions::Expression const& getDoneActionExpression(uint64_t obs) const; + + void printScheduler(std::vector const& ); + void initialize(uint64_t k); @@ -61,13 +101,24 @@ namespace pomdp { storm::storage::BitVector targetStates; storm::storage::BitVector surelyReachSinkStates; + std::vector schedulerVariables; + std::vector schedulerVariableExpressions; std::vector> statesPerObservation; std::vector> actionSelectionVarExpressions; // A_{z,a} std::vector> actionSelectionVars; + std::vector reachVars; std::vector reachVarExpressions; + + std::vector switchVars; + std::vector switchVarExpressions; + std::vector continuationVars; + std::vector continuationVarExpressions; std::vector> pathVars; + std::vector finalSchedulers; + std::vector> schedulerForObs; + };