|
|
@ -171,11 +171,11 @@ namespace storm { |
|
|
|
obs = 0; |
|
|
|
for (auto const& actionVars : actionSelectionVarExpressions) { |
|
|
|
std::vector<storm::expressions::Expression> actExprs = actionVars; |
|
|
|
//actExprs.push_back(followVarExpressions[obs]);
|
|
|
|
actExprs.push_back(followVarExpressions[obs]); |
|
|
|
smtSolver->add(storm::expressions::disjunction(actExprs)); |
|
|
|
//for (auto const& av : actionVars) {
|
|
|
|
// smtSolver->add(!followVarExpressions[obs] || !av);
|
|
|
|
//}
|
|
|
|
for (auto const& av : actionVars) { |
|
|
|
smtSolver->add(!followVarExpressions[obs] || !av); |
|
|
|
} |
|
|
|
++obs; |
|
|
|
} |
|
|
|
|
|
|
@ -191,7 +191,7 @@ namespace storm { |
|
|
|
if (targetStates.get(state)) { |
|
|
|
smtSolver->add(pathVars[state][0]); |
|
|
|
} else { |
|
|
|
smtSolver->add(!pathVars[state][0]); |
|
|
|
smtSolver->add(!pathVars[state][0] || followVarExpressions[pomdp.getObservation(state)]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
@ -213,7 +213,11 @@ namespace storm { |
|
|
|
subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); |
|
|
|
subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]); |
|
|
|
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { |
|
|
|
subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn())); |
|
|
|
if (pomdp.getObservation(entries.getColumn() == pomdp.getObservation(state))) { |
|
|
|
subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn())); |
|
|
|
} else { |
|
|
|
subexprreachSwitch.push_back(reachVarExpressions.at(entries.getColumn())); |
|
|
|
} |
|
|
|
smtSolver->add(storm::expressions::disjunction(subexprreachSwitch)); |
|
|
|
subexprreachSwitch.pop_back(); |
|
|
|
subexprreachNoSwitch.push_back(reachVarExpressions.at(entries.getColumn())); |
|
|
@ -270,10 +274,14 @@ namespace storm { |
|
|
|
pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) && storm::expressions::disjunction(pathsubsubexprs[j - 1][action])); |
|
|
|
} |
|
|
|
pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state))); |
|
|
|
pathsubexprs.push_back(followVarExpressions[pomdp.getObservation(state)]); |
|
|
|
smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs))); |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
for (uint64_t j = 1; j < k; ++j) { |
|
|
|
smtSolver->add(pathVars[state][j]); |
|
|
|
} |
|
|
|
rowindex += pomdp.getNumberOfChoices(state); |
|
|
|
} |
|
|
|
} |
|
|
@ -284,6 +292,7 @@ namespace storm { |
|
|
|
for(auto const& state : statesForObservation) { |
|
|
|
if (!targetStates.get(state)) { |
|
|
|
smtSolver->add(!continuationVars[state] || schedulerVariableExpressions[obs] > 0); |
|
|
|
smtSolver->add(!reachVarExpressions[state] || !followVarExpressions[obs] || schedulerVariableExpressions[obs] > 0); |
|
|
|
} |
|
|
|
} |
|
|
|
++obs; |
|
|
@ -335,9 +344,10 @@ namespace storm { |
|
|
|
STORM_LOG_ASSERT(reachVarExpressions.size() > state, "state id " << state << " exceeds number of states (" << reachVarExpressions.size() << ")" ); |
|
|
|
atLeastOneOfStates.push_back(reachVarExpressions[state]); |
|
|
|
} |
|
|
|
assert(atLeastOneOfStates.size() > 0); |
|
|
|
// PAPER COMMENT 11
|
|
|
|
smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates)); |
|
|
|
if (!atLeastOneOfStates.empty()) { |
|
|
|
smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates)); |
|
|
|
} |
|
|
|
smtSolver->push(); |
|
|
|
|
|
|
|
std::set<storm::expressions::Expression> allOfTheseAssumption; |
|
|
@ -349,18 +359,54 @@ namespace storm { |
|
|
|
allOfTheseAssumption.insert(reachVarExpressions[state]); |
|
|
|
} |
|
|
|
|
|
|
|
for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) { |
|
|
|
updateForObservationExpressions.push_back(storm::expressions::disjunction(reachVarExpressionsPerObservation[ob])); |
|
|
|
schedulerForObs.push_back(std::vector<uint64_t>()); |
|
|
|
if (winningRegion.empty()) { |
|
|
|
// Keep it simple here to help bughunting if necessary.
|
|
|
|
for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) { |
|
|
|
updateForObservationExpressions.push_back( |
|
|
|
storm::expressions::disjunction(reachVarExpressionsPerObservation[ob])); |
|
|
|
schedulerForObs.push_back(0); |
|
|
|
} |
|
|
|
} else { |
|
|
|
|
|
|
|
uint64_t obs = 0; |
|
|
|
for (auto const &statesForObservation : statesPerObservation) { |
|
|
|
schedulerForObs.push_back(0); |
|
|
|
for (auto const& winningSet : winningRegion.getWinningSetsPerObservation(obs)) { |
|
|
|
for (auto const &stateOffset : ~winningSet) { |
|
|
|
uint64_t state = statesForObservation[stateOffset]; |
|
|
|
assert(obs < schedulerForObs.size()); |
|
|
|
++(schedulerForObs[obs]); |
|
|
|
auto constant = expressionManager->integer(schedulerForObs[obs]); |
|
|
|
// PAPER COMMENT 14:
|
|
|
|
smtSolver->add(!(continuationVarExpressions[state] && |
|
|
|
(schedulerVariableExpressions[obs] == constant))); |
|
|
|
smtSolver->add(!(reachVarExpressions[state] && |
|
|
|
followVarExpressions[pomdp.getObservation(state)] && |
|
|
|
(schedulerVariableExpressions[obs] == constant))); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
if (winningRegion.getWinningSetsPerObservation(obs).empty()) { |
|
|
|
updateForObservationExpressions.push_back( |
|
|
|
storm::expressions::disjunction(reachVarExpressionsPerObservation[obs])); |
|
|
|
} else { |
|
|
|
updateForObservationExpressions.push_back(winningRegion.extensionExpression(obs, reachVarExpressionsPerObservation[obs])); |
|
|
|
} |
|
|
|
++obs; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { |
|
|
|
auto constant = expressionManager->integer(schedulerForObs[obs].size()); |
|
|
|
auto constant = expressionManager->integer(schedulerForObs[obs]); |
|
|
|
smtSolver->add(schedulerVariableExpressions[obs] <= constant); |
|
|
|
smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs])); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert(pomdp.getNrObservations() == schedulerForObs.size()); |
|
|
|
|
|
|
|
|
|
|
|
InternalObservationScheduler scheduler; |
|
|
@ -549,9 +595,9 @@ namespace storm { |
|
|
|
storm::analysis::QualitativeAnalysisOnGraphs<ValueType> graphanalysis(pomdp); |
|
|
|
uint64_t targetStatesBefore = targetStates.getNumberOfSetBits(); |
|
|
|
STORM_LOG_INFO("Target states before graph based analysis " << targetStates.getNumberOfSetBits()); |
|
|
|
targetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates); |
|
|
|
uint64_t targetStatesAfter = targetStates.getNumberOfSetBits(); |
|
|
|
STORM_LOG_INFO("Target states after graph based analysis " << targetStates.getNumberOfSetBits()); |
|
|
|
storm::storage::BitVector newtargetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates); |
|
|
|
uint64_t targetStatesAfter = newtargetStates.getNumberOfSetBits(); |
|
|
|
STORM_LOG_INFO("Target states after graph based analysis " << newtargetStates.getNumberOfSetBits()); |
|
|
|
if (targetStatesAfter - targetStatesBefore > 0) { |
|
|
|
stats.winningRegionUpdatesTimer.start(); |
|
|
|
|
|
|
@ -561,7 +607,7 @@ namespace storm { |
|
|
|
} |
|
|
|
bool observationIsWinning = true; |
|
|
|
for (uint64_t state : statesPerObservation[observation]) { |
|
|
|
if(!targetStates.get(state)) { |
|
|
|
if(!newtargetStates.get(state)) { |
|
|
|
observationIsWinning = false; |
|
|
|
break; |
|
|
|
} |
|
|
@ -569,6 +615,9 @@ namespace storm { |
|
|
|
if(observationIsWinning) { |
|
|
|
stats.incrementGraphBasedWinningObservations(); |
|
|
|
winningRegion.setObservationIsWinning(observation); |
|
|
|
for(auto const& state : statesPerObservation[observation]) { |
|
|
|
targetStates.set(state); |
|
|
|
} |
|
|
|
updated.set(observation); |
|
|
|
} |
|
|
|
} |
|
|
@ -608,14 +657,15 @@ namespace storm { |
|
|
|
if (observations.get(obs) && updated.get(obs)) { |
|
|
|
STORM_LOG_DEBUG("We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << "."); |
|
|
|
assert(schedulerForObs.size() > obs); |
|
|
|
schedulerForObs[obs].push_back(finalSchedulers.size()); |
|
|
|
STORM_LOG_DEBUG("We now have " << schedulerForObs[obs].size() << " policies for states with observation " << obs); |
|
|
|
(schedulerForObs[obs])++; |
|
|
|
STORM_LOG_DEBUG("We now have " << schedulerForObs[obs] << " policies for states with observation " << obs); |
|
|
|
|
|
|
|
for (auto const &state : statesForObservation) { |
|
|
|
if (!coveredStates.get(state)) { |
|
|
|
auto constant = expressionManager->integer(schedulerForObs[obs].size()); |
|
|
|
auto constant = expressionManager->integer(schedulerForObs[obs]); |
|
|
|
// PAPER COMMENT 14:
|
|
|
|
smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant))); |
|
|
|
smtSolver->add(!(reachVarExpressions[state] && followVarExpressions[pomdp.getObservation(state)] && (schedulerVariableExpressions[obs] == constant))); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
@ -626,7 +676,7 @@ namespace storm { |
|
|
|
smtSolver->push(); |
|
|
|
|
|
|
|
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { |
|
|
|
auto constant = expressionManager->integer(schedulerForObs[obs].size()); |
|
|
|
auto constant = expressionManager->integer(schedulerForObs[obs]); |
|
|
|
// PAPER COMMENT 13
|
|
|
|
smtSolver->add(schedulerVariableExpressions[obs] <= constant); |
|
|
|
// PAPER COMMENT 12
|
|
|
|