Browse Source
make everything compile again, add/fix method for memless strategy search (CCD16) and towards iterative search
tempestpy_adaptions
make everything compile again, add/fix method for memless strategy search (CCD16) and towards iterative search
tempestpy_adaptions
Sebastian Junges
5 years ago
7 changed files with 450 additions and 73 deletions
-
9src/storm-pomdp-cli/settings/modules/POMDPSettings.cpp
-
1src/storm-pomdp-cli/settings/modules/POMDPSettings.h
-
46src/storm-pomdp-cli/storm-pomdp.cpp
-
149src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
-
52src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h
-
186src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp
-
74src/storm-pomdp/analysis/QualitativeStrategySearchNaive.h
@ -0,0 +1,186 @@ |
|||
|
|||
|
|||
#include "storm-pomdp/analysis/QualitativeStrategySearchNaive.h"
|
|||
|
|||
|
|||
namespace storm { |
|||
namespace pomdp { |
|||
|
|||
template <typename ValueType> |
|||
void QualitativeStrategySearchNaive<ValueType>::initialize(uint64_t k) { |
|||
if (maxK == std::numeric_limits<uint64_t>::max()) { |
|||
// not initialized at all.
|
|||
// Create some data structures.
|
|||
for(uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { |
|||
actionSelectionVars.push_back(std::vector<storm::expressions::Variable>()); |
|||
actionSelectionVarExpressions.push_back(std::vector<storm::expressions::Expression>()); |
|||
statesPerObservation.push_back(std::vector<uint64_t>()); // Consider using bitvectors instead.
|
|||
} |
|||
|
|||
// Fill the states-per-observation mapping,
|
|||
// declare the reachability variables,
|
|||
// declare the path variables.
|
|||
uint64_t stateId = 0; |
|||
for(auto obs : pomdp.getObservations()) { |
|||
pathVars.push_back(std::vector<storm::expressions::Expression>()); |
|||
for (uint64_t i = 0; i < k; ++i) { |
|||
pathVars.back().push_back(expressionManager->declareBooleanVariable("P-"+std::to_string(stateId)+"-"+std::to_string(i)).getExpression()); |
|||
} |
|||
reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId))); |
|||
reachVarExpressions.push_back(reachVars.back().getExpression()); |
|||
statesPerObservation.at(obs).push_back(stateId++); |
|||
} |
|||
assert(pathVars.size() == pomdp.getNumberOfStates()); |
|||
|
|||
// Create the action selection variables.
|
|||
uint64_t obs = 0; |
|||
for(auto const& statesForObservation : statesPerObservation) { |
|||
for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()); ++a) { |
|||
std::string varName = "A-" + std::to_string(obs) + "-" + std::to_string(a); |
|||
actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName)); |
|||
actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression()); |
|||
} |
|||
++obs; |
|||
} |
|||
} else { |
|||
assert(false); |
|||
} |
|||
|
|||
uint64_t rowindex = 0; |
|||
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|||
if (targetStates.get(state)) { |
|||
smtSolver->add(pathVars[state][0]); |
|||
} else { |
|||
smtSolver->add(!pathVars[state][0]); |
|||
} |
|||
|
|||
if (surelyReachSinkStates.get(state)) { |
|||
smtSolver->add(!reachVarExpressions[state]); |
|||
} else if(!targetStates.get(state)) { |
|||
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs; |
|||
for (uint64_t j = 1; j < k; ++j) { |
|||
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>()); |
|||
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { |
|||
pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>()); |
|||
} |
|||
} |
|||
|
|||
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { |
|||
std::vector<storm::expressions::Expression> subexprreach; |
|||
|
|||
subexprreach.push_back(!reachVarExpressions.at(state)); |
|||
subexprreach.push_back(!actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action)); |
|||
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { |
|||
subexprreach.push_back(reachVarExpressions.at(entries.getColumn())); |
|||
} |
|||
smtSolver->add(storm::expressions::disjunction(subexprreach)); |
|||
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { |
|||
for (uint64_t j = 1; j < k; ++j) { |
|||
pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]); |
|||
} |
|||
} |
|||
rowindex++; |
|||
} |
|||
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back())); |
|||
|
|||
for (uint64_t j = 1; j < k; ++j) { |
|||
std::vector<storm::expressions::Expression> pathsubexprs; |
|||
|
|||
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { |
|||
pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); |
|||
} |
|||
smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); |
|||
} |
|||
} |
|||
} |
|||
|
|||
for (auto const& actionVars : actionSelectionVarExpressions) { |
|||
smtSolver->add(storm::expressions::disjunction(actionVars)); |
|||
} |
|||
} |
|||
|
|||
template <typename ValueType> |
|||
bool QualitativeStrategySearchNaive<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) { |
|||
if (k < maxK) { |
|||
initialize(k); |
|||
} |
|||
|
|||
std::vector<storm::expressions::Expression> atLeastOneOfStates; |
|||
|
|||
for(uint64_t state : oneOfTheseStates) { |
|||
atLeastOneOfStates.push_back(reachVarExpressions[state]); |
|||
} |
|||
assert(atLeastOneOfStates.size() > 0); |
|||
smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates)); |
|||
|
|||
for(uint64_t state : allOfTheseStates) { |
|||
smtSolver->add(reachVarExpressions[state]); |
|||
} |
|||
|
|||
|
|||
|
|||
std::cout << smtSolver->getSmtLibString() << std::endl; |
|||
|
|||
auto result = smtSolver->check(); |
|||
uint64_t i = 0; |
|||
smtSolver->push(); |
|||
|
|||
|
|||
|
|||
if (result == storm::solver::SmtSolver::CheckResult::Unknown) { |
|||
STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); |
|||
} else if(result == storm::solver::SmtSolver::CheckResult::Unsat) { |
|||
std::cout << std::endl << "Unsatisfiable!" << std::endl; |
|||
return false; |
|||
} else { |
|||
|
|||
std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl; |
|||
auto model = smtSolver->getModel(); |
|||
std::cout << "states that are okay" << std::endl; |
|||
storm::storage::BitVector observations(pomdp.getNrObservations()); |
|||
storm::storage::BitVector remainingstates(pomdp.getNumberOfStates()); |
|||
for (auto rv : reachVars) { |
|||
if (model->getBooleanValue(rv)) { |
|||
std::cout << i << " " << std::endl; |
|||
observations.set(pomdp.getObservation(i)); |
|||
} else { |
|||
remainingstates.set(i); |
|||
} |
|||
//std::cout << i << ": " << model->getBooleanValue(rv) << ", ";
|
|||
++i; |
|||
} |
|||
std::vector <std::set<uint64_t>> scheduler; |
|||
for (auto const &actionSelectionVarsForObs : actionSelectionVars) { |
|||
uint64_t act = 0; |
|||
scheduler.push_back(std::set<uint64_t>()); |
|||
for (auto const &asv : actionSelectionVarsForObs) { |
|||
if (model->getBooleanValue(asv)) { |
|||
scheduler.back().insert(act); |
|||
} |
|||
act++; |
|||
} |
|||
} |
|||
std::cout << "the scheduler: " << std::endl; |
|||
for (uint64_t obs = 0; obs < scheduler.size(); ++obs) { |
|||
if (observations.get(obs)) { |
|||
std::cout << "observation: " << obs << std::endl; |
|||
std::cout << "actions:"; |
|||
for (auto act : scheduler[obs]) { |
|||
std::cout << " " << act; |
|||
} |
|||
std::cout << std::endl; |
|||
} |
|||
} |
|||
|
|||
|
|||
return true; |
|||
} |
|||
|
|||
|
|||
|
|||
} |
|||
|
|||
template class QualitativeStrategySearchNaive<double>; |
|||
template class QualitativeStrategySearchNaive<storm::RationalNumber>; |
|||
} |
|||
} |
@ -0,0 +1,74 @@ |
|||
#include <vector> |
|||
#include "storm/storage/expressions/Expressions.h" |
|||
#include "storm/solver/SmtSolver.h" |
|||
#include "storm/models/sparse/Pomdp.h" |
|||
#include "storm/utility/solver.h" |
|||
#include "storm/exceptions/UnexpectedException.h" |
|||
|
|||
namespace storm { |
|||
namespace pomdp { |
|||
|
|||
template<typename ValueType> |
|||
class QualitativeStrategySearchNaive { |
|||
// Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper. |
|||
|
|||
|
|||
public: |
|||
QualitativeStrategySearchNaive(storm::models::sparse::Pomdp<ValueType> const& pomdp, |
|||
std::set<uint32_t> const& targetObservationSet, |
|||
storm::storage::BitVector const& targetStates, |
|||
storm::storage::BitVector const& surelyReachSinkStates, |
|||
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory) : |
|||
pomdp(pomdp), |
|||
targetStates(targetStates), |
|||
surelyReachSinkStates(surelyReachSinkStates), |
|||
targetObservations(targetObservationSet) { |
|||
this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>(); |
|||
smtSolver = smtSolverFactory->create(*expressionManager); |
|||
|
|||
} |
|||
|
|||
void setSurelyReachSinkStates(storm::storage::BitVector const& surelyReachSink) { |
|||
surelyReachSinkStates = surelyReachSink; |
|||
} |
|||
|
|||
void analyzeForInitialStates(uint64_t k) { |
|||
analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates()); |
|||
} |
|||
|
|||
void findNewStrategyForSomeState(uint64_t k) { |
|||
std::cout << surelyReachSinkStates << std::endl; |
|||
std::cout << targetStates << std::endl; |
|||
std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl; |
|||
analyze(k, ~surelyReachSinkStates & ~targetStates); |
|||
|
|||
|
|||
} |
|||
|
|||
bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector()); |
|||
|
|||
private: |
|||
void initialize(uint64_t k); |
|||
|
|||
|
|||
std::unique_ptr<storm::solver::SmtSolver> smtSolver; |
|||
storm::models::sparse::Pomdp<ValueType> const& pomdp; |
|||
std::shared_ptr<storm::expressions::ExpressionManager> expressionManager; |
|||
uint64_t maxK = std::numeric_limits<uint64_t>::max(); |
|||
|
|||
std::set<uint32_t> targetObservations; |
|||
storm::storage::BitVector targetStates; |
|||
storm::storage::BitVector surelyReachSinkStates; |
|||
|
|||
std::vector<std::vector<uint64_t>> statesPerObservation; |
|||
std::vector<std::vector<storm::expressions::Expression>> actionSelectionVarExpressions; // A_{z,a} |
|||
std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars; |
|||
std::vector<storm::expressions::Variable> reachVars; |
|||
std::vector<storm::expressions::Expression> reachVarExpressions; |
|||
std::vector<std::vector<storm::expressions::Expression>> pathVars; |
|||
|
|||
|
|||
|
|||
}; |
|||
} |
|||
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue