Browse Source

make everything compile again, add/fix method for memless strategy search (CCD16) and towards iterative search

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
5bbf54cb78
  1. 9
      src/storm-pomdp-cli/settings/modules/POMDPSettings.cpp
  2. 1
      src/storm-pomdp-cli/settings/modules/POMDPSettings.h
  3. 46
      src/storm-pomdp-cli/storm-pomdp.cpp
  4. 149
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  5. 52
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h
  6. 186
      src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp
  7. 74
      src/storm-pomdp/analysis/QualitativeStrategySearchNaive.h

9
src/storm-pomdp-cli/settings/modules/POMDPSettings.cpp

@ -27,6 +27,7 @@ namespace storm {
const std::string transformBinaryOption = "transformbinary";
const std::string transformSimpleOption = "transformsimple";
const std::string memlessSearchOption = "memlesssearch";
std::vector<std::string> memlessSearchMethods = {"none", "ccdmemless", "ccdmemory", "iterative"};
POMDPSettings::POMDPSettings() : ModuleSettings(moduleName) {
this->addOption(storm::settings::OptionBuilder(moduleName, exportAsParametricModelOption, false, "Export the parametric file.").addArgument(storm::settings::ArgumentBuilder::createStringArgument("filename", "The name of the file to which to write the model.").build()).build());
@ -46,7 +47,9 @@ namespace storm {
10).addValidatorUnsignedInteger(
storm::settings::ArgumentValidatorFactory::createUnsignedGreaterValidator(
0)).build()).build());
this->addOption(storm::settings::OptionBuilder(moduleName, memlessSearchOption, false, "Search for a qualitative memoryless scheuler").build());
this->addOption(storm::settings::OptionBuilder(moduleName, memlessSearchOption, false, "Search for a qualitative memoryless scheuler").addArgument(storm::settings::ArgumentBuilder::createStringArgument("method", "method name").addValidatorString(ArgumentValidatorFactory::createMultipleChoiceValidator(memlessSearchMethods)).setDefaultValueString("none").build()).build());
}
bool POMDPSettings::isExportToParametricSet() const {
@ -86,6 +89,10 @@ namespace storm {
return this->getOption(memlessSearchOption).getHasOptionBeenSet();
}
std::string POMDPSettings::getMemlessSearchMethod() const {
return this->getOption(memlessSearchOption).getArgumentByName("method").getValueAsString();
}
uint64_t POMDPSettings::getMemoryBound() const {
return this->getOption(memoryBoundOption).getArgumentByName("bound").getValueAsUnsignedInteger();
}

1
src/storm-pomdp-cli/settings/modules/POMDPSettings.h

@ -33,6 +33,7 @@ namespace storm {
bool isTransformSimpleSet() const;
bool isTransformBinarySet() const;
bool isMemlessSearchSet() const;
std::string getMemlessSearchMethod() const;
std::string getFscApplicationTypeString() const;
uint64_t getMemoryBound() const;

46
src/storm-pomdp-cli/storm-pomdp.cpp

@ -26,9 +26,10 @@
#include "storm/settings/modules/TopologicalEquationSolverSettings.h"
#include "storm/settings/modules/ModelCheckerSettings.h"
#include "storm/settings/modules/MultiplierSettings.h"
#include "storm/settings/modules/TransformationSettings.h"
#include "storm/settings/modules/MultiObjectiveSettings.h"
#include "storm-pomdp-cli/settings/modules/POMDPSettings.h"
#include "storm/analysis/GraphConditions.h"
#include "storm-cli-utilities/cli.h"
@ -44,6 +45,7 @@
#include "storm-pomdp/analysis/QualitativeAnalysis.h"
#include "storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.h"
#include "storm-pomdp/analysis/MemlessStrategySearchQualitative.h"
#include "storm-pomdp/analysis/QualitativeStrategySearchNaive.h"
#include "storm/api/storm.h"
#include <typeinfo>
@ -59,6 +61,8 @@ void initializeSettings() {
storm::settings::addModule<storm::settings::modules::CoreSettings>();
storm::settings::addModule<storm::settings::modules::DebugSettings>();
storm::settings::addModule<storm::settings::modules::BuildSettings>();
storm::settings::addModule<storm::settings::modules::TransformationSettings>();
storm::settings::addModule<storm::settings::modules::GmmxxEquationSolverSettings>();
storm::settings::addModule<storm::settings::modules::EigenEquationSolverSettings>();
storm::settings::addModule<storm::settings::modules::NativeEquationSolverSettings>();
@ -79,9 +83,9 @@ void initializeSettings() {
}
template<typename ValueType>
bool extractTargetAndSinkObservationSets(std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> const& pomdp, storm::logic::Formula const& subformula, std::set<uint32_t>& targetObservationSet, storm::storage::BitVector& badStates) {
bool extractTargetAndSinkObservationSets(std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> const& pomdp, storm::logic::Formula const& subformula, std::set<uint32_t>& targetObservationSet, storm::storage::BitVector& targetStates, storm::storage::BitVector& badStates) {
//TODO refactor (use model checker to determine the states, then transform into observations).
//TODO rename into appropriate function name.
bool validFormula = false;
if (subformula.isEventuallyFormula()) {
storm::logic::EventuallyFormula const &eventuallyFormula = subformula.asEventuallyFormula();
@ -94,6 +98,7 @@ bool extractTargetAndSinkObservationSets(std::shared_ptr<storm::models::sparse::
for (size_t state = 0; state < pomdp->getNumberOfStates(); ++state) {
if (labeling.getStateHasLabel(targetLabel, state)) {
targetObservationSet.insert(pomdp->getObservation(state));
targetStates.set(state);
}
}
} else if (subformula2.isAtomicExpressionFormula()) {
@ -106,18 +111,19 @@ bool extractTargetAndSinkObservationSets(std::shared_ptr<storm::models::sparse::
for (size_t state = 0; state < pomdp->getNumberOfStates(); ++state) {
if (labeling.getStateHasLabel(targetLabel, state)) {
targetObservationSet.insert(pomdp->getObservation(state));
targetStates.set(state);
}
}
}
} else if (subformula.isUntilFormula()) {
storm::logic::UntilFormula const &eventuallyFormula = subformula.asUntilFormula();
storm::logic::Formula const &subformula1 = eventuallyFormula.getLeftSubformula();
storm::logic::UntilFormula const &untilFormula = subformula.asUntilFormula();
storm::logic::Formula const &subformula1 = untilFormula.getLeftSubformula();
if (subformula1.isAtomicLabelFormula()) {
storm::logic::AtomicLabelFormula const &alFormula = subformula1.asAtomicLabelFormula();
std::string targetLabel = alFormula.getLabel();
auto labeling = pomdp->getStateLabeling();
for (size_t state = 0; state < pomdp->getNumberOfStates(); ++state) {
if (labeling.getStateHasLabel(targetLabel, state)) {
if (!labeling.getStateHasLabel(targetLabel, state)) {
badStates.set(state);
}
}
@ -128,14 +134,14 @@ bool extractTargetAndSinkObservationSets(std::shared_ptr<storm::models::sparse::
std::string targetLabel = formula3.getLabel();
auto labeling = pomdp->getStateLabeling();
for (size_t state = 0; state < pomdp->getNumberOfStates(); ++state) {
if (labeling.getStateHasLabel(targetLabel, state)) {
if (!labeling.getStateHasLabel(targetLabel, state)) {
badStates.set(state);
}
}
} else {
return false;
}
storm::logic::Formula const &subformula2 = eventuallyFormula.getRightSubformula();
storm::logic::Formula const &subformula2 = untilFormula.getRightSubformula();
if (subformula2.isAtomicLabelFormula()) {
storm::logic::AtomicLabelFormula const &alFormula = subformula2.asAtomicLabelFormula();
validFormula = true;
@ -144,7 +150,9 @@ bool extractTargetAndSinkObservationSets(std::shared_ptr<storm::models::sparse::
for (size_t state = 0; state < pomdp->getNumberOfStates(); ++state) {
if (labeling.getStateHasLabel(targetLabel, state)) {
targetObservationSet.insert(pomdp->getObservation(state));
targetStates.set(state);
}
}
} else if (subformula2.isAtomicExpressionFormula()) {
validFormula = true;
@ -156,7 +164,9 @@ bool extractTargetAndSinkObservationSets(std::shared_ptr<storm::models::sparse::
for (size_t state = 0; state < pomdp->getNumberOfStates(); ++state) {
if (labeling.getStateHasLabel(targetLabel, state)) {
targetObservationSet.insert(pomdp->getObservation(state));
targetStates.set(state);
}
}
}
}
@ -227,9 +237,10 @@ int main(const int argc, const char** argv) {
if (formula->isProbabilityOperatorFormula()) {
std::set<uint32_t> targetObservationSet;
std::set<uint32_t> badObservationSet;
storm::storage::BitVector targetStates(pomdp->getNumberOfStates());
storm::storage::BitVector badStates(pomdp->getNumberOfStates());
bool validFormula = extractTargetAndSinkObservationSets(pomdp, subformula1, targetObservationSet, badObservationSet);
bool validFormula = extractTargetAndSinkObservationSets(pomdp, subformula1, targetObservationSet, targetStates, badStates);
STORM_LOG_THROW(validFormula, storm::exceptions::InvalidPropertyException,
"The formula is not supported by the grid approximation");
STORM_LOG_ASSERT(!targetObservationSet.empty(), "The set of target observations is empty!");
@ -278,11 +289,22 @@ int main(const int argc, const char** argv) {
}
}
if (pomdpSettings.isMemlessSearchSet()) {
// std::cout << std::endl;
// pomdp->writeDotToStream(std::cout);
// std::cout << std::endl;
// std::cout << std::endl;
storm::expressions::ExpressionManager expressionManager;
std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::Z3SmtSolverFactory>();
if (pomdpSettings.getMemlessSearchMethod() == "ccd16memless") {
storm::pomdp::QualitativeStrategySearchNaive<double> memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory);
memlessSearch.findNewStrategyForSomeState(5);
} else if (pomdpSettings.getMemlessSearchMethod() == "iterative") {
storm::pomdp::MemlessStrategySearchQualitative<double> memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory);
memlessSearch.findNewStrategyForSomeState(5);
} else {
STORM_LOG_ERROR("This method is not implemented.");
}
storm::pomdp::MemlessStrategySearchQualitative<double> memlessSearch(*pomdp, targetObservationSet, smtSolverFactory);
memlessSearch.analyze(5);
}
} else if (formula->isRewardOperatorFormula()) {

149
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -6,12 +6,12 @@ namespace storm {
template <typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::initialize(uint64_t k) {
if (maxK == -1) {
if (maxK == std::numeric_limits<uint64_t>::max()) {
// not initialized at all.
// Create some data structures.
for(uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
actionSelectionVars.push_back(std::vector<storm::expressions::Expression>());
actionSelectionVars.push_back(std::vector<storm::expressions::Variable>());
actionSelectionVarExpressions.push_back(std::vector<storm::expressions::Expression>());
statesPerObservation.push_back(std::vector<uint64_t>()); // Consider using bitvectors instead.
}
@ -24,91 +24,182 @@ namespace storm {
for (uint64_t i = 0; i < k; ++i) {
pathVars.back().push_back(expressionManager->declareBooleanVariable("P-"+std::to_string(stateId)+"-"+std::to_string(i)).getExpression());
}
reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId)).getExpression());
reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId)));
reachVarExpressions.push_back(reachVars.back().getExpression());
statesPerObservation.at(obs).push_back(stateId++);
}
assert(pathVars.size() == pomdp.getNumberOfStates());
assert(reachVars.size() == pomdp.getNumberOfStates());
assert(reachVarExpressions.size() == pomdp.getNumberOfStates());
// Create the action selection variables.
uint64_t obs = 0;
for(auto const& statesForObservation : statesPerObservation) {
for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()); ++a) {
std::string varName = "A-" + std::to_string(obs) + "-" + std::to_string(a);
actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName).getExpression());
actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName));
actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression());
}
++obs;
}
} else {
assert(false);
}
uint64_t rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
std::vector<std::vector<storm::expressions::Expression>> pathsubsubexprs;
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs.push_back(std::vector<storm::expressions::Expression>());
}
if (targetObservations.count(pomdp.getObservation(state)) > 0) {
if (targetStates.get(state)) {
smtSolver->add(pathVars[state][0]);
} else {
smtSolver->add(!pathVars[state][0]);
}
if (surelyReachSinkStates.get(state)) {
smtSolver->add(!reachVars[state]);
smtSolver->add(!reachVarExpressions[state]);
} else if(!targetStates.get(state)) {
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>());
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>());
}
}
else {
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
std::vector<storm::expressions::Expression> subexprreach;
subexprreach.push_back(!reachVars.at(state));
subexprreach.push_back(!actionSelectionVars.at(pomdp.getObservation(state)).at(action));
subexprreach.push_back(!reachVarExpressions.at(state));
subexprreach.push_back(!actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action));
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
subexprreach.push_back(reachVars.at(entries.getColumn()));
subexprreach.push_back(reachVarExpressions.at(entries.getColumn()));
}
smtSolver->add(storm::expressions::disjunction(subexprreach));
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs[j - 1].push_back(pathVars[entries.getColumn()][j - 1]);
pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]);
}
}
rowindex++;
}
smtSolver->add(storm::expressions::implies(reachVars.at(state), pathVars.at(state).back()));
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back()));
for (uint64_t j = 1; j < k; ++j) {
std::vector<storm::expressions::Expression> pathsubexprs;
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubexprs.push_back(actionSelectionVars.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1]));
pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action]));
}
smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs)));
}
}
}
for (auto const& actionVars : actionSelectionVars) {
for (auto const& actionVars : actionSelectionVarExpressions) {
smtSolver->add(storm::expressions::disjunction(actionVars));
}
}
template <typename ValueType>
bool MemlessStrategySearchQualitative<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) {
if (k < maxK) {
initialize(k);
}
std::vector<storm::expressions::Expression> atLeastOneOfStates;
for (uint64_t state : oneOfTheseStates) {
STORM_LOG_ASSERT(reachVarExpressions.size() > state, "state id " << state << " exceeds number of states (" << reachVarExpressions.size() << ")" );
atLeastOneOfStates.push_back(reachVarExpressions[state]);
}
assert(atLeastOneOfStates.size() > 0);
smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates));
for (uint64_t state : allOfTheseStates) {
assert(reachVarExpressions.size() > state);
smtSolver->add(reachVarExpressions[state]);
}
std::cout << smtSolver->getSmtLibString() << std::endl;
std::vector<std::set<uint64_t>> scheduler;
//for (auto const& )
while (true) {
auto result = smtSolver->check();
uint64_t i = 0;
if (result == storm::solver::SmtSolver::CheckResult::Unknown) {
STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result");
} else if (result == storm::solver::SmtSolver::CheckResult::Unsat) {
std::cout << std::endl << "Unsatisfiable!" << std::endl;
return false;
}
std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl;
auto model = smtSolver->getModel();
std::cout << "states that are okay" << std::endl;
storm::storage::BitVector observations(pomdp.getNrObservations());
storm::storage::BitVector remainingstates(pomdp.getNumberOfStates());
for (auto rv : reachVars) {
if (model->getBooleanValue(rv)) {
std::cout << i << " " << std::endl;
observations.set(pomdp.getObservation(i));
} else {
remainingstates.set(i);
}
//std::cout << i << ": " << model->getBooleanValue(rv) << ", ";
++i;
}
scheduler.clear();
std::vector<storm::expressions::Expression> schedulerSoFar;
uint64_t obs = 0;
for (auto const &actionSelectionVarsForObs : actionSelectionVars) {
uint64_t act = 0;
scheduler.push_back(std::set<uint64_t>());
for (auto const &asv : actionSelectionVarsForObs) {
if (model->getBooleanValue(asv)) {
scheduler.back().insert(act);
schedulerSoFar.push_back(actionSelectionVarExpressions[obs][act]);
}
act++;
}
obs++;
}
std::cout << "the scheduler: " << std::endl;
for (uint64_t obs = 0; obs < scheduler.size(); ++obs) {
if (observations.get(obs)) {
std::cout << "observation: " << obs << std::endl;
std::cout << "actions:";
for (auto act : scheduler[obs]) {
std::cout << " " << act;
}
std::cout << std::endl;
}
}
std::vector<storm::expressions::Expression> remainingExpressions;
for (auto index : remainingstates) {
remainingExpressions.push_back(reachVarExpressions[index]);
}
smtSolver->push();
// Add scheduler
smtSolver->add(storm::expressions::conjunction(schedulerSoFar));
smtSolver->add(storm::expressions::disjunction(remainingExpressions));
}
}
template class MemlessStrategySearchQualitative<double>;
template class MemlessStrategySearchQualitative<storm::RationalNumber>;
}
}

