Browse Source

support for computing the winning region or from initial state, some documentation

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
e22cbdb91b
  1. 60
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  2. 16
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

60
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -81,7 +81,6 @@ namespace storm {
STORM_LOG_INFO("Start intializing solver..."); STORM_LOG_INFO("Start intializing solver...");
// TODO fix this // TODO fix this
bool lookaheadConstraintsRequired = options.lookaheadRequired; bool lookaheadConstraintsRequired = options.lookaheadRequired;
STORM_LOG_WARN("We have hardcoded that we do not need lookahead");
if (maxK == std::numeric_limits<uint64_t>::max()) { if (maxK == std::numeric_limits<uint64_t>::max()) {
// not initialized at all. // not initialized at all.
// Create some data structures. // Create some data structures.
@ -134,12 +133,16 @@ namespace storm {
++obs; ++obs;
} }
// PAPER COMMENT: 1
for (auto const& actionVars : actionSelectionVarExpressions) { for (auto const& actionVars : actionSelectionVarExpressions) {
smtSolver->add(storm::expressions::disjunction(actionVars)); smtSolver->add(storm::expressions::disjunction(actionVars));
} }
// Update at least one observation.
// PAPER COMMENT: 2
smtSolver->add(storm::expressions::disjunction(observationUpdatedExpressions)); smtSolver->add(storm::expressions::disjunction(observationUpdatedExpressions));
// PAPER COMMENT: 3
if (lookaheadConstraintsRequired) { if (lookaheadConstraintsRequired) {
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (targetStates.get(state)) { if (targetStates.get(state)) {
@ -150,6 +153,7 @@ namespace storm {
} }
} }
// PAPER COMMENT: 4
uint64_t rowindex = 0; uint64_t rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
@ -185,6 +189,7 @@ namespace storm {
uint64_t rowindex = 0; uint64_t rowindex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
// PAPER COMMENT 5
if (surelyReachSinkStates.get(state)) { if (surelyReachSinkStates.get(state)) {
smtSolver->add(!reachVarExpressions[state]); smtSolver->add(!reachVarExpressions[state]);
smtSolver->add(!continuationVarExpressions[state]); smtSolver->add(!continuationVarExpressions[state]);
@ -196,8 +201,10 @@ namespace storm {
rowindex += pomdp.getNumberOfChoices(state); rowindex += pomdp.getNumberOfChoices(state);
} else if(!targetStates.get(state)) { } else if(!targetStates.get(state)) {
if (lookaheadConstraintsRequired) { if (lookaheadConstraintsRequired) {
// PAPER COMMENT 6
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back())); smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back()));
// PAPER COMMENT 7
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs; std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
for (uint64_t j = 1; j < k; ++j) { for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>()); pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>());
@ -231,6 +238,7 @@ namespace storm {
} }
} }
// PAPER COMMENT 8
uint64_t obs = 0; uint64_t obs = 0;
for(auto const& statesForObservation : statesPerObservation) { for(auto const& statesForObservation : statesPerObservation) {
for(auto const& state : statesForObservation) { for(auto const& state : statesForObservation) {
@ -239,9 +247,11 @@ namespace storm {
++obs; ++obs;
} }
// PAPER COMMENT 9
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
smtSolver->add(storm::expressions::implies(switchVarExpressions[obs], storm::expressions::disjunction(reachVarExpressionsPerObservation[obs]))); smtSolver->add(storm::expressions::implies(switchVarExpressions[obs], storm::expressions::disjunction(reachVarExpressionsPerObservation[obs])));
} }
// PAPER COMMENT 10
if (!lookaheadConstraintsRequired) { if (!lookaheadConstraintsRequired) {
uint64_t rowIndex = 0; uint64_t rowIndex = 0;
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
@ -263,9 +273,6 @@ namespace storm {
} else { } else {
STORM_LOG_WARN("Some optimization not implemented yet."); STORM_LOG_WARN("Some optimization not implemented yet.");
} }
// TODO: Update found schedulers if k is increased. // TODO: Update found schedulers if k is increased.
} }
@ -287,15 +294,19 @@ namespace storm {
atLeastOneOfStates.push_back(reachVarExpressions[state]); atLeastOneOfStates.push_back(reachVarExpressions[state]);
} }
assert(atLeastOneOfStates.size() > 0); assert(atLeastOneOfStates.size() > 0);
// PAPER COMMENT 11
smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates)); smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates));
smtSolver->push();
std::set<storm::expressions::Expression> allOfTheseAssumption;
std::vector<storm::expressions::Expression> updateForObservationExpressions;
for (uint64_t state : allOfTheseStates) { for (uint64_t state : allOfTheseStates) {
assert(reachVarExpressions.size() > state); assert(reachVarExpressions.size() > state);
smtSolver->add(reachVarExpressions[state]);
allOfTheseAssumption.insert(reachVarExpressions[state]);
} }
smtSolver->push();
std::vector<storm::expressions::Expression> updateForObservationExpressions;
for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) { for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) {
updateForObservationExpressions.push_back(storm::expressions::disjunction(reachVarExpressionsPerObservation[ob])); updateForObservationExpressions.push_back(storm::expressions::disjunction(reachVarExpressionsPerObservation[ob]));
schedulerForObs.push_back(std::vector<uint64_t>()); schedulerForObs.push_back(std::vector<uint64_t>());
@ -333,11 +344,19 @@ namespace storm {
coveredStates.clear(); coveredStates.clear();
coveredStatesAfterSwitch.clear(); coveredStatesAfterSwitch.clear();
observationUpdated.clear(); observationUpdated.clear();
if (!allOfTheseAssumption.empty()) {
bool foundResult = this->smtCheck(iterations, allOfTheseAssumption);
if (foundResult) {
// Consider storing the scheduler
return true;
}
}
bool newSchedulerDiscovered = false; bool newSchedulerDiscovered = false;
while (true) { while (true) {
++iterations; ++iterations;
bool foundScheduler = this->smtCheck(iterations); bool foundScheduler = this->smtCheck(iterations);
if (!foundScheduler) { if (!foundScheduler) {
break; break;
@ -493,6 +512,7 @@ namespace storm {
for (auto const &state : statesForObservation) { for (auto const &state : statesForObservation) {
if (!coveredStates.get(state)) { if (!coveredStates.get(state)) {
auto constant = expressionManager->integer(schedulerForObs[obs].size()); auto constant = expressionManager->integer(schedulerForObs[obs].size());
// PAPER COMMENT 14:
smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant)));
} }
} }
@ -505,7 +525,9 @@ namespace storm {
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
auto constant = expressionManager->integer(schedulerForObs[obs].size()); auto constant = expressionManager->integer(schedulerForObs[obs].size());
// PAPER COMMENT 13
smtSolver->add(schedulerVariableExpressions[obs] <= constant); smtSolver->add(schedulerVariableExpressions[obs] <= constant);
// PAPER COMMENT 12
smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs])); smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs]));
} }
stats.updateNewStrategySolverTime.stop(); stats.updateNewStrategySolverTime.stop();
@ -514,6 +536,21 @@ namespace storm {
} }
winningRegion.print(); winningRegion.print();
if (!allOfTheseStates.empty()) {
for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
storm::storage::BitVector check(statesPerObservation[observation].size());
uint64_t i = 0;
for (uint64_t state : statesPerObservation[observation]) {
if (allOfTheseStates.get(state)) {
check.set(i);
}
++i;
}
if (!winningRegion.query(observation, check)) {
return false;
}
}
}
return true; return true;
} }
@ -546,7 +583,7 @@ namespace storm {
} }
template <typename ValueType> template <typename ValueType>
bool MemlessStrategySearchQualitative<ValueType>::smtCheck(uint64_t iteration) {
bool MemlessStrategySearchQualitative<ValueType>::smtCheck(uint64_t iteration, std::set<storm::expressions::Expression> const& assumptions) {
if(options.isExportSATSet()) { if(options.isExportSATSet()) {
STORM_LOG_DEBUG("Export SMT Solver Call (" <<iteration << ")"); STORM_LOG_DEBUG("Export SMT Solver Call (" <<iteration << ")");
std::string filepath = options.getExportSATCallsPath() + "call_" + std::to_string(iteration) + ".smt2"; std::string filepath = options.getExportSATCallsPath() + "call_" + std::to_string(iteration) + ".smt2";
@ -557,8 +594,13 @@ namespace storm {
} }
STORM_LOG_DEBUG("Call to SMT Solver (" <<iteration << ")"); STORM_LOG_DEBUG("Call to SMT Solver (" <<iteration << ")");
storm::solver::SmtSolver::CheckResult result;
stats.smtCheckTimer.start(); stats.smtCheckTimer.start();
auto result = smtSolver->check();
if (assumptions.empty()) {
result = smtSolver->check();
} else {
result = smtSolver->checkWithAssumptions(assumptions);
}
stats.smtCheckTimer.stop(); stats.smtCheckTimer.stop();
stats.incrementSmtChecks(); stats.incrementSmtChecks();

16
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

@ -132,8 +132,20 @@ namespace pomdp {
void analyzeForInitialStates(uint64_t k) { void analyzeForInitialStates(uint64_t k) {
stats.totalTimer.start(); stats.totalTimer.start();
analyze(k, pomdp.getInitialStates(), pomdp.getInitialStates());
STORM_LOG_TRACE("Bad states: " << surelyReachSinkStates);
STORM_LOG_TRACE("Target states: " << targetStates);
STORM_LOG_TRACE("Questionmark states: " << (~surelyReachSinkStates & ~targetStates));
bool result = analyze(k, ~surelyReachSinkStates & ~targetStates, pomdp.getInitialStates());
stats.totalTimer.stop(); stats.totalTimer.stop();
if (result) {
STORM_PRINT_AND_LOG("From initial state, one can almost-surely reach the target.");
} else {
if (k == pomdp.getNumberOfStates()) {
STORM_PRINT_AND_LOG("From initial state, one cannot almost-surely reach the target.");
} else {
STORM_PRINT_AND_LOG("From initial state, one may not almost-surely reach the target.");
}
}
} }
void findNewStrategyForSomeState(uint64_t k) { void findNewStrategyForSomeState(uint64_t k) {
@ -157,7 +169,7 @@ namespace pomdp {
void initialize(uint64_t k); void initialize(uint64_t k);
bool smtCheck(uint64_t iteration);
bool smtCheck(uint64_t iteration, std::set<storm::expressions::Expression> const& assumptions = {});
std::unique_ptr<storm::solver::SmtSolver> smtSolver; std::unique_ptr<storm::solver::SmtSolver> smtSolver;

Loading…
Cancel
Save