Browse Source

major improvements by introducing real-valued ranking and various related fixes

main
Sebastian Junges 5 years ago
parent
commit
53800c2145
  1. 239
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  2. 25
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

239
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -83,8 +83,9 @@ namespace storm {
} }
template <typename ValueType> template <typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::initialize(uint64_t k) { bool MemlessStrategySearchQualitative<ValueType>::initialize(uint64_t k) {
STORM_LOG_INFO("Start intializing solver..."); STORM_LOG_INFO("Start intializing solver...");
bool lookaheadConstraintsRequired; bool lookaheadConstraintsRequired;
if (options.forceLookahead) { if (options.forceLookahead) {
lookaheadConstraintsRequired = true; lookaheadConstraintsRequired = true;
@ -131,8 +132,10 @@ namespace storm {
} }
for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) {
pathVars.push_back(std::vector<storm::expressions::Expression>()); pathVars.push_back(std::vector<storm::expressions::Variable>());
pathVarExpressions.push_back(std::vector<storm::expressions::Expression>());
} }
} }
uint64_t initK = 0; uint64_t initK = 0;
@ -142,15 +145,25 @@ namespace storm {
if (initK < k) { if (initK < k) {
for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) {
if (lookaheadConstraintsRequired) { if (lookaheadConstraintsRequired) {
for (uint64_t i = initK; i < k; ++i) { if(options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) {
pathVars[stateId].push_back(expressionManager->declareBooleanVariable( for (uint64_t i = initK; i < k; ++i) {
"P-" + std::to_string(stateId) + "-" + std::to_string(i)).getExpression()); pathVars[stateId].push_back(expressionManager->declareBooleanVariable(
"P-" + std::to_string(stateId) + "-" + std::to_string(i)));
pathVarExpressions[stateId].push_back(pathVars[stateId].back().getExpression());
}
} else if (options.pathVariableType == MemlessSearchPathVariables::IntegerRanking) {
pathVars[stateId].push_back(expressionManager->declareIntegerVariable("P-" + std::to_string(stateId)));
pathVarExpressions[stateId].push_back(pathVars[stateId].back().getExpression());
} else {
assert(options.pathVariableType == MemlessSearchPathVariables::RealRanking);
pathVars[stateId].push_back(expressionManager->declareRationalVariable("P-" + std::to_string(stateId)));
pathVarExpressions[stateId].push_back(pathVars[stateId].back().getExpression());
} }
} }
} }
} }
assert(!lookaheadConstraintsRequired || pathVars.size() == pomdp.getNumberOfStates()); assert(!lookaheadConstraintsRequired || pathVarExpressions.size() == pomdp.getNumberOfStates());
assert(reachVars.size() == pomdp.getNumberOfStates()); assert(reachVars.size() == pomdp.getNumberOfStates());
assert(reachVarExpressions.size() == pomdp.getNumberOfStates()); assert(reachVarExpressions.size() == pomdp.getNumberOfStates());
@ -193,13 +206,23 @@ namespace storm {
// PAPER COMMENT: 3 // PAPER COMMENT: 3
if (lookaheadConstraintsRequired) { if (lookaheadConstraintsRequired) {
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { if (options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) {
if (targetStates.get(state)) { for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
smtSolver->add(pathVars[state][0]); if (targetStates.get(state)) {
} else { smtSolver->add(pathVarExpressions[state][0]);
smtSolver->add(!pathVars[state][0] || followVarExpressions[pomdp.getObservation(state)]); } else {
smtSolver->add(!pathVarExpressions[state][0] || followVarExpressions[pomdp.getObservation(state)]);
}
} }
} else {
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
smtSolver->add(pathVarExpressions[state][0] <= expressionManager->integer(k));
smtSolver->add(pathVarExpressions[state][0] >= expressionManager->integer(0));
}
//assert(false);
} }
} }
// PAPER COMMENT: 4 // PAPER COMMENT: 4
@ -218,6 +241,8 @@ namespace storm {
subexprreachNoSwitch.push_back(!reachVarExpressions[state]); subexprreachNoSwitch.push_back(!reachVarExpressions[state]);
subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]);
subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]); subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]);
subexprreachSwitch.push_back(followVarExpressions[pomdp.getObservation(state)]);
subexprreachNoSwitch.push_back(followVarExpressions[pomdp.getObservation(state)]);
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
if (pomdp.getObservation(entries.getColumn() == pomdp.getObservation(state))) { if (pomdp.getObservation(entries.getColumn() == pomdp.getObservation(state))) {
subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn())); subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn()));
@ -244,50 +269,85 @@ namespace storm {
smtSolver->add(!reachVarExpressions[state]); smtSolver->add(!reachVarExpressions[state]);
smtSolver->add(!continuationVarExpressions[state]); smtSolver->add(!continuationVarExpressions[state]);
if (lookaheadConstraintsRequired) { if (lookaheadConstraintsRequired) {
for (uint64_t j = 1; j < k; ++j) { if(options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) {
smtSolver->add(!pathVars[state][j]); for (uint64_t j = 1; j < k; ++j) {
smtSolver->add(!pathVarExpressions[state][j]);
}
} else {
smtSolver->add(pathVarExpressions[state][0] == expressionManager->integer(k));
} }
} }
rowindex += pomdp.getNumberOfChoices(state); rowindex += pomdp.getNumberOfChoices(state);
} else if(!targetStates.get(state)) { } else if(!targetStates.get(state)) {
if (lookaheadConstraintsRequired) { if (lookaheadConstraintsRequired) {
// PAPER COMMENT 6
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back()));
// PAPER COMMENT 7 if(options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) {
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs; // PAPER COMMENT 6
for (uint64_t j = 1; j < k; ++j) { smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state),
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>()); pathVarExpressions.at(state).back()));
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { // PAPER COMMENT 7
pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>()); std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>());
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>());
}
} }
}
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
std::vector<storm::expressions::Expression> subexprreach; std::vector<storm::expressions::Expression> subexprreach;
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
for (uint64_t j = 1; j < k; ++j) { for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]); pathsubsubexprs[j - 1][action].push_back(pathVarExpressions[entries.getColumn()][j - 1]);
}
} }
rowindex++;
} }
rowindex++;
}
for (uint64_t j = 1; j < k; ++j) { for (uint64_t j = 1; j < k; ++j) {
std::vector<storm::expressions::Expression> pathsubexprs; std::vector<storm::expressions::Expression> pathsubexprs;
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action]));
}
pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state)));
pathsubexprs.push_back(followVarExpressions[pomdp.getObservation(state)]);
smtSolver->add(storm::expressions::iff(pathVarExpressions[state][j],
storm::expressions::disjunction(pathsubexprs)));
}
} else {
std::vector<storm::expressions::Expression> actPathDisjunction;
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); std::vector<storm::expressions::Expression> pathDisjunction;
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
pathDisjunction.push_back(pathVarExpressions[entries.getColumn()][0] < pathVarExpressions[state][0]);
}
actPathDisjunction.push_back(storm::expressions::disjunction(pathDisjunction) && actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action));
rowindex++;
} }
pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state))); // TODO reconsider if this next add is sound
pathsubexprs.push_back(followVarExpressions[pomdp.getObservation(state)]); actPathDisjunction.push_back(switchVarExpressions.at(pomdp.getObservation(state)));
smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); actPathDisjunction.push_back(followVarExpressions[pomdp.getObservation(state)]);
actPathDisjunction.push_back(!reachVarExpressions[state]);
smtSolver->add(storm::expressions::disjunction(actPathDisjunction));
} }
} }
} else { } else {
for (uint64_t j = 1; j < k; ++j) { if (lookaheadConstraintsRequired) {
smtSolver->add(pathVars[state][j]); if (options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) {
for (uint64_t j = 1; j < k; ++j) {
smtSolver->add(pathVarExpressions[state][j]);
}
} else {
smtSolver->add(pathVarExpressions[state][0] == expressionManager->integer(0));
//assert(false);
}
} }
smtSolver->add(reachVars[state]);
rowindex += pomdp.getNumberOfChoices(state); rowindex += pomdp.getNumberOfChoices(state);
} }
} }
@ -308,29 +368,7 @@ namespace storm {
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
smtSolver->add(storm::expressions::implies(switchVarExpressions[obs], storm::expressions::disjunction(reachVarExpressionsPerObservation[obs]))); smtSolver->add(storm::expressions::implies(switchVarExpressions[obs], storm::expressions::disjunction(reachVarExpressionsPerObservation[obs])));
} }
// PAPER COMMENT 10 return lookaheadConstraintsRequired;
// if (!lookaheadConstraintsRequired) {
// uint64_t rowIndex = 0;
// for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
// uint64_t enabledActions = pomdp.getNumberOfChoices(state);
// if (!surelyReachSinkStates.get(state)) {
// std::vector<storm::expressions::Expression> successorVars;
// for (uint64_t act = 0; act < enabledActions; ++act) {
// for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowIndex)) {
// successorVars.push_back(reachVarExpressions[entries.getColumn()]);
// }
// rowIndex++;
// }
// successorVars.push_back(!switchVars[pomdp.getObservation(state)]);
// smtSolver->add(storm::expressions::implies(storm::expressions::conjunction(successorVars), reachVarExpressions[state]));
// } else {
// rowIndex += enabledActions;
// }
// }
// } else {
// STORM_LOG_WARN("Some optimization not implemented yet.");
// }
// TODO: Update found schedulers if k is increased.
} }
template<typename ValueType> template<typename ValueType>
@ -357,9 +395,10 @@ namespace storm {
STORM_LOG_DEBUG("Questionmark states " << (~surelyReachSinkStates & ~targetStates)); STORM_LOG_DEBUG("Questionmark states " << (~surelyReachSinkStates & ~targetStates));
stats.initializeSolverTimer.start(); stats.initializeSolverTimer.start();
// TODO: When do we need to reinitialize? When the solver has been reset. // TODO: When do we need to reinitialize? When the solver has been reset.
initialize(k); bool lookaheadConstraintsRequired = initialize(k);
maxK = k; if(lookaheadConstraintsRequired) {
maxK = k;
}
stats.winningRegionUpdatesTimer.start(); stats.winningRegionUpdatesTimer.start();
storm::storage::BitVector updated(pomdp.getNrObservations()); storm::storage::BitVector updated(pomdp.getNrObservations());
@ -500,11 +539,11 @@ namespace storm {
uint64_t iterations = 0; uint64_t iterations = 0;
while(true) { while(true) {
stats.incrementOuterIterations(); stats.incrementOuterIterations();
// TODO consider what we really want to store about the schedulers.
scheduler.reset(pomdp.getNrObservations(), maximalNrActions); scheduler.reset(pomdp.getNrObservations(), maximalNrActions);
observations.clear(); observations.clear();
observationsAfterSwitch.clear(); observationsAfterSwitch.clear();
coveredStates.clear(); coveredStates = targetStates;
coveredStatesAfterSwitch.clear(); coveredStatesAfterSwitch.clear();
observationUpdated.clear(); observationUpdated.clear();
if (!allOfTheseAssumption.empty()) { if (!allOfTheseAssumption.empty()) {
@ -540,20 +579,23 @@ namespace storm {
obs++; obs++;
} }
// for(uint64_t i : targetStates) {
// assert(model->getBooleanValue(reachVars[i]));
// }
uncoveredStates = ~coveredStates; uncoveredStates = ~coveredStates;
for (uint64_t i : uncoveredStates) { for (uint64_t i : uncoveredStates) {
auto const& rv =reachVars[i]; auto const& rv =reachVars[i];
auto const& rvExpr =reachVarExpressions[i]; auto const& rvExpr =reachVarExpressions[i];
if (model->getBooleanValue(rv)) { if (observationUpdated.get(pomdp.getObservation(i)) && model->getBooleanValue(rv)) {
STORM_LOG_TRACE("New state: " << i); STORM_LOG_TRACE("New state: " << i);
smtSolver->add(rvExpr); smtSolver->add(rvExpr);
assert(!surelyReachSinkStates.get(i)); assert(!surelyReachSinkStates.get(i));
newObservations.set(pomdp.getObservation(i)); newObservations.set(pomdp.getObservation(i));
coveredStates.set(i); coveredStates.set(i);
if(lookaheadConstraintsRequired) {
if (options.pathVariableType == MemlessSearchPathVariables::IntegerRanking) {
smtSolver->add(pathVarExpressions[i][0] == expressionManager->integer(model->getIntegerValue(pathVars[i][0])));
} else if(options.pathVariableType == MemlessSearchPathVariables::RealRanking) {
smtSolver->add(pathVarExpressions[i][0] == expressionManager->rational(model->getRationalValue(pathVars[i][0])));
}
}
} }
} }
@ -594,6 +636,11 @@ namespace storm {
} else { } else {
smtSolver->add(!switchVarExpressions[obs]); smtSolver->add(!switchVarExpressions[obs]);
} }
if (model->getBooleanValue(followVars[obs])) {
smtSolver->add(followVarExpressions[obs]);
} else {
smtSolver->add(!followVarExpressions[obs]);
}
} }
for (auto obs : newObservationsAfterSwitch) { for (auto obs : newObservationsAfterSwitch) {
observationsAfterSwitch.set(obs); observationsAfterSwitch.set(obs);
@ -724,7 +771,7 @@ namespace storm {
if (observationsWithPartialWinners.getNumberOfSetBits() > 0) { if (observationsWithPartialWinners.getNumberOfSetBits() > 0) {
reset(); reset();
return analyze(k, ~targetStates & ~surelyReachSinkStates); return analyze(k, ~targetStates & ~surelyReachSinkStates, allOfTheseStates);
} }
} }
@ -737,6 +784,10 @@ namespace storm {
STORM_LOG_WARN("Validating every step, for debug purposes only!"); STORM_LOG_WARN("Validating every step, for debug purposes only!");
validator->validate(surelyReachSinkStates); validator->validate(surelyReachSinkStates);
} }
if (stats.getIterations() % options.restartAfterNIterations == options.restartAfterNIterations-1) {
reset();
return analyze(k, ~targetStates & ~surelyReachSinkStates, allOfTheseStates);
}
stats.updateNewStrategySolverTime.start(); stats.updateNewStrategySolverTime.start();
for(uint64_t observation : updated) { for(uint64_t observation : updated) {
updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]); updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]);
@ -749,15 +800,27 @@ namespace storm {
assert(schedulerForObs.size() > obs); assert(schedulerForObs.size() > obs);
(schedulerForObs[obs])++; (schedulerForObs[obs])++;
STORM_LOG_DEBUG("We now have " << schedulerForObs[obs] << " policies for states with observation " << obs); STORM_LOG_DEBUG("We now have " << schedulerForObs[obs] << " policies for states with observation " << obs);
if (winningRegion.observationIsWinning(obs)) {
for (auto const &state : statesForObservation) { for (auto const &state : statesForObservation) {
if (!coveredStates.get(state)) { smtSolver->add(reachVarExpressions[state]);
auto constant = expressionManager->integer(schedulerForObs[obs]); }
// PAPER COMMENT 14: auto constant = expressionManager->integer(schedulerForObs[obs]);
smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); smtSolver->add(schedulerVariableExpressions[obs] == constant);
smtSolver->add(!(reachVarExpressions[state] && followVarExpressions[pomdp.getObservation(state)] && (schedulerVariableExpressions[obs] == constant))); } else {
auto constant = expressionManager->integer(schedulerForObs[obs]);
for (auto const &state : statesForObservation) {
if (!coveredStates.get(state)) {
// PAPER COMMENT 14:
smtSolver->add(!(continuationVarExpressions[state] &&
(schedulerVariableExpressions[obs] == constant)));
smtSolver->add(!(reachVarExpressions[state] &&
followVarExpressions[pomdp.getObservation(state)] &&
(schedulerVariableExpressions[obs] == constant)));
}
} }
} }
} }
++obs; ++obs;
} }
@ -766,11 +829,21 @@ namespace storm {
smtSolver->push(); smtSolver->push();
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
auto constant = expressionManager->integer(schedulerForObs[obs]); if(winningRegion.observationIsWinning(obs)) {
// PAPER COMMENT 13 auto constant = expressionManager->integer(schedulerForObs[obs]);
smtSolver->add(schedulerVariableExpressions[obs] <= constant); // PAPER COMMENT 13
// PAPER COMMENT 12 // Scheduler variable is already fixed.
smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs])); // PAPER COMMENT 12
// Observation will not be updated.
smtSolver->add(!observationUpdatedExpressions[obs]);
} else {
auto constant = expressionManager->integer(schedulerForObs[obs]);
// PAPER COMMENT 13
smtSolver->add(schedulerVariableExpressions[obs] <= constant);
// PAPER COMMENT 12
smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs],
updateForObservationExpressions[obs]));
}
} }
stats.updateNewStrategySolverTime.stop(); stats.updateNewStrategySolverTime.stop();

