Browse Source

various changes to allow restarting and more finegrained selection of switch-and-finish-with-policy

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
b2e7c5d5ed
  1. 90
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp
  2. 2
      src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

90
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp

@ -171,11 +171,11 @@ namespace storm {
obs = 0; obs = 0;
for (auto const& actionVars : actionSelectionVarExpressions) { for (auto const& actionVars : actionSelectionVarExpressions) {
std::vector<storm::expressions::Expression> actExprs = actionVars; std::vector<storm::expressions::Expression> actExprs = actionVars;
//actExprs.push_back(followVarExpressions[obs]);
actExprs.push_back(followVarExpressions[obs]);
smtSolver->add(storm::expressions::disjunction(actExprs)); smtSolver->add(storm::expressions::disjunction(actExprs));
//for (auto const& av : actionVars) {
// smtSolver->add(!followVarExpressions[obs] || !av);
//}
for (auto const& av : actionVars) {
smtSolver->add(!followVarExpressions[obs] || !av);
}
++obs; ++obs;
} }
@ -191,7 +191,7 @@ namespace storm {
if (targetStates.get(state)) { if (targetStates.get(state)) {
smtSolver->add(pathVars[state][0]); smtSolver->add(pathVars[state][0]);
} else { } else {
smtSolver->add(!pathVars[state][0]);
smtSolver->add(!pathVars[state][0] || followVarExpressions[pomdp.getObservation(state)]);
} }
} }
} }
@ -213,7 +213,11 @@ namespace storm {
subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]);
subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]); subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]);
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn()));
if (pomdp.getObservation(entries.getColumn() == pomdp.getObservation(state))) {
subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn()));
} else {
subexprreachSwitch.push_back(reachVarExpressions.at(entries.getColumn()));
}
smtSolver->add(storm::expressions::disjunction(subexprreachSwitch)); smtSolver->add(storm::expressions::disjunction(subexprreachSwitch));
subexprreachSwitch.pop_back(); subexprreachSwitch.pop_back();
subexprreachNoSwitch.push_back(reachVarExpressions.at(entries.getColumn())); subexprreachNoSwitch.push_back(reachVarExpressions.at(entries.getColumn()));
@ -270,10 +274,14 @@ namespace storm {
pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action]));
} }
pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state))); pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state)));
pathsubexprs.push_back(followVarExpressions[pomdp.getObservation(state)]);
smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs)));
} }
} }
} else { } else {
for (uint64_t j = 1; j < k; ++j) {
smtSolver->add(pathVars[state][j]);
}
rowindex += pomdp.getNumberOfChoices(state); rowindex += pomdp.getNumberOfChoices(state);
} }
} }
@ -284,6 +292,7 @@ namespace storm {
for(auto const& state : statesForObservation) { for(auto const& state : statesForObservation) {
if (!targetStates.get(state)) { if (!targetStates.get(state)) {
smtSolver->add(!continuationVars[state] || schedulerVariableExpressions[obs] > 0); smtSolver->add(!continuationVars[state] || schedulerVariableExpressions[obs] > 0);
smtSolver->add(!reachVarExpressions[state] || !followVarExpressions[obs] || schedulerVariableExpressions[obs] > 0);
} }
} }
++obs; ++obs;
@ -335,9 +344,10 @@ namespace storm {
STORM_LOG_ASSERT(reachVarExpressions.size() > state, "state id " << state << " exceeds number of states (" << reachVarExpressions.size() << ")" ); STORM_LOG_ASSERT(reachVarExpressions.size() > state, "state id " << state << " exceeds number of states (" << reachVarExpressions.size() << ")" );
atLeastOneOfStates.push_back(reachVarExpressions[state]); atLeastOneOfStates.push_back(reachVarExpressions[state]);
} }
assert(atLeastOneOfStates.size() > 0);
// PAPER COMMENT 11 // PAPER COMMENT 11
smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates));
if (!atLeastOneOfStates.empty()) {
smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates));
}
smtSolver->push(); smtSolver->push();
std::set<storm::expressions::Expression> allOfTheseAssumption; std::set<storm::expressions::Expression> allOfTheseAssumption;
@ -349,18 +359,54 @@ namespace storm {
allOfTheseAssumption.insert(reachVarExpressions[state]); allOfTheseAssumption.insert(reachVarExpressions[state]);
} }
for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) {
updateForObservationExpressions.push_back(storm::expressions::disjunction(reachVarExpressionsPerObservation[ob]));
schedulerForObs.push_back(std::vector<uint64_t>());
if (winningRegion.empty()) {
// Keep it simple here to help bughunting if necessary.
for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) {
updateForObservationExpressions.push_back(
storm::expressions::disjunction(reachVarExpressionsPerObservation[ob]));
schedulerForObs.push_back(0);
}
} else {
uint64_t obs = 0;
for (auto const &statesForObservation : statesPerObservation) {
schedulerForObs.push_back(0);
for (auto const& winningSet : winningRegion.getWinningSetsPerObservation(obs)) {
for (auto const &stateOffset : ~winningSet) {
uint64_t state = statesForObservation[stateOffset];
assert(obs < schedulerForObs.size());
++(schedulerForObs[obs]);
auto constant = expressionManager->integer(schedulerForObs[obs]);
// PAPER COMMENT 14:
smtSolver->add(!(continuationVarExpressions[state] &&
(schedulerVariableExpressions[obs] == constant)));
smtSolver->add(!(reachVarExpressions[state] &&
followVarExpressions[pomdp.getObservation(state)] &&
(schedulerVariableExpressions[obs] == constant)));
}
}
if (winningRegion.getWinningSetsPerObservation(obs).empty()) {
updateForObservationExpressions.push_back(
storm::expressions::disjunction(reachVarExpressionsPerObservation[obs]));
} else {
updateForObservationExpressions.push_back(winningRegion.extensionExpression(obs, reachVarExpressionsPerObservation[obs]));
}
++obs;
}
} }
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
auto constant = expressionManager->integer(schedulerForObs[obs].size());
auto constant = expressionManager->integer(schedulerForObs[obs]);
smtSolver->add(schedulerVariableExpressions[obs] <= constant); smtSolver->add(schedulerVariableExpressions[obs] <= constant);
smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs])); smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs]));
} }
assert(pomdp.getNrObservations() == schedulerForObs.size());
InternalObservationScheduler scheduler; InternalObservationScheduler scheduler;
@ -549,9 +595,9 @@ namespace storm {
storm::analysis::QualitativeAnalysisOnGraphs<ValueType> graphanalysis(pomdp); storm::analysis::QualitativeAnalysisOnGraphs<ValueType> graphanalysis(pomdp);
uint64_t targetStatesBefore = targetStates.getNumberOfSetBits(); uint64_t targetStatesBefore = targetStates.getNumberOfSetBits();
STORM_LOG_INFO("Target states before graph based analysis " << targetStates.getNumberOfSetBits()); STORM_LOG_INFO("Target states before graph based analysis " << targetStates.getNumberOfSetBits());
targetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates);
uint64_t targetStatesAfter = targetStates.getNumberOfSetBits();
STORM_LOG_INFO("Target states after graph based analysis " << targetStates.getNumberOfSetBits());
storm::storage::BitVector newtargetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates);
uint64_t targetStatesAfter = newtargetStates.getNumberOfSetBits();
STORM_LOG_INFO("Target states after graph based analysis " << newtargetStates.getNumberOfSetBits());
if (targetStatesAfter - targetStatesBefore > 0) { if (targetStatesAfter - targetStatesBefore > 0) {
stats.winningRegionUpdatesTimer.start(); stats.winningRegionUpdatesTimer.start();
@ -561,7 +607,7 @@ namespace storm {
} }
bool observationIsWinning = true; bool observationIsWinning = true;
for (uint64_t state : statesPerObservation[observation]) { for (uint64_t state : statesPerObservation[observation]) {
if(!targetStates.get(state)) {
if(!newtargetStates.get(state)) {
observationIsWinning = false; observationIsWinning = false;
break; break;
} }
@ -569,6 +615,9 @@ namespace storm {
if(observationIsWinning) { if(observationIsWinning) {
stats.incrementGraphBasedWinningObservations(); stats.incrementGraphBasedWinningObservations();
winningRegion.setObservationIsWinning(observation); winningRegion.setObservationIsWinning(observation);
for(auto const& state : statesPerObservation[observation]) {
targetStates.set(state);
}
updated.set(observation); updated.set(observation);
} }
} }
@ -608,14 +657,15 @@ namespace storm {
if (observations.get(obs) && updated.get(obs)) { if (observations.get(obs) && updated.get(obs)) {
STORM_LOG_DEBUG("We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << "."); STORM_LOG_DEBUG("We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << ".");
assert(schedulerForObs.size() > obs); assert(schedulerForObs.size() > obs);
schedulerForObs[obs].push_back(finalSchedulers.size());
STORM_LOG_DEBUG("We now have " << schedulerForObs[obs].size() << " policies for states with observation " << obs);
(schedulerForObs[obs])++;
STORM_LOG_DEBUG("We now have " << schedulerForObs[obs] << " policies for states with observation " << obs);
for (auto const &state : statesForObservation) { for (auto const &state : statesForObservation) {
if (!coveredStates.get(state)) { if (!coveredStates.get(state)) {
auto constant = expressionManager->integer(schedulerForObs[obs].size());
auto constant = expressionManager->integer(schedulerForObs[obs]);
// PAPER COMMENT 14: // PAPER COMMENT 14:
smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant)));
smtSolver->add(!(reachVarExpressions[state] && followVarExpressions[pomdp.getObservation(state)] && (schedulerVariableExpressions[obs] == constant)));
} }
} }
} }
@ -626,7 +676,7 @@ namespace storm {
smtSolver->push(); smtSolver->push();
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
auto constant = expressionManager->integer(schedulerForObs[obs].size());
auto constant = expressionManager->integer(schedulerForObs[obs]);
// PAPER COMMENT 13 // PAPER COMMENT 13
smtSolver->add(schedulerVariableExpressions[obs] <= constant); smtSolver->add(schedulerVariableExpressions[obs] <= constant);
// PAPER COMMENT 12 // PAPER COMMENT 12

2
src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h

@ -231,7 +231,7 @@ namespace pomdp {
std::vector<std::vector<storm::expressions::Expression>> pathVars; std::vector<std::vector<storm::expressions::Expression>> pathVars;
std::vector<InternalObservationScheduler> finalSchedulers; std::vector<InternalObservationScheduler> finalSchedulers;
std::vector<std::vector<uint64_t>> schedulerForObs;
std::vector<uint64_t> schedulerForObs;
WinningRegion winningRegion; WinningRegion winningRegion;
MemlessSearchOptions options; MemlessSearchOptions options;

Loading…
Cancel
Save