|
@ -81,7 +81,6 @@ namespace storm { |
|
|
STORM_LOG_INFO("Start intializing solver..."); |
|
|
STORM_LOG_INFO("Start intializing solver..."); |
|
|
// TODO fix this
|
|
|
// TODO fix this
|
|
|
bool lookaheadConstraintsRequired = options.lookaheadRequired; |
|
|
bool lookaheadConstraintsRequired = options.lookaheadRequired; |
|
|
STORM_LOG_WARN("We have hardcoded that we do not need lookahead"); |
|
|
|
|
|
if (maxK == std::numeric_limits<uint64_t>::max()) { |
|
|
if (maxK == std::numeric_limits<uint64_t>::max()) { |
|
|
// not initialized at all.
|
|
|
// not initialized at all.
|
|
|
// Create some data structures.
|
|
|
// Create some data structures.
|
|
@ -134,12 +133,16 @@ namespace storm { |
|
|
++obs; |
|
|
++obs; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// PAPER COMMENT: 1
|
|
|
for (auto const& actionVars : actionSelectionVarExpressions) { |
|
|
for (auto const& actionVars : actionSelectionVarExpressions) { |
|
|
smtSolver->add(storm::expressions::disjunction(actionVars)); |
|
|
smtSolver->add(storm::expressions::disjunction(actionVars)); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Update at least one observation.
|
|
|
|
|
|
// PAPER COMMENT: 2
|
|
|
smtSolver->add(storm::expressions::disjunction(observationUpdatedExpressions)); |
|
|
smtSolver->add(storm::expressions::disjunction(observationUpdatedExpressions)); |
|
|
|
|
|
|
|
|
|
|
|
// PAPER COMMENT: 3
|
|
|
if (lookaheadConstraintsRequired) { |
|
|
if (lookaheadConstraintsRequired) { |
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
if (targetStates.get(state)) { |
|
|
if (targetStates.get(state)) { |
|
@ -150,6 +153,7 @@ namespace storm { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// PAPER COMMENT: 4
|
|
|
uint64_t rowindex = 0; |
|
|
uint64_t rowindex = 0; |
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
|
|
|
|
|
@ -185,6 +189,7 @@ namespace storm { |
|
|
|
|
|
|
|
|
uint64_t rowindex = 0; |
|
|
uint64_t rowindex = 0; |
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
|
|
|
// PAPER COMMENT 5
|
|
|
if (surelyReachSinkStates.get(state)) { |
|
|
if (surelyReachSinkStates.get(state)) { |
|
|
smtSolver->add(!reachVarExpressions[state]); |
|
|
smtSolver->add(!reachVarExpressions[state]); |
|
|
smtSolver->add(!continuationVarExpressions[state]); |
|
|
smtSolver->add(!continuationVarExpressions[state]); |
|
@ -196,8 +201,10 @@ namespace storm { |
|
|
rowindex += pomdp.getNumberOfChoices(state); |
|
|
rowindex += pomdp.getNumberOfChoices(state); |
|
|
} else if(!targetStates.get(state)) { |
|
|
} else if(!targetStates.get(state)) { |
|
|
if (lookaheadConstraintsRequired) { |
|
|
if (lookaheadConstraintsRequired) { |
|
|
|
|
|
// PAPER COMMENT 6
|
|
|
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back())); |
|
|
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back())); |
|
|
|
|
|
|
|
|
|
|
|
// PAPER COMMENT 7
|
|
|
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs; |
|
|
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs; |
|
|
for (uint64_t j = 1; j < k; ++j) { |
|
|
for (uint64_t j = 1; j < k; ++j) { |
|
|
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>()); |
|
|
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>()); |
|
@ -231,6 +238,7 @@ namespace storm { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// PAPER COMMENT 8
|
|
|
uint64_t obs = 0; |
|
|
uint64_t obs = 0; |
|
|
for(auto const& statesForObservation : statesPerObservation) { |
|
|
for(auto const& statesForObservation : statesPerObservation) { |
|
|
for(auto const& state : statesForObservation) { |
|
|
for(auto const& state : statesForObservation) { |
|
@ -239,9 +247,11 @@ namespace storm { |
|
|
++obs; |
|
|
++obs; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// PAPER COMMENT 9
|
|
|
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { |
|
|
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { |
|
|
smtSolver->add(storm::expressions::implies(switchVarExpressions[obs], storm::expressions::disjunction(reachVarExpressionsPerObservation[obs]))); |
|
|
smtSolver->add(storm::expressions::implies(switchVarExpressions[obs], storm::expressions::disjunction(reachVarExpressionsPerObservation[obs]))); |
|
|
} |
|
|
} |
|
|
|
|
|
// PAPER COMMENT 10
|
|
|
if (!lookaheadConstraintsRequired) { |
|
|
if (!lookaheadConstraintsRequired) { |
|
|
uint64_t rowIndex = 0; |
|
|
uint64_t rowIndex = 0; |
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
@ -263,9 +273,6 @@ namespace storm { |
|
|
} else { |
|
|
} else { |
|
|
STORM_LOG_WARN("Some optimization not implemented yet."); |
|
|
STORM_LOG_WARN("Some optimization not implemented yet."); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// TODO: Update found schedulers if k is increased.
|
|
|
// TODO: Update found schedulers if k is increased.
|
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@ -287,15 +294,19 @@ namespace storm { |
|
|
atLeastOneOfStates.push_back(reachVarExpressions[state]); |
|
|
atLeastOneOfStates.push_back(reachVarExpressions[state]); |
|
|
} |
|
|
} |
|
|
assert(atLeastOneOfStates.size() > 0); |
|
|
assert(atLeastOneOfStates.size() > 0); |
|
|
|
|
|
// PAPER COMMENT 11
|
|
|
smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates)); |
|
|
smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates)); |
|
|
|
|
|
smtSolver->push(); |
|
|
|
|
|
|
|
|
|
|
|
std::set<storm::expressions::Expression> allOfTheseAssumption; |
|
|
|
|
|
|
|
|
|
|
|
std::vector<storm::expressions::Expression> updateForObservationExpressions; |
|
|
|
|
|
|
|
|
for (uint64_t state : allOfTheseStates) { |
|
|
for (uint64_t state : allOfTheseStates) { |
|
|
assert(reachVarExpressions.size() > state); |
|
|
assert(reachVarExpressions.size() > state); |
|
|
smtSolver->add(reachVarExpressions[state]); |
|
|
|
|
|
|
|
|
allOfTheseAssumption.insert(reachVarExpressions[state]); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
smtSolver->push(); |
|
|
|
|
|
std::vector<storm::expressions::Expression> updateForObservationExpressions; |
|
|
|
|
|
for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) { |
|
|
for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) { |
|
|
updateForObservationExpressions.push_back(storm::expressions::disjunction(reachVarExpressionsPerObservation[ob])); |
|
|
updateForObservationExpressions.push_back(storm::expressions::disjunction(reachVarExpressionsPerObservation[ob])); |
|
|
schedulerForObs.push_back(std::vector<uint64_t>()); |
|
|
schedulerForObs.push_back(std::vector<uint64_t>()); |
|
@ -333,11 +344,19 @@ namespace storm { |
|
|
coveredStates.clear(); |
|
|
coveredStates.clear(); |
|
|
coveredStatesAfterSwitch.clear(); |
|
|
coveredStatesAfterSwitch.clear(); |
|
|
observationUpdated.clear(); |
|
|
observationUpdated.clear(); |
|
|
|
|
|
if (!allOfTheseAssumption.empty()) { |
|
|
|
|
|
bool foundResult = this->smtCheck(iterations, allOfTheseAssumption); |
|
|
|
|
|
if (foundResult) { |
|
|
|
|
|
// Consider storing the scheduler
|
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
bool newSchedulerDiscovered = false; |
|
|
bool newSchedulerDiscovered = false; |
|
|
|
|
|
|
|
|
while (true) { |
|
|
while (true) { |
|
|
++iterations; |
|
|
++iterations; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool foundScheduler = this->smtCheck(iterations); |
|
|
bool foundScheduler = this->smtCheck(iterations); |
|
|
if (!foundScheduler) { |
|
|
if (!foundScheduler) { |
|
|
break; |
|
|
break; |
|
@ -493,6 +512,7 @@ namespace storm { |
|
|
for (auto const &state : statesForObservation) { |
|
|
for (auto const &state : statesForObservation) { |
|
|
if (!coveredStates.get(state)) { |
|
|
if (!coveredStates.get(state)) { |
|
|
auto constant = expressionManager->integer(schedulerForObs[obs].size()); |
|
|
auto constant = expressionManager->integer(schedulerForObs[obs].size()); |
|
|
|
|
|
// PAPER COMMENT 14:
|
|
|
smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); |
|
|
smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
@ -505,7 +525,9 @@ namespace storm { |
|
|
|
|
|
|
|
|
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { |
|
|
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { |
|
|
auto constant = expressionManager->integer(schedulerForObs[obs].size()); |
|
|
auto constant = expressionManager->integer(schedulerForObs[obs].size()); |
|
|
|
|
|
// PAPER COMMENT 13
|
|
|
smtSolver->add(schedulerVariableExpressions[obs] <= constant); |
|
|
smtSolver->add(schedulerVariableExpressions[obs] <= constant); |
|
|
|
|
|
// PAPER COMMENT 12
|
|
|
smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs])); |
|
|
smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs])); |
|
|
} |
|
|
} |
|
|
stats.updateNewStrategySolverTime.stop(); |
|
|
stats.updateNewStrategySolverTime.stop(); |
|
@ -514,6 +536,21 @@ namespace storm { |
|
|
} |
|
|
} |
|
|
winningRegion.print(); |
|
|
winningRegion.print(); |
|
|
|
|
|
|
|
|
|
|
|
if (!allOfTheseStates.empty()) { |
|
|
|
|
|
for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { |
|
|
|
|
|
storm::storage::BitVector check(statesPerObservation[observation].size()); |
|
|
|
|
|
uint64_t i = 0; |
|
|
|
|
|
for (uint64_t state : statesPerObservation[observation]) { |
|
|
|
|
|
if (allOfTheseStates.get(state)) { |
|
|
|
|
|
check.set(i); |
|
|
|
|
|
} |
|
|
|
|
|
++i; |
|
|
|
|
|
} |
|
|
|
|
|
if (!winningRegion.query(observation, check)) { |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@ -546,7 +583,7 @@ namespace storm { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template <typename ValueType> |
|
|
template <typename ValueType> |
|
|
bool MemlessStrategySearchQualitative<ValueType>::smtCheck(uint64_t iteration) { |
|
|
|
|
|
|
|
|
bool MemlessStrategySearchQualitative<ValueType>::smtCheck(uint64_t iteration, std::set<storm::expressions::Expression> const& assumptions) { |
|
|
if(options.isExportSATSet()) { |
|
|
if(options.isExportSATSet()) { |
|
|
STORM_LOG_DEBUG("Export SMT Solver Call (" <<iteration << ")"); |
|
|
STORM_LOG_DEBUG("Export SMT Solver Call (" <<iteration << ")"); |
|
|
std::string filepath = options.getExportSATCallsPath() + "call_" + std::to_string(iteration) + ".smt2"; |
|
|
std::string filepath = options.getExportSATCallsPath() + "call_" + std::to_string(iteration) + ".smt2"; |
|
@ -557,8 +594,13 @@ namespace storm { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
STORM_LOG_DEBUG("Call to SMT Solver (" <<iteration << ")"); |
|
|
STORM_LOG_DEBUG("Call to SMT Solver (" <<iteration << ")"); |
|
|
|
|
|
storm::solver::SmtSolver::CheckResult result; |
|
|
stats.smtCheckTimer.start(); |
|
|
stats.smtCheckTimer.start(); |
|
|
auto result = smtSolver->check(); |
|
|
|
|
|
|
|
|
if (assumptions.empty()) { |
|
|
|
|
|
result = smtSolver->check(); |
|
|
|
|
|
} else { |
|
|
|
|
|
result = smtSolver->checkWithAssumptions(assumptions); |
|
|
|
|
|
} |
|
|
stats.smtCheckTimer.stop(); |
|
|
stats.smtCheckTimer.stop(); |
|
|
stats.incrementSmtChecks(); |
|
|
stats.incrementSmtChecks(); |
|
|
|
|
|
|
|
|