Browse Source

added options, allow to toggle output

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
aae8774e5f
  1. 19
      src/storm-pomdp-cli/storm-pomdp.cpp
  2. 107
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  3. 40
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

19
src/storm-pomdp-cli/storm-pomdp.cpp

@ -305,12 +305,29 @@ int main(const int argc, const char** argv) {
storm::pomdp::MemlessSearchOptions options; storm::pomdp::MemlessSearchOptions options;
uint64_t loglevel = 0;
// TODO a big ugly, but we have our own loglevels.
if(storm::utility::getLogLevel() == l3pp::LogLevel::INFO) {
loglevel = 1;
}
else if(storm::utility::getLogLevel() == l3pp::LogLevel::DEBUG) {
loglevel = 2;
}
else if(storm::utility::getLogLevel() == l3pp::LogLevel::TRACE) {
loglevel = 3;
}
options.setDebugLevel(loglevel);
if (pomdpQualSettings.isExportSATCallsSet()) {
options.setExportSATCalls(pomdpQualSettings.getExportSATCallsPath());
}
if (pomdpSettings.getMemlessSearchMethod() == "ccd16memless") { if (pomdpSettings.getMemlessSearchMethod() == "ccd16memless") {
storm::pomdp::QualitativeStrategySearchNaive<double> memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory); storm::pomdp::QualitativeStrategySearchNaive<double> memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory);
memlessSearch.findNewStrategyForSomeState(lookahead); memlessSearch.findNewStrategyForSomeState(lookahead);
} else if (pomdpSettings.getMemlessSearchMethod() == "iterative") { } else if (pomdpSettings.getMemlessSearchMethod() == "iterative") {
storm::pomdp::MemlessStrategySearchQualitative<double> memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory);
storm::pomdp::MemlessStrategySearchQualitative<double> memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory, options);
memlessSearch.findNewStrategyForSomeState(lookahead); memlessSearch.findNewStrategyForSomeState(lookahead);
} else { } else {
STORM_LOG_ERROR("This method is not implemented."); STORM_LOG_ERROR("This method is not implemented.");

107
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -1,19 +1,48 @@
#include "storm-pomdp/analysis/MemlessStrategySearchQualitative.h" #include "storm-pomdp/analysis/MemlessStrategySearchQualitative.h"
#include "storm/utility/file.h"
namespace storm { namespace storm {
namespace pomdp { namespace pomdp {
namespace detail {
void printRelevantInfoFromModel(std::shared_ptr<storm::solver::SmtSolver::ModelReference> const& model, std::vector<storm::expressions::Variable> const& reachVars, std::vector<storm::expressions::Variable> const& continuationVars) {
uint64_t i = 0;
std::stringstream ss;
STORM_LOG_TRACE("states which we have now: ");
for (auto rv : reachVars) {
if (model->getBooleanValue(rv)) {
ss << " " << i;
}
++i;
}
STORM_LOG_TRACE(ss.str());
i = 0;
STORM_LOG_TRACE("states from which we continue: ");
ss.clear();
for (auto rv : continuationVars) {
if (model->getBooleanValue(rv)) {
ss << " " << i;
}
++i;
}
STORM_LOG_TRACE(ss.str());
}
}
template <typename ValueType> template <typename ValueType>
MemlessStrategySearchQualitative<ValueType>::MemlessStrategySearchQualitative(storm::models::sparse::Pomdp<ValueType> const& pomdp, MemlessStrategySearchQualitative<ValueType>::MemlessStrategySearchQualitative(storm::models::sparse::Pomdp<ValueType> const& pomdp,
std::set<uint32_t> const& targetObservationSet, std::set<uint32_t> const& targetObservationSet,
storm::storage::BitVector const& targetStates, storm::storage::BitVector const& targetStates,
storm::storage::BitVector const& surelyReachSinkStates, storm::storage::BitVector const& surelyReachSinkStates,
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory) :
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory,
MemlessSearchOptions const& options) :
pomdp(pomdp), pomdp(pomdp),
targetStates(targetStates), targetStates(targetStates),
surelyReachSinkStates(surelyReachSinkStates), surelyReachSinkStates(surelyReachSinkStates),
targetObservations(targetObservationSet)
targetObservations(targetObservationSet),
options(options)
{ {
this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>(); this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>();
smtSolver = smtSolverFactory->create(*expressionManager); smtSolver = smtSolverFactory->create(*expressionManager);
@ -46,7 +75,6 @@ namespace storm {
// Fill the states-per-observation mapping, // Fill the states-per-observation mapping,
// declare the reachability variables, // declare the reachability variables,
// declare the path variables. // declare the path variables.
uint64_t stateId = 0;
for(uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { for(uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) {
pathVars.push_back(std::vector<storm::expressions::Expression>()); pathVars.push_back(std::vector<storm::expressions::Expression>());
for (uint64_t i = 0; i < k; ++i) { for (uint64_t i = 0; i < k; ++i) {
@ -222,20 +250,29 @@ namespace storm {
while (true) { while (true) {
++iterations; ++iterations;
std::cout << "Call to SMT Solver (" <<iterations << ")" << std::endl;
std::cout << smtSolver->getSmtLibString() << std::endl;
if(options.isExportSATSet()) {
STORM_LOG_DEBUG("Export SMT Solver Call (" <<iterations << ")");
std::string filepath = options.getExportSATCallsPath() + "call_" + std::to_string(iterations) + ".smt2";
std::ofstream filestream;
storm::utility::openFile(filepath, filestream);
filestream << smtSolver->getSmtLibString() << std::endl;
storm::utility::closeFile(filestream);
}
STORM_LOG_DEBUG("Call to SMT Solver (" <<iterations << ")");
auto result = smtSolver->check(); auto result = smtSolver->check();
uint64_t i = 0; uint64_t i = 0;
if (result == storm::solver::SmtSolver::CheckResult::Unknown) { if (result == storm::solver::SmtSolver::CheckResult::Unknown) {
STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result");
} else if (result == storm::solver::SmtSolver::CheckResult::Unsat) { } else if (result == storm::solver::SmtSolver::CheckResult::Unsat) {
std::cout << std::endl << "Unsatisfiable!" << std::endl;
STORM_LOG_DEBUG("Unsatisfiable!");
break; break;
} }
std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl;
STORM_LOG_DEBUG("Satisfying assignment: ");
STORM_LOG_TRACE(smtSolver->getModelAsValuation().toString(true));
auto model = smtSolver->getModel(); auto model = smtSolver->getModel();
@ -253,24 +290,17 @@ namespace storm {
} }
++i; ++i;
} }
i = 0; i = 0;
std::cout << "states from which we continue" << std::endl;
for (auto rv : continuationVars) { for (auto rv : continuationVars) {
if (model->getBooleanValue(rv)) { if (model->getBooleanValue(rv)) {
smtSolver->add(rv.getExpression()); smtSolver->add(rv.getExpression());
observationsAfterSwitch.set(pomdp.getObservation(i)); observationsAfterSwitch.set(pomdp.getObservation(i));
std::cout << " " << i;
} }
++i; ++i;
} }
std::cout << std::endl;
std::cout << "states that are okay" << std::endl;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (!remainingstates.get(state)) {
std::cout << " " << state;
}
if (options.computeTraceOutput()) {
detail::printRelevantInfoFromModel(model, reachVars, continuationVars);
} }
// TODO do not repush everyting to the solver. // TODO do not repush everyting to the solver.
@ -303,11 +333,12 @@ namespace storm {
obs++; obs++;
} }
std::cout << "the scheduler so far: " << std::endl;
if(options.computeTraceOutput()) {
// generates debug output, but here we only want it for trace level.
// For consistency, all output on debug level.
STORM_LOG_DEBUG("the scheduler so far: ");
scheduler.printForObservations(observations,observationsAfterSwitch); scheduler.printForObservations(observations,observationsAfterSwitch);
}
std::vector<storm::expressions::Expression> remainingExpressions; std::vector<storm::expressions::Expression> remainingExpressions;
for (auto index : remainingstates) { for (auto index : remainingstates) {
@ -322,15 +353,14 @@ namespace storm {
break; break;
} }
smtSolver->pop(); smtSolver->pop();
std::cout << "states that are okay" << std::endl;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (!remainingstates.get(state)) {
std::cout << " " << state;
}
}
std::cout << std::endl;
std::cout << "the scheduler: " << std::endl;
if(options.computeDebugOutput()) {
printCoveredStates(remainingstates);
// generates info output, but here we only want it for debug level.
// For consistency, all output on info level.
STORM_LOG_DEBUG("the scheduler: ");
scheduler.printForObservations(observations,observationsAfterSwitch); scheduler.printForObservations(observations,observationsAfterSwitch);
}
std::vector<storm::expressions::Expression> remainingExpressions; std::vector<storm::expressions::Expression> remainingExpressions;
for (auto index : remainingstates) { for (auto index : remainingstates) {
@ -353,12 +383,11 @@ namespace storm {
uint64_t obs = 0; uint64_t obs = 0;
for (auto const &statesForObservation : statesPerObservation) { for (auto const &statesForObservation : statesPerObservation) {
if (observations.get(obs)) { if (observations.get(obs)) {
std::cout << "We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << "." << std::endl;
STORM_LOG_DEBUG("We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << ".");
assert(schedulerForObs.size() > obs); assert(schedulerForObs.size() > obs);
schedulerForObs[obs].push_back(finalSchedulers.size()); schedulerForObs[obs].push_back(finalSchedulers.size());
std::cout << "We now have " << schedulerForObs[obs].size() << " policies for states with observation " << obs << std::endl;
STORM_LOG_DEBUG("We now have " << schedulerForObs[obs].size() << " policies for states with observation " << obs);
for (auto const &state : statesForObservation) { for (auto const &state : statesForObservation) {
if (remainingstates.get(state)) { if (remainingstates.get(state)) {
@ -366,7 +395,6 @@ namespace storm {
smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant)));
} }
} }
} }
++obs; ++obs;
} }
@ -382,6 +410,21 @@ namespace storm {
return true; return true;
} }
template<typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::printCoveredStates(storm::storage::BitVector const &remaining) const {
STORM_LOG_DEBUG("states that are okay");
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (!remaining.get(state)) {
std::cout << " " << state;
}
}
std::cout << std::endl;
}
template<typename ValueType> template<typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::printScheduler(std::vector<InternalObservationScheduler> const& ) { void MemlessStrategySearchQualitative<ValueType>::printScheduler(std::vector<InternalObservationScheduler> const& ) {

40
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

@ -1,4 +1,5 @@
#include <vector> #include <vector>
#include <sstream>
#include "storm/storage/expressions/Expressions.h" #include "storm/storage/expressions/Expressions.h"
#include "storm/solver/SmtSolver.h" #include "storm/solver/SmtSolver.h"
#include "storm/models/sparse/Pomdp.h" #include "storm/models/sparse/Pomdp.h"
@ -22,10 +23,28 @@ namespace pomdp {
} }
bool isExportSATSet() const { bool isExportSATSet() const {
return exportSATcalls == "";
return exportSATcalls != "";
} }
void setDebugLevel(uint64_t level = 1) {
debugLevel = level;
}
bool computeInfoOutput() const {
return debugLevel > 0;
}
bool computeDebugOutput() const {
return debugLevel > 1;
}
bool computeTraceOutput() const {
return debugLevel > 2;
}
private: private:
std::string exportSATcalls = ""; std::string exportSATcalls = "";
uint64_t debugLevel = 0;
}; };
@ -47,18 +66,19 @@ namespace pomdp {
void printForObservations(storm::storage::BitVector const& observations, storm::storage::BitVector const& observationsAfterSwitch) const { void printForObservations(storm::storage::BitVector const& observations, storm::storage::BitVector const& observationsAfterSwitch) const {
for (uint64_t obs = 0; obs < observations.size(); ++obs) { for (uint64_t obs = 0; obs < observations.size(); ++obs) {
if (observations.get(obs)) { if (observations.get(obs)) {
std::cout << "observation: " << obs << std::endl;
std::cout << "actions:";
STORM_LOG_INFO("For observation: " << obs);
std::stringstream ss;
ss << "actions:";
for (auto act : actions[obs]) { for (auto act : actions[obs]) {
std::cout << " " << act;
ss << " " << act;
} }
if (switchObservations.get(obs)) { if (switchObservations.get(obs)) {
std::cout << " and switch.";
ss << " and switch.";
} }
std::cout << std::endl;
STORM_LOG_INFO(ss.str());
} }
if (observationsAfterSwitch.get(obs)) { if (observationsAfterSwitch.get(obs)) {
std::cout << "scheduler ref: " << schedulerRef[obs] << std::endl;
STORM_LOG_INFO("scheduler ref: " << schedulerRef[obs]);
} }
} }
@ -74,7 +94,8 @@ namespace pomdp {
std::set<uint32_t> const& targetObservationSet, std::set<uint32_t> const& targetObservationSet,
storm::storage::BitVector const& targetStates, storm::storage::BitVector const& targetStates,
storm::storage::BitVector const& surelyReachSinkStates, storm::storage::BitVector const& surelyReachSinkStates,
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory);
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory,
MemlessSearchOptions const& options);
void analyzeForInitialStates(uint64_t k) { void analyzeForInitialStates(uint64_t k) {
analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates()); analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates());
@ -94,6 +115,7 @@ namespace pomdp {
storm::expressions::Expression const& getDoneActionExpression(uint64_t obs) const; storm::expressions::Expression const& getDoneActionExpression(uint64_t obs) const;
void printScheduler(std::vector<InternalObservationScheduler> const& ); void printScheduler(std::vector<InternalObservationScheduler> const& );
void printCoveredStates(storm::storage::BitVector const& remaining) const;
void initialize(uint64_t k); void initialize(uint64_t k);
@ -126,6 +148,8 @@ namespace pomdp {
std::vector<std::vector<uint64_t>> schedulerForObs; std::vector<std::vector<uint64_t>> schedulerForObs;
WinningRegion winningRegion; WinningRegion winningRegion;
MemlessSearchOptions options;
}; };

Loading…
Cancel
Save