Browse Source
make everything compile again, add/fix method for memless strategy search (CCD16) and towards iterative search
main
make everything compile again, add/fix method for memless strategy search (CCD16) and towards iterative search
main
7 changed files with 450 additions and 73 deletions
-
9src/storm-pomdp-cli/settings/modules/POMDPSettings.cpp
-
1src/storm-pomdp-cli/settings/modules/POMDPSettings.h
-
46src/storm-pomdp-cli/storm-pomdp.cpp
-
149src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
-
52src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h
-
186src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp
-
74src/storm-pomdp/analysis/QualitativeStrategySearchNaive.h
@ -0,0 +1,186 @@ |
|||||
|
|
||||
|
|
||||
|
#include "storm-pomdp/analysis/QualitativeStrategySearchNaive.h"
|
||||
|
|
||||
|
|
||||
|
namespace storm { |
||||
|
namespace pomdp { |
||||
|
|
||||
|
template <typename ValueType> |
||||
|
void QualitativeStrategySearchNaive<ValueType>::initialize(uint64_t k) { |
||||
|
if (maxK == std::numeric_limits<uint64_t>::max()) { |
||||
|
// not initialized at all.
|
||||
|
// Create some data structures.
|
||||
|
for(uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { |
||||
|
actionSelectionVars.push_back(std::vector<storm::expressions::Variable>()); |
||||
|
actionSelectionVarExpressions.push_back(std::vector<storm::expressions::Expression>()); |
||||
|
statesPerObservation.push_back(std::vector<uint64_t>()); // Consider using bitvectors instead.
|
||||
|
} |
||||
|
|
||||
|
// Fill the states-per-observation mapping,
|
||||
|
// declare the reachability variables,
|
||||
|
// declare the path variables.
|
||||
|
uint64_t stateId = 0; |
||||
|
for(auto obs : pomdp.getObservations()) { |
||||
|
pathVars.push_back(std::vector<storm::expressions::Expression>()); |
||||
|
for (uint64_t i = 0; i < k; ++i) { |
||||
|
pathVars.back().push_back(expressionManager->declareBooleanVariable("P-"+std::to_string(stateId)+"-"+std::to_string(i)).getExpression()); |
||||
|
} |
||||
|
reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId))); |
||||
|
reachVarExpressions.push_back(reachVars.back().getExpression()); |
||||
|
statesPerObservation.at(obs).push_back(stateId++); |
||||
|
} |
||||
|
assert(pathVars.size() == pomdp.getNumberOfStates()); |
||||
|
|
||||
|
// Create the action selection variables.
|
||||
|
uint64_t obs = 0; |
||||
|
for(auto const& statesForObservation : statesPerObservation) { |
||||
|
for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()); ++a) { |
||||
|
std::string varName = "A-" + std::to_string(obs) + "-" + std::to_string(a); |
||||
|
actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName)); |
||||
|
actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression()); |
||||
|
} |
||||
|
++obs; |
||||
|
} |
||||
|
} else { |
||||
|
assert(false); |
||||
|
} |
||||
|
|
||||
|
uint64_t rowindex = 0; |
||||
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
||||
|
if (targetStates.get(state)) { |
||||
|
smtSolver->add(pathVars[state][0]); |
||||
|
} else { |
||||
|
smtSolver->add(!pathVars[state][0]); |
||||
|
} |
||||
|
|
||||
|
if (surelyReachSinkStates.get(state)) { |
||||
|
smtSolver->add(!reachVarExpressions[state]); |
||||
|
} else if(!targetStates.get(state)) { |
||||
|
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs; |
||||
|
for (uint64_t j = 1; j < k; ++j) { |
||||
|
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>()); |
||||
|
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { |
||||
|
pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>()); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { |
||||
|
std::vector<storm::expressions::Expression> subexprreach; |
||||
|
|
||||
|
subexprreach.push_back(!reachVarExpressions.at(state)); |
||||
|
subexprreach.push_back(!actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action)); |
||||
|
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { |
||||
|
subexprreach.push_back(reachVarExpressions.at(entries.getColumn())); |
||||
|
} |
||||
|
smtSolver->add(storm::expressions::disjunction(subexprreach)); |
||||
|
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { |
||||
|
for (uint64_t j = 1; j < k; ++j) { |
||||
|
pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]); |
||||
|
} |
||||
|
} |
||||
|
rowindex++; |
||||
|
} |
||||
|
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back())); |
||||
|
|
||||
|
for (uint64_t j = 1; j < k; ++j) { |
||||
|
std::vector<storm::expressions::Expression> pathsubexprs; |
||||
|
|
||||
|
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { |
||||
|
pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); |
||||
|
} |
||||
|
smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
for (auto const& actionVars : actionSelectionVarExpressions) { |
||||
|
smtSolver->add(storm::expressions::disjunction(actionVars)); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
template <typename ValueType> |
||||
|
bool QualitativeStrategySearchNaive<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) { |
||||
|
if (k < maxK) { |
||||
|
initialize(k); |
||||
|
} |
||||
|
|
||||
|
std::vector<storm::expressions::Expression> atLeastOneOfStates; |
||||
|
|
||||
|
for(uint64_t state : oneOfTheseStates) { |
||||
|
atLeastOneOfStates.push_back(reachVarExpressions[state]); |
||||
|
} |
||||
|
assert(atLeastOneOfStates.size() > 0); |
||||
|
smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates)); |
||||
|
|
||||
|
for(uint64_t state : allOfTheseStates) { |
||||
|
smtSolver->add(reachVarExpressions[state]); |
||||
|
} |
||||
|
|
||||
|
|
||||
|
|
||||
|
std::cout << smtSolver->getSmtLibString() << std::endl; |
||||
|
|
||||
|
auto result = smtSolver->check(); |
||||
|
uint64_t i = 0; |
||||
|
smtSolver->push(); |
||||
|
|
||||
|
|
||||
|
|
||||
|
if (result == storm::solver::SmtSolver::CheckResult::Unknown) { |
||||
|
STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); |
||||
|
} else if(result == storm::solver::SmtSolver::CheckResult::Unsat) { |
||||
|
std::cout << std::endl << "Unsatisfiable!" << std::endl; |
||||
|
return false; |
||||
|
} else { |
||||
|
|
||||
|
std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl; |
||||
|
auto model = smtSolver->getModel(); |
||||
|
std::cout << "states that are okay" << std::endl; |
||||
|
storm::storage::BitVector observations(pomdp.getNrObservations()); |
||||
|
storm::storage::BitVector remainingstates(pomdp.getNumberOfStates()); |
||||
|
for (auto rv : reachVars) { |
||||
|
if (model->getBooleanValue(rv)) { |
||||
|
std::cout << i << " " << std::endl; |
||||
|
observations.set(pomdp.getObservation(i)); |
||||
|
} else { |
||||
|
remainingstates.set(i); |
||||
|
} |
||||
|
//std::cout << i << ": " << model->getBooleanValue(rv) << ", ";
|
||||
|
++i; |
||||
|
} |
||||
|
std::vector <std::set<uint64_t>> scheduler; |
||||
|
for (auto const &actionSelectionVarsForObs : actionSelectionVars) { |
||||
|
uint64_t act = 0; |
||||
|
scheduler.push_back(std::set<uint64_t>()); |
||||
|
for (auto const &asv : actionSelectionVarsForObs) { |
||||
|
if (model->getBooleanValue(asv)) { |
||||
|
scheduler.back().insert(act); |
||||
|
} |
||||
|
act++; |
||||
|
} |
||||
|
} |
||||
|
std::cout << "the scheduler: " << std::endl; |
||||
|
for (uint64_t obs = 0; obs < scheduler.size(); ++obs) { |
||||
|
if (observations.get(obs)) { |
||||
|
std::cout << "observation: " << obs << std::endl; |
||||
|
std::cout << "actions:"; |
||||
|
for (auto act : scheduler[obs]) { |
||||
|
std::cout << " " << act; |
||||
|
} |
||||
|
std::cout << std::endl; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
|
||||
|
return true; |
||||
|
} |
||||
|
|
||||
|
|
||||
|
|
||||
|
} |
||||
|
|
||||
|
template class QualitativeStrategySearchNaive<double>; |
||||
|
template class QualitativeStrategySearchNaive<storm::RationalNumber>; |
||||
|
} |
||||
|
} |
@ -0,0 +1,74 @@ |
|||||
|
#include <vector> |
||||
|
#include "storm/storage/expressions/Expressions.h" |
||||
|
#include "storm/solver/SmtSolver.h" |
||||
|
#include "storm/models/sparse/Pomdp.h" |
||||
|
#include "storm/utility/solver.h" |
||||
|
#include "storm/exceptions/UnexpectedException.h" |
||||
|
|
||||
|
namespace storm { |
||||
|
namespace pomdp { |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
class QualitativeStrategySearchNaive { |
||||
|
// Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper. |
||||
|
|
||||
|
|
||||
|
public: |
||||
|
QualitativeStrategySearchNaive(storm::models::sparse::Pomdp<ValueType> const& pomdp, |
||||
|
std::set<uint32_t> const& targetObservationSet, |
||||
|
storm::storage::BitVector const& targetStates, |
||||
|
storm::storage::BitVector const& surelyReachSinkStates, |
||||
|
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory) : |
||||
|
pomdp(pomdp), |
||||
|
targetStates(targetStates), |
||||
|
surelyReachSinkStates(surelyReachSinkStates), |
||||
|
targetObservations(targetObservationSet) { |
||||
|
this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>(); |
||||
|
smtSolver = smtSolverFactory->create(*expressionManager); |
||||
|
|
||||
|
} |
||||
|
|
||||
|
void setSurelyReachSinkStates(storm::storage::BitVector const& surelyReachSink) { |
||||
|
surelyReachSinkStates = surelyReachSink; |
||||
|
} |
||||
|
|
||||
|
void analyzeForInitialStates(uint64_t k) { |
||||
|
analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates()); |
||||
|
} |
||||
|
|
||||
|
void findNewStrategyForSomeState(uint64_t k) { |
||||
|
std::cout << surelyReachSinkStates << std::endl; |
||||
|
std::cout << targetStates << std::endl; |
||||
|
std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl; |
||||
|
analyze(k, ~surelyReachSinkStates & ~targetStates); |
||||
|
|
||||
|
|
||||
|
} |
||||
|
|
||||
|
bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector()); |
||||
|
|
||||
|
private: |
||||
|
void initialize(uint64_t k); |
||||
|
|
||||
|
|
||||
|
std::unique_ptr<storm::solver::SmtSolver> smtSolver; |
||||
|
storm::models::sparse::Pomdp<ValueType> const& pomdp; |
||||
|
std::shared_ptr<storm::expressions::ExpressionManager> expressionManager; |
||||
|
uint64_t maxK = std::numeric_limits<uint64_t>::max(); |
||||
|
|
||||
|
std::set<uint32_t> targetObservations; |
||||
|
storm::storage::BitVector targetStates; |
||||
|
storm::storage::BitVector surelyReachSinkStates; |
||||
|
|
||||
|
std::vector<std::vector<uint64_t>> statesPerObservation; |
||||
|
std::vector<std::vector<storm::expressions::Expression>> actionSelectionVarExpressions; // A_{z,a} |
||||
|
std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars; |
||||
|
std::vector<storm::expressions::Variable> reachVars; |
||||
|
std::vector<storm::expressions::Expression> reachVarExpressions; |
||||
|
std::vector<std::vector<storm::expressions::Expression>> pathVars; |
||||
|
|
||||
|
|
||||
|
|
||||
|
}; |
||||
|
} |
||||
|
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue