Browse Source

major improvements by introducing real-valued ranking and various related fixes

main
Sebastian Junges 5 years ago
parent
commit
53800c2145
  1. 239
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  2. 25
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

239
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -83,8 +83,9 @@ namespace storm {
}
template <typename ValueType>
void MemlessStrategySearchQualitative<ValueType>::initialize(uint64_t k) {
bool MemlessStrategySearchQualitative<ValueType>::initialize(uint64_t k) {
STORM_LOG_INFO("Start intializing solver...");
bool lookaheadConstraintsRequired;
if (options.forceLookahead) {
lookaheadConstraintsRequired = true;
@ -131,8 +132,10 @@ namespace storm {
}
for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) {
pathVars.push_back(std::vector<storm::expressions::Expression>());
pathVars.push_back(std::vector<storm::expressions::Variable>());
pathVarExpressions.push_back(std::vector<storm::expressions::Expression>());
}
}
uint64_t initK = 0;
@ -142,15 +145,25 @@ namespace storm {
if (initK < k) {
for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) {
if (lookaheadConstraintsRequired) {
for (uint64_t i = initK; i < k; ++i) {
pathVars[stateId].push_back(expressionManager->declareBooleanVariable(
"P-" + std::to_string(stateId) + "-" + std::to_string(i)).getExpression());
if(options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) {
for (uint64_t i = initK; i < k; ++i) {
pathVars[stateId].push_back(expressionManager->declareBooleanVariable(
"P-" + std::to_string(stateId) + "-" + std::to_string(i)));
pathVarExpressions[stateId].push_back(pathVars[stateId].back().getExpression());
}
} else if (options.pathVariableType == MemlessSearchPathVariables::IntegerRanking) {
pathVars[stateId].push_back(expressionManager->declareIntegerVariable("P-" + std::to_string(stateId)));
pathVarExpressions[stateId].push_back(pathVars[stateId].back().getExpression());
} else {
assert(options.pathVariableType == MemlessSearchPathVariables::RealRanking);
pathVars[stateId].push_back(expressionManager->declareRationalVariable("P-" + std::to_string(stateId)));
pathVarExpressions[stateId].push_back(pathVars[stateId].back().getExpression());
}
}
}
}
assert(!lookaheadConstraintsRequired || pathVars.size() == pomdp.getNumberOfStates());
assert(!lookaheadConstraintsRequired || pathVarExpressions.size() == pomdp.getNumberOfStates());
assert(reachVars.size() == pomdp.getNumberOfStates());
assert(reachVarExpressions.size() == pomdp.getNumberOfStates());
@ -193,13 +206,23 @@ namespace storm {
// PAPER COMMENT: 3
if (lookaheadConstraintsRequired) {
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (targetStates.get(state)) {
smtSolver->add(pathVars[state][0]);
} else {
smtSolver->add(!pathVars[state][0] || followVarExpressions[pomdp.getObservation(state)]);
if (options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) {
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (targetStates.get(state)) {
smtSolver->add(pathVarExpressions[state][0]);
} else {
smtSolver->add(!pathVarExpressions[state][0] || followVarExpressions[pomdp.getObservation(state)]);
}
}
} else {
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
smtSolver->add(pathVarExpressions[state][0] <= expressionManager->integer(k));
smtSolver->add(pathVarExpressions[state][0] >= expressionManager->integer(0));
}
//assert(false);
}
}
// PAPER COMMENT: 4
@ -218,6 +241,8 @@ namespace storm {
subexprreachNoSwitch.push_back(!reachVarExpressions[state]);
subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]);
subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]);
subexprreachSwitch.push_back(followVarExpressions[pomdp.getObservation(state)]);
subexprreachNoSwitch.push_back(followVarExpressions[pomdp.getObservation(state)]);
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
if (pomdp.getObservation(entries.getColumn() == pomdp.getObservation(state))) {
subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn()));
@ -244,50 +269,85 @@ namespace storm {
smtSolver->add(!reachVarExpressions[state]);
smtSolver->add(!continuationVarExpressions[state]);
if (lookaheadConstraintsRequired) {
for (uint64_t j = 1; j < k; ++j) {
smtSolver->add(!pathVars[state][j]);
if(options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) {
for (uint64_t j = 1; j < k; ++j) {
smtSolver->add(!pathVarExpressions[state][j]);
}
} else {
smtSolver->add(pathVarExpressions[state][0] == expressionManager->integer(k));
}
}
rowindex += pomdp.getNumberOfChoices(state);
} else if(!targetStates.get(state)) {
if (lookaheadConstraintsRequired) {
// PAPER COMMENT 6
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back()));
// PAPER COMMENT 7
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>());
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>());
if(options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) {
// PAPER COMMENT 6
smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state),
pathVarExpressions.at(state).back()));
// PAPER COMMENT 7
std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>());
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>());
}
}
}
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
std::vector<storm::expressions::Expression> subexprreach;
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]);
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
std::vector<storm::expressions::Expression> subexprreach;
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
for (uint64_t j = 1; j < k; ++j) {
pathsubsubexprs[j - 1][action].push_back(pathVarExpressions[entries.getColumn()][j - 1]);
}
}
rowindex++;
}
rowindex++;
}
for (uint64_t j = 1; j < k; ++j) {
std::vector<storm::expressions::Expression> pathsubexprs;
for (uint64_t j = 1; j < k; ++j) {
std::vector<storm::expressions::Expression> pathsubexprs;
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action]));
}
pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state)));
pathsubexprs.push_back(followVarExpressions[pomdp.getObservation(state)]);
smtSolver->add(storm::expressions::iff(pathVarExpressions[state][j],
storm::expressions::disjunction(pathsubexprs)));
}
} else {
std::vector<storm::expressions::Expression> actPathDisjunction;
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action]));
std::vector<storm::expressions::Expression> pathDisjunction;
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
pathDisjunction.push_back(pathVarExpressions[entries.getColumn()][0] < pathVarExpressions[state][0]);
}
actPathDisjunction.push_back(storm::expressions::disjunction(pathDisjunction) && actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action));
rowindex++;
}
pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state)));
pathsubexprs.push_back(followVarExpressions[pomdp.getObservation(state)]);
smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs)));
// TODO reconsider if this next add is sound
actPathDisjunction.push_back(switchVarExpressions.at(pomdp.getObservation(state)));
actPathDisjunction.push_back(followVarExpressions[pomdp.getObservation(state)]);
actPathDisjunction.push_back(!reachVarExpressions[state]);
smtSolver->add(storm::expressions::disjunction(actPathDisjunction));
}
}
} else {
for (uint64_t j = 1; j < k; ++j) {
smtSolver->add(pathVars[state][j]);
if (lookaheadConstraintsRequired) {
if (options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) {
for (uint64_t j = 1; j < k; ++j) {
smtSolver->add(pathVarExpressions[state][j]);
}
} else {
smtSolver->add(pathVarExpressions[state][0] == expressionManager->integer(0));
//assert(false);
}
}
smtSolver->add(reachVars[state]);
rowindex += pomdp.getNumberOfChoices(state);
}
}
@ -308,29 +368,7 @@ namespace storm {
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
smtSolver->add(storm::expressions::implies(switchVarExpressions[obs], storm::expressions::disjunction(reachVarExpressionsPerObservation[obs])));
}
// PAPER COMMENT 10
// if (!lookaheadConstraintsRequired) {
// uint64_t rowIndex = 0;
// for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
// uint64_t enabledActions = pomdp.getNumberOfChoices(state);
// if (!surelyReachSinkStates.get(state)) {
// std::vector<storm::expressions::Expression> successorVars;
// for (uint64_t act = 0; act < enabledActions; ++act) {
// for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowIndex)) {
// successorVars.push_back(reachVarExpressions[entries.getColumn()]);
// }
// rowIndex++;
// }
// successorVars.push_back(!switchVars[pomdp.getObservation(state)]);
// smtSolver->add(storm::expressions::implies(storm::expressions::conjunction(successorVars), reachVarExpressions[state]));
// } else {
// rowIndex += enabledActions;
// }
// }
// } else {
// STORM_LOG_WARN("Some optimization not implemented yet.");
// }
// TODO: Update found schedulers if k is increased.
return lookaheadConstraintsRequired;
}
template<typename ValueType>
@ -357,9 +395,10 @@ namespace storm {
STORM_LOG_DEBUG("Questionmark states " << (~surelyReachSinkStates & ~targetStates));
stats.initializeSolverTimer.start();
// TODO: When do we need to reinitialize? When the solver has been reset.
initialize(k);
maxK = k;
bool lookaheadConstraintsRequired = initialize(k);
if(lookaheadConstraintsRequired) {
maxK = k;
}
stats.winningRegionUpdatesTimer.start();
storm::storage::BitVector updated(pomdp.getNrObservations());
@ -500,11 +539,11 @@ namespace storm {
uint64_t iterations = 0;
while(true) {
stats.incrementOuterIterations();
// TODO consider what we really want to store about the schedulers.
scheduler.reset(pomdp.getNrObservations(), maximalNrActions);
observations.clear();
observationsAfterSwitch.clear();
coveredStates.clear();
coveredStates = targetStates;
coveredStatesAfterSwitch.clear();
observationUpdated.clear();
if (!allOfTheseAssumption.empty()) {
@ -540,20 +579,23 @@ namespace storm {
obs++;
}
// for(uint64_t i : targetStates) {
// assert(model->getBooleanValue(reachVars[i]));
// }
uncoveredStates = ~coveredStates;
for (uint64_t i : uncoveredStates) {
auto const& rv =reachVars[i];
auto const& rvExpr =reachVarExpressions[i];
if (model->getBooleanValue(rv)) {
if (observationUpdated.get(pomdp.getObservation(i)) && model->getBooleanValue(rv)) {
STORM_LOG_TRACE("New state: " << i);
smtSolver->add(rvExpr);
assert(!surelyReachSinkStates.get(i));
newObservations.set(pomdp.getObservation(i));
coveredStates.set(i);
if(lookaheadConstraintsRequired) {
if (options.pathVariableType == MemlessSearchPathVariables::IntegerRanking) {
smtSolver->add(pathVarExpressions[i][0] == expressionManager->integer(model->getIntegerValue(pathVars[i][0])));
} else if(options.pathVariableType == MemlessSearchPathVariables::RealRanking) {
smtSolver->add(pathVarExpressions[i][0] == expressionManager->rational(model->getRationalValue(pathVars[i][0])));
}
}
}
}
@ -594,6 +636,11 @@ namespace storm {
} else {
smtSolver->add(!switchVarExpressions[obs]);
}
if (model->getBooleanValue(followVars[obs])) {
smtSolver->add(followVarExpressions[obs]);
} else {
smtSolver->add(!followVarExpressions[obs]);
}
}
for (auto obs : newObservationsAfterSwitch) {
observationsAfterSwitch.set(obs);
@ -724,7 +771,7 @@ namespace storm {
if (observationsWithPartialWinners.getNumberOfSetBits() > 0) {
reset();
return analyze(k, ~targetStates & ~surelyReachSinkStates);
return analyze(k, ~targetStates & ~surelyReachSinkStates, allOfTheseStates);
}
}
@ -737,6 +784,10 @@ namespace storm {
STORM_LOG_WARN("Validating every step, for debug purposes only!");
validator->validate(surelyReachSinkStates);
}
if (stats.getIterations() % options.restartAfterNIterations == options.restartAfterNIterations-1) {
reset();
return analyze(k, ~targetStates & ~surelyReachSinkStates, allOfTheseStates);
}
stats.updateNewStrategySolverTime.start();
for(uint64_t observation : updated) {
updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]);
@ -749,15 +800,27 @@ namespace storm {
assert(schedulerForObs.size() > obs);
(schedulerForObs[obs])++;
STORM_LOG_DEBUG("We now have " << schedulerForObs[obs] << " policies for states with observation " << obs);
for (auto const &state : statesForObservation) {
if (!coveredStates.get(state)) {
auto constant = expressionManager->integer(schedulerForObs[obs]);
// PAPER COMMENT 14:
smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant)));
smtSolver->add(!(reachVarExpressions[state] && followVarExpressions[pomdp.getObservation(state)] && (schedulerVariableExpressions[obs] == constant)));
if (winningRegion.observationIsWinning(obs)) {
for (auto const &state : statesForObservation) {
smtSolver->add(reachVarExpressions[state]);
}
auto constant = expressionManager->integer(schedulerForObs[obs]);
smtSolver->add(schedulerVariableExpressions[obs] == constant);
} else {
auto constant = expressionManager->integer(schedulerForObs[obs]);
for (auto const &state : statesForObservation) {
if (!coveredStates.get(state)) {
// PAPER COMMENT 14:
smtSolver->add(!(continuationVarExpressions[state] &&
(schedulerVariableExpressions[obs] == constant)));
smtSolver->add(!(reachVarExpressions[state] &&
followVarExpressions[pomdp.getObservation(state)] &&
(schedulerVariableExpressions[obs] == constant)));
}
}
}
}
++obs;
}
@ -766,11 +829,21 @@ namespace storm {
smtSolver->push();
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
auto constant = expressionManager->integer(schedulerForObs[obs]);
// PAPER COMMENT 13
smtSolver->add(schedulerVariableExpressions[obs] <= constant);
// PAPER COMMENT 12
smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs]));
if(winningRegion.observationIsWinning(obs)) {
auto constant = expressionManager->integer(schedulerForObs[obs]);
// PAPER COMMENT 13
// Scheduler variable is already fixed.
// PAPER COMMENT 12
// Observation will not be updated.
smtSolver->add(!observationUpdatedExpressions[obs]);
} else {
auto constant = expressionManager->integer(schedulerForObs[obs]);
// PAPER COMMENT 13
smtSolver->add(schedulerVariableExpressions[obs] <= constant);
// PAPER COMMENT 12
smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs],
updateForObservationExpressions[obs]));
}
}
stats.updateNewStrategySolverTime.stop();