52
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

@ -3,6 +3,7 @@
#include "storm/solver/SmtSolver.h"
#include "storm/models/sparse/Pomdp.h"
#include "storm/utility/solver.h"
#include "storm/exceptions/UnexpectedException.h"
namespace storm {
namespace pomdp {
@ -15,8 +16,12 @@ namespace pomdp {
public:
MemlessStrategySearchQualitative(storm::models::sparse::Pomdp<ValueType> const& pomdp,
std::set<uint32_t> const& targetObservationSet,
storm::storage::BitVector const& targetStates,
storm::storage::BitVector const& surelyReachSinkStates,
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory) :
pomdp(pomdp),
targetStates(targetStates),
surelyReachSinkStates(surelyReachSinkStates),
targetObservations(targetObservationSet) {
this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>();
smtSolver = smtSolverFactory->create(*expressionManager);
@ -27,49 +32,40 @@ namespace pomdp {
surelyReachSinkStates = surelyReachSink;
}
void analyze(uint64_t k) {
if (k < maxK) {
initialize(k);
void analyzeForInitialStates(uint64_t k) {
analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates());
}
std::cout << smtSolver->getSmtLibString() << std::endl;
for (uint64_t state : pomdp.getInitialStates()) {
smtSolver->add(reachVars[state]);
}
auto result = smtSolver->check();
switch(result) {
case storm::solver::SmtSolver::CheckResult::Sat:
std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl;
case storm::solver::SmtSolver::CheckResult::Unsat:
// std::cout << std::endl << "Unsatisfiability core: {" << std::endl;
// for (auto const& expr : solver->getUnsatCore()) {
// std::cout << "\t " << expr << std::endl;
// }
// std::cout << "}" << std::endl;
default:
std::cout<< "oops." << std::endl;
// STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result");
}
//std::cout << "get model:" << std::endl;
//std::cout << smtSolver->getModel().toString() << std::endl;
void findNewStrategyForSomeState(uint64_t k) {
std::cout << surelyReachSinkStates << std::endl;
std::cout << targetStates << std::endl;
std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl;
analyze(k, ~surelyReachSinkStates & ~targetStates);
}
bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector());
private:
void initialize(uint64_t k);
std::unique_ptr<storm::solver::SmtSolver> smtSolver;
storm::models::sparse::Pomdp<ValueType> const& pomdp;
std::shared_ptr<storm::expressions::ExpressionManager> expressionManager;
uint64_t maxK = -1;
uint64_t maxK = std::numeric_limits<uint64_t>::max();
std::set<uint32_t> targetObservations;
storm::storage::BitVector targetStates;
storm::storage::BitVector surelyReachSinkStates;
std::vector<std::vector<uint64_t>> statesPerObservation;
std::vector<std::vector<storm::expressions::Expression>> actionSelectionVars; // A_{z,a}
std::vector<storm::expressions::Expression> reachVars;
std::vector<std::vector<storm::expressions::Expression>> actionSelectionVarExpressions; // A_{z,a}
std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars;
std::vector<storm::expressions::Variable> reachVars;
std::vector<storm::expressions::Expression> reachVarExpressions;
std::vector<std::vector<storm::expressions::Expression>> pathVars;

186
src/storm-pomdp/analysis/QualitativeStrategySearchNaive.cpp

@ -0,0 +1,186 @@
#include "storm-pomdp/analysis/QualitativeStrategySearchNaive.h"
namespace storm {
namespace pomdp {
template <typename ValueType>
void QualitativeStrategySearchNaive<ValueType>::initialize(uint64_t k) {
if (maxK == std::numeric_limits<uint64_t>::max()) {
// not initialized at all.
// Create some data structures.
for(uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
actionSelectionVars.push_back(std::vector<storm::expressions::Variable>());
actionSelectionVarExpressions.push_back(std::vector<storm::expressions::Expression>());
statesPerObservation.push_back(std::vector<uint64_t>()); // Consider using bitvectors instead.
}
// Fill the states-per-observation mapping,
// declare the reachability variables,
// declare the path variables.
uint64_t stateId = 0;
for(auto obs : pomdp.getObservations()) {
pathVars.push_back(std::vector<storm::expressions::Expression>());
for (uint64_t i = 0; i < k; ++i) {
pathVars.back().push_back(expressionManager->declareBooleanVariable("P-"+std::to_string(stateId)+"-"+std::to_string(i)).getExpression());
}
reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId)));
reachVarExpressions.push_back(reachVars.back().getExpression());
statesPerObservation.at(obs).push_back(stateId++);
}
assert(pathVars.size() == pomdp.getNumberOfStates());
// Create the action selection variables.
uint64_t obs = 0;
for(auto const& statesForObservation : statesPerObservation) {
for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()); ++a) {
std::string varName = "A-" + std::to_string(obs) + "-" + std::to_string(a);
actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName));
actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression());
}
++obs;
}
} else {
assert(false);
}
uint64_t rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (targetStates.get(state)) {
smtSolver->add(pathVars[state][0]);
} else {
smtSolver->add(!pathVars[state][0]);
}
if (surelyReachSinkStates.get(state)) {
smtSolver->add(!reachVarExpressions[state]);
} else if(!targetStates.get(state)) {
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>());
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>());
}
}
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
std::vector<storm::expressions::Expression> subexprreach;
subexprreach.push_back(!reachVarExpressions.at(state));
subexprreach.push_back(!actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action));
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
subexprreach.push_back(reachVarExpressions.at(entries.getColumn()));
}
smtSolver->add(storm::expressions::disjunction(subexprreach));
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]);
}
}
rowindex++;
}
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back()));
for (uint64_t j = 1; j < k; ++j) {
std::vector<storm::expressions::Expression> pathsubexprs;
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action]));
}
smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs)));
}
}
}
for (auto const& actionVars : actionSelectionVarExpressions) {
smtSolver->add(storm::expressions::disjunction(actionVars));
}
}
template <typename ValueType>
bool QualitativeStrategySearchNaive<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) {
if (k < maxK) {
initialize(k);
}
std::vector<storm::expressions::Expression> atLeastOneOfStates;
for(uint64_t state : oneOfTheseStates) {
atLeastOneOfStates.push_back(reachVarExpressions[state]);
}
assert(atLeastOneOfStates.size() > 0);
smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates));
for(uint64_t state : allOfTheseStates) {
smtSolver->add(reachVarExpressions[state]);
}
std::cout << smtSolver->getSmtLibString() << std::endl;
auto result = smtSolver->check();
uint64_t i = 0;
smtSolver->push();
if (result == storm::solver::SmtSolver::CheckResult::Unknown) {
STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result");
} else if(result == storm::solver::SmtSolver::CheckResult::Unsat) {
std::cout << std::endl << "Unsatisfiable!" << std::endl;
return false;
} else {
std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl;
auto model = smtSolver->getModel();
std::cout << "states that are okay" << std::endl;
storm::storage::BitVector observations(pomdp.getNrObservations());
storm::storage::BitVector remainingstates(pomdp.getNumberOfStates());
for (auto rv : reachVars) {
if (model->getBooleanValue(rv)) {
std::cout << i << " " << std::endl;
observations.set(pomdp.getObservation(i));
} else {
remainingstates.set(i);
}
//std::cout << i << ": " << model->getBooleanValue(rv) << ", ";
++i;
}
std::vector <std::set<uint64_t>> scheduler;
for (auto const &actionSelectionVarsForObs : actionSelectionVars) {
uint64_t act = 0;
scheduler.push_back(std::set<uint64_t>());
for (auto const &asv : actionSelectionVarsForObs) {
if (model->getBooleanValue(asv)) {
scheduler.back().insert(act);
}
act++;
}
}
std::cout << "the scheduler: " << std::endl;
for (uint64_t obs = 0; obs < scheduler.size(); ++obs) {
if (observations.get(obs)) {
std::cout << "observation: " << obs << std::endl;
std::cout << "actions:";
for (auto act : scheduler[obs]) {
std::cout << " " << act;
}
std::cout << std::endl;
}
}
return true;
}
}
template class QualitativeStrategySearchNaive<double>;
template class QualitativeStrategySearchNaive<storm::RationalNumber>;
}
}

74
src/storm-pomdp/analysis/QualitativeStrategySearchNaive.h

@ -0,0 +1,74 @@
#include <vector>
#include "storm/storage/expressions/Expressions.h"
#include "storm/solver/SmtSolver.h"
#include "storm/models/sparse/Pomdp.h"
#include "storm/utility/solver.h"
#include "storm/exceptions/UnexpectedException.h"
namespace storm {
namespace pomdp {
template<typename ValueType>
class QualitativeStrategySearchNaive {
// Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper.
public:
QualitativeStrategySearchNaive(storm::models::sparse::Pomdp<ValueType> const& pomdp,
std::set<uint32_t> const& targetObservationSet,
storm::storage::BitVector const& targetStates,
storm::storage::BitVector const& surelyReachSinkStates,
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory) :
pomdp(pomdp),
targetStates(targetStates),
surelyReachSinkStates(surelyReachSinkStates),
targetObservations(targetObservationSet) {
this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>();
smtSolver = smtSolverFactory->create(*expressionManager);
}
void setSurelyReachSinkStates(storm::storage::BitVector const& surelyReachSink) {
surelyReachSinkStates = surelyReachSink;
}
void analyzeForInitialStates(uint64_t k) {
analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates());
}
void findNewStrategyForSomeState(uint64_t k) {
std::cout << surelyReachSinkStates << std::endl;
std::cout << targetStates << std::endl;
std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl;
analyze(k, ~surelyReachSinkStates & ~targetStates);
}
bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector());
private:
void initialize(uint64_t k);
std::unique_ptr<storm::solver::SmtSolver> smtSolver;
storm::models::sparse::Pomdp<ValueType> const& pomdp;
std::shared_ptr<storm::expressions::ExpressionManager> expressionManager;
uint64_t maxK = std::numeric_limits<uint64_t>::max();
std::set<uint32_t> targetObservations;
storm::storage::BitVector targetStates;
storm::storage::BitVector surelyReachSinkStates;
std::vector<std::vector<uint64_t>> statesPerObservation;
std::vector<std::vector<storm::expressions::Expression>> actionSelectionVarExpressions; // A_{z,a}
std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars;
std::vector<storm::expressions::Variable> reachVars;
std::vector<storm::expressions::Expression> reachVarExpressions;
std::vector<std::vector<storm::expressions::Expression>> pathVars;
};
}
}
Loading…
Cancel
Save