Browse Source

Fixed implementation from CCD16

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
6608f9f00d
  1. 81
      src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp
  2. 28
      src/storm-pomdp/analysis/QualitativeStrategySearchNaive.h

81
src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp

@ -1,4 +1,4 @@
#include "storm/utility/file.h"
#include "storm-pomdp/analysis/QualitativeStrategySearchNaive.h" #include "storm-pomdp/analysis/QualitativeStrategySearchNaive.h"
@ -46,7 +46,27 @@ namespace storm {
assert(false); assert(false);
} }
for (auto const& actionVars : actionSelectionVarExpressions) {
smtSolver->add(storm::expressions::disjunction(actionVars));
}
uint64_t rowindex = 0; uint64_t rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
std::vector<storm::expressions::Expression> subexprreach;
subexprreach.push_back(!reachVarExpressions[state]);
subexprreach.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]);
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
subexprreach.push_back(reachVarExpressions.at(entries.getColumn()));
smtSolver->add(storm::expressions::disjunction(subexprreach));
subexprreach.pop_back();
}
rowindex++;
}
}
rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (targetStates.get(state)) { if (targetStates.get(state)) {
smtSolver->add(pathVars[state][0]); smtSolver->add(pathVars[state][0]);
@ -56,7 +76,10 @@ namespace storm {
if (surelyReachSinkStates.get(state)) { if (surelyReachSinkStates.get(state)) {
smtSolver->add(!reachVarExpressions[state]); smtSolver->add(!reachVarExpressions[state]);
rowindex += pomdp.getNumberOfChoices(state);
} else if(!targetStates.get(state)) { } else if(!targetStates.get(state)) {
std::cout << state << " is not a target state" << std::endl;
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs; std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
for (uint64_t j = 1; j < k; ++j) { for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>()); pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>());
@ -67,13 +90,6 @@ namespace storm {
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
std::vector<storm::expressions::Expression> subexprreach; std::vector<storm::expressions::Expression> subexprreach;
subexprreach.push_back(!reachVarExpressions.at(state));
subexprreach.push_back(!actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action));
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
subexprreach.push_back(reachVarExpressions.at(entries.getColumn()));
}
smtSolver->add(storm::expressions::disjunction(subexprreach));
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
for (uint64_t j = 1; j < k; ++j) { for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]); pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]);
@ -81,7 +97,6 @@ namespace storm {
} }
rowindex++; rowindex++;
} }
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back()));
for (uint64_t j = 1; j < k; ++j) { for (uint64_t j = 1; j < k; ++j) {
std::vector<storm::expressions::Expression> pathsubexprs; std::vector<storm::expressions::Expression> pathsubexprs;
@ -91,16 +106,21 @@ namespace storm {
} }
smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs)));
} }
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back()));
} else {
rowindex += pomdp.getNumberOfChoices(state);
} }
} }
for (auto const& actionVars : actionSelectionVarExpressions) {
smtSolver->add(storm::expressions::disjunction(actionVars));
}
} }
template <typename ValueType> template <typename ValueType>
bool QualitativeStrategySearchNaive<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) { bool QualitativeStrategySearchNaive<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) {
STORM_LOG_TRACE("Use lookahead of "<<k);
if (k < maxK) { if (k < maxK) {
initialize(k); initialize(k);
} }
@ -119,24 +139,23 @@ namespace storm {
std::cout << smtSolver->getSmtLibString() << std::endl;
STORM_LOG_TRACE(smtSolver->getSmtLibString());
STORM_LOG_DEBUG("Call to SMT Solver");
stats.smtCheckTimer.start();
auto result = smtSolver->check(); auto result = smtSolver->check();
uint64_t i = 0;
smtSolver->push();
stats.smtCheckTimer.stop();
if (result == storm::solver::SmtSolver::CheckResult::Unknown) { if (result == storm::solver::SmtSolver::CheckResult::Unknown) {
STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result");
} else if (result == storm::solver::SmtSolver::CheckResult::Unsat) { } else if (result == storm::solver::SmtSolver::CheckResult::Unsat) {
std::cout << std::endl << "Unsatisfiable!" << std::endl;
STORM_LOG_DEBUG("Unsatisfiable!");
return false; return false;
} else {
}
std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl;
STORM_LOG_DEBUG("Satisfying assignment: ");
auto model = smtSolver->getModel(); auto model = smtSolver->getModel();
std::cout << "states that are okay" << std::endl;
size_t i = 0;
storm::storage::BitVector observations(pomdp.getNrObservations()); storm::storage::BitVector observations(pomdp.getNrObservations());
storm::storage::BitVector remainingstates(pomdp.getNumberOfStates()); storm::storage::BitVector remainingstates(pomdp.getNumberOfStates());
for (auto rv : reachVars) { for (auto rv : reachVars) {
@ -146,7 +165,6 @@ namespace storm {
} else { } else {
remainingstates.set(i); remainingstates.set(i);
} }
//std::cout << i << ": " << model->getBooleanValue(rv) << ", ";
++i; ++i;
} }
std::vector <std::set<uint64_t>> scheduler; std::vector <std::set<uint64_t>> scheduler;
@ -160,15 +178,17 @@ namespace storm {
act++; act++;
} }
} }
std::cout << "the scheduler: " << std::endl;
// TODO move this into a print scheduler function.
//STORM_LOG_TRACE("the scheduler: ");
for (uint64_t obs = 0; obs < scheduler.size(); ++obs) { for (uint64_t obs = 0; obs < scheduler.size(); ++obs) {
if (observations.get(obs)) { if (observations.get(obs)) {
std::cout << "observation: " << obs << std::endl;
std::cout << "actions:";
for (auto act : scheduler[obs]) {
std::cout << " " << act;
}
std::cout << std::endl;
//STORM_LOG_TRACE("observation: " << obs);
//std::cout << "actions:";
//for (auto act : scheduler[obs]) {
// std::cout << " " << act;
//}
//std::cout << std::endl;
} }
} }
@ -177,9 +197,6 @@ namespace storm {
} }
}
template class QualitativeStrategySearchNaive<double>; template class QualitativeStrategySearchNaive<double>;
template class QualitativeStrategySearchNaive<storm::RationalNumber>; template class QualitativeStrategySearchNaive<storm::RationalNumber>;
} }

28
src/storm-pomdp/analysis/QualitativeStrategySearchNaive.h

@ -3,6 +3,7 @@
#include "storm/solver/SmtSolver.h" #include "storm/solver/SmtSolver.h"
#include "storm/models/sparse/Pomdp.h" #include "storm/models/sparse/Pomdp.h"
#include "storm/utility/solver.h" #include "storm/utility/solver.h"
#include "storm/utility/Stopwatch.h"
#include "storm/exceptions/UnexpectedException.h" #include "storm/exceptions/UnexpectedException.h"
namespace storm { namespace storm {
@ -10,8 +11,15 @@ namespace storm {
template<typename ValueType> template<typename ValueType>
class QualitativeStrategySearchNaive { class QualitativeStrategySearchNaive {
// Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper.
// Implements to the Chatterjee, Chmelik, Davies (AAAI-16) paper.
class Statistics {
public:
Statistics() = default;
storm::utility::Stopwatch totalTimer;
storm::utility::Stopwatch smtCheckTimer;
storm::utility::Stopwatch initializeSolverTimer;
};
public: public:
QualitativeStrategySearchNaive(storm::models::sparse::Pomdp<ValueType> const& pomdp, QualitativeStrategySearchNaive(storm::models::sparse::Pomdp<ValueType> const& pomdp,
@ -33,16 +41,15 @@ namespace storm {
} }
void analyzeForInitialStates(uint64_t k) { void analyzeForInitialStates(uint64_t k) {
analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates());
STORM_LOG_TRACE("Bad states: " << surelyReachSinkStates);
STORM_LOG_TRACE("Target states: " << targetStates);
STORM_LOG_TRACE("Questionmark states: " << (~surelyReachSinkStates & ~targetStates));
bool result = analyze(k, ~surelyReachSinkStates & ~targetStates, pomdp.getInitialStates());
if (result) {
STORM_PRINT_AND_LOG("From initial state, one can almost-surely reach the target.");
} else {
STORM_PRINT_AND_LOG("From initial state, one may not almost-surely reach the target .");
} }
void findNewStrategyForSomeState(uint64_t k) {
std::cout << surelyReachSinkStates << std::endl;
std::cout << targetStates << std::endl;
std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl;
analyze(k, ~surelyReachSinkStates & ~targetStates);
} }
bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector()); bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector());
@ -51,6 +58,7 @@ namespace storm {
void initialize(uint64_t k); void initialize(uint64_t k);
Statistics stats;
std::unique_ptr<storm::solver::SmtSolver> smtSolver; std::unique_ptr<storm::solver::SmtSolver> smtSolver;
storm::models::sparse::Pomdp<ValueType> const& pomdp; storm::models::sparse::Pomdp<ValueType> const& pomdp;
std::shared_ptr<storm::expressions::ExpressionManager> expressionManager; std::shared_ptr<storm::expressions::ExpressionManager> expressionManager;

Loading…
Cancel
Save