|
@ -1,19 +1,48 @@ |
|
|
#include "storm-pomdp/analysis/MemlessStrategySearchQualitative.h"
|
|
|
#include "storm-pomdp/analysis/MemlessStrategySearchQualitative.h"
|
|
|
|
|
|
#include "storm/utility/file.h"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace storm { |
|
|
namespace storm { |
|
|
namespace pomdp { |
|
|
namespace pomdp { |
|
|
|
|
|
|
|
|
|
|
|
namespace detail { |
|
|
|
|
|
void printRelevantInfoFromModel(std::shared_ptr<storm::solver::SmtSolver::ModelReference> const& model, std::vector<storm::expressions::Variable> const& reachVars, std::vector<storm::expressions::Variable> const& continuationVars) { |
|
|
|
|
|
uint64_t i = 0; |
|
|
|
|
|
std::stringstream ss; |
|
|
|
|
|
STORM_LOG_TRACE("states which we have now: "); |
|
|
|
|
|
for (auto rv : reachVars) { |
|
|
|
|
|
if (model->getBooleanValue(rv)) { |
|
|
|
|
|
ss << " " << i; |
|
|
|
|
|
} |
|
|
|
|
|
++i; |
|
|
|
|
|
} |
|
|
|
|
|
STORM_LOG_TRACE(ss.str()); |
|
|
|
|
|
i = 0; |
|
|
|
|
|
STORM_LOG_TRACE("states from which we continue: "); |
|
|
|
|
|
ss.clear(); |
|
|
|
|
|
for (auto rv : continuationVars) { |
|
|
|
|
|
if (model->getBooleanValue(rv)) { |
|
|
|
|
|
ss << " " << i; |
|
|
|
|
|
} |
|
|
|
|
|
++i; |
|
|
|
|
|
} |
|
|
|
|
|
STORM_LOG_TRACE(ss.str()); |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
template <typename ValueType> |
|
|
template <typename ValueType> |
|
|
MemlessStrategySearchQualitative<ValueType>::MemlessStrategySearchQualitative(storm::models::sparse::Pomdp<ValueType> const& pomdp, |
|
|
MemlessStrategySearchQualitative<ValueType>::MemlessStrategySearchQualitative(storm::models::sparse::Pomdp<ValueType> const& pomdp, |
|
|
std::set<uint32_t> const& targetObservationSet, |
|
|
|
|
|
|
|
|
std::set<uint32_t> const& targetObservationSet, |
|
|
storm::storage::BitVector const& targetStates, |
|
|
storm::storage::BitVector const& targetStates, |
|
|
storm::storage::BitVector const& surelyReachSinkStates, |
|
|
|
|
|
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory) : |
|
|
|
|
|
|
|
|
storm::storage::BitVector const& surelyReachSinkStates, |
|
|
|
|
|
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory, |
|
|
|
|
|
MemlessSearchOptions const& options) : |
|
|
pomdp(pomdp), |
|
|
pomdp(pomdp), |
|
|
targetStates(targetStates), |
|
|
targetStates(targetStates), |
|
|
surelyReachSinkStates(surelyReachSinkStates), |
|
|
surelyReachSinkStates(surelyReachSinkStates), |
|
|
targetObservations(targetObservationSet) |
|
|
|
|
|
|
|
|
targetObservations(targetObservationSet), |
|
|
|
|
|
options(options) |
|
|
{ |
|
|
{ |
|
|
this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>(); |
|
|
this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>(); |
|
|
smtSolver = smtSolverFactory->create(*expressionManager); |
|
|
smtSolver = smtSolverFactory->create(*expressionManager); |
|
@ -46,7 +75,6 @@ namespace storm { |
|
|
// Fill the states-per-observation mapping,
|
|
|
// Fill the states-per-observation mapping,
|
|
|
// declare the reachability variables,
|
|
|
// declare the reachability variables,
|
|
|
// declare the path variables.
|
|
|
// declare the path variables.
|
|
|
uint64_t stateId = 0; |
|
|
|
|
|
for(uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { |
|
|
for(uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { |
|
|
pathVars.push_back(std::vector<storm::expressions::Expression>()); |
|
|
pathVars.push_back(std::vector<storm::expressions::Expression>()); |
|
|
for (uint64_t i = 0; i < k; ++i) { |
|
|
for (uint64_t i = 0; i < k; ++i) { |
|
@ -222,20 +250,29 @@ namespace storm { |
|
|
|
|
|
|
|
|
while (true) { |
|
|
while (true) { |
|
|
++iterations; |
|
|
++iterations; |
|
|
std::cout << "Call to SMT Solver (" <<iterations << ")" << std::endl; |
|
|
|
|
|
std::cout << smtSolver->getSmtLibString() << std::endl; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if(options.isExportSATSet()) { |
|
|
|
|
|
STORM_LOG_DEBUG("Export SMT Solver Call (" <<iterations << ")"); |
|
|
|
|
|
std::string filepath = options.getExportSATCallsPath() + "call_" + std::to_string(iterations) + ".smt2"; |
|
|
|
|
|
std::ofstream filestream; |
|
|
|
|
|
storm::utility::openFile(filepath, filestream); |
|
|
|
|
|
filestream << smtSolver->getSmtLibString() << std::endl; |
|
|
|
|
|
storm::utility::closeFile(filestream); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
STORM_LOG_DEBUG("Call to SMT Solver (" <<iterations << ")"); |
|
|
auto result = smtSolver->check(); |
|
|
auto result = smtSolver->check(); |
|
|
uint64_t i = 0; |
|
|
uint64_t i = 0; |
|
|
|
|
|
|
|
|
if (result == storm::solver::SmtSolver::CheckResult::Unknown) { |
|
|
if (result == storm::solver::SmtSolver::CheckResult::Unknown) { |
|
|
STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); |
|
|
STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); |
|
|
} else if (result == storm::solver::SmtSolver::CheckResult::Unsat) { |
|
|
} else if (result == storm::solver::SmtSolver::CheckResult::Unsat) { |
|
|
std::cout << std::endl << "Unsatisfiable!" << std::endl; |
|
|
|
|
|
|
|
|
STORM_LOG_DEBUG("Unsatisfiable!"); |
|
|
break; |
|
|
break; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl; |
|
|
|
|
|
|
|
|
STORM_LOG_DEBUG("Satisfying assignment: "); |
|
|
|
|
|
STORM_LOG_TRACE(smtSolver->getModelAsValuation().toString(true)); |
|
|
auto model = smtSolver->getModel(); |
|
|
auto model = smtSolver->getModel(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -253,24 +290,17 @@ namespace storm { |
|
|
} |
|
|
} |
|
|
++i; |
|
|
++i; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
i = 0; |
|
|
i = 0; |
|
|
std::cout << "states from which we continue" << std::endl; |
|
|
|
|
|
for (auto rv : continuationVars) { |
|
|
for (auto rv : continuationVars) { |
|
|
if (model->getBooleanValue(rv)) { |
|
|
if (model->getBooleanValue(rv)) { |
|
|
smtSolver->add(rv.getExpression()); |
|
|
smtSolver->add(rv.getExpression()); |
|
|
observationsAfterSwitch.set(pomdp.getObservation(i)); |
|
|
observationsAfterSwitch.set(pomdp.getObservation(i)); |
|
|
std::cout << " " << i; |
|
|
|
|
|
} |
|
|
} |
|
|
++i; |
|
|
++i; |
|
|
} |
|
|
} |
|
|
std::cout << std::endl; |
|
|
|
|
|
|
|
|
|
|
|
std::cout << "states that are okay" << std::endl; |
|
|
|
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
|
|
|
if (!remainingstates.get(state)) { |
|
|
|
|
|
std::cout << " " << state; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
if (options.computeTraceOutput()) { |
|
|
|
|
|
detail::printRelevantInfoFromModel(model, reachVars, continuationVars); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// TODO do not repush everyting to the solver.
|
|
|
// TODO do not repush everyting to the solver.
|
|
@ -303,11 +333,12 @@ namespace storm { |
|
|
obs++; |
|
|
obs++; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
std::cout << "the scheduler so far: " << std::endl; |
|
|
|
|
|
scheduler.printForObservations(observations,observationsAfterSwitch); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if(options.computeTraceOutput()) { |
|
|
|
|
|
// generates debug output, but here we only want it for trace level.
|
|
|
|
|
|
// For consistency, all output on debug level.
|
|
|
|
|
|
STORM_LOG_DEBUG("the scheduler so far: "); |
|
|
|
|
|
scheduler.printForObservations(observations,observationsAfterSwitch); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
std::vector<storm::expressions::Expression> remainingExpressions; |
|
|
std::vector<storm::expressions::Expression> remainingExpressions; |
|
|
for (auto index : remainingstates) { |
|
|
for (auto index : remainingstates) { |
|
@ -322,15 +353,14 @@ namespace storm { |
|
|
break; |
|
|
break; |
|
|
} |
|
|
} |
|
|
smtSolver->pop(); |
|
|
smtSolver->pop(); |
|
|
std::cout << "states that are okay" << std::endl; |
|
|
|
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
|
|
|
if (!remainingstates.get(state)) { |
|
|
|
|
|
std::cout << " " << state; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if(options.computeDebugOutput()) { |
|
|
|
|
|
printCoveredStates(remainingstates); |
|
|
|
|
|
// generates info output, but here we only want it for debug level.
|
|
|
|
|
|
// For consistency, all output on info level.
|
|
|
|
|
|
STORM_LOG_DEBUG("the scheduler: "); |
|
|
|
|
|
scheduler.printForObservations(observations,observationsAfterSwitch); |
|
|
} |
|
|
} |
|
|
std::cout << std::endl; |
|
|
|
|
|
std::cout << "the scheduler: " << std::endl; |
|
|
|
|
|
scheduler.printForObservations(observations,observationsAfterSwitch); |
|
|
|
|
|
|
|
|
|
|
|
std::vector<storm::expressions::Expression> remainingExpressions; |
|
|
std::vector<storm::expressions::Expression> remainingExpressions; |
|
|
for (auto index : remainingstates) { |
|
|
for (auto index : remainingstates) { |
|
@ -353,12 +383,11 @@ namespace storm { |
|
|
|
|
|
|
|
|
uint64_t obs = 0; |
|
|
uint64_t obs = 0; |
|
|
for (auto const &statesForObservation : statesPerObservation) { |
|
|
for (auto const &statesForObservation : statesPerObservation) { |
|
|
|
|
|
|
|
|
if (observations.get(obs)) { |
|
|
if (observations.get(obs)) { |
|
|
std::cout << "We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << "." << std::endl; |
|
|
|
|
|
|
|
|
STORM_LOG_DEBUG("We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << "."); |
|
|
assert(schedulerForObs.size() > obs); |
|
|
assert(schedulerForObs.size() > obs); |
|
|
schedulerForObs[obs].push_back(finalSchedulers.size()); |
|
|
schedulerForObs[obs].push_back(finalSchedulers.size()); |
|
|
std::cout << "We now have " << schedulerForObs[obs].size() << " policies for states with observation " << obs << std::endl; |
|
|
|
|
|
|
|
|
STORM_LOG_DEBUG("We now have " << schedulerForObs[obs].size() << " policies for states with observation " << obs); |
|
|
|
|
|
|
|
|
for (auto const &state : statesForObservation) { |
|
|
for (auto const &state : statesForObservation) { |
|
|
if (remainingstates.get(state)) { |
|
|
if (remainingstates.get(state)) { |
|
@ -366,7 +395,6 @@ namespace storm { |
|
|
smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); |
|
|
smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
} |
|
|
} |
|
|
++obs; |
|
|
++obs; |
|
|
} |
|
|
} |
|
@ -382,6 +410,21 @@ namespace storm { |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
|
|
|
void MemlessStrategySearchQualitative<ValueType>::printCoveredStates(storm::storage::BitVector const &remaining) const { |
|
|
|
|
|
|
|
|
|
|
|
STORM_LOG_DEBUG("states that are okay"); |
|
|
|
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
|
|
|
if (!remaining.get(state)) { |
|
|
|
|
|
std::cout << " " << state; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
std::cout << std::endl; |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
template<typename ValueType> |
|
|
void MemlessStrategySearchQualitative<ValueType>::printScheduler(std::vector<InternalObservationScheduler> const& ) { |
|
|
void MemlessStrategySearchQualitative<ValueType>::printScheduler(std::vector<InternalObservationScheduler> const& ) { |
|
|
|
|
|
|
|
|