diff --git a/src/storm-pomdp-cli/storm-pomdp.cpp b/src/storm-pomdp-cli/storm-pomdp.cpp index c16354973..e67e716e6 100644 --- a/src/storm-pomdp-cli/storm-pomdp.cpp +++ b/src/storm-pomdp-cli/storm-pomdp.cpp @@ -305,12 +305,29 @@ int main(const int argc, const char** argv) { storm::pomdp::MemlessSearchOptions options; + uint64_t loglevel = 0; + // TODO a big ugly, but we have our own loglevels. + if(storm::utility::getLogLevel() == l3pp::LogLevel::INFO) { + loglevel = 1; + } + else if(storm::utility::getLogLevel() == l3pp::LogLevel::DEBUG) { + loglevel = 2; + } + else if(storm::utility::getLogLevel() == l3pp::LogLevel::TRACE) { + loglevel = 3; + } + options.setDebugLevel(loglevel); + + if (pomdpQualSettings.isExportSATCallsSet()) { + options.setExportSATCalls(pomdpQualSettings.getExportSATCallsPath()); + } + if (pomdpSettings.getMemlessSearchMethod() == "ccd16memless") { storm::pomdp::QualitativeStrategySearchNaive memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory); memlessSearch.findNewStrategyForSomeState(lookahead); } else if (pomdpSettings.getMemlessSearchMethod() == "iterative") { - storm::pomdp::MemlessStrategySearchQualitative memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory); + storm::pomdp::MemlessStrategySearchQualitative memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory, options); memlessSearch.findNewStrategyForSomeState(lookahead); } else { STORM_LOG_ERROR("This method is not implemented."); diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp index 3788a2ac7..d1a67d132 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp @@ -1,19 +1,48 @@ #include "storm-pomdp/analysis/MemlessStrategySearchQualitative.h" +#include "storm/utility/file.h" namespace storm { namespace pomdp { + namespace detail { + void printRelevantInfoFromModel(std::shared_ptr const& model, std::vector const& reachVars, std::vector const& continuationVars) { + uint64_t i = 0; + std::stringstream ss; + STORM_LOG_TRACE("states which we have now: "); + for (auto rv : reachVars) { + if (model->getBooleanValue(rv)) { + ss << " " << i; + } + ++i; + } + STORM_LOG_TRACE(ss.str()); + i = 0; + STORM_LOG_TRACE("states from which we continue: "); + ss.clear(); + for (auto rv : continuationVars) { + if (model->getBooleanValue(rv)) { + ss << " " << i; + } + ++i; + } + STORM_LOG_TRACE(ss.str()); + + } + } + template MemlessStrategySearchQualitative::MemlessStrategySearchQualitative(storm::models::sparse::Pomdp const& pomdp, - std::set const& targetObservationSet, + std::set const& targetObservationSet, storm::storage::BitVector const& targetStates, - storm::storage::BitVector const& surelyReachSinkStates, - std::shared_ptr& smtSolverFactory) : + storm::storage::BitVector const& surelyReachSinkStates, + std::shared_ptr& smtSolverFactory, + MemlessSearchOptions const& options) : pomdp(pomdp), targetStates(targetStates), surelyReachSinkStates(surelyReachSinkStates), - targetObservations(targetObservationSet) + targetObservations(targetObservationSet), + options(options) { this->expressionManager = std::make_shared(); smtSolver = smtSolverFactory->create(*expressionManager); @@ -46,7 +75,6 @@ namespace storm { // Fill the states-per-observation mapping, // declare the reachability variables, // declare the path variables. - uint64_t stateId = 0; for(uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { pathVars.push_back(std::vector()); for (uint64_t i = 0; i < k; ++i) { @@ -222,20 +250,29 @@ namespace storm { while (true) { ++iterations; - std::cout << "Call to SMT Solver (" <getSmtLibString() << std::endl; + if(options.isExportSATSet()) { + STORM_LOG_DEBUG("Export SMT Solver Call (" <getSmtLibString() << std::endl; + storm::utility::closeFile(filestream); + } + + STORM_LOG_DEBUG("Call to SMT Solver (" <check(); uint64_t i = 0; if (result == storm::solver::SmtSolver::CheckResult::Unknown) { STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); } else if (result == storm::solver::SmtSolver::CheckResult::Unsat) { - std::cout << std::endl << "Unsatisfiable!" << std::endl; + STORM_LOG_DEBUG("Unsatisfiable!"); break; } - std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl; + STORM_LOG_DEBUG("Satisfying assignment: "); + STORM_LOG_TRACE(smtSolver->getModelAsValuation().toString(true)); auto model = smtSolver->getModel(); @@ -253,24 +290,17 @@ namespace storm { } ++i; } - i = 0; - std::cout << "states from which we continue" << std::endl; for (auto rv : continuationVars) { if (model->getBooleanValue(rv)) { smtSolver->add(rv.getExpression()); observationsAfterSwitch.set(pomdp.getObservation(i)); - std::cout << " " << i; } ++i; } - std::cout << std::endl; - std::cout << "states that are okay" << std::endl; - for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { - if (!remainingstates.get(state)) { - std::cout << " " << state; - } + if (options.computeTraceOutput()) { + detail::printRelevantInfoFromModel(model, reachVars, continuationVars); } // TODO do not repush everyting to the solver. @@ -303,11 +333,12 @@ namespace storm { obs++; } - std::cout << "the scheduler so far: " << std::endl; - scheduler.printForObservations(observations,observationsAfterSwitch); - - - + if(options.computeTraceOutput()) { + // generates debug output, but here we only want it for trace level. + // For consistency, all output on debug level. + STORM_LOG_DEBUG("the scheduler so far: "); + scheduler.printForObservations(observations,observationsAfterSwitch); + } std::vector remainingExpressions; for (auto index : remainingstates) { @@ -322,15 +353,14 @@ namespace storm { break; } smtSolver->pop(); - std::cout << "states that are okay" << std::endl; - for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { - if (!remainingstates.get(state)) { - std::cout << " " << state; - } + + if(options.computeDebugOutput()) { + printCoveredStates(remainingstates); + // generates info output, but here we only want it for debug level. + // For consistency, all output on info level. + STORM_LOG_DEBUG("the scheduler: "); + scheduler.printForObservations(observations,observationsAfterSwitch); } - std::cout << std::endl; - std::cout << "the scheduler: " << std::endl; - scheduler.printForObservations(observations,observationsAfterSwitch); std::vector remainingExpressions; for (auto index : remainingstates) { @@ -353,12 +383,11 @@ namespace storm { uint64_t obs = 0; for (auto const &statesForObservation : statesPerObservation) { - if (observations.get(obs)) { - std::cout << "We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << "." << std::endl; + STORM_LOG_DEBUG("We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << "."); assert(schedulerForObs.size() > obs); schedulerForObs[obs].push_back(finalSchedulers.size()); - std::cout << "We now have " << schedulerForObs[obs].size() << " policies for states with observation " << obs << std::endl; + STORM_LOG_DEBUG("We now have " << schedulerForObs[obs].size() << " policies for states with observation " << obs); for (auto const &state : statesForObservation) { if (remainingstates.get(state)) { @@ -366,7 +395,6 @@ namespace storm { smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); } } - } ++obs; } @@ -382,6 +410,21 @@ namespace storm { return true; } + template + void MemlessStrategySearchQualitative::printCoveredStates(storm::storage::BitVector const &remaining) const { + + STORM_LOG_DEBUG("states that are okay"); + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + if (!remaining.get(state)) { + std::cout << " " << state; + } + } + std::cout << std::endl; + + } + + + template void MemlessStrategySearchQualitative::printScheduler(std::vector const& ) { diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h index 900cdc0d7..674f41b0c 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h @@ -1,4 +1,5 @@ #include +#include #include "storm/storage/expressions/Expressions.h" #include "storm/solver/SmtSolver.h" #include "storm/models/sparse/Pomdp.h" @@ -22,10 +23,28 @@ namespace pomdp { } bool isExportSATSet() const { - return exportSATcalls == ""; + return exportSATcalls != ""; } + + void setDebugLevel(uint64_t level = 1) { + debugLevel = level; + } + + bool computeInfoOutput() const { + return debugLevel > 0; + } + + bool computeDebugOutput() const { + return debugLevel > 1; + } + + bool computeTraceOutput() const { + return debugLevel > 2; + } + private: std::string exportSATcalls = ""; + uint64_t debugLevel = 0; }; @@ -47,18 +66,19 @@ namespace pomdp { void printForObservations(storm::storage::BitVector const& observations, storm::storage::BitVector const& observationsAfterSwitch) const { for (uint64_t obs = 0; obs < observations.size(); ++obs) { if (observations.get(obs)) { - std::cout << "observation: " << obs << std::endl; - std::cout << "actions:"; + STORM_LOG_INFO("For observation: " << obs); + std::stringstream ss; + ss << "actions:"; for (auto act : actions[obs]) { - std::cout << " " << act; + ss << " " << act; } if (switchObservations.get(obs)) { - std::cout << " and switch."; + ss << " and switch."; } - std::cout << std::endl; + STORM_LOG_INFO(ss.str()); } if (observationsAfterSwitch.get(obs)) { - std::cout << "scheduler ref: " << schedulerRef[obs] << std::endl; + STORM_LOG_INFO("scheduler ref: " << schedulerRef[obs]); } } @@ -74,7 +94,8 @@ namespace pomdp { std::set const& targetObservationSet, storm::storage::BitVector const& targetStates, storm::storage::BitVector const& surelyReachSinkStates, - std::shared_ptr& smtSolverFactory); + std::shared_ptr& smtSolverFactory, + MemlessSearchOptions const& options); void analyzeForInitialStates(uint64_t k) { analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates()); @@ -94,6 +115,7 @@ namespace pomdp { storm::expressions::Expression const& getDoneActionExpression(uint64_t obs) const; void printScheduler(std::vector const& ); + void printCoveredStates(storm::storage::BitVector const& remaining) const; void initialize(uint64_t k); @@ -126,6 +148,8 @@ namespace pomdp { std::vector> schedulerForObs; WinningRegion winningRegion; + MemlessSearchOptions options; + };