From 6608f9f00d65ce84695f7802a555a04b23a3a2f3 Mon Sep 17 00:00:00 2001 From: Sebastian Junges Date: Tue, 14 Apr 2020 14:28:14 -0700 Subject: [PATCH] Fixed implementation from CCD16 --- .../QualitativeStrategySearchNaive.cpp | 131 ++++++++++-------- .../analysis/QualitativeStrategySearchNaive.h | 30 ++-- 2 files changed, 93 insertions(+), 68 deletions(-) diff --git a/src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp b/src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp index f1ec0b8e8..1abe193e5 100644 --- a/src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp +++ b/src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp @@ -1,4 +1,4 @@ - +#include "storm/utility/file.h" #include "storm-pomdp/analysis/QualitativeStrategySearchNaive.h" @@ -46,7 +46,27 @@ namespace storm { assert(false); } + for (auto const& actionVars : actionSelectionVarExpressions) { + smtSolver->add(storm::expressions::disjunction(actionVars)); + } + + uint64_t rowindex = 0; + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { + std::vector subexprreach; + subexprreach.push_back(!reachVarExpressions[state]); + subexprreach.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); + for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { + subexprreach.push_back(reachVarExpressions.at(entries.getColumn())); + smtSolver->add(storm::expressions::disjunction(subexprreach)); + subexprreach.pop_back(); + } + rowindex++; + } + } + + rowindex = 0; for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { if (targetStates.get(state)) { smtSolver->add(pathVars[state][0]); @@ -56,7 +76,10 @@ namespace storm { if (surelyReachSinkStates.get(state)) { smtSolver->add(!reachVarExpressions[state]); + rowindex += pomdp.getNumberOfChoices(state); } else if(!targetStates.get(state)) { + std::cout << state << " is not a target state" << std::endl; + std::vector>> pathsubsubexprs; for (uint64_t j = 1; j < k; ++j) { pathsubsubexprs.push_back(std::vector>()); @@ -67,13 +90,6 @@ namespace storm { for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { std::vector subexprreach; - - subexprreach.push_back(!reachVarExpressions.at(state)); - subexprreach.push_back(!actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action)); - for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { - subexprreach.push_back(reachVarExpressions.at(entries.getColumn())); - } - smtSolver->add(storm::expressions::disjunction(subexprreach)); for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { for (uint64_t j = 1; j < k; ++j) { pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]); @@ -81,7 +97,6 @@ namespace storm { } rowindex++; } - smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back())); for (uint64_t j = 1; j < k; ++j) { std::vector pathsubexprs; @@ -91,16 +106,21 @@ namespace storm { } smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); } + + + smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back())); + + } else { + rowindex += pomdp.getNumberOfChoices(state); } } - for (auto const& actionVars : actionSelectionVarExpressions) { - smtSolver->add(storm::expressions::disjunction(actionVars)); - } } template bool QualitativeStrategySearchNaive::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) { + + STORM_LOG_TRACE("Use lookahead of "<getSmtLibString() << std::endl; + STORM_LOG_TRACE(smtSolver->getSmtLibString()); + STORM_LOG_DEBUG("Call to SMT Solver"); + stats.smtCheckTimer.start(); auto result = smtSolver->check(); - uint64_t i = 0; - smtSolver->push(); - - + stats.smtCheckTimer.stop(); if (result == storm::solver::SmtSolver::CheckResult::Unknown) { STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); - } else if(result == storm::solver::SmtSolver::CheckResult::Unsat) { - std::cout << std::endl << "Unsatisfiable!" << std::endl; + } else if (result == storm::solver::SmtSolver::CheckResult::Unsat) { + STORM_LOG_DEBUG("Unsatisfiable!"); return false; - } else { + } - std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl; - auto model = smtSolver->getModel(); - std::cout << "states that are okay" << std::endl; - storm::storage::BitVector observations(pomdp.getNrObservations()); - storm::storage::BitVector remainingstates(pomdp.getNumberOfStates()); - for (auto rv : reachVars) { - if (model->getBooleanValue(rv)) { - std::cout << i << " " << std::endl; - observations.set(pomdp.getObservation(i)); - } else { - remainingstates.set(i); - } - //std::cout << i << ": " << model->getBooleanValue(rv) << ", "; - ++i; - } - std::vector > scheduler; - for (auto const &actionSelectionVarsForObs : actionSelectionVars) { - uint64_t act = 0; - scheduler.push_back(std::set()); - for (auto const &asv : actionSelectionVarsForObs) { - if (model->getBooleanValue(asv)) { - scheduler.back().insert(act); - } - act++; - } + STORM_LOG_DEBUG("Satisfying assignment: "); + auto model = smtSolver->getModel(); + size_t i = 0; + storm::storage::BitVector observations(pomdp.getNrObservations()); + storm::storage::BitVector remainingstates(pomdp.getNumberOfStates()); + for (auto rv : reachVars) { + if (model->getBooleanValue(rv)) { + std::cout << i << " " << std::endl; + observations.set(pomdp.getObservation(i)); + } else { + remainingstates.set(i); } - std::cout << "the scheduler: " << std::endl; - for (uint64_t obs = 0; obs < scheduler.size(); ++obs) { - if (observations.get(obs)) { - std::cout << "observation: " << obs << std::endl; - std::cout << "actions:"; - for (auto act : scheduler[obs]) { - std::cout << " " << act; - } - std::cout << std::endl; + ++i; + } + std::vector > scheduler; + for (auto const &actionSelectionVarsForObs : actionSelectionVars) { + uint64_t act = 0; + scheduler.push_back(std::set()); + for (auto const &asv : actionSelectionVarsForObs) { + if (model->getBooleanValue(asv)) { + scheduler.back().insert(act); } + act++; } - - - return true; } + // TODO move this into a print scheduler function. + //STORM_LOG_TRACE("the scheduler: "); + for (uint64_t obs = 0; obs < scheduler.size(); ++obs) { + if (observations.get(obs)) { + //STORM_LOG_TRACE("observation: " << obs); + //std::cout << "actions:"; + //for (auto act : scheduler[obs]) { + // std::cout << " " << act; + //} + //std::cout << std::endl; + } + } + return true; } + template class QualitativeStrategySearchNaive; template class QualitativeStrategySearchNaive; } diff --git a/src/storm-pomdp/analysis/QualitativeStrategySearchNaive.h b/src/storm-pomdp/analysis/QualitativeStrategySearchNaive.h index 97dc0f679..4fee165f3 100644 --- a/src/storm-pomdp/analysis/QualitativeStrategySearchNaive.h +++ b/src/storm-pomdp/analysis/QualitativeStrategySearchNaive.h @@ -3,6 +3,7 @@ #include "storm/solver/SmtSolver.h" #include "storm/models/sparse/Pomdp.h" #include "storm/utility/solver.h" +#include "storm/utility/Stopwatch.h" #include "storm/exceptions/UnexpectedException.h" namespace storm { @@ -10,8 +11,15 @@ namespace storm { template class QualitativeStrategySearchNaive { - // Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper. + // Implements to the Chatterjee, Chmelik, Davies (AAAI-16) paper. + class Statistics { + public: + Statistics() = default; + storm::utility::Stopwatch totalTimer; + storm::utility::Stopwatch smtCheckTimer; + storm::utility::Stopwatch initializeSolverTimer; + }; public: QualitativeStrategySearchNaive(storm::models::sparse::Pomdp const& pomdp, @@ -33,16 +41,15 @@ namespace storm { } void analyzeForInitialStates(uint64_t k) { - analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates()); - } - - void findNewStrategyForSomeState(uint64_t k) { - std::cout << surelyReachSinkStates << std::endl; - std::cout << targetStates << std::endl; - std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl; - analyze(k, ~surelyReachSinkStates & ~targetStates); - - + STORM_LOG_TRACE("Bad states: " << surelyReachSinkStates); + STORM_LOG_TRACE("Target states: " << targetStates); + STORM_LOG_TRACE("Questionmark states: " << (~surelyReachSinkStates & ~targetStates)); + bool result = analyze(k, ~surelyReachSinkStates & ~targetStates, pomdp.getInitialStates()); + if (result) { + STORM_PRINT_AND_LOG("From initial state, one can almost-surely reach the target."); + } else { + STORM_PRINT_AND_LOG("From initial state, one may not almost-surely reach the target ."); + } } bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector()); @@ -51,6 +58,7 @@ namespace storm { void initialize(uint64_t k); + Statistics stats; std::unique_ptr smtSolver; storm::models::sparse::Pomdp const& pomdp; std::shared_ptr expressionManager;