25
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

@ -13,6 +13,20 @@
namespace storm {
namespace pomdp {
enum class MemlessSearchPathVariables {
BooleanRanking, IntegerRanking, RealRanking
};
MemlessSearchPathVariables pathVariableTypeFromString(std::string const& in) {
if(in == "int") {
return MemlessSearchPathVariables::IntegerRanking;
} else if (in == "real") {
return MemlessSearchPathVariables::RealRanking;
} else {
assert(in == "bool");
return MemlessSearchPathVariables::BooleanRanking;
}
}
class MemlessSearchOptions {
public:
@ -45,9 +59,13 @@ namespace pomdp {
}
bool onlyDeterministicStrategies = false;
bool forceLookahead = true;
bool forceLookahead = false;
bool validateEveryStep = false;
bool validateResult = false;
MemlessSearchPathVariables pathVariableType = MemlessSearchPathVariables::RealRanking;
uint64_t restartAfterNIterations = 250;
uint64_t extensionCallTimeout = 0u;
uint64_t localIterationMaximum = 600;
private:
std::string exportSATcalls = "";
@ -198,7 +216,7 @@ namespace pomdp {
void printScheduler(std::vector<InternalObservationScheduler> const& );
void printCoveredStates(storm::storage::BitVector const& remaining) const;
void initialize(uint64_t k);
bool initialize(uint64_t k);
bool smtCheck(uint64_t iteration, std::set<storm::expressions::Expression> const& assumptions = {});
@ -230,7 +248,8 @@ namespace pomdp {
std::vector<storm::expressions::Expression> followVarExpressions;
std::vector<storm::expressions::Variable> continuationVars;
std::vector<storm::expressions::Expression> continuationVarExpressions;
std::vector<std::vector<storm::expressions::Expression>> pathVars;
std::vector<std::vector<storm::expressions::Variable>> pathVars;
std::vector<std::vector<storm::expressions::Expression>> pathVarExpressions;
std::vector<InternalObservationScheduler> finalSchedulers;
std::vector<uint64_t> schedulerForObs;

Loading…
Cancel
Save