Browse Source

various improvements and fixes in winning region computation

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
858e2f8a60
  1. 5
      src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.cpp
  2. 1
      src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.h
  3. 71
      src/storm-pomdp-cli/storm-pomdp.cpp
  4. 379
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  5. 63
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h
  6. 17
      src/storm-pomdp/analysis/QualitativeAnalysis.cpp
  7. 2
      src/storm-pomdp/analysis/QualitativeAnalysis.h
  8. 36
      src/storm-pomdp/analysis/WinningRegion.cpp
  9. 6
      src/storm-pomdp/analysis/WinningRegion.h

5
src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.cpp

@ -15,11 +15,13 @@ namespace storm {
const std::string QualitativePOMDPAnalysisSettings::moduleName = "pomdpQualitative"; const std::string QualitativePOMDPAnalysisSettings::moduleName = "pomdpQualitative";
const std::string exportSATCallsOption = "exportSATCallsPath"; const std::string exportSATCallsOption = "exportSATCallsPath";
const std::string lookaheadHorizonOption = "lookaheadHorizon"; const std::string lookaheadHorizonOption = "lookaheadHorizon";
const std::string onlyDeterministicOption = "onlyDeterministic";
QualitativePOMDPAnalysisSettings::QualitativePOMDPAnalysisSettings() : ModuleSettings(moduleName) { QualitativePOMDPAnalysisSettings::QualitativePOMDPAnalysisSettings() : ModuleSettings(moduleName) {
this->addOption(storm::settings::OptionBuilder(moduleName, exportSATCallsOption, false, "Export the SAT calls?.").addArgument(storm::settings::ArgumentBuilder::createStringArgument("path", "The name of the file to which to write the model.").build()).build()); this->addOption(storm::settings::OptionBuilder(moduleName, exportSATCallsOption, false, "Export the SAT calls?.").addArgument(storm::settings::ArgumentBuilder::createStringArgument("path", "The name of the file to which to write the model.").build()).build());
this->addOption(storm::settings::OptionBuilder(moduleName, lookaheadHorizonOption, false, "In reachability in combination with a discrete ranking function, a lookahead is necessary.").addArgument(storm::settings::ArgumentBuilder::createUnsignedIntegerArgument("bound", "The lookahead. Use 0 for the number of states.").setDefaultValueUnsignedInteger(0).build()).build()); this->addOption(storm::settings::OptionBuilder(moduleName, lookaheadHorizonOption, false, "In reachability in combination with a discrete ranking function, a lookahead is necessary.").addArgument(storm::settings::ArgumentBuilder::createUnsignedIntegerArgument("bound", "The lookahead. Use 0 for the number of states.").setDefaultValueUnsignedInteger(0).build()).build());
this->addOption(storm::settings::OptionBuilder(moduleName, onlyDeterministicOption, false, "Search only for deterministic schedulers").build());
} }
uint64_t QualitativePOMDPAnalysisSettings::getLookahead() const { uint64_t QualitativePOMDPAnalysisSettings::getLookahead() const {
@ -33,6 +35,9 @@ namespace storm {
return this->getOption(exportSATCallsOption).getArgumentByName("path").getValueAsString(); return this->getOption(exportSATCallsOption).getArgumentByName("path").getValueAsString();
} }
bool QualitativePOMDPAnalysisSettings::isOnlyDeterministicSet() const {
return this->getOption(onlyDeterministicOption).getHasOptionBeenSet();
}
void QualitativePOMDPAnalysisSettings::finalize() { void QualitativePOMDPAnalysisSettings::finalize() {

1
src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.h

@ -23,6 +23,7 @@ namespace storm {
uint64_t getLookahead() const; uint64_t getLookahead() const;
bool isExportSATCallsSet() const; bool isExportSATCallsSet() const;
std::string getExportSATCallsPath() const; std::string getExportSATCallsPath() const;
bool isOnlyDeterministicSet() const;
virtual ~QualitativePOMDPAnalysisSettings() = default; virtual ~QualitativePOMDPAnalysisSettings() = default;

71
src/storm-pomdp-cli/storm-pomdp.cpp

@ -174,6 +174,15 @@ bool extractTargetAndSinkObservationSets(std::shared_ptr<storm::models::sparse::
return validFormula; return validFormula;
} }
template<typename ValueType>
std::set<uint32_t> extractObservations(storm::models::sparse::Pomdp<ValueType> const& pomdp, storm::storage::BitVector const& states) {
std::set<uint32_t> observations;
for(auto state : states) {
observations.insert(pomdp.getObservation(state));
}
return observations;
}
/*! /*!
* Entry point for the pomdp backend. * Entry point for the pomdp backend.
* *
@ -195,6 +204,7 @@ int main(const int argc, const char** argv) {
auto const& coreSettings = storm::settings::getModule<storm::settings::modules::CoreSettings>(); auto const& coreSettings = storm::settings::getModule<storm::settings::modules::CoreSettings>();
auto const& pomdpSettings = storm::settings::getModule<storm::settings::modules::POMDPSettings>(); auto const& pomdpSettings = storm::settings::getModule<storm::settings::modules::POMDPSettings>();
auto const& ioSettings = storm::settings::getModule<storm::settings::modules::IOSettings>();
auto const &general = storm::settings::getModule<storm::settings::modules::GeneralSettings>(); auto const &general = storm::settings::getModule<storm::settings::modules::GeneralSettings>();
auto const &debug = storm::settings::getModule<storm::settings::modules::DebugSettings>(); auto const &debug = storm::settings::getModule<storm::settings::modules::DebugSettings>();
auto const& pomdpQualSettings = storm::settings::getModule<storm::settings::modules::QualitativePOMDPAnalysisSettings>(); auto const& pomdpQualSettings = storm::settings::getModule<storm::settings::modules::QualitativePOMDPAnalysisSettings>();
@ -208,6 +218,9 @@ int main(const int argc, const char** argv) {
if (debug.isTraceSet()) { if (debug.isTraceSet()) {
storm::utility::setLogLevel(l3pp::LogLevel::TRACE); storm::utility::setLogLevel(l3pp::LogLevel::TRACE);
} }
if (debug.isLogfileSet()) {
storm::utility::initializeFileLogging();
}
// For several engines, no model building step is performed, but the verification is started right away. // For several engines, no model building step is performed, but the verification is started right away.
storm::settings::modules::CoreSettings::Engine engine = coreSettings.getEngine(); storm::settings::modules::CoreSettings::Engine engine = coreSettings.getEngine();
@ -237,17 +250,6 @@ int main(const int argc, const char** argv) {
if (formula->isProbabilityOperatorFormula()) { if (formula->isProbabilityOperatorFormula()) {
std::set<uint32_t> targetObservationSet;
storm::storage::BitVector targetStates(pomdp->getNumberOfStates());
storm::storage::BitVector badStates(pomdp->getNumberOfStates());
bool validFormula = extractTargetAndSinkObservationSets(pomdp, subformula1, targetObservationSet, targetStates, badStates);
STORM_LOG_THROW(validFormula, storm::exceptions::InvalidPropertyException,
"The formula is not supported by the grid approximation");
STORM_LOG_ASSERT(!targetObservationSet.empty(), "The set of target observations is empty!");
boost::optional<storm::storage::BitVector> prob1States; boost::optional<storm::storage::BitVector> prob1States;
boost::optional<storm::storage::BitVector> prob0States; boost::optional<storm::storage::BitVector> prob0States;
if (pomdpSettings.isSelfloopReductionSet() && !storm::solver::minimize(formula->asProbabilityOperatorFormula().getOptimalityType())) { if (pomdpSettings.isSelfloopReductionSet() && !storm::solver::minimize(formula->asProbabilityOperatorFormula().getOptimalityType())) {
@ -271,8 +273,29 @@ int main(const int argc, const char** argv) {
storm::pomdp::transformer::KnownProbabilityTransformer<double> kpt = storm::pomdp::transformer::KnownProbabilityTransformer<double>(); storm::pomdp::transformer::KnownProbabilityTransformer<double> kpt = storm::pomdp::transformer::KnownProbabilityTransformer<double>();
pomdp = kpt.transform(*pomdp, *prob0States, *prob1States); pomdp = kpt.transform(*pomdp, *prob0States, *prob1States);
} }
if (ioSettings.isExportDotSet()) {
std::shared_ptr<storm::models::sparse::Model<double>> sparseModel = pomdp;
storm::api::exportSparseModelAsDot(sparseModel, ioSettings.getExportDotFilename(), ioSettings.getExportDotMaxWidth());
}
if (ioSettings.isExportExplicitSet()) {
std::shared_ptr<storm::models::sparse::Model<double>> sparseModel = pomdp;
storm::api::exportSparseModelAsDrn(sparseModel, ioSettings.getExportExplicitFilename());
}
if (pomdpSettings.isGridApproximationSet()) { if (pomdpSettings.isGridApproximationSet()) {
std::set<uint32_t> targetObservationSet;
storm::storage::BitVector targetStates(pomdp->getNumberOfStates());
storm::storage::BitVector badStates(pomdp->getNumberOfStates());
bool validFormula = extractTargetAndSinkObservationSets(pomdp, subformula1, targetObservationSet, targetStates, badStates);
STORM_LOG_THROW(validFormula, storm::exceptions::InvalidPropertyException,
"The formula is not supported by the grid approximation");
STORM_LOG_ASSERT(!targetObservationSet.empty(), "The set of target observations is empty!");
storm::pomdp::modelchecker::ApproximatePOMDPModelchecker<double> checker = storm::pomdp::modelchecker::ApproximatePOMDPModelchecker<double>(); storm::pomdp::modelchecker::ApproximatePOMDPModelchecker<double> checker = storm::pomdp::modelchecker::ApproximatePOMDPModelchecker<double>();
double overRes = storm::utility::one<double>(); double overRes = storm::utility::one<double>();
double underRes = storm::utility::zero<double>(); double underRes = storm::utility::zero<double>();
@ -291,10 +314,12 @@ int main(const int argc, const char** argv) {
} }
} }
if (pomdpSettings.isMemlessSearchSet()) { if (pomdpSettings.isMemlessSearchSet()) {
// std::cout << std::endl;
// pomdp->writeDotToStream(std::cout);
// std::cout << std::endl;
// std::cout << std::endl;
storm::analysis::QualitativeAnalysis<double> qualitativeAnalysis(*pomdp);
// After preprocessing, this might be done cheaper.
storm::storage::BitVector targetStates = qualitativeAnalysis.analyseProb1(formula->asProbabilityOperatorFormula());
storm::storage::BitVector surelyNotAlmostSurelyReachTarget = qualitativeAnalysis.analyseProbSmaller1(formula->asProbabilityOperatorFormula());
std::set<uint32_t> targetObservationSet = extractObservations(*pomdp, targetStates);
storm::expressions::ExpressionManager expressionManager; storm::expressions::ExpressionManager expressionManager;
std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::Z3SmtSolverFactory>(); std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::Z3SmtSolverFactory>();
@ -302,9 +327,9 @@ int main(const int argc, const char** argv) {
if (lookahead == 0) { if (lookahead == 0) {
lookahead = pomdp->getNumberOfStates(); lookahead = pomdp->getNumberOfStates();
} }
storm::pomdp::MemlessSearchOptions options; storm::pomdp::MemlessSearchOptions options;
options.onlyDeterministicStrategies = pomdpQualSettings.isOnlyDeterministicSet();
uint64_t loglevel = 0; uint64_t loglevel = 0;
// TODO a big ugly, but we have our own loglevels. // TODO a big ugly, but we have our own loglevels.
if(storm::utility::getLogLevel() == l3pp::LogLevel::INFO) { if(storm::utility::getLogLevel() == l3pp::LogLevel::INFO) {
@ -322,13 +347,23 @@ int main(const int argc, const char** argv) {
options.setExportSATCalls(pomdpQualSettings.getExportSATCallsPath()); options.setExportSATCalls(pomdpQualSettings.getExportSATCallsPath());
} }
if (storm::utility::graph::checkIfECWithChoiceExists(pomdp->getTransitionMatrix(), pomdp->getBackwardTransitions(), ~targetStates & ~surelyNotAlmostSurelyReachTarget, storm::storage::BitVector(pomdp->getNumberOfChoices(), true))) {
options.lookaheadRequired = true;
STORM_LOG_DEBUG("Lookahead required.");
} else {
options.lookaheadRequired = false;
STORM_LOG_DEBUG("No lookahead required.");
}
if (pomdpSettings.getMemlessSearchMethod() == "ccd16memless") { if (pomdpSettings.getMemlessSearchMethod() == "ccd16memless") {
storm::pomdp::QualitativeStrategySearchNaive<double> memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory);
storm::pomdp::QualitativeStrategySearchNaive<double> memlessSearch(*pomdp, targetObservationSet, targetStates, surelyNotAlmostSurelyReachTarget, smtSolverFactory);
memlessSearch.findNewStrategyForSomeState(lookahead); memlessSearch.findNewStrategyForSomeState(lookahead);
} else if (pomdpSettings.getMemlessSearchMethod() == "iterative") { } else if (pomdpSettings.getMemlessSearchMethod() == "iterative") {
storm::pomdp::MemlessStrategySearchQualitative<double> memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory, options);
storm::pomdp::MemlessStrategySearchQualitative<double> memlessSearch(*pomdp, targetObservationSet, targetStates, surelyNotAlmostSurelyReachTarget, smtSolverFactory, options);
memlessSearch.findNewStrategyForSomeState(lookahead); memlessSearch.findNewStrategyForSomeState(lookahead);
memlessSearch.finalizeStatistics();
memlessSearch.getStatistics().print();
} else { } else {
STORM_LOG_ERROR("This method is not implemented."); STORM_LOG_ERROR("This method is not implemented.");
} }

379
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -19,18 +19,30 @@ namespace storm {
STORM_LOG_TRACE(ss.str()); STORM_LOG_TRACE(ss.str());
i = 0; i = 0;
STORM_LOG_TRACE("states from which we continue: "); STORM_LOG_TRACE("states from which we continue: ");
ss.clear();
std::stringstream ss2;
for (auto rv : continuationVars) { for (auto rv : continuationVars) {
if (model->getBooleanValue(rv)) { if (model->getBooleanValue(rv)) {
ss << " " << i;
ss2 << " " << i;
} }
++i; ++i;
} }
STORM_LOG_TRACE(ss.str());
STORM_LOG_TRACE(ss2.str());
} }
} }
template <typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::Statistics::print() const {
STORM_PRINT_AND_LOG("Total time: " << totalTimer << std::endl);
STORM_PRINT_AND_LOG("SAT Calls " << satCalls << std::endl);
STORM_PRINT_AND_LOG("SAT Calls time: " << smtCheckTimer << std::endl);
STORM_PRINT_AND_LOG("Outer iterations: " << outerIterations << std::endl);
STORM_PRINT_AND_LOG("Solver initialization time: " << initializeSolverTimer << std::endl);
STORM_PRINT_AND_LOG("Extend partial scheduler time: " << updateExtensionSolverTime << std::endl);
STORM_PRINT_AND_LOG("Update solver with new scheduler time: " << updateNewStrategySolverTime << std::endl);
STORM_PRINT_AND_LOG("Winning regions update time: " << winningRegionUpdatesTimer << std::endl);
}
template <typename ValueType> template <typename ValueType>
MemlessStrategySearchQualitative<ValueType>::MemlessStrategySearchQualitative(storm::models::sparse::Pomdp<ValueType> const& pomdp, MemlessStrategySearchQualitative<ValueType>::MemlessStrategySearchQualitative(storm::models::sparse::Pomdp<ValueType> const& pomdp,
std::set<uint32_t> const& targetObservationSet, std::set<uint32_t> const& targetObservationSet,
@ -39,9 +51,9 @@ namespace storm {
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory, std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory,
MemlessSearchOptions const& options) : MemlessSearchOptions const& options) :
pomdp(pomdp), pomdp(pomdp),
targetStates(targetStates),
surelyReachSinkStates(surelyReachSinkStates), surelyReachSinkStates(surelyReachSinkStates),
targetObservations(targetObservationSet), targetObservations(targetObservationSet),
targetStates(targetStates),
options(options) options(options)
{ {
this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>(); this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>();
@ -49,6 +61,7 @@ namespace storm {
// Initialize states per observation. // Initialize states per observation.
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
statesPerObservation.push_back(std::vector<uint64_t>()); // Consider using bitvectors instead. statesPerObservation.push_back(std::vector<uint64_t>()); // Consider using bitvectors instead.
reachVarExpressionsPerObservation.push_back(std::vector<storm::expressions::Expression>());
} }
uint64_t state = 0; uint64_t state = 0;
for (auto obs : pomdp.getObservations()) { for (auto obs : pomdp.getObservations()) {
@ -60,10 +73,15 @@ namespace storm {
nrStatesPerObservation.push_back(states.size()); nrStatesPerObservation.push_back(states.size());
} }
winningRegion = WinningRegion(nrStatesPerObservation); winningRegion = WinningRegion(nrStatesPerObservation);
} }
template <typename ValueType> template <typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::initialize(uint64_t k) { void MemlessStrategySearchQualitative<ValueType>::initialize(uint64_t k) {
STORM_LOG_INFO("Start intializing solver...");
// TODO fix this
bool lookaheadConstraintsRequired = options.lookaheadRequired;
STORM_LOG_WARN("We have hardcoded that we do not need lookahead");
if (maxK == std::numeric_limits<uint64_t>::max()) { if (maxK == std::numeric_limits<uint64_t>::max()) {
// not initialized at all. // not initialized at all.
// Create some data structures. // Create some data structures.
@ -76,16 +94,19 @@ namespace storm {
// declare the reachability variables, // declare the reachability variables,
// declare the path variables. // declare the path variables.
for(uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { for(uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) {
pathVars.push_back(std::vector<storm::expressions::Expression>());
for (uint64_t i = 0; i < k; ++i) {
pathVars.back().push_back(expressionManager->declareBooleanVariable("P-"+std::to_string(stateId)+"-"+std::to_string(i)).getExpression());
if(lookaheadConstraintsRequired) {
pathVars.push_back(std::vector<storm::expressions::Expression>());
for (uint64_t i = 0; i < k; ++i) {
pathVars.back().push_back(expressionManager->declareBooleanVariable("P-" + std::to_string(stateId) + "-" + std::to_string(i)).getExpression());
}
} }
reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId))); reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId)));
reachVarExpressions.push_back(reachVars.back().getExpression()); reachVarExpressions.push_back(reachVars.back().getExpression());
reachVarExpressionsPerObservation[pomdp.getObservation(stateId)].push_back(reachVarExpressions.back());
continuationVars.push_back(expressionManager->declareBooleanVariable("D-" + std::to_string(stateId))); continuationVars.push_back(expressionManager->declareBooleanVariable("D-" + std::to_string(stateId)));
continuationVarExpressions.push_back(continuationVars.back().getExpression()); continuationVarExpressions.push_back(continuationVars.back().getExpression());
} }
assert(pathVars.size() == pomdp.getNumberOfStates());
assert(!lookaheadConstraintsRequired || pathVars.size() == pomdp.getNumberOfStates());
assert(reachVars.size() == pomdp.getNumberOfStates()); assert(reachVars.size() == pomdp.getNumberOfStates());
assert(reachVarExpressions.size() == pomdp.getNumberOfStates()); assert(reachVarExpressions.size() == pomdp.getNumberOfStates());
@ -101,7 +122,15 @@ namespace storm {
schedulerVariableExpressions.push_back(schedulerVariables.back()); schedulerVariableExpressions.push_back(schedulerVariables.back());
switchVars.push_back(expressionManager->declareBooleanVariable("S-" + std::to_string(obs))); switchVars.push_back(expressionManager->declareBooleanVariable("S-" + std::to_string(obs)));
switchVarExpressions.push_back(switchVars.back().getExpression()); switchVarExpressions.push_back(switchVars.back().getExpression());
observationUpdatedVariables.push_back(expressionManager->declareBooleanVariable("U-" + std::to_string(obs)));
observationUpdatedExpressions.push_back(observationUpdatedVariables.back().getExpression());
if (options.onlyDeterministicStrategies) {
for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front())-1; ++a) {
for (uint64_t b = a+1; b < pomdp.getNumberOfChoices(statesForObservation.front()); ++b) {
smtSolver->add(!actionSelectionVarExpressions[obs][a] || !actionSelectionVarExpressions[obs][b]);
}
}
}
++obs; ++obs;
} }
@ -109,11 +138,15 @@ namespace storm {
smtSolver->add(storm::expressions::disjunction(actionVars)); smtSolver->add(storm::expressions::disjunction(actionVars));
} }
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (targetStates.get(state)) {
smtSolver->add(pathVars[state][0]);
} else {
smtSolver->add(!pathVars[state][0]);
smtSolver->add(storm::expressions::disjunction(observationUpdatedExpressions));
if (lookaheadConstraintsRequired) {
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (targetStates.get(state)) {
smtSolver->add(pathVars[state][0]);
} else {
smtSolver->add(!pathVars[state][0]);
}
} }
} }
@ -152,42 +185,49 @@ namespace storm {
uint64_t rowindex = 0; uint64_t rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (surelyReachSinkStates.get(state)) { if (surelyReachSinkStates.get(state)) {
smtSolver->add(!reachVarExpressions[state]); smtSolver->add(!reachVarExpressions[state]);
for (uint64_t j = 1; j < k; ++j) {
smtSolver->add(!pathVars[state][j]);
}
smtSolver->add(!continuationVarExpressions[state]); smtSolver->add(!continuationVarExpressions[state]);
} else if(!targetStates.get(state)) {
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>());
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>());
if (lookaheadConstraintsRequired) {
for (uint64_t j = 1; j < k; ++j) {
smtSolver->add(!pathVars[state][j]);
} }
} }
rowindex += pomdp.getNumberOfChoices(state);
} else if(!targetStates.get(state)) {
if (lookaheadConstraintsRequired) {
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back()));
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>());
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>());
}
}
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
std::vector<storm::expressions::Expression> subexprreach;
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]);
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
std::vector<storm::expressions::Expression> subexprreach;
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]);
}
} }
rowindex++;
} }
rowindex++;
}
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back()));
for (uint64_t j = 1; j < k; ++j) {
std::vector<storm::expressions::Expression> pathsubexprs;
for (uint64_t j = 1; j < k; ++j) {
std::vector<storm::expressions::Expression> pathsubexprs;
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action]));
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action]));
}
pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state)));
smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs)));
} }
pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state)));
smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs)));
} }
} else {
rowindex += pomdp.getNumberOfChoices(state);
} }
} }
@ -199,17 +239,47 @@ namespace storm {
++obs; ++obs;
} }
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
smtSolver->add(storm::expressions::implies(switchVarExpressions[obs], storm::expressions::disjunction(reachVarExpressionsPerObservation[obs])));
}
if (!lookaheadConstraintsRequired) {
uint64_t rowIndex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
uint64_t enabledActions = pomdp.getNumberOfChoices(state);
if (!surelyReachSinkStates.get(state)) {
std::vector<storm::expressions::Expression> successorVars;
for (uint64_t act = 0; act < enabledActions; ++act) {
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowIndex)) {
successorVars.push_back(reachVarExpressions[entries.getColumn()]);
}
rowIndex++;
}
successorVars.push_back(!switchVars[pomdp.getObservation(state)]);
smtSolver->add(storm::expressions::implies(storm::expressions::conjunction(successorVars), reachVarExpressions[state]));
} else {
rowIndex += enabledActions;
}
}
} else {
STORM_LOG_WARN("Some optimization not implemented yet.");
}
// TODO: Update found schedulers if k is increased. // TODO: Update found schedulers if k is increased.
} }
template <typename ValueType> template <typename ValueType>
bool MemlessStrategySearchQualitative<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) { bool MemlessStrategySearchQualitative<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) {
stats.initializeSolverTimer.start();
if (k < maxK) { if (k < maxK) {
initialize(k); initialize(k);
maxK = k; maxK = k;
} }
uint64_t maximalNrActions = 8;
STORM_LOG_WARN("We have hardcoded (an upper bound on) the number of actions");
std::vector<storm::expressions::Expression> atLeastOneOfStates; std::vector<storm::expressions::Expression> atLeastOneOfStates;
for (uint64_t state : oneOfTheseStates) { for (uint64_t state : oneOfTheseStates) {
@ -225,112 +295,117 @@ namespace storm {
} }
smtSolver->push(); smtSolver->push();
uint64_t obs = 0;
for(auto const& statesForObservation : statesPerObservation) {
smtSolver->add(schedulerVariableExpressions[obs] <= schedulerForObs.size());
++obs;
}
std::vector<storm::expressions::Expression> updateForObservationExpressions;
for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) { for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) {
updateForObservationExpressions.push_back(storm::expressions::disjunction(reachVarExpressionsPerObservation[ob]));
schedulerForObs.push_back(std::vector<uint64_t>()); schedulerForObs.push_back(std::vector<uint64_t>());
} }
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
auto constant = expressionManager->integer(schedulerForObs[obs].size());
smtSolver->add(schedulerVariableExpressions[obs] <= constant);
smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs]));
}
InternalObservationScheduler scheduler; InternalObservationScheduler scheduler;
scheduler.switchObservations = storm::storage::BitVector(pomdp.getNrObservations()); scheduler.switchObservations = storm::storage::BitVector(pomdp.getNrObservations());
storm::storage::BitVector newObservations(pomdp.getNrObservations());
storm::storage::BitVector newObservationsAfterSwitch(pomdp.getNrObservations());
storm::storage::BitVector observations(pomdp.getNrObservations()); storm::storage::BitVector observations(pomdp.getNrObservations());
storm::storage::BitVector observationsAfterSwitch(pomdp.getNrObservations()); storm::storage::BitVector observationsAfterSwitch(pomdp.getNrObservations());
storm::storage::BitVector remainingstates(pomdp.getNumberOfStates());
storm::storage::BitVector observationUpdated(pomdp.getNrObservations());
storm::storage::BitVector coveredStates(pomdp.getNumberOfStates());
storm::storage::BitVector coveredStatesAfterSwitch(pomdp.getNumberOfStates());
stats.initializeSolverTimer.stop();
STORM_LOG_INFO("Start iterative solver...");
uint64_t iterations = 0; uint64_t iterations = 0;
while(true) { while(true) {
scheduler.clear();
stats.incrementOuterIterations();
scheduler.reset(pomdp.getNrObservations(), maximalNrActions);
observations.clear(); observations.clear();
observationsAfterSwitch.clear(); observationsAfterSwitch.clear();
remainingstates.clear();
coveredStates.clear();
coveredStatesAfterSwitch.clear();
observationUpdated.clear();
bool newSchedulerDiscovered = false;
while (true) { while (true) {
++iterations; ++iterations;
if(options.isExportSATSet()) {
STORM_LOG_DEBUG("Export SMT Solver Call (" <<iterations << ")");
std::string filepath = options.getExportSATCallsPath() + "call_" + std::to_string(iterations) + ".smt2";
std::ofstream filestream;
storm::utility::openFile(filepath, filestream);
filestream << smtSolver->getSmtLibString() << std::endl;
storm::utility::closeFile(filestream);
}
STORM_LOG_DEBUG("Call to SMT Solver (" <<iterations << ")");
auto result = smtSolver->check();
uint64_t i = 0;
if (result == storm::solver::SmtSolver::CheckResult::Unknown) {
STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result");
} else if (result == storm::solver::SmtSolver::CheckResult::Unsat) {
STORM_LOG_DEBUG("Unsatisfiable!");
bool foundScheduler = this->smtCheck(iterations);
if (!foundScheduler) {
break; break;
} }
STORM_LOG_DEBUG("Satisfying assignment: ");
STORM_LOG_TRACE(smtSolver->getModelAsValuation().toString(true));
newSchedulerDiscovered = true;
stats.updateExtensionSolverTime.start();
auto model = smtSolver->getModel(); auto model = smtSolver->getModel();
newObservationsAfterSwitch.clear();
newObservations.clear();
observations.clear();
observationsAfterSwitch.clear();
remainingstates.clear();
scheduler.clear();
uint64_t obs = 0;
for (auto ov : observationUpdatedVariables) {
if (!observationUpdated.get(obs) && model->getBooleanValue(ov)) {
STORM_LOG_TRACE("New observation updated: " << obs);
observationUpdated.set(obs);
}
obs++;
}
uint64_t i = 0;
for (auto rv : reachVars) { for (auto rv : reachVars) {
if (model->getBooleanValue(rv)) {
if (!coveredStates.get(i) && model->getBooleanValue(rv)) {
STORM_LOG_TRACE("New state: " << i);
smtSolver->add(rv.getExpression()); smtSolver->add(rv.getExpression());
observations.set(pomdp.getObservation(i));
} else {
remainingstates.set(i);
newObservations.set(pomdp.getObservation(i));
coveredStates.set(i);
} }
++i; ++i;
} }
i = 0; i = 0;
for (auto rv : continuationVars) { for (auto rv : continuationVars) {
if (model->getBooleanValue(rv)) {
if (!coveredStatesAfterSwitch.get(i) && model->getBooleanValue(rv) ) {
smtSolver->add(rv.getExpression()); smtSolver->add(rv.getExpression());
observationsAfterSwitch.set(pomdp.getObservation(i));
if (!observationsAfterSwitch.get(pomdp.getObservation(i))) {
newObservationsAfterSwitch.set(pomdp.getObservation(i));
}
++i;
} }
++i;
} }
if (options.computeTraceOutput()) { if (options.computeTraceOutput()) {
detail::printRelevantInfoFromModel(model, reachVars, continuationVars); detail::printRelevantInfoFromModel(model, reachVars, continuationVars);
} }
// TODO do not repush everyting to the solver.
std::vector<storm::expressions::Expression> schedulerSoFar;
uint64_t obs = 0;
for (auto const &actionSelectionVarsForObs : actionSelectionVars) {
scheduler.actions.push_back(std::set<uint64_t>());
if (observations.get(obs)) {
for (uint64_t act = 0; act < actionSelectionVarsForObs.size(); ++act) {
auto const& asv = actionSelectionVarsForObs[act];
if (model->getBooleanValue(asv)) {
scheduler.actions.back().insert(act);
schedulerSoFar.push_back(actionSelectionVarExpressions[obs][act]);
}
}
if (model->getBooleanValue(switchVars[obs])) {
scheduler.switchObservations.set(obs);
schedulerSoFar.push_back(switchVarExpressions[obs]);
for (auto obs : newObservations) {
auto const &actionSelectionVarsForObs = actionSelectionVars[obs];
observations.set(obs);
for (uint64_t act = 0; act < actionSelectionVarsForObs.size(); ++act) {
if (model->getBooleanValue(actionSelectionVarsForObs[act])) {
scheduler.actions[obs].set(act);
smtSolver->add(actionSelectionVarExpressions[obs][act]);
} else { } else {
schedulerSoFar.push_back(!switchVarExpressions[obs]);
smtSolver->add(!actionSelectionVarExpressions[obs][act]);
} }
} }
if (observationsAfterSwitch.get(obs)) {
scheduler.schedulerRef.push_back(model->getIntegerValue(schedulerVariables[obs]));
schedulerSoFar.push_back(schedulerVariableExpressions[obs] == expressionManager->integer(scheduler.schedulerRef.back()));
if (model->getBooleanValue(switchVars[obs])) {
scheduler.switchObservations.set(obs);
smtSolver->add(switchVarExpressions[obs]);
} else { } else {
scheduler.schedulerRef.push_back(0);
smtSolver->add(!switchVarExpressions[obs]);
} }
obs++;
}
for (auto obs : newObservationsAfterSwitch) {
observationsAfterSwitch.set(obs);
scheduler.schedulerRef[obs] = model->getIntegerValue(schedulerVariables[obs]);
smtSolver->add(schedulerVariableExpressions[obs] == expressionManager->integer(scheduler.schedulerRef.back()));
} }
if(options.computeTraceOutput()) { if(options.computeTraceOutput()) {
@ -341,56 +416,82 @@ namespace storm {
} }
std::vector<storm::expressions::Expression> remainingExpressions; std::vector<storm::expressions::Expression> remainingExpressions;
for (auto index : remainingstates) {
remainingExpressions.push_back(reachVarExpressions[index]);
for (auto index : ~coveredStates) {
if (observationUpdated.get(pomdp.getObservation(index))) {
remainingExpressions.push_back(reachVarExpressions[index]);
}
}
for (auto index : ~observationUpdated) {
remainingExpressions.push_back(observationUpdatedExpressions[index]);
}
if (remainingExpressions.empty()) {
stats.updateExtensionSolverTime.stop();
break;
} }
// Add scheduler // Add scheduler
smtSolver->add(storm::expressions::conjunction(schedulerSoFar));
//std::cout << storm::expressions::disjunction(remainingExpressions) << std::endl;
smtSolver->add(storm::expressions::disjunction(remainingExpressions)); smtSolver->add(storm::expressions::disjunction(remainingExpressions));
stats.updateExtensionSolverTime.stop();
} }
if (scheduler.empty()) {
if (!newSchedulerDiscovered) {
break; break;
} }
smtSolver->pop(); smtSolver->pop();
if(options.computeDebugOutput()) { if(options.computeDebugOutput()) {
printCoveredStates(remainingstates);
printCoveredStates(~coveredStates);
// generates info output, but here we only want it for debug level. // generates info output, but here we only want it for debug level.
// For consistency, all output on info level. // For consistency, all output on info level.
STORM_LOG_DEBUG("the scheduler: "); STORM_LOG_DEBUG("the scheduler: ");
scheduler.printForObservations(observations,observationsAfterSwitch); scheduler.printForObservations(observations,observationsAfterSwitch);
} }
std::vector<storm::expressions::Expression> remainingExpressions;
for (auto index : remainingstates) {
remainingExpressions.push_back(reachVarExpressions[index]);
}
stats.winningRegionUpdatesTimer.start();
storm::storage::BitVector updated(observations.size());
for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
storm::storage::BitVector update = storm::storage::BitVector(statesPerObservation[observation].size());
STORM_LOG_TRACE("consider observation " << observation);
storm::storage::BitVector update(statesPerObservation[observation].size());
uint64_t i = 0; uint64_t i = 0;
for (uint64_t state : statesPerObservation[observation]) { for (uint64_t state : statesPerObservation[observation]) {
if (!remainingstates.get(state)) {
if (coveredStates.get(state)) {
update.set(i); update.set(i);
} }
++i;
}
if(!update.empty()) {
STORM_LOG_TRACE("Update Winning Region: Observation " << observation << " with update " << update);
bool updateResult = winningRegion.update(observation, update);
STORM_LOG_TRACE("Region changed:" << updateResult);
if (updateResult) {
updated.set(observation);
updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]);
}
} }
winningRegion.update(observation, update);
++i;
} }
STORM_LOG_ASSERT(!updated.empty(), "The strategy should be new in at least one place");
stats.winningRegionUpdatesTimer.stop();
smtSolver->add(storm::expressions::disjunction(remainingExpressions));
if(options.computeDebugOutput()) {
winningRegion.print();
}
stats.updateNewStrategySolverTime.start();
uint64_t obs = 0; uint64_t obs = 0;
for (auto const &statesForObservation : statesPerObservation) { for (auto const &statesForObservation : statesPerObservation) {
if (observations.get(obs)) {
if (observations.get(obs) && updated.get(obs)) {
STORM_LOG_DEBUG("We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << "."); STORM_LOG_DEBUG("We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << ".");
assert(schedulerForObs.size() > obs); assert(schedulerForObs.size() > obs);
schedulerForObs[obs].push_back(finalSchedulers.size()); schedulerForObs[obs].push_back(finalSchedulers.size());
STORM_LOG_DEBUG("We now have " << schedulerForObs[obs].size() << " policies for states with observation " << obs); STORM_LOG_DEBUG("We now have " << schedulerForObs[obs].size() << " policies for states with observation " << obs);
for (auto const &state : statesForObservation) { for (auto const &state : statesForObservation) {
if (remainingstates.get(state)) {
if (!coveredStates.get(state)) {
auto constant = expressionManager->integer(schedulerForObs[obs].size()); auto constant = expressionManager->integer(schedulerForObs[obs].size());
smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant)));
} }
@ -399,14 +500,20 @@ namespace storm {
++obs; ++obs;
} }
finalSchedulers.push_back(scheduler); finalSchedulers.push_back(scheduler);
smtSolver->push(); smtSolver->push();
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
auto constant = expressionManager->integer(schedulerForObs[obs].size()); auto constant = expressionManager->integer(schedulerForObs[obs].size());
smtSolver->add(schedulerVariableExpressions[obs] <= constant); smtSolver->add(schedulerVariableExpressions[obs] <= constant);
smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs]));
} }
stats.updateNewStrategySolverTime.stop();
} }
winningRegion.print();
return true; return true;
} }
@ -423,21 +530,53 @@ namespace storm {
} }
template<typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::printScheduler(std::vector<InternalObservationScheduler> const& ) {
}
template<typename ValueType> template<typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::printScheduler(std::vector<InternalObservationScheduler> const& ) {
void MemlessStrategySearchQualitative<ValueType>::finalizeStatistics() {
} }
template<typename ValueType>
typename MemlessStrategySearchQualitative<ValueType>::Statistics const& MemlessStrategySearchQualitative<ValueType>::getStatistics() const{
return stats;
}
template <typename ValueType> template <typename ValueType>
storm::expressions::Expression const& MemlessStrategySearchQualitative<ValueType>::getDoneActionExpression(uint64_t obs) const {
return actionSelectionVarExpressions[obs].back();
}
bool MemlessStrategySearchQualitative<ValueType>::smtCheck(uint64_t iteration) {
if(options.isExportSATSet()) {
STORM_LOG_DEBUG("Export SMT Solver Call (" <<iteration << ")");
std::string filepath = options.getExportSATCallsPath() + "call_" + std::to_string(iteration) + ".smt2";
std::ofstream filestream;
storm::utility::openFile(filepath, filestream);
filestream << smtSolver->getSmtLibString() << std::endl;
storm::utility::closeFile(filestream);
}
STORM_LOG_DEBUG("Call to SMT Solver (" <<iteration << ")");
stats.smtCheckTimer.start();
auto result = smtSolver->check();
stats.smtCheckTimer.stop();
stats.incrementSmtChecks();
if (result == storm::solver::SmtSolver::CheckResult::Unknown) {
STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result");
} else if (result == storm::solver::SmtSolver::CheckResult::Unsat) {
STORM_LOG_DEBUG("Unsatisfiable!");
return false;
}
STORM_LOG_DEBUG("Satisfying assignment: ");
STORM_LOG_TRACE(smtSolver->getModelAsValuation().toString(true));
return true;
}
template class MemlessStrategySearchQualitative<double>; template class MemlessStrategySearchQualitative<double>;
template class MemlessStrategySearchQualitative<storm::RationalNumber>; template class MemlessStrategySearchQualitative<storm::RationalNumber>;
} }
} }

63
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

@ -4,6 +4,7 @@
#include "storm/solver/SmtSolver.h" #include "storm/solver/SmtSolver.h"
#include "storm/models/sparse/Pomdp.h" #include "storm/models/sparse/Pomdp.h"
#include "storm/utility/solver.h" #include "storm/utility/solver.h"
#include "storm/utility/Stopwatch.h"
#include "storm/exceptions/UnexpectedException.h" #include "storm/exceptions/UnexpectedException.h"
#include "storm-pomdp/analysis/WinningRegion.h" #include "storm-pomdp/analysis/WinningRegion.h"
@ -42,20 +43,24 @@ namespace pomdp {
return debugLevel > 2; return debugLevel > 2;
} }
bool onlyDeterministicStrategies = false;
bool lookaheadRequired = true;
private: private:
std::string exportSATcalls = ""; std::string exportSATcalls = "";
uint64_t debugLevel = 0; uint64_t debugLevel = 0;
}; };
struct InternalObservationScheduler { struct InternalObservationScheduler {
std::vector<std::set<uint64_t>> actions;
std::vector<storm::storage::BitVector> actions;
std::vector<uint64_t> schedulerRef; std::vector<uint64_t> schedulerRef;
storm::storage::BitVector switchObservations; storm::storage::BitVector switchObservations;
void clear() {
actions.clear();
schedulerRef.clear();
void reset(uint64_t nrObservations, uint64_t nrActions) {
actions = std::vector<storm::storage::BitVector>(nrObservations, storm::storage::BitVector(nrActions));
schedulerRef = std::vector<uint64_t>(nrObservations, 0);
switchObservations.clear(); switchObservations.clear();
} }
@ -65,8 +70,10 @@ namespace pomdp {
void printForObservations(storm::storage::BitVector const& observations, storm::storage::BitVector const& observationsAfterSwitch) const { void printForObservations(storm::storage::BitVector const& observations, storm::storage::BitVector const& observationsAfterSwitch) const {
for (uint64_t obs = 0; obs < observations.size(); ++obs) { for (uint64_t obs = 0; obs < observations.size(); ++obs) {
if (observations.get(obs)) {
if (observations.get(obs) || observationsAfterSwitch.get(obs)) {
STORM_LOG_INFO("For observation: " << obs); STORM_LOG_INFO("For observation: " << obs);
}
if (observations.get(obs)) {
std::stringstream ss; std::stringstream ss;
ss << "actions:"; ss << "actions:";
for (auto act : actions[obs]) { for (auto act : actions[obs]) {
@ -90,6 +97,31 @@ namespace pomdp {
// Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper. // Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper.
public: public:
class Statistics {
public:
Statistics() = default;
void print() const;
storm::utility::Stopwatch totalTimer;
storm::utility::Stopwatch smtCheckTimer;
storm::utility::Stopwatch initializeSolverTimer;
storm::utility::Stopwatch updateExtensionSolverTime;
storm::utility::Stopwatch updateNewStrategySolverTime;
storm::utility::Stopwatch winningRegionUpdatesTimer;
void incrementOuterIterations() {
outerIterations++;
}
void incrementSmtChecks() {
satCalls++;
}
private:
uint64_t satCalls = 0;
uint64_t outerIterations = 0;
};
MemlessStrategySearchQualitative(storm::models::sparse::Pomdp<ValueType> const& pomdp, MemlessStrategySearchQualitative(storm::models::sparse::Pomdp<ValueType> const& pomdp,
std::set<uint32_t> const& targetObservationSet, std::set<uint32_t> const& targetObservationSet,
storm::storage::BitVector const& targetStates, storm::storage::BitVector const& targetStates,
@ -98,19 +130,24 @@ namespace pomdp {
MemlessSearchOptions const& options); MemlessSearchOptions const& options);
void analyzeForInitialStates(uint64_t k) { void analyzeForInitialStates(uint64_t k) {
stats.totalTimer.start();
analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates()); analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates());
stats.totalTimer.stop();
} }
void findNewStrategyForSomeState(uint64_t k) { void findNewStrategyForSomeState(uint64_t k) {
std::cout << surelyReachSinkStates << std::endl; std::cout << surelyReachSinkStates << std::endl;
std::cout << targetStates << std::endl; std::cout << targetStates << std::endl;
std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl; std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl;
stats.totalTimer.start();
analyze(k, ~surelyReachSinkStates & ~targetStates); analyze(k, ~surelyReachSinkStates & ~targetStates);
stats.totalTimer.stop();
} }
bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector()); bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector());
Statistics const& getStatistics() const;
void finalizeStatistics();
private: private:
storm::expressions::Expression const& getDoneActionExpression(uint64_t obs) const; storm::expressions::Expression const& getDoneActionExpression(uint64_t obs) const;
@ -119,24 +156,30 @@ namespace pomdp {
void initialize(uint64_t k); void initialize(uint64_t k);
bool smtCheck(uint64_t iteration);
std::unique_ptr<storm::solver::SmtSolver> smtSolver; std::unique_ptr<storm::solver::SmtSolver> smtSolver;
storm::models::sparse::Pomdp<ValueType> const& pomdp; storm::models::sparse::Pomdp<ValueType> const& pomdp;
std::shared_ptr<storm::expressions::ExpressionManager> expressionManager; std::shared_ptr<storm::expressions::ExpressionManager> expressionManager;
uint64_t maxK = std::numeric_limits<uint64_t>::max(); uint64_t maxK = std::numeric_limits<uint64_t>::max();
storm::storage::BitVector surelyReachSinkStates;
std::set<uint32_t> targetObservations; std::set<uint32_t> targetObservations;
storm::storage::BitVector targetStates; storm::storage::BitVector targetStates;
storm::storage::BitVector surelyReachSinkStates;
std::vector<std::vector<uint64_t>> statesPerObservation;
std::vector<storm::expressions::Variable> schedulerVariables; std::vector<storm::expressions::Variable> schedulerVariables;
std::vector<storm::expressions::Expression> schedulerVariableExpressions; std::vector<storm::expressions::Expression> schedulerVariableExpressions;
std::vector<std::vector<uint64_t>> statesPerObservation;
std::vector<std::vector<storm::expressions::Expression>> actionSelectionVarExpressions; // A_{z,a} std::vector<std::vector<storm::expressions::Expression>> actionSelectionVarExpressions; // A_{z,a}
std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars;
std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars; // A_{z,a}
std::vector<storm::expressions::Variable> reachVars; std::vector<storm::expressions::Variable> reachVars;
std::vector<storm::expressions::Expression> reachVarExpressions; std::vector<storm::expressions::Expression> reachVarExpressions;
std::vector<std::vector<storm::expressions::Expression>> reachVarExpressionsPerObservation;
std::vector<storm::expressions::Variable> observationUpdatedVariables;
std::vector<storm::expressions::Expression> observationUpdatedExpressions;
std::vector<storm::expressions::Variable> switchVars; std::vector<storm::expressions::Variable> switchVars;
std::vector<storm::expressions::Expression> switchVarExpressions; std::vector<storm::expressions::Expression> switchVarExpressions;
@ -149,7 +192,7 @@ namespace pomdp {
WinningRegion winningRegion; WinningRegion winningRegion;
MemlessSearchOptions options; MemlessSearchOptions options;
Statistics stats;
}; };

17
src/storm-pomdp/analysis/QualitativeAnalysis.cpp

@ -26,6 +26,23 @@ namespace storm {
storm::storage::BitVector QualitativeAnalysis<ValueType>::analyseProb1(storm::logic::ProbabilityOperatorFormula const& formula) const { storm::storage::BitVector QualitativeAnalysis<ValueType>::analyseProb1(storm::logic::ProbabilityOperatorFormula const& formula) const {
return analyseProb0or1(formula, false); return analyseProb0or1(formula, false);
} }
template<typename ValueType>
storm::storage::BitVector QualitativeAnalysis<ValueType>::analyseProbSmaller1(storm::logic::ProbabilityOperatorFormula const &formula) const {
STORM_LOG_THROW(formula.hasOptimalityType() || formula.hasBound(), storm::exceptions::InvalidPropertyException, "The formula " << formula << " does not specify whether to minimize or maximize.");
bool minimizes = (formula.hasOptimalityType() && storm::solver::minimize(formula.getOptimalityType())) || (formula.hasBound() && storm::logic::isLowerBound(formula.getBound().comparisonType));
STORM_LOG_THROW(!minimizes,storm::exceptions::NotImplementedException, "This operation is only supported when maximizing");
std::shared_ptr<storm::logic::Formula const> subformula = formula.getSubformula().asSharedPointer();
std::shared_ptr<storm::logic::UntilFormula> untilSubformula;
// If necessary, convert the subformula to a more general case
if (subformula->isEventuallyFormula()) {
untilSubformula = std::make_shared<storm::logic::UntilFormula>(storm::logic::Formula::getTrueFormula(), subformula->asEventuallyFormula().getSubformula().asSharedPointer());
} else if(subformula->isUntilFormula()) {
untilSubformula = std::make_shared<storm::logic::UntilFormula>(subformula->asUntilFormula());
}
// The vector is sound, but not necessarily complete!
return ~storm::utility::graph::performProb1E(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), checkPropositionalFormula(untilSubformula->getLeftSubformula()), checkPropositionalFormula(untilSubformula->getRightSubformula()));
}
template<typename ValueType> template<typename ValueType>
storm::storage::BitVector QualitativeAnalysis<ValueType>::analyseProb0or1(storm::logic::ProbabilityOperatorFormula const& formula, bool prob0) const { storm::storage::BitVector QualitativeAnalysis<ValueType>::analyseProb0or1(storm::logic::ProbabilityOperatorFormula const& formula, bool prob0) const {

2
src/storm-pomdp/analysis/QualitativeAnalysis.h

@ -10,7 +10,7 @@ namespace storm {
QualitativeAnalysis(storm::models::sparse::Pomdp<ValueType> const& pomdp); QualitativeAnalysis(storm::models::sparse::Pomdp<ValueType> const& pomdp);
storm::storage::BitVector analyseProb0(storm::logic::ProbabilityOperatorFormula const& formula) const; storm::storage::BitVector analyseProb0(storm::logic::ProbabilityOperatorFormula const& formula) const;
storm::storage::BitVector analyseProb1(storm::logic::ProbabilityOperatorFormula const& formula) const; storm::storage::BitVector analyseProb1(storm::logic::ProbabilityOperatorFormula const& formula) const;
storm::storage::BitVector analyseProbSmaller1(storm::logic::ProbabilityOperatorFormula const& formula) const;
private: private:
storm::storage::BitVector analyseProb0or1(storm::logic::ProbabilityOperatorFormula const& formula, bool prob0) const; storm::storage::BitVector analyseProb0or1(storm::logic::ProbabilityOperatorFormula const& formula, bool prob0) const;
storm::storage::BitVector analyseProb0Max(storm::logic::UntilFormula const& formula) const; storm::storage::BitVector analyseProb0Max(storm::logic::UntilFormula const& formula) const;

36
src/storm-pomdp/analysis/WinningRegion.cpp

@ -1,4 +1,6 @@
#include <iostream> #include <iostream>
#include "storm/storage/expressions/Expression.h"
#include "storm/storage/expressions/ExpressionManager.h"
#include "storm-pomdp/analysis/WinningRegion.h" #include "storm-pomdp/analysis/WinningRegion.h"
namespace storm { namespace storm {
@ -10,13 +12,13 @@ namespace pomdp {
} }
} }
void WinningRegion::update(uint64_t observation, storm::storage::BitVector const& winning) {
bool WinningRegion::update(uint64_t observation, storm::storage::BitVector const& winning) {
std::vector<storm::storage::BitVector> newWinningSupport = std::vector<storm::storage::BitVector>(); std::vector<storm::storage::BitVector> newWinningSupport = std::vector<storm::storage::BitVector>();
bool changed = false; bool changed = false;
for (auto const& support : winningRegion[observation]) { for (auto const& support : winningRegion[observation]) {
if (winning.isSubsetOf(support)) { if (winning.isSubsetOf(support)) {
// This new winning support is already covered. // This new winning support is already covered.
return;
return false;
} }
if(support.isSubsetOf(winning)) { if(support.isSubsetOf(winning)) {
// This new winning support extends the previous support, thus the previous support is now spurious // This new winning support extends the previous support, thus the previous support is now spurious
@ -33,6 +35,7 @@ namespace pomdp {
} else { } else {
winningRegion[observation].push_back(winning); winningRegion[observation].push_back(winning);
} }
return true;
} }
@ -45,6 +48,34 @@ namespace pomdp {
return false; return false;
} }
storm::expressions::Expression WinningRegion::extensionExpression(uint64_t observation, std::vector<storm::expressions::Expression>& varsForStates) const {
std::vector<storm::expressions::Expression> expressionForEntry;
for(auto const& winningForObservation : winningRegion[observation]) {
if (winningForObservation.full()) {
assert(winningRegion[observation].size() == 1);
return varsForStates.front().getManager().boolean(false);
}
std::vector<storm::expressions::Expression> subexpr;
std::vector<storm::expressions::Expression> leftHandSides;
assert(varsForStates.size() == winningForObservation.size());
for(uint64_t i = 0; i < varsForStates.size(); ++i) {
if (winningForObservation.get(i)) {
leftHandSides.push_back(varsForStates[i]);
} else {
subexpr.push_back(varsForStates[i]);
}
}
storm::expressions::Expression rightHandSide = storm::expressions::disjunction(subexpr);
for(auto const& lhs : leftHandSides) {
expressionForEntry.push_back(storm::expressions::implies(lhs,rightHandSide));
}
expressionForEntry.push_back(storm::expressions::disjunction(varsForStates));
}
return storm::expressions::conjunction(expressionForEntry);
}
/** /**
* If we observe this observation, do we surely win? * If we observe this observation, do we surely win?
* @param observation * @param observation
@ -62,6 +93,7 @@ namespace pomdp {
std::cout << " " << support; std::cout << " " << support;
} }
std::cout << std::endl; std::cout << std::endl;
observation++;
} }
} }

6
src/storm-pomdp/analysis/WinningRegion.h

@ -9,12 +9,14 @@ namespace storm {
public: public:
WinningRegion(std::vector<uint64_t> const& observationSizes = {}); WinningRegion(std::vector<uint64_t> const& observationSizes = {});
void update(uint64_t observation, storm::storage::BitVector const& winning);
bool update(uint64_t observation, storm::storage::BitVector const& winning);
bool query(uint64_t observation, storm::storage::BitVector const& currently) const; bool query(uint64_t observation, storm::storage::BitVector const& currently) const;
bool observationIsWinning(uint64_t observation) const; bool observationIsWinning(uint64_t observation) const;
storm::expressions::Expression extensionExpression(uint64_t observation, std::vector<storm::expressions::Expression>& varsForStates) const;
uint64_t getStorageSize() const;
uint64_t getStorageSize() const;
uint64_t getNumberOfObservations() const; uint64_t getNumberOfObservations() const;
void print() const; void print() const;
private: private:

Loading…
Cancel
Save