|
|
@ -83,8 +83,9 @@ namespace storm { |
|
|
|
} |
|
|
|
|
|
|
|
template <typename ValueType> |
|
|
|
void MemlessStrategySearchQualitative<ValueType>::initialize(uint64_t k) { |
|
|
|
bool MemlessStrategySearchQualitative<ValueType>::initialize(uint64_t k) { |
|
|
|
STORM_LOG_INFO("Start intializing solver..."); |
|
|
|
|
|
|
|
bool lookaheadConstraintsRequired; |
|
|
|
if (options.forceLookahead) { |
|
|
|
lookaheadConstraintsRequired = true; |
|
|
@ -131,8 +132,10 @@ namespace storm { |
|
|
|
} |
|
|
|
|
|
|
|
for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { |
|
|
|
pathVars.push_back(std::vector<storm::expressions::Expression>()); |
|
|
|
pathVars.push_back(std::vector<storm::expressions::Variable>()); |
|
|
|
pathVarExpressions.push_back(std::vector<storm::expressions::Expression>()); |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
uint64_t initK = 0; |
|
|
@ -142,15 +145,25 @@ namespace storm { |
|
|
|
if (initK < k) { |
|
|
|
for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { |
|
|
|
if (lookaheadConstraintsRequired) { |
|
|
|
for (uint64_t i = initK; i < k; ++i) { |
|
|
|
pathVars[stateId].push_back(expressionManager->declareBooleanVariable( |
|
|
|
"P-" + std::to_string(stateId) + "-" + std::to_string(i)).getExpression()); |
|
|
|
if(options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) { |
|
|
|
for (uint64_t i = initK; i < k; ++i) { |
|
|
|
pathVars[stateId].push_back(expressionManager->declareBooleanVariable( |
|
|
|
"P-" + std::to_string(stateId) + "-" + std::to_string(i))); |
|
|
|
pathVarExpressions[stateId].push_back(pathVars[stateId].back().getExpression()); |
|
|
|
} |
|
|
|
} else if (options.pathVariableType == MemlessSearchPathVariables::IntegerRanking) { |
|
|
|
pathVars[stateId].push_back(expressionManager->declareIntegerVariable("P-" + std::to_string(stateId))); |
|
|
|
pathVarExpressions[stateId].push_back(pathVars[stateId].back().getExpression()); |
|
|
|
} else { |
|
|
|
assert(options.pathVariableType == MemlessSearchPathVariables::RealRanking); |
|
|
|
pathVars[stateId].push_back(expressionManager->declareRationalVariable("P-" + std::to_string(stateId))); |
|
|
|
pathVarExpressions[stateId].push_back(pathVars[stateId].back().getExpression()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
assert(!lookaheadConstraintsRequired || pathVars.size() == pomdp.getNumberOfStates()); |
|
|
|
assert(!lookaheadConstraintsRequired || pathVarExpressions.size() == pomdp.getNumberOfStates()); |
|
|
|
assert(reachVars.size() == pomdp.getNumberOfStates()); |
|
|
|
assert(reachVarExpressions.size() == pomdp.getNumberOfStates()); |
|
|
|
|
|
|
@ -193,13 +206,23 @@ namespace storm { |
|
|
|
|
|
|
|
// PAPER COMMENT: 3
|
|
|
|
if (lookaheadConstraintsRequired) { |
|
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
|
if (targetStates.get(state)) { |
|
|
|
smtSolver->add(pathVars[state][0]); |
|
|
|
} else { |
|
|
|
smtSolver->add(!pathVars[state][0] || followVarExpressions[pomdp.getObservation(state)]); |
|
|
|
if (options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) { |
|
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
|
if (targetStates.get(state)) { |
|
|
|
smtSolver->add(pathVarExpressions[state][0]); |
|
|
|
} else { |
|
|
|
smtSolver->add(!pathVarExpressions[state][0] || followVarExpressions[pomdp.getObservation(state)]); |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
|
smtSolver->add(pathVarExpressions[state][0] <= expressionManager->integer(k)); |
|
|
|
smtSolver->add(pathVarExpressions[state][0] >= expressionManager->integer(0)); |
|
|
|
|
|
|
|
} |
|
|
|
//assert(false);
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
// PAPER COMMENT: 4
|
|
|
@ -218,6 +241,8 @@ namespace storm { |
|
|
|
subexprreachNoSwitch.push_back(!reachVarExpressions[state]); |
|
|
|
subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); |
|
|
|
subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]); |
|
|
|
subexprreachSwitch.push_back(followVarExpressions[pomdp.getObservation(state)]); |
|
|
|
subexprreachNoSwitch.push_back(followVarExpressions[pomdp.getObservation(state)]); |
|
|
|
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { |
|
|
|
if (pomdp.getObservation(entries.getColumn() == pomdp.getObservation(state))) { |
|
|
|
subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn())); |
|
|
@ -244,50 +269,85 @@ namespace storm { |
|
|
|
smtSolver->add(!reachVarExpressions[state]); |
|
|
|
smtSolver->add(!continuationVarExpressions[state]); |
|
|
|
if (lookaheadConstraintsRequired) { |
|
|
|
for (uint64_t j = 1; j < k; ++j) { |
|
|
|
smtSolver->add(!pathVars[state][j]); |
|
|
|
if(options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) { |
|
|
|
for (uint64_t j = 1; j < k; ++j) { |
|
|
|
smtSolver->add(!pathVarExpressions[state][j]); |
|
|
|
} |
|
|
|
} else { |
|
|
|
smtSolver->add(pathVarExpressions[state][0] == expressionManager->integer(k)); |
|
|
|
} |
|
|
|
} |
|
|
|
rowindex += pomdp.getNumberOfChoices(state); |
|
|
|
} else if(!targetStates.get(state)) { |
|
|
|
if (lookaheadConstraintsRequired) { |
|
|
|
// PAPER COMMENT 6
|
|
|
|
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back())); |
|
|
|
|
|
|
|
// PAPER COMMENT 7
|
|
|
|
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs; |
|
|
|
for (uint64_t j = 1; j < k; ++j) { |
|
|
|
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>()); |
|
|
|
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { |
|
|
|
pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>()); |
|
|
|
|
|
|
|
if(options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) { |
|
|
|
// PAPER COMMENT 6
|
|
|
|
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), |
|
|
|
pathVarExpressions.at(state).back())); |
|
|
|
// PAPER COMMENT 7
|
|
|
|
|
|
|
|
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs; |
|
|
|
for (uint64_t j = 1; j < k; ++j) { |
|
|
|
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>()); |
|
|
|
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { |
|
|
|
pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { |
|
|
|
std::vector<storm::expressions::Expression> subexprreach; |
|
|
|
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { |
|
|
|
for (uint64_t j = 1; j < k; ++j) { |
|
|
|
pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]); |
|
|
|
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { |
|
|
|
std::vector<storm::expressions::Expression> subexprreach; |
|
|
|
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { |
|
|
|
for (uint64_t j = 1; j < k; ++j) { |
|
|
|
pathsubsubexprs[j - 1][action].push_back(pathVarExpressions[entries.getColumn()][j - 1]); |
|
|
|
} |
|
|
|
} |
|
|
|
rowindex++; |
|
|
|
} |
|
|
|
rowindex++; |
|
|
|
} |
|
|
|
|
|
|
|
for (uint64_t j = 1; j < k; ++j) { |
|
|
|
std::vector<storm::expressions::Expression> pathsubexprs; |
|
|
|
for (uint64_t j = 1; j < k; ++j) { |
|
|
|
std::vector<storm::expressions::Expression> pathsubexprs; |
|
|
|
|
|
|
|
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { |
|
|
|
pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); |
|
|
|
} |
|
|
|
pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state))); |
|
|
|
pathsubexprs.push_back(followVarExpressions[pomdp.getObservation(state)]); |
|
|
|
smtSolver->add(storm::expressions::iff(pathVarExpressions[state][j], |
|
|
|
storm::expressions::disjunction(pathsubexprs))); |
|
|
|
|
|
|
|
} |
|
|
|
} else { |
|
|
|
|
|
|
|
std::vector<storm::expressions::Expression> actPathDisjunction; |
|
|
|
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { |
|
|
|
pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); |
|
|
|
std::vector<storm::expressions::Expression> pathDisjunction; |
|
|
|
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { |
|
|
|
pathDisjunction.push_back(pathVarExpressions[entries.getColumn()][0] < pathVarExpressions[state][0]); |
|
|
|
} |
|
|
|
actPathDisjunction.push_back(storm::expressions::disjunction(pathDisjunction) && actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action)); |
|
|
|
rowindex++; |
|
|
|
} |
|
|
|
pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state))); |
|
|
|
pathsubexprs.push_back(followVarExpressions[pomdp.getObservation(state)]); |
|
|
|
smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); |
|
|
|
// TODO reconsider if this next add is sound
|
|
|
|
actPathDisjunction.push_back(switchVarExpressions.at(pomdp.getObservation(state))); |
|
|
|
actPathDisjunction.push_back(followVarExpressions[pomdp.getObservation(state)]); |
|
|
|
actPathDisjunction.push_back(!reachVarExpressions[state]); |
|
|
|
smtSolver->add(storm::expressions::disjunction(actPathDisjunction)); |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
for (uint64_t j = 1; j < k; ++j) { |
|
|
|
smtSolver->add(pathVars[state][j]); |
|
|
|
if (lookaheadConstraintsRequired) { |
|
|
|
if (options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) { |
|
|
|
for (uint64_t j = 1; j < k; ++j) { |
|
|
|
smtSolver->add(pathVarExpressions[state][j]); |
|
|
|
} |
|
|
|
} else { |
|
|
|
smtSolver->add(pathVarExpressions[state][0] == expressionManager->integer(0)); |
|
|
|
//assert(false);
|
|
|
|
} |
|
|
|
} |
|
|
|
smtSolver->add(reachVars[state]); |
|
|
|
rowindex += pomdp.getNumberOfChoices(state); |
|
|
|
} |
|
|
|
} |
|
|
@ -308,29 +368,7 @@ namespace storm { |
|
|
|
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { |
|
|
|
smtSolver->add(storm::expressions::implies(switchVarExpressions[obs], storm::expressions::disjunction(reachVarExpressionsPerObservation[obs]))); |
|
|
|
} |
|
|
|
// PAPER COMMENT 10
|
|
|
|
// if (!lookaheadConstraintsRequired) {
|
|
|
|
// uint64_t rowIndex = 0;
|
|
|
|
// for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
|
|
|
|
// uint64_t enabledActions = pomdp.getNumberOfChoices(state);
|
|
|
|
// if (!surelyReachSinkStates.get(state)) {
|
|
|
|
// std::vector<storm::expressions::Expression> successorVars;
|
|
|
|
// for (uint64_t act = 0; act < enabledActions; ++act) {
|
|
|
|
// for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowIndex)) {
|
|
|
|
// successorVars.push_back(reachVarExpressions[entries.getColumn()]);
|
|
|
|
// }
|
|
|
|
// rowIndex++;
|
|
|
|
// }
|
|
|
|
// successorVars.push_back(!switchVars[pomdp.getObservation(state)]);
|
|
|
|
// smtSolver->add(storm::expressions::implies(storm::expressions::conjunction(successorVars), reachVarExpressions[state]));
|
|
|
|
// } else {
|
|
|
|
// rowIndex += enabledActions;
|
|
|
|
// }
|
|
|
|
// }
|
|
|
|
// } else {
|
|
|
|
// STORM_LOG_WARN("Some optimization not implemented yet.");
|
|
|
|
// }
|
|
|
|
// TODO: Update found schedulers if k is increased.
|
|
|
|
return lookaheadConstraintsRequired; |
|
|
|
} |
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
@ -357,9 +395,10 @@ namespace storm { |
|
|
|
STORM_LOG_DEBUG("Questionmark states " << (~surelyReachSinkStates & ~targetStates)); |
|
|
|
stats.initializeSolverTimer.start(); |
|
|
|
// TODO: When do we need to reinitialize? When the solver has been reset.
|
|
|
|
initialize(k); |
|
|
|
maxK = k; |
|
|
|
|
|
|
|
bool lookaheadConstraintsRequired = initialize(k); |
|
|
|
if(lookaheadConstraintsRequired) { |
|
|
|
maxK = k; |
|
|
|
} |
|
|
|
|
|
|
|
stats.winningRegionUpdatesTimer.start(); |
|
|
|
storm::storage::BitVector updated(pomdp.getNrObservations()); |
|
|
@ -500,11 +539,11 @@ namespace storm { |
|
|
|
uint64_t iterations = 0; |
|
|
|
while(true) { |
|
|
|
stats.incrementOuterIterations(); |
|
|
|
|
|
|
|
// TODO consider what we really want to store about the schedulers.
|
|
|
|
scheduler.reset(pomdp.getNrObservations(), maximalNrActions); |
|
|
|
observations.clear(); |
|
|
|
observationsAfterSwitch.clear(); |
|
|
|
coveredStates.clear(); |
|
|
|
coveredStates = targetStates; |
|
|
|
coveredStatesAfterSwitch.clear(); |
|
|
|
observationUpdated.clear(); |
|
|
|
if (!allOfTheseAssumption.empty()) { |
|
|
@ -540,20 +579,23 @@ namespace storm { |
|
|
|
obs++; |
|
|
|
} |
|
|
|
|
|
|
|
// for(uint64_t i : targetStates) {
|
|
|
|
// assert(model->getBooleanValue(reachVars[i]));
|
|
|
|
// }
|
|
|
|
|
|
|
|
uncoveredStates = ~coveredStates; |
|
|
|
for (uint64_t i : uncoveredStates) { |
|
|
|
auto const& rv =reachVars[i]; |
|
|
|
auto const& rvExpr =reachVarExpressions[i]; |
|
|
|
if (model->getBooleanValue(rv)) { |
|
|
|
if (observationUpdated.get(pomdp.getObservation(i)) && model->getBooleanValue(rv)) { |
|
|
|
STORM_LOG_TRACE("New state: " << i); |
|
|
|
smtSolver->add(rvExpr); |
|
|
|
assert(!surelyReachSinkStates.get(i)); |
|
|
|
newObservations.set(pomdp.getObservation(i)); |
|
|
|
coveredStates.set(i); |
|
|
|
if(lookaheadConstraintsRequired) { |
|
|
|
if (options.pathVariableType == MemlessSearchPathVariables::IntegerRanking) { |
|
|
|
smtSolver->add(pathVarExpressions[i][0] == expressionManager->integer(model->getIntegerValue(pathVars[i][0]))); |
|
|
|
} else if(options.pathVariableType == MemlessSearchPathVariables::RealRanking) { |
|
|
|
smtSolver->add(pathVarExpressions[i][0] == expressionManager->rational(model->getRationalValue(pathVars[i][0]))); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@ -594,6 +636,11 @@ namespace storm { |
|
|
|
} else { |
|
|
|
smtSolver->add(!switchVarExpressions[obs]); |
|
|
|
} |
|
|
|
if (model->getBooleanValue(followVars[obs])) { |
|
|
|
smtSolver->add(followVarExpressions[obs]); |
|
|
|
} else { |
|
|
|
smtSolver->add(!followVarExpressions[obs]); |
|
|
|
} |
|
|
|
} |
|
|
|
for (auto obs : newObservationsAfterSwitch) { |
|
|
|
observationsAfterSwitch.set(obs); |
|
|
@ -724,7 +771,7 @@ namespace storm { |
|
|
|
|
|
|
|
if (observationsWithPartialWinners.getNumberOfSetBits() > 0) { |
|
|
|
reset(); |
|
|
|
return analyze(k, ~targetStates & ~surelyReachSinkStates); |
|
|
|
return analyze(k, ~targetStates & ~surelyReachSinkStates, allOfTheseStates); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@ -737,6 +784,10 @@ namespace storm { |
|
|
|
STORM_LOG_WARN("Validating every step, for debug purposes only!"); |
|
|
|
validator->validate(surelyReachSinkStates); |
|
|
|
} |
|
|
|
if (stats.getIterations() % options.restartAfterNIterations == options.restartAfterNIterations-1) { |
|
|
|
reset(); |
|
|
|
return analyze(k, ~targetStates & ~surelyReachSinkStates, allOfTheseStates); |
|
|
|
} |
|
|
|
stats.updateNewStrategySolverTime.start(); |
|
|
|
for(uint64_t observation : updated) { |
|
|
|
updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]); |
|
|
@ -749,15 +800,27 @@ namespace storm { |
|
|
|
assert(schedulerForObs.size() > obs); |
|
|
|
(schedulerForObs[obs])++; |
|
|
|
STORM_LOG_DEBUG("We now have " << schedulerForObs[obs] << " policies for states with observation " << obs); |
|
|
|
|
|
|
|
for (auto const &state : statesForObservation) { |
|
|
|
if (!coveredStates.get(state)) { |
|
|
|
auto constant = expressionManager->integer(schedulerForObs[obs]); |
|
|
|
// PAPER COMMENT 14:
|
|
|
|
smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); |
|
|
|
smtSolver->add(!(reachVarExpressions[state] && followVarExpressions[pomdp.getObservation(state)] && (schedulerVariableExpressions[obs] == constant))); |
|
|
|
if (winningRegion.observationIsWinning(obs)) { |
|
|
|
for (auto const &state : statesForObservation) { |
|
|
|
smtSolver->add(reachVarExpressions[state]); |
|
|
|
} |
|
|
|
auto constant = expressionManager->integer(schedulerForObs[obs]); |
|
|
|
smtSolver->add(schedulerVariableExpressions[obs] == constant); |
|
|
|
} else { |
|
|
|
auto constant = expressionManager->integer(schedulerForObs[obs]); |
|
|
|
for (auto const &state : statesForObservation) { |
|
|
|
if (!coveredStates.get(state)) { |
|
|
|
|
|
|
|
// PAPER COMMENT 14:
|
|
|
|
smtSolver->add(!(continuationVarExpressions[state] && |
|
|
|
(schedulerVariableExpressions[obs] == constant))); |
|
|
|
smtSolver->add(!(reachVarExpressions[state] && |
|
|
|
followVarExpressions[pomdp.getObservation(state)] && |
|
|
|
(schedulerVariableExpressions[obs] == constant))); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
++obs; |
|
|
|
} |
|
|
@ -766,11 +829,21 @@ namespace storm { |
|
|
|
smtSolver->push(); |
|
|
|
|
|
|
|
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { |
|
|
|
auto constant = expressionManager->integer(schedulerForObs[obs]); |
|
|
|
// PAPER COMMENT 13
|
|
|
|
smtSolver->add(schedulerVariableExpressions[obs] <= constant); |
|
|
|
// PAPER COMMENT 12
|
|
|
|
smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs])); |
|
|
|
if(winningRegion.observationIsWinning(obs)) { |
|
|
|
auto constant = expressionManager->integer(schedulerForObs[obs]); |
|
|
|
// PAPER COMMENT 13
|
|
|
|
// Scheduler variable is already fixed.
|
|
|
|
// PAPER COMMENT 12
|
|
|
|
// Observation will not be updated.
|
|
|
|
smtSolver->add(!observationUpdatedExpressions[obs]); |
|
|
|
} else { |
|
|
|
auto constant = expressionManager->integer(schedulerForObs[obs]); |
|
|
|
// PAPER COMMENT 13
|
|
|
|
smtSolver->add(schedulerVariableExpressions[obs] <= constant); |
|
|
|
// PAPER COMMENT 12
|
|
|
|
smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], |
|
|
|
updateForObservationExpressions[obs])); |
|
|
|
} |
|
|
|
} |
|
|
|
stats.updateNewStrategySolverTime.stop(); |
|
|
|
|
|
|
|