Browse Source

various improvements and fixes in winning region computation

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
858e2f8a60
  1. 5
      src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.cpp
  2. 1
      src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.h
  3. 71
      src/storm-pomdp-cli/storm-pomdp.cpp
  4. 311
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  5. 63
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h
  6. 17
      src/storm-pomdp/analysis/QualitativeAnalysis.cpp
  7. 2
      src/storm-pomdp/analysis/QualitativeAnalysis.h
  8. 36
      src/storm-pomdp/analysis/WinningRegion.cpp
  9. 4
      src/storm-pomdp/analysis/WinningRegion.h

5
src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.cpp

@ -15,11 +15,13 @@ namespace storm {
const std::string QualitativePOMDPAnalysisSettings::moduleName = "pomdpQualitative";
const std::string exportSATCallsOption = "exportSATCallsPath";
const std::string lookaheadHorizonOption = "lookaheadHorizon";
const std::string onlyDeterministicOption = "onlyDeterministic";
QualitativePOMDPAnalysisSettings::QualitativePOMDPAnalysisSettings() : ModuleSettings(moduleName) {
this->addOption(storm::settings::OptionBuilder(moduleName, exportSATCallsOption, false, "Export the SAT calls?.").addArgument(storm::settings::ArgumentBuilder::createStringArgument("path", "The name of the file to which to write the model.").build()).build());
this->addOption(storm::settings::OptionBuilder(moduleName, lookaheadHorizonOption, false, "In reachability in combination with a discrete ranking function, a lookahead is necessary.").addArgument(storm::settings::ArgumentBuilder::createUnsignedIntegerArgument("bound", "The lookahead. Use 0 for the number of states.").setDefaultValueUnsignedInteger(0).build()).build());
this->addOption(storm::settings::OptionBuilder(moduleName, onlyDeterministicOption, false, "Search only for deterministic schedulers").build());
}
uint64_t QualitativePOMDPAnalysisSettings::getLookahead() const {
@ -33,6 +35,9 @@ namespace storm {
return this->getOption(exportSATCallsOption).getArgumentByName("path").getValueAsString();
}
bool QualitativePOMDPAnalysisSettings::isOnlyDeterministicSet() const {
return this->getOption(onlyDeterministicOption).getHasOptionBeenSet();
}
void QualitativePOMDPAnalysisSettings::finalize() {

1
src/storm-pomdp-cli/settings/modules/QualitativePOMDPAnalysisSettings.h

@ -23,6 +23,7 @@ namespace storm {
uint64_t getLookahead() const;
bool isExportSATCallsSet() const;
std::string getExportSATCallsPath() const;
bool isOnlyDeterministicSet() const;
virtual ~QualitativePOMDPAnalysisSettings() = default;

71
src/storm-pomdp-cli/storm-pomdp.cpp

@ -174,6 +174,15 @@ bool extractTargetAndSinkObservationSets(std::shared_ptr<storm::models::sparse::
return validFormula;
}
template<typename ValueType>
std::set<uint32_t> extractObservations(storm::models::sparse::Pomdp<ValueType> const& pomdp, storm::storage::BitVector const& states) {
std::set<uint32_t> observations;
for(auto state : states) {
observations.insert(pomdp.getObservation(state));
}
return observations;
}
/*!
* Entry point for the pomdp backend.
*
@ -195,6 +204,7 @@ int main(const int argc, const char** argv) {
auto const& coreSettings = storm::settings::getModule<storm::settings::modules::CoreSettings>();
auto const& pomdpSettings = storm::settings::getModule<storm::settings::modules::POMDPSettings>();
auto const& ioSettings = storm::settings::getModule<storm::settings::modules::IOSettings>();
auto const &general = storm::settings::getModule<storm::settings::modules::GeneralSettings>();
auto const &debug = storm::settings::getModule<storm::settings::modules::DebugSettings>();
auto const& pomdpQualSettings = storm::settings::getModule<storm::settings::modules::QualitativePOMDPAnalysisSettings>();
@ -208,6 +218,9 @@ int main(const int argc, const char** argv) {
if (debug.isTraceSet()) {
storm::utility::setLogLevel(l3pp::LogLevel::TRACE);
}
if (debug.isLogfileSet()) {
storm::utility::initializeFileLogging();
}
// For several engines, no model building step is performed, but the verification is started right away.
storm::settings::modules::CoreSettings::Engine engine = coreSettings.getEngine();
@ -237,17 +250,6 @@ int main(const int argc, const char** argv) {
if (formula->isProbabilityOperatorFormula()) {
std::set<uint32_t> targetObservationSet;
storm::storage::BitVector targetStates(pomdp->getNumberOfStates());
storm::storage::BitVector badStates(pomdp->getNumberOfStates());
bool validFormula = extractTargetAndSinkObservationSets(pomdp, subformula1, targetObservationSet, targetStates, badStates);
STORM_LOG_THROW(validFormula, storm::exceptions::InvalidPropertyException,
"The formula is not supported by the grid approximation");
STORM_LOG_ASSERT(!targetObservationSet.empty(), "The set of target observations is empty!");
boost::optional<storm::storage::BitVector> prob1States;
boost::optional<storm::storage::BitVector> prob0States;
if (pomdpSettings.isSelfloopReductionSet() && !storm::solver::minimize(formula->asProbabilityOperatorFormula().getOptimalityType())) {
@ -271,8 +273,29 @@ int main(const int argc, const char** argv) {
storm::pomdp::transformer::KnownProbabilityTransformer<double> kpt = storm::pomdp::transformer::KnownProbabilityTransformer<double>();
pomdp = kpt.transform(*pomdp, *prob0States, *prob1States);
}
if (ioSettings.isExportDotSet()) {
std::shared_ptr<storm::models::sparse::Model<double>> sparseModel = pomdp;
storm::api::exportSparseModelAsDot(sparseModel, ioSettings.getExportDotFilename(), ioSettings.getExportDotMaxWidth());
}
if (ioSettings.isExportExplicitSet()) {
std::shared_ptr<storm::models::sparse::Model<double>> sparseModel = pomdp;
storm::api::exportSparseModelAsDrn(sparseModel, ioSettings.getExportExplicitFilename());
}
if (pomdpSettings.isGridApproximationSet()) {
std::set<uint32_t> targetObservationSet;
storm::storage::BitVector targetStates(pomdp->getNumberOfStates());
storm::storage::BitVector badStates(pomdp->getNumberOfStates());
bool validFormula = extractTargetAndSinkObservationSets(pomdp, subformula1, targetObservationSet, targetStates, badStates);
STORM_LOG_THROW(validFormula, storm::exceptions::InvalidPropertyException,
"The formula is not supported by the grid approximation");
STORM_LOG_ASSERT(!targetObservationSet.empty(), "The set of target observations is empty!");
storm::pomdp::modelchecker::ApproximatePOMDPModelchecker<double> checker = storm::pomdp::modelchecker::ApproximatePOMDPModelchecker<double>();
double overRes = storm::utility::one<double>();
double underRes = storm::utility::zero<double>();
@ -291,10 +314,12 @@ int main(const int argc, const char** argv) {
}
}
if (pomdpSettings.isMemlessSearchSet()) {
// std::cout << std::endl;
// pomdp->writeDotToStream(std::cout);
// std::cout << std::endl;
// std::cout << std::endl;
storm::analysis::QualitativeAnalysis<double> qualitativeAnalysis(*pomdp);
// After preprocessing, this might be done cheaper.
storm::storage::BitVector targetStates = qualitativeAnalysis.analyseProb1(formula->asProbabilityOperatorFormula());
storm::storage::BitVector surelyNotAlmostSurelyReachTarget = qualitativeAnalysis.analyseProbSmaller1(formula->asProbabilityOperatorFormula());
std::set<uint32_t> targetObservationSet = extractObservations(*pomdp, targetStates);
storm::expressions::ExpressionManager expressionManager;
std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::Z3SmtSolverFactory>();
@ -302,9 +327,9 @@ int main(const int argc, const char** argv) {
if (lookahead == 0) {
lookahead = pomdp->getNumberOfStates();
}
storm::pomdp::MemlessSearchOptions options;
options.onlyDeterministicStrategies = pomdpQualSettings.isOnlyDeterministicSet();
uint64_t loglevel = 0;
// TODO a big ugly, but we have our own loglevels.
if(storm::utility::getLogLevel() == l3pp::LogLevel::INFO) {
@ -322,13 +347,23 @@ int main(const int argc, const char** argv) {
options.setExportSATCalls(pomdpQualSettings.getExportSATCallsPath());
}
if (storm::utility::graph::checkIfECWithChoiceExists(pomdp->getTransitionMatrix(), pomdp->getBackwardTransitions(), ~targetStates & ~surelyNotAlmostSurelyReachTarget, storm::storage::BitVector(pomdp->getNumberOfChoices(), true))) {
options.lookaheadRequired = true;
STORM_LOG_DEBUG("Lookahead required.");
} else {
options.lookaheadRequired = false;
STORM_LOG_DEBUG("No lookahead required.");
}
if (pomdpSettings.getMemlessSearchMethod() == "ccd16memless") {
storm::pomdp::QualitativeStrategySearchNaive<double> memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory);
storm::pomdp::QualitativeStrategySearchNaive<double> memlessSearch(*pomdp, targetObservationSet, targetStates, surelyNotAlmostSurelyReachTarget, smtSolverFactory);
memlessSearch.findNewStrategyForSomeState(lookahead);
} else if (pomdpSettings.getMemlessSearchMethod() == "iterative") {
storm::pomdp::MemlessStrategySearchQualitative<double> memlessSearch(*pomdp, targetObservationSet, targetStates, badStates, smtSolverFactory, options);
storm::pomdp::MemlessStrategySearchQualitative<double> memlessSearch(*pomdp, targetObservationSet, targetStates, surelyNotAlmostSurelyReachTarget, smtSolverFactory, options);
memlessSearch.findNewStrategyForSomeState(lookahead);
memlessSearch.finalizeStatistics();
memlessSearch.getStatistics().print();
} else {
STORM_LOG_ERROR("This method is not implemented.");
}

311
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -19,18 +19,30 @@ namespace storm {
STORM_LOG_TRACE(ss.str());
i = 0;
STORM_LOG_TRACE("states from which we continue: ");
ss.clear();
std::stringstream ss2;
for (auto rv : continuationVars) {
if (model->getBooleanValue(rv)) {
ss << " " << i;
ss2 << " " << i;
}
++i;
}
STORM_LOG_TRACE(ss.str());
STORM_LOG_TRACE(ss2.str());
}
}
template <typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::Statistics::print() const {
STORM_PRINT_AND_LOG("Total time: " << totalTimer << std::endl);
STORM_PRINT_AND_LOG("SAT Calls " << satCalls << std::endl);
STORM_PRINT_AND_LOG("SAT Calls time: " << smtCheckTimer << std::endl);
STORM_PRINT_AND_LOG("Outer iterations: " << outerIterations << std::endl);
STORM_PRINT_AND_LOG("Solver initialization time: " << initializeSolverTimer << std::endl);
STORM_PRINT_AND_LOG("Extend partial scheduler time: " << updateExtensionSolverTime << std::endl);
STORM_PRINT_AND_LOG("Update solver with new scheduler time: " << updateNewStrategySolverTime << std::endl);
STORM_PRINT_AND_LOG("Winning regions update time: " << winningRegionUpdatesTimer << std::endl);
}
template <typename ValueType>
MemlessStrategySearchQualitative<ValueType>::MemlessStrategySearchQualitative(storm::models::sparse::Pomdp<ValueType> const& pomdp,
std::set<uint32_t> const& targetObservationSet,
@ -39,9 +51,9 @@ namespace storm {
std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory,
MemlessSearchOptions const& options) :
pomdp(pomdp),
targetStates(targetStates),
surelyReachSinkStates(surelyReachSinkStates),
targetObservations(targetObservationSet),
targetStates(targetStates),
options(options)
{
this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>();
@ -49,6 +61,7 @@ namespace storm {
// Initialize states per observation.
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
statesPerObservation.push_back(std::vector<uint64_t>()); // Consider using bitvectors instead.
reachVarExpressionsPerObservation.push_back(std::vector<storm::expressions::Expression>());
}
uint64_t state = 0;
for (auto obs : pomdp.getObservations()) {
@ -60,10 +73,15 @@ namespace storm {
nrStatesPerObservation.push_back(states.size());
}
winningRegion = WinningRegion(nrStatesPerObservation);
}
template <typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::initialize(uint64_t k) {
STORM_LOG_INFO("Start intializing solver...");
// TODO fix this
bool lookaheadConstraintsRequired = options.lookaheadRequired;
STORM_LOG_WARN("We have hardcoded that we do not need lookahead");
if (maxK == std::numeric_limits<uint64_t>::max()) {
// not initialized at all.
// Create some data structures.
@ -76,16 +94,19 @@ namespace storm {
// declare the reachability variables,
// declare the path variables.
for(uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) {
if(lookaheadConstraintsRequired) {
pathVars.push_back(std::vector<storm::expressions::Expression>());
for (uint64_t i = 0; i < k; ++i) {
pathVars.back().push_back(expressionManager->declareBooleanVariable("P-"+std::to_string(stateId)+"-"+std::to_string(i)).getExpression());
pathVars.back().push_back(expressionManager->declareBooleanVariable("P-" + std::to_string(stateId) + "-" + std::to_string(i)).getExpression());
}
}
reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId)));
reachVarExpressions.push_back(reachVars.back().getExpression());
reachVarExpressionsPerObservation[pomdp.getObservation(stateId)].push_back(reachVarExpressions.back());
continuationVars.push_back(expressionManager->declareBooleanVariable("D-" + std::to_string(stateId)));
continuationVarExpressions.push_back(continuationVars.back().getExpression());
}
assert(pathVars.size() == pomdp.getNumberOfStates());
assert(!lookaheadConstraintsRequired || pathVars.size() == pomdp.getNumberOfStates());
assert(reachVars.size() == pomdp.getNumberOfStates());
assert(reachVarExpressions.size() == pomdp.getNumberOfStates());
@ -101,7 +122,15 @@ namespace storm {
schedulerVariableExpressions.push_back(schedulerVariables.back());
switchVars.push_back(expressionManager->declareBooleanVariable("S-" + std::to_string(obs)));
switchVarExpressions.push_back(switchVars.back().getExpression());
observationUpdatedVariables.push_back(expressionManager->declareBooleanVariable("U-" + std::to_string(obs)));
observationUpdatedExpressions.push_back(observationUpdatedVariables.back().getExpression());
if (options.onlyDeterministicStrategies) {
for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front())-1; ++a) {
for (uint64_t b = a+1; b < pomdp.getNumberOfChoices(statesForObservation.front()); ++b) {
smtSolver->add(!actionSelectionVarExpressions[obs][a] || !actionSelectionVarExpressions[obs][b]);
}
}
}
++obs;
}
@ -109,6 +138,9 @@ namespace storm {
smtSolver->add(storm::expressions::disjunction(actionVars));
}
smtSolver->add(storm::expressions::disjunction(observationUpdatedExpressions));
if (lookaheadConstraintsRequired) {
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (targetStates.get(state)) {
smtSolver->add(pathVars[state][0]);
@ -116,6 +148,7 @@ namespace storm {
smtSolver->add(!pathVars[state][0]);
}
}
}
uint64_t rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
@ -152,14 +185,19 @@ namespace storm {
uint64_t rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (surelyReachSinkStates.get(state)) {
smtSolver->add(!reachVarExpressions[state]);
smtSolver->add(!continuationVarExpressions[state]);
if (lookaheadConstraintsRequired) {
for (uint64_t j = 1; j < k; ++j) {
smtSolver->add(!pathVars[state][j]);
}
smtSolver->add(!continuationVarExpressions[state]);
}
rowindex += pomdp.getNumberOfChoices(state);
} else if(!targetStates.get(state)) {
if (lookaheadConstraintsRequired) {
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back()));
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>());
@ -177,7 +215,6 @@ namespace storm {
}
rowindex++;
}
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back()));
for (uint64_t j = 1; j < k; ++j) {
std::vector<storm::expressions::Expression> pathsubexprs;
@ -189,6 +226,9 @@ namespace storm {
smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs)));
}
}
} else {
rowindex += pomdp.getNumberOfChoices(state);
}
}
uint64_t obs = 0;
@ -199,17 +239,47 @@ namespace storm {
++obs;
}
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
smtSolver->add(storm::expressions::implies(switchVarExpressions[obs], storm::expressions::disjunction(reachVarExpressionsPerObservation[obs])));
}
if (!lookaheadConstraintsRequired) {
uint64_t rowIndex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
uint64_t enabledActions = pomdp.getNumberOfChoices(state);
if (!surelyReachSinkStates.get(state)) {
std::vector<storm::expressions::Expression> successorVars;
for (uint64_t act = 0; act < enabledActions; ++act) {
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowIndex)) {
successorVars.push_back(reachVarExpressions[entries.getColumn()]);
}
rowIndex++;
}
successorVars.push_back(!switchVars[pomdp.getObservation(state)]);
smtSolver->add(storm::expressions::implies(storm::expressions::conjunction(successorVars), reachVarExpressions[state]));
} else {
rowIndex += enabledActions;
}
}
} else {
STORM_LOG_WARN("Some optimization not implemented yet.");
}
// TODO: Update found schedulers if k is increased.
}
template <typename ValueType>
bool MemlessStrategySearchQualitative<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) {
stats.initializeSolverTimer.start();
if (k < maxK) {
initialize(k);
maxK = k;
}
uint64_t maximalNrActions = 8;
STORM_LOG_WARN("We have hardcoded (an upper bound on) the number of actions");
std::vector<storm::expressions::Expression> atLeastOneOfStates;
for (uint64_t state : oneOfTheseStates) {
@ -225,112 +295,117 @@ namespace storm {
}
smtSolver->push();
uint64_t obs = 0;
for(auto const& statesForObservation : statesPerObservation) {
smtSolver->add(schedulerVariableExpressions[obs] <= schedulerForObs.size());
++obs;
}
std::vector<storm::expressions::Expression> updateForObservationExpressions;
for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) {
updateForObservationExpressions.push_back(storm::expressions::disjunction(reachVarExpressionsPerObservation[ob]));
schedulerForObs.push_back(std::vector<uint64_t>());
}
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
auto constant = expressionManager->integer(schedulerForObs[obs].size());
smtSolver->add(schedulerVariableExpressions[obs] <= constant);
smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs]));
}
InternalObservationScheduler scheduler;
scheduler.switchObservations = storm::storage::BitVector(pomdp.getNrObservations());
storm::storage::BitVector newObservations(pomdp.getNrObservations());
storm::storage::BitVector newObservationsAfterSwitch(pomdp.getNrObservations());
storm::storage::BitVector observations(pomdp.getNrObservations());
storm::storage::BitVector observationsAfterSwitch(pomdp.getNrObservations());
storm::storage::BitVector remainingstates(pomdp.getNumberOfStates());
storm::storage::BitVector observationUpdated(pomdp.getNrObservations());
storm::storage::BitVector coveredStates(pomdp.getNumberOfStates());
storm::storage::BitVector coveredStatesAfterSwitch(pomdp.getNumberOfStates());
stats.initializeSolverTimer.stop();
STORM_LOG_INFO("Start iterative solver...");
uint64_t iterations = 0;
while(true) {
scheduler.clear();
stats.incrementOuterIterations();
scheduler.reset(pomdp.getNrObservations(), maximalNrActions);
observations.clear();
observationsAfterSwitch.clear();
remainingstates.clear();
coveredStates.clear();
coveredStatesAfterSwitch.clear();
observationUpdated.clear();
bool newSchedulerDiscovered = false;
while (true) {
++iterations;
if(options.isExportSATSet()) {
STORM_LOG_DEBUG("Export SMT Solver Call (" <<iterations << ")");
std::string filepath = options.getExportSATCallsPath() + "call_" + std::to_string(iterations) + ".smt2";
std::ofstream filestream;
storm::utility::openFile(filepath, filestream);
filestream << smtSolver->getSmtLibString() << std::endl;
storm::utility::closeFile(filestream);
}
STORM_LOG_DEBUG("Call to SMT Solver (" <<iterations << ")");
auto result = smtSolver->check();
uint64_t i = 0;
if (result == storm::solver::SmtSolver::CheckResult::Unknown) {
STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result");
} else if (result == storm::solver::SmtSolver::CheckResult::Unsat) {
STORM_LOG_DEBUG("Unsatisfiable!");
bool foundScheduler = this->smtCheck(iterations);
if (!foundScheduler) {
break;
}
STORM_LOG_DEBUG("Satisfying assignment: ");
STORM_LOG_TRACE(smtSolver->getModelAsValuation().toString(true));
newSchedulerDiscovered = true;
stats.updateExtensionSolverTime.start();
auto model = smtSolver->getModel();
newObservationsAfterSwitch.clear();
newObservations.clear();
observations.clear();
observationsAfterSwitch.clear();
remainingstates.clear();
scheduler.clear();
uint64_t obs = 0;
for (auto ov : observationUpdatedVariables) {
if (!observationUpdated.get(obs) && model->getBooleanValue(ov)) {
STORM_LOG_TRACE("New observation updated: " << obs);
observationUpdated.set(obs);
}
obs++;
}
uint64_t i = 0;
for (auto rv : reachVars) {
if (model->getBooleanValue(rv)) {
if (!coveredStates.get(i) && model->getBooleanValue(rv)) {
STORM_LOG_TRACE("New state: " << i);
smtSolver->add(rv.getExpression());
observations.set(pomdp.getObservation(i));
} else {
remainingstates.set(i);
newObservations.set(pomdp.getObservation(i));
coveredStates.set(i);
}
++i;
}
i = 0;
for (auto rv : continuationVars) {
if (model->getBooleanValue(rv)) {
if (!coveredStatesAfterSwitch.get(i) && model->getBooleanValue(rv) ) {
smtSolver->add(rv.getExpression());
observationsAfterSwitch.set(pomdp.getObservation(i));
if (!observationsAfterSwitch.get(pomdp.getObservation(i))) {
newObservationsAfterSwitch.set(pomdp.getObservation(i));
}
++i;
}
}
if (options.computeTraceOutput()) {
detail::printRelevantInfoFromModel(model, reachVars, continuationVars);
}
// TODO do not repush everyting to the solver.
std::vector<storm::expressions::Expression> schedulerSoFar;
uint64_t obs = 0;
for (auto const &actionSelectionVarsForObs : actionSelectionVars) {
scheduler.actions.push_back(std::set<uint64_t>());
if (observations.get(obs)) {
for (auto obs : newObservations) {
auto const &actionSelectionVarsForObs = actionSelectionVars[obs];
observations.set(obs);
for (uint64_t act = 0; act < actionSelectionVarsForObs.size(); ++act) {
auto const& asv = actionSelectionVarsForObs[act];
if (model->getBooleanValue(asv)) {
scheduler.actions.back().insert(act);
schedulerSoFar.push_back(actionSelectionVarExpressions[obs][act]);
if (model->getBooleanValue(actionSelectionVarsForObs[act])) {
scheduler.actions[obs].set(act);
smtSolver->add(actionSelectionVarExpressions[obs][act]);
} else {
smtSolver->add(!actionSelectionVarExpressions[obs][act]);
}
}
if (model->getBooleanValue(switchVars[obs])) {
scheduler.switchObservations.set(obs);
schedulerSoFar.push_back(switchVarExpressions[obs]);
smtSolver->add(switchVarExpressions[obs]);
} else {
schedulerSoFar.push_back(!switchVarExpressions[obs]);
smtSolver->add(!switchVarExpressions[obs]);
}
}
if (observationsAfterSwitch.get(obs)) {
scheduler.schedulerRef.push_back(model->getIntegerValue(schedulerVariables[obs]));
schedulerSoFar.push_back(schedulerVariableExpressions[obs] == expressionManager->integer(scheduler.schedulerRef.back()));
} else {
scheduler.schedulerRef.push_back(0);
}
obs++;
for (auto obs : newObservationsAfterSwitch) {
observationsAfterSwitch.set(obs);
scheduler.schedulerRef[obs] = model->getIntegerValue(schedulerVariables[obs]);
smtSolver->add(schedulerVariableExpressions[obs] == expressionManager->integer(scheduler.schedulerRef.back()));
}
if(options.computeTraceOutput()) {
@ -341,56 +416,82 @@ namespace storm {
}
std::vector<storm::expressions::Expression> remainingExpressions;
for (auto index : remainingstates) {
for (auto index : ~coveredStates) {
if (observationUpdated.get(pomdp.getObservation(index))) {
remainingExpressions.push_back(reachVarExpressions[index]);
}
}
for (auto index : ~observationUpdated) {
remainingExpressions.push_back(observationUpdatedExpressions[index]);
}
if (remainingExpressions.empty()) {
stats.updateExtensionSolverTime.stop();
break;
}
// Add scheduler
smtSolver->add(storm::expressions::conjunction(schedulerSoFar));
//std::cout << storm::expressions::disjunction(remainingExpressions) << std::endl;
smtSolver->add(storm::expressions::disjunction(remainingExpressions));
stats.updateExtensionSolverTime.stop();
}
if (scheduler.empty()) {
if (!newSchedulerDiscovered) {
break;
}
smtSolver->pop();
if(options.computeDebugOutput()) {
printCoveredStates(remainingstates);
printCoveredStates(~coveredStates);
// generates info output, but here we only want it for debug level.
// For consistency, all output on info level.
STORM_LOG_DEBUG("the scheduler: ");
scheduler.printForObservations(observations,observationsAfterSwitch);
}
std::vector<storm::expressions::Expression> remainingExpressions;
for (auto index : remainingstates) {
remainingExpressions.push_back(reachVarExpressions[index]);
}
stats.winningRegionUpdatesTimer.start();
storm::storage::BitVector updated(observations.size());
for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
storm::storage::BitVector update = storm::storage::BitVector(statesPerObservation[observation].size());
STORM_LOG_TRACE("consider observation " << observation);
storm::storage::BitVector update(statesPerObservation[observation].size());
uint64_t i = 0;
for (uint64_t state : statesPerObservation[observation]) {
if (!remainingstates.get(state)) {
if (coveredStates.get(state)) {
update.set(i);
}
}
winningRegion.update(observation, update);
++i;
}
if(!update.empty()) {
STORM_LOG_TRACE("Update Winning Region: Observation " << observation << " with update " << update);
bool updateResult = winningRegion.update(observation, update);
STORM_LOG_TRACE("Region changed:" << updateResult);
if (updateResult) {
updated.set(observation);
updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]);
}
}
}
STORM_LOG_ASSERT(!updated.empty(), "The strategy should be new in at least one place");
stats.winningRegionUpdatesTimer.stop();
smtSolver->add(storm::expressions::disjunction(remainingExpressions));
if(options.computeDebugOutput()) {
winningRegion.print();
}
stats.updateNewStrategySolverTime.start();
uint64_t obs = 0;
for (auto const &statesForObservation : statesPerObservation) {
if (observations.get(obs)) {
if (observations.get(obs) && updated.get(obs)) {
STORM_LOG_DEBUG("We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << ".");
assert(schedulerForObs.size() > obs);
schedulerForObs[obs].push_back(finalSchedulers.size());
STORM_LOG_DEBUG("We now have " << schedulerForObs[obs].size() << " policies for states with observation " << obs);
for (auto const &state : statesForObservation) {
if (remainingstates.get(state)) {
if (!coveredStates.get(state)) {
auto constant = expressionManager->integer(schedulerForObs[obs].size());
smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant)));
}
@ -399,14 +500,20 @@ namespace storm {
++obs;
}
finalSchedulers.push_back(scheduler);
smtSolver->push();
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
auto constant = expressionManager->integer(schedulerForObs[obs].size());
smtSolver->add(schedulerVariableExpressions[obs] <= constant);
smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs]));
}
stats.updateNewStrategySolverTime.stop();
}
winningRegion.print();
return true;
}
@ -423,21 +530,53 @@ namespace storm {
}
template<typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::printScheduler(std::vector<InternalObservationScheduler> const& ) {
}
template<typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::printScheduler(std::vector<InternalObservationScheduler> const& ) {
void MemlessStrategySearchQualitative<ValueType>::finalizeStatistics() {
}
template<typename ValueType>
typename MemlessStrategySearchQualitative<ValueType>::Statistics const& MemlessStrategySearchQualitative<ValueType>::getStatistics() const{
return stats;
}
template <typename ValueType>
storm::expressions::Expression const& MemlessStrategySearchQualitative<ValueType>::getDoneActionExpression(uint64_t obs) const {
return actionSelectionVarExpressions[obs].back();
bool MemlessStrategySearchQualitative<ValueType>::smtCheck(uint64_t iteration) {
if(options.isExportSATSet()) {
STORM_LOG_DEBUG("Export SMT Solver Call (" <<iteration << ")");
std::string filepath = options.getExportSATCallsPath() + "call_" + std::to_string(iteration) + ".smt2";
std::ofstream filestream;
storm::utility::openFile(filepath, filestream);
filestream << smtSolver->getSmtLibString() << std::endl;
storm::utility::closeFile(filestream);
}
STORM_LOG_DEBUG("Call to SMT Solver (" <<iteration << ")");
stats.smtCheckTimer.start();
auto result = smtSolver->check();
stats.smtCheckTimer.stop();
stats.incrementSmtChecks();
if (result == storm::solver::SmtSolver::CheckResult::Unknown) {
STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result");
} else if (result == storm::solver::SmtSolver::CheckResult::Unsat) {
STORM_LOG_DEBUG("Unsatisfiable!");
return false;
}
STORM_LOG_DEBUG("Satisfying assignment: ");
STORM_LOG_TRACE(smtSolver->getModelAsValuation().toString(true));
return true;
}
template class MemlessStrategySearchQualitative<double>;
template class MemlessStrategySearchQualitative<storm::RationalNumber>;
}
}

63
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

@ -4,6 +4,7 @@
#include "storm/solver/SmtSolver.h"
#include "storm/models/sparse/Pomdp.h"
#include "storm/utility/solver.h"
#include "storm/utility/Stopwatch.h"
#include "storm/exceptions/UnexpectedException.h"
#include "storm-pomdp/analysis/WinningRegion.h"
@ -42,20 +43,24 @@ namespace pomdp {
return debugLevel > 2;
}
bool onlyDeterministicStrategies = false;
bool lookaheadRequired = true;
private:
std::string exportSATcalls = "";
uint64_t debugLevel = 0;
};
struct InternalObservationScheduler {
std::vector<std::set<uint64_t>> actions;
std::vector<storm::storage::BitVector> actions;
std::vector<uint64_t> schedulerRef;
storm::storage::BitVector switchObservations;
void clear() {
actions.clear();
schedulerRef.clear();
void reset(uint64_t nrObservations, uint64_t nrActions) {
actions = std::vector<storm::storage::BitVector>(nrObservations, storm::storage::BitVector(nrActions));
schedulerRef = std::vector<uint64_t>(nrObservations, 0);
switchObservations.clear();
}
@ -65,8 +70,10 @@ namespace pomdp {
void printForObservations(storm::storage::BitVector const& observations, storm::storage::BitVector const& observationsAfterSwitch) const {
for (uint64_t obs = 0; obs < observations.size(); ++obs) {
if (observations.get(obs)) {
if (observations.get(obs) || observationsAfterSwitch.get(obs)) {
STORM_LOG_INFO("For observation: " << obs);
}
if (observations.get(obs)) {
std::stringstream ss;
ss << "actions:";
for (auto act : actions[obs]) {
@ -90,6 +97,31 @@ namespace pomdp {
// Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper.
public:
class Statistics {
public:
Statistics() = default;
void print() const;
storm::utility::Stopwatch totalTimer;
storm::utility::Stopwatch smtCheckTimer;
storm::utility::Stopwatch initializeSolverTimer;
storm::utility::Stopwatch updateExtensionSolverTime;
storm::utility::Stopwatch updateNewStrategySolverTime;
storm::utility::Stopwatch winningRegionUpdatesTimer;
void incrementOuterIterations() {
outerIterations++;
}
void incrementSmtChecks() {
satCalls++;
}
private:
uint64_t satCalls = 0;
uint64_t outerIterations = 0;
};
MemlessStrategySearchQualitative(storm::models::sparse::Pomdp<ValueType> const& pomdp,
std::set<uint32_t> const& targetObservationSet,
storm::storage::BitVector const& targetStates,
@ -98,19 +130,24 @@ namespace pomdp {
MemlessSearchOptions const& options);
void analyzeForInitialStates(uint64_t k) {
stats.totalTimer.start();
analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates());
stats.totalTimer.stop();
}
void findNewStrategyForSomeState(uint64_t k) {
std::cout << surelyReachSinkStates << std::endl;
std::cout << targetStates << std::endl;
std::cout << (~surelyReachSinkStates & ~targetStates) << std::endl;
stats.totalTimer.start();
analyze(k, ~surelyReachSinkStates & ~targetStates);
stats.totalTimer.stop();
}
bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector());
Statistics const& getStatistics() const;
void finalizeStatistics();
private:
storm::expressions::Expression const& getDoneActionExpression(uint64_t obs) const;
@ -119,24 +156,30 @@ namespace pomdp {
void initialize(uint64_t k);
bool smtCheck(uint64_t iteration);
std::unique_ptr<storm::solver::SmtSolver> smtSolver;
storm::models::sparse::Pomdp<ValueType> const& pomdp;
std::shared_ptr<storm::expressions::ExpressionManager> expressionManager;
uint64_t maxK = std::numeric_limits<uint64_t>::max();
storm::storage::BitVector surelyReachSinkStates;
std::set<uint32_t> targetObservations;
storm::storage::BitVector targetStates;
storm::storage::BitVector surelyReachSinkStates;
std::vector<std::vector<uint64_t>> statesPerObservation;
std::vector<storm::expressions::Variable> schedulerVariables;
std::vector<storm::expressions::Expression> schedulerVariableExpressions;
std::vector<std::vector<uint64_t>> statesPerObservation;
std::vector<std::vector<storm::expressions::Expression>> actionSelectionVarExpressions; // A_{z,a}
std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars;
std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars; // A_{z,a}
std::vector<storm::expressions::Variable> reachVars;
std::vector<storm::expressions::Expression> reachVarExpressions;
std::vector<std::vector<storm::expressions::Expression>> reachVarExpressionsPerObservation;
std::vector<storm::expressions::Variable> observationUpdatedVariables;
std::vector<storm::expressions::Expression> observationUpdatedExpressions;
std::vector<storm::expressions::Variable> switchVars;
std::vector<storm::expressions::Expression> switchVarExpressions;
@ -149,7 +192,7 @@ namespace pomdp {
WinningRegion winningRegion;
MemlessSearchOptions options;
Statistics stats;
};

17
src/storm-pomdp/analysis/QualitativeAnalysis.cpp

@ -27,6 +27,23 @@ namespace storm {
return analyseProb0or1(formula, false);
}
template<typename ValueType>
storm::storage::BitVector QualitativeAnalysis<ValueType>::analyseProbSmaller1(storm::logic::ProbabilityOperatorFormula const &formula) const {
STORM_LOG_THROW(formula.hasOptimalityType() || formula.hasBound(), storm::exceptions::InvalidPropertyException, "The formula " << formula << " does not specify whether to minimize or maximize.");
bool minimizes = (formula.hasOptimalityType() && storm::solver::minimize(formula.getOptimalityType())) || (formula.hasBound() && storm::logic::isLowerBound(formula.getBound().comparisonType));
STORM_LOG_THROW(!minimizes,storm::exceptions::NotImplementedException, "This operation is only supported when maximizing");
std::shared_ptr<storm::logic::Formula const> subformula = formula.getSubformula().asSharedPointer();
std::shared_ptr<storm::logic::UntilFormula> untilSubformula;
// If necessary, convert the subformula to a more general case
if (subformula->isEventuallyFormula()) {
untilSubformula = std::make_shared<storm::logic::UntilFormula>(storm::logic::Formula::getTrueFormula(), subformula->asEventuallyFormula().getSubformula().asSharedPointer());
} else if(subformula->isUntilFormula()) {
untilSubformula = std::make_shared<storm::logic::UntilFormula>(subformula->asUntilFormula());
}
// The vector is sound, but not necessarily complete!
return ~storm::utility::graph::performProb1E(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), pomdp.getBackwardTransitions(), checkPropositionalFormula(untilSubformula->getLeftSubformula()), checkPropositionalFormula(untilSubformula->getRightSubformula()));
}
template<typename ValueType>
storm::storage::BitVector QualitativeAnalysis<ValueType>::analyseProb0or1(storm::logic::ProbabilityOperatorFormula const& formula, bool prob0) const {
// check whether the property is minimizing or maximizing

2
src/storm-pomdp/analysis/QualitativeAnalysis.h

@ -10,7 +10,7 @@ namespace storm {
QualitativeAnalysis(storm::models::sparse::Pomdp<ValueType> const& pomdp);
storm::storage::BitVector analyseProb0(storm::logic::ProbabilityOperatorFormula const& formula) const;
storm::storage::BitVector analyseProb1(storm::logic::ProbabilityOperatorFormula const& formula) const;
storm::storage::BitVector analyseProbSmaller1(storm::logic::ProbabilityOperatorFormula const& formula) const;
private:
storm::storage::BitVector analyseProb0or1(storm::logic::ProbabilityOperatorFormula const& formula, bool prob0) const;
storm::storage::BitVector analyseProb0Max(storm::logic::UntilFormula const& formula) const;

36
src/storm-pomdp/analysis/WinningRegion.cpp

@ -1,4 +1,6 @@
#include <iostream>
#include "storm/storage/expressions/Expression.h"
#include "storm/storage/expressions/ExpressionManager.h"
#include "storm-pomdp/analysis/WinningRegion.h"
namespace storm {
@ -10,13 +12,13 @@ namespace pomdp {
}
}
void WinningRegion::update(uint64_t observation, storm::storage::BitVector const& winning) {
bool WinningRegion::update(uint64_t observation, storm::storage::BitVector const& winning) {
std::vector<storm::storage::BitVector> newWinningSupport = std::vector<storm::storage::BitVector>();
bool changed = false;
for (auto const& support : winningRegion[observation]) {
if (winning.isSubsetOf(support)) {
// This new winning support is already covered.
return;
return false;
}
if(support.isSubsetOf(winning)) {
// This new winning support extends the previous support, thus the previous support is now spurious
@ -33,6 +35,7 @@ namespace pomdp {
} else {
winningRegion[observation].push_back(winning);
}
return true;
}
@ -45,6 +48,34 @@ namespace pomdp {
return false;
}
storm::expressions::Expression WinningRegion::extensionExpression(uint64_t observation, std::vector<storm::expressions::Expression>& varsForStates) const {
std::vector<storm::expressions::Expression> expressionForEntry;
for(auto const& winningForObservation : winningRegion[observation]) {
if (winningForObservation.full()) {
assert(winningRegion[observation].size() == 1);
return varsForStates.front().getManager().boolean(false);
}
std::vector<storm::expressions::Expression> subexpr;
std::vector<storm::expressions::Expression> leftHandSides;
assert(varsForStates.size() == winningForObservation.size());
for(uint64_t i = 0; i < varsForStates.size(); ++i) {
if (winningForObservation.get(i)) {
leftHandSides.push_back(varsForStates[i]);
} else {
subexpr.push_back(varsForStates[i]);
}
}
storm::expressions::Expression rightHandSide = storm::expressions::disjunction(subexpr);
for(auto const& lhs : leftHandSides) {
expressionForEntry.push_back(storm::expressions::implies(lhs,rightHandSide));
}
expressionForEntry.push_back(storm::expressions::disjunction(varsForStates));
}
return storm::expressions::conjunction(expressionForEntry);
}
/**
* If we observe this observation, do we surely win?
* @param observation
@ -62,6 +93,7 @@ namespace pomdp {
std::cout << " " << support;
}
std::cout << std::endl;
observation++;
}
}

4
src/storm-pomdp/analysis/WinningRegion.h

@ -9,10 +9,12 @@ namespace storm {
public:
WinningRegion(std::vector<uint64_t> const& observationSizes = {});
void update(uint64_t observation, storm::storage::BitVector const& winning);
bool update(uint64_t observation, storm::storage::BitVector const& winning);
bool query(uint64_t observation, storm::storage::BitVector const& currently) const;
bool observationIsWinning(uint64_t observation) const;
storm::expressions::Expression extensionExpression(uint64_t observation, std::vector<storm::expressions::Expression>& varsForStates) const;
uint64_t getStorageSize() const;
uint64_t getNumberOfObservations() const;

Loading…
Cancel
Save