25
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

@ -13,6 +13,20 @@
namespace storm { namespace storm {
namespace pomdp { namespace pomdp {
enum class MemlessSearchPathVariables {
BooleanRanking, IntegerRanking, RealRanking
};
MemlessSearchPathVariables pathVariableTypeFromString(std::string const& in) {
if(in == "int") {
return MemlessSearchPathVariables::IntegerRanking;
} else if (in == "real") {
return MemlessSearchPathVariables::RealRanking;
} else {
assert(in == "bool");
return MemlessSearchPathVariables::BooleanRanking;
}
}
class MemlessSearchOptions { class MemlessSearchOptions {
public: public:
@ -45,9 +59,13 @@ namespace pomdp {
} }
bool onlyDeterministicStrategies = false; bool onlyDeterministicStrategies = false;
bool forceLookahead = true; bool forceLookahead = false;
bool validateEveryStep = false; bool validateEveryStep = false;
bool validateResult = false; bool validateResult = false;
MemlessSearchPathVariables pathVariableType = MemlessSearchPathVariables::RealRanking;
uint64_t restartAfterNIterations = 250;
uint64_t extensionCallTimeout = 0u;
uint64_t localIterationMaximum = 600;
private: private:
std::string exportSATcalls = ""; std::string exportSATcalls = "";
@ -198,7 +216,7 @@ namespace pomdp {
void printScheduler(std::vector<InternalObservationScheduler> const& ); void printScheduler(std::vector<InternalObservationScheduler> const& );
void printCoveredStates(storm::storage::BitVector const& remaining) const; void printCoveredStates(storm::storage::BitVector const& remaining) const;
void initialize(uint64_t k); bool initialize(uint64_t k);
bool smtCheck(uint64_t iteration, std::set<storm::expressions::Expression> const& assumptions = {}); bool smtCheck(uint64_t iteration, std::set<storm::expressions::Expression> const& assumptions = {});
@ -230,7 +248,8 @@ namespace pomdp {
std::vector<storm::expressions::Expression> followVarExpressions; std::vector<storm::expressions::Expression> followVarExpressions;
std::vector<storm::expressions::Variable> continuationVars; std::vector<storm::expressions::Variable> continuationVars;
std::vector<storm::expressions::Expression> continuationVarExpressions; std::vector<storm::expressions::Expression> continuationVarExpressions;
std::vector<std::vector<storm::expressions::Expression>> pathVars; std::vector<std::vector<storm::expressions::Variable>> pathVars;
std::vector<std::vector<storm::expressions::Expression>> pathVarExpressions;
std::vector<InternalObservationScheduler> finalSchedulers; std::vector<InternalObservationScheduler> finalSchedulers;
std::vector<uint64_t> schedulerForObs; std::vector<uint64_t> schedulerForObs;

|||||||
100:0
Loading…
Cancel
Save