|
|
@ -3,6 +3,7 @@ |
|
|
|
|
|
|
|
#include "storm-pomdp/analysis/QualitativeStrategySearchNaive.h"
|
|
|
|
#include "storm-pomdp/analysis/QualitativeAnalysis.h"
|
|
|
|
#include "storm-pomdp/analysis/QualitativeAnalysisOnGraphs.h"
|
|
|
|
|
|
|
|
namespace storm { |
|
|
|
namespace pomdp { |
|
|
@ -53,9 +54,9 @@ namespace storm { |
|
|
|
MemlessSearchOptions const& options) : |
|
|
|
pomdp(pomdp), |
|
|
|
surelyReachSinkStates(surelyReachSinkStates), |
|
|
|
targetObservations(storm::pomdp::extractObservations(pomdp, targetStates)), |
|
|
|
targetStates(targetStates), |
|
|
|
options(options) |
|
|
|
options(options), |
|
|
|
smtSolverFactory(smtSolverFactory) |
|
|
|
{ |
|
|
|
this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>(); |
|
|
|
smtSolver = smtSolverFactory->create(*expressionManager); |
|
|
@ -86,113 +87,142 @@ namespace storm { |
|
|
|
} else { |
|
|
|
lookaheadConstraintsRequired = qualitative::isLookaheadRequired(pomdp, targetStates, surelyReachSinkStates); |
|
|
|
} |
|
|
|
if (maxK == std::numeric_limits<uint64_t>::max()) { |
|
|
|
// not initialized at all.
|
|
|
|
// Create some data structures.
|
|
|
|
for(uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { |
|
|
|
|
|
|
|
if (actionSelectionVars.empty()) { |
|
|
|
for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) { |
|
|
|
actionSelectionVars.push_back(std::vector<storm::expressions::Variable>()); |
|
|
|
actionSelectionVarExpressions.push_back(std::vector<storm::expressions::Expression>()); |
|
|
|
} |
|
|
|
|
|
|
|
// Fill the states-per-observation mapping,
|
|
|
|
// declare the reachability variables,
|
|
|
|
// declare the path variables.
|
|
|
|
for(uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { |
|
|
|
if(lookaheadConstraintsRequired) { |
|
|
|
pathVars.push_back(std::vector<storm::expressions::Expression>()); |
|
|
|
for (uint64_t i = 0; i < k; ++i) { |
|
|
|
pathVars.back().push_back(expressionManager->declareBooleanVariable("P-" + std::to_string(stateId) + "-" + std::to_string(i)).getExpression()); |
|
|
|
} |
|
|
|
} |
|
|
|
for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { |
|
|
|
reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId))); |
|
|
|
reachVarExpressions.push_back(reachVars.back().getExpression()); |
|
|
|
reachVarExpressionsPerObservation[pomdp.getObservation(stateId)].push_back(reachVarExpressions.back()); |
|
|
|
continuationVars.push_back(expressionManager->declareBooleanVariable("D-" + std::to_string(stateId))); |
|
|
|
reachVarExpressionsPerObservation[pomdp.getObservation(stateId)].push_back( |
|
|
|
reachVarExpressions.back()); |
|
|
|
continuationVars.push_back( |
|
|
|
expressionManager->declareBooleanVariable("D-" + std::to_string(stateId))); |
|
|
|
continuationVarExpressions.push_back(continuationVars.back().getExpression()); |
|
|
|
} |
|
|
|
assert(!lookaheadConstraintsRequired || pathVars.size() == pomdp.getNumberOfStates()); |
|
|
|
assert(reachVars.size() == pomdp.getNumberOfStates()); |
|
|
|
assert(reachVarExpressions.size() == pomdp.getNumberOfStates()); |
|
|
|
|
|
|
|
// Create the action selection variables.
|
|
|
|
uint64_t obs = 0; |
|
|
|
for(auto const& statesForObservation : statesPerObservation) { |
|
|
|
for (auto const &statesForObservation : statesPerObservation) { |
|
|
|
for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()); ++a) { |
|
|
|
std::string varName = "A-" + std::to_string(obs) + "-" + std::to_string(a); |
|
|
|
actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName)); |
|
|
|
actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression()); |
|
|
|
actionSelectionVarExpressions.at(obs).push_back( |
|
|
|
actionSelectionVars.at(obs).back().getExpression()); |
|
|
|
} |
|
|
|
schedulerVariables.push_back(expressionManager->declareBitVectorVariable("scheduler-obs-" + std::to_string(obs), statesPerObservation.size())); |
|
|
|
schedulerVariables.push_back( |
|
|
|
expressionManager->declareBitVectorVariable("scheduler-obs-" + std::to_string(obs), |
|
|
|
statesPerObservation.size())); |
|
|
|
schedulerVariableExpressions.push_back(schedulerVariables.back()); |
|
|
|
switchVars.push_back(expressionManager->declareBooleanVariable("S-" + std::to_string(obs))); |
|
|
|
switchVarExpressions.push_back(switchVars.back().getExpression()); |
|
|
|
observationUpdatedVariables.push_back(expressionManager->declareBooleanVariable("U-" + std::to_string(obs))); |
|
|
|
observationUpdatedVariables.push_back( |
|
|
|
expressionManager->declareBooleanVariable("U-" + std::to_string(obs))); |
|
|
|
observationUpdatedExpressions.push_back(observationUpdatedVariables.back().getExpression()); |
|
|
|
if (options.onlyDeterministicStrategies) { |
|
|
|
for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front())-1; ++a) { |
|
|
|
for (uint64_t b = a+1; b < pomdp.getNumberOfChoices(statesForObservation.front()); ++b) { |
|
|
|
smtSolver->add(!actionSelectionVarExpressions[obs][a] || !actionSelectionVarExpressions[obs][b]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
followVars.push_back(expressionManager->declareBooleanVariable("F-"+std::to_string(obs))); |
|
|
|
followVarExpressions.push_back(followVars.back().getExpression()); |
|
|
|
|
|
|
|
++obs; |
|
|
|
} |
|
|
|
|
|
|
|
// PAPER COMMENT: 1
|
|
|
|
for (auto const& actionVars : actionSelectionVarExpressions) { |
|
|
|
smtSolver->add(storm::expressions::disjunction(actionVars)); |
|
|
|
for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { |
|
|
|
pathVars.push_back(std::vector<storm::expressions::Expression>()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Update at least one observation.
|
|
|
|
// PAPER COMMENT: 2
|
|
|
|
smtSolver->add(storm::expressions::disjunction(observationUpdatedExpressions)); |
|
|
|
|
|
|
|
// PAPER COMMENT: 3
|
|
|
|
if (lookaheadConstraintsRequired) { |
|
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
|
if (targetStates.get(state)) { |
|
|
|
smtSolver->add(pathVars[state][0]); |
|
|
|
} else { |
|
|
|
smtSolver->add(!pathVars[state][0]); |
|
|
|
uint64_t initK = 0; |
|
|
|
if (maxK != std::numeric_limits<uint64_t>::max()) { |
|
|
|
initK = maxK; |
|
|
|
} |
|
|
|
if (initK < k) { |
|
|
|
for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) { |
|
|
|
if (lookaheadConstraintsRequired) { |
|
|
|
for (uint64_t i = initK; i < k; ++i) { |
|
|
|
pathVars[stateId].push_back(expressionManager->declareBooleanVariable( |
|
|
|
"P-" + std::to_string(stateId) + "-" + std::to_string(i)).getExpression()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// PAPER COMMENT: 4
|
|
|
|
uint64_t rowindex = 0; |
|
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
|
assert(!lookaheadConstraintsRequired || pathVars.size() == pomdp.getNumberOfStates()); |
|
|
|
assert(reachVars.size() == pomdp.getNumberOfStates()); |
|
|
|
assert(reachVarExpressions.size() == pomdp.getNumberOfStates()); |
|
|
|
|
|
|
|
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { |
|
|
|
std::vector<storm::expressions::Expression> subexprreachSwitch; |
|
|
|
std::vector<storm::expressions::Expression> subexprreachNoSwitch; |
|
|
|
subexprreachSwitch.push_back(!reachVarExpressions[state]); |
|
|
|
subexprreachSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); |
|
|
|
subexprreachSwitch.push_back(!switchVarExpressions[pomdp.getObservation(state)]); |
|
|
|
subexprreachNoSwitch.push_back(!reachVarExpressions[state]); |
|
|
|
subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); |
|
|
|
subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]); |
|
|
|
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { |
|
|
|
subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn())); |
|
|
|
smtSolver->add(storm::expressions::disjunction(subexprreachSwitch)); |
|
|
|
subexprreachSwitch.pop_back(); |
|
|
|
subexprreachNoSwitch.push_back(reachVarExpressions.at(entries.getColumn())); |
|
|
|
smtSolver->add(storm::expressions::disjunction(subexprreachNoSwitch)); |
|
|
|
subexprreachNoSwitch.pop_back(); |
|
|
|
} |
|
|
|
|
|
|
|
rowindex++; |
|
|
|
uint64_t obs = 0; |
|
|
|
if (options.onlyDeterministicStrategies) { |
|
|
|
for(auto const& statesForObservation : statesPerObservation) { |
|
|
|
for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front())-1; ++a) { |
|
|
|
for (uint64_t b = a+1; b < pomdp.getNumberOfChoices(statesForObservation.front()); ++b) { |
|
|
|
smtSolver->add(!actionSelectionVarExpressions[obs][a] || !actionSelectionVarExpressions[obs][b]); |
|
|
|
} |
|
|
|
} |
|
|
|
++obs; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
smtSolver->push(); |
|
|
|
} else { |
|
|
|
smtSolver->pop(); |
|
|
|
smtSolver->pop(); |
|
|
|
smtSolver->push(); |
|
|
|
assert(false); |
|
|
|
// PAPER COMMENT: 1
|
|
|
|
obs = 0; |
|
|
|
for (auto const& actionVars : actionSelectionVarExpressions) { |
|
|
|
std::vector<storm::expressions::Expression> actExprs = actionVars; |
|
|
|
//actExprs.push_back(followVarExpressions[obs]);
|
|
|
|
smtSolver->add(storm::expressions::disjunction(actExprs)); |
|
|
|
//for (auto const& av : actionVars) {
|
|
|
|
// smtSolver->add(!followVarExpressions[obs] || !av);
|
|
|
|
//}
|
|
|
|
++obs; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Update at least one observation.
|
|
|
|
// PAPER COMMENT: 2
|
|
|
|
smtSolver->add(storm::expressions::disjunction(observationUpdatedExpressions)); |
|
|
|
|
|
|
|
// PAPER COMMENT: 3
|
|
|
|
if (lookaheadConstraintsRequired) { |
|
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
|
if (targetStates.get(state)) { |
|
|
|
smtSolver->add(pathVars[state][0]); |
|
|
|
} else { |
|
|
|
smtSolver->add(!pathVars[state][0]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// PAPER COMMENT: 4
|
|
|
|
uint64_t rowindex = 0; |
|
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
|
if (targetStates.get(state) || surelyReachSinkStates.get(state)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { |
|
|
|
std::vector<storm::expressions::Expression> subexprreachSwitch; |
|
|
|
std::vector<storm::expressions::Expression> subexprreachNoSwitch; |
|
|
|
subexprreachSwitch.push_back(!reachVarExpressions[state]); |
|
|
|
subexprreachSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); |
|
|
|
subexprreachSwitch.push_back(!switchVarExpressions[pomdp.getObservation(state)]); |
|
|
|
subexprreachNoSwitch.push_back(!reachVarExpressions[state]); |
|
|
|
subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]); |
|
|
|
subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]); |
|
|
|
for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowindex)) { |
|
|
|
subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn())); |
|
|
|
smtSolver->add(storm::expressions::disjunction(subexprreachSwitch)); |
|
|
|
subexprreachSwitch.pop_back(); |
|
|
|
subexprreachNoSwitch.push_back(reachVarExpressions.at(entries.getColumn())); |
|
|
|
smtSolver->add(storm::expressions::disjunction(subexprreachNoSwitch)); |
|
|
|
subexprreachNoSwitch.pop_back(); |
|
|
|
} |
|
|
|
|
|
|
|
rowindex++; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rowindex = 0; |
|
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
|
// PAPER COMMENT 5
|
|
|
|
if (surelyReachSinkStates.get(state)) { |
|
|
@ -244,10 +274,12 @@ namespace storm { |
|
|
|
} |
|
|
|
|
|
|
|
// PAPER COMMENT 8
|
|
|
|
uint64_t obs = 0; |
|
|
|
obs = 0; |
|
|
|
for(auto const& statesForObservation : statesPerObservation) { |
|
|
|
for(auto const& state : statesForObservation) { |
|
|
|
smtSolver->add(!continuationVars[state] || schedulerVariableExpressions[obs] > 0); |
|
|
|
if (!targetStates.get(state)) { |
|
|
|
smtSolver->add(!continuationVars[state] || schedulerVariableExpressions[obs] > 0); |
|
|
|
} |
|
|
|
} |
|
|
|
++obs; |
|
|
|
} |
|
|
@ -284,10 +316,10 @@ namespace storm { |
|
|
|
template <typename ValueType> |
|
|
|
bool MemlessStrategySearchQualitative<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) { |
|
|
|
stats.initializeSolverTimer.start(); |
|
|
|
if (k < maxK) { |
|
|
|
initialize(k); |
|
|
|
maxK = k; |
|
|
|
} |
|
|
|
// TODO: When do we need to reinitialize? When the solver has been reset.
|
|
|
|
initialize(k); |
|
|
|
maxK = k; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uint64_t maximalNrActions = 8; |
|
|
@ -374,7 +406,7 @@ namespace storm { |
|
|
|
newObservations.clear(); |
|
|
|
|
|
|
|
uint64_t obs = 0; |
|
|
|
for (auto ov : observationUpdatedVariables) { |
|
|
|
for (auto const& ov : observationUpdatedVariables) { |
|
|
|
|
|
|
|
if (!observationUpdated.get(obs) && model->getBooleanValue(ov)) { |
|
|
|
STORM_LOG_TRACE("New observation updated: " << obs); |
|
|
@ -384,17 +416,18 @@ namespace storm { |
|
|
|
} |
|
|
|
|
|
|
|
uint64_t i = 0; |
|
|
|
for (auto rv : reachVars) { |
|
|
|
for (auto const& rv : reachVars) { |
|
|
|
if (!coveredStates.get(i) && model->getBooleanValue(rv)) { |
|
|
|
STORM_LOG_TRACE("New state: " << i); |
|
|
|
smtSolver->add(rv.getExpression()); |
|
|
|
assert(!surelyReachSinkStates.get(i)); |
|
|
|
newObservations.set(pomdp.getObservation(i)); |
|
|
|
coveredStates.set(i); |
|
|
|
} |
|
|
|
++i; |
|
|
|
} |
|
|
|
i = 0; |
|
|
|
for (auto rv : continuationVars) { |
|
|
|
for (auto const& rv : continuationVars) { |
|
|
|
if (!coveredStatesAfterSwitch.get(i) && model->getBooleanValue(rv) ) { |
|
|
|
smtSolver->add(rv.getExpression()); |
|
|
|
if (!observationsAfterSwitch.get(pomdp.getObservation(i))) { |
|
|
@ -478,12 +511,14 @@ namespace storm { |
|
|
|
|
|
|
|
stats.winningRegionUpdatesTimer.start(); |
|
|
|
storm::storage::BitVector updated(observations.size()); |
|
|
|
uint64_t newTargetObservations = 0; |
|
|
|
for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { |
|
|
|
STORM_LOG_TRACE("consider observation " << observation); |
|
|
|
storm::storage::BitVector update(statesPerObservation[observation].size()); |
|
|
|
uint64_t i = 0; |
|
|
|
for (uint64_t state : statesPerObservation[observation]) { |
|
|
|
if (coveredStates.get(state)) { |
|
|
|
assert(!surelyReachSinkStates.get(state)); |
|
|
|
update.set(i); |
|
|
|
} |
|
|
|
++i; |
|
|
@ -493,19 +528,77 @@ namespace storm { |
|
|
|
bool updateResult = winningRegion.update(observation, update); |
|
|
|
STORM_LOG_TRACE("Region changed:" << updateResult); |
|
|
|
if (updateResult) { |
|
|
|
if (winningRegion.observationIsWinning(observation)) { |
|
|
|
++newTargetObservations; |
|
|
|
for (uint64_t state : statesPerObservation[observation]) { |
|
|
|
targetStates.set(state); |
|
|
|
} |
|
|
|
} |
|
|
|
updated.set(observation); |
|
|
|
updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
STORM_LOG_ASSERT(!updated.empty(), "The strategy should be new in at least one place"); |
|
|
|
stats.winningRegionUpdatesTimer.stop(); |
|
|
|
if (newTargetObservations>0) { |
|
|
|
storm::analysis::QualitativeAnalysisOnGraphs<ValueType> graphanalysis(pomdp); |
|
|
|
uint64_t targetStatesBefore = targetStates.getNumberOfSetBits(); |
|
|
|
STORM_LOG_INFO("Target states before graph based analysis " << targetStates.getNumberOfSetBits()); |
|
|
|
targetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates); |
|
|
|
uint64_t targetStatesAfter = targetStates.getNumberOfSetBits(); |
|
|
|
STORM_LOG_INFO("Target states after graph based analysis " << targetStates.getNumberOfSetBits()); |
|
|
|
if (targetStatesAfter - targetStatesBefore > 0) { |
|
|
|
stats.winningRegionUpdatesTimer.start(); |
|
|
|
|
|
|
|
for(uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) { |
|
|
|
if (winningRegion.observationIsWinning(observation)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
bool observationIsWinning = true; |
|
|
|
for (uint64_t state : statesPerObservation[observation]) { |
|
|
|
if(!targetStates.get(state)) { |
|
|
|
observationIsWinning = false; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if(observationIsWinning) { |
|
|
|
stats.incrementGraphBasedWinningObservations(); |
|
|
|
winningRegion.setObservationIsWinning(observation); |
|
|
|
updated.set(observation); |
|
|
|
} |
|
|
|
} |
|
|
|
STORM_LOG_INFO("Graph based winning obs: " << stats.getGraphBasedwinningObservations()); |
|
|
|
uint64_t nonWinObTargetStates =0; |
|
|
|
for (uint64_t state : targetStates) { |
|
|
|
if (!winningRegion.observationIsWinning(pomdp.getObservation(state))) { |
|
|
|
nonWinObTargetStates++; |
|
|
|
} |
|
|
|
} |
|
|
|
stats.winningRegionUpdatesTimer.stop(); |
|
|
|
if (nonWinObTargetStates > 0) { |
|
|
|
std::cout << "Non winning target states " << nonWinObTargetStates << std::endl; |
|
|
|
STORM_LOG_WARN("This case has been barely tested and likely contains bug"); |
|
|
|
reset(); |
|
|
|
return analyze(k, ~targetStates & ~surelyReachSinkStates); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
// TODO temporarily switched off due to intiialization issues when restartin.
|
|
|
|
STORM_LOG_ASSERT(!updated.empty(), "The strategy should be new in at least one place"); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if(options.computeDebugOutput()) { |
|
|
|
winningRegion.print(); |
|
|
|
} |
|
|
|
|
|
|
|
stats.updateNewStrategySolverTime.start(); |
|
|
|
for(uint64_t observation : updated) { |
|
|
|
updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]); |
|
|
|
} |
|
|
|
|
|
|
|
uint64_t obs = 0; |
|
|
|
for (auto const &statesForObservation : statesPerObservation) { |
|
|
|
if (observations.get(obs) && updated.get(obs)) { |
|
|
@ -537,7 +630,7 @@ namespace storm { |
|
|
|
} |
|
|
|
stats.updateNewStrategySolverTime.stop(); |
|
|
|
|
|
|
|
|
|
|
|
STORM_LOG_INFO("... after iteration " << stats.getIterations() << " so far " << stats.getChecks() << " checks." ); |
|
|
|
} |
|
|
|
winningRegion.print(); |
|
|
|
|
|
|
|