@ -41,9 +41,11 @@ namespace storm {
STORM_PRINT_AND_LOG ( " SAT Calls time: " < < smtCheckTimer < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " SAT Calls time: " < < smtCheckTimer < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Outer iterations: " < < outerIterations < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Outer iterations: " < < outerIterations < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Solver initialization time: " < < initializeSolverTimer < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Solver initialization time: " < < initializeSolverTimer < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Extend partial scheduler time: " < < updateExtensionSolverTime < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Obtain partial scheduler time: " < < evaluateExtensionSolverTime < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Update solver to extend partial scheduler time: " < < encodeExtensionSolverTime < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Update solver with new scheduler time: " < < updateNewStrategySolverTime < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Update solver with new scheduler time: " < < updateNewStrategySolverTime < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Winning regions update time: " < < winningRegionUpdatesTimer < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Winning regions update time: " < < winningRegionUpdatesTimer < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Graph search time: " < < graphSearchTime < < std : : endl ) ;
}
}
template < typename ValueType >
template < typename ValueType >
@ -155,15 +157,21 @@ namespace storm {
uint64_t obs = 0 ;
uint64_t obs = 0 ;
if ( options . onlyDeterministicStrategies ) {
for ( auto const & statesForObservation : statesPerObservation ) {
for ( auto const & statesForObservation : statesPerObservation ) {
if ( pomdp . getNumberOfChoices ( statesForObservation . front ( ) ) = = 1 ) {
+ + obs ;
continue ;
}
if ( options . onlyDeterministicStrategies | | statesForObservation . size ( ) = = 1 ) {
for ( uint64_t a = 0 ; a < pomdp . getNumberOfChoices ( statesForObservation . front ( ) ) - 1 ; + + a ) {
for ( uint64_t a = 0 ; a < pomdp . getNumberOfChoices ( statesForObservation . front ( ) ) - 1 ; + + a ) {
for ( uint64_t b = a + 1 ; b < pomdp . getNumberOfChoices ( statesForObservation . front ( ) ) ; + + b ) {
for ( uint64_t b = a + 1 ; b < pomdp . getNumberOfChoices ( statesForObservation . front ( ) ) ; + + b ) {
smtSolver - > add ( ! actionSelectionVarExpressions [ obs ] [ a ] | | ! actionSelectionVarExpressions [ obs ] [ b ] ) ;
smtSolver - > add ( ! ( actionSelectionVarExpressions [ obs ] [ a ] ) | |
! ( actionSelectionVarExpressions [ obs ] [ b ] ) ) ;
}
}
}
}
+ + obs ;
}
}
+ + obs ;
}
}
// PAPER COMMENT: 1
// PAPER COMMENT: 1
@ -302,38 +310,107 @@ namespace storm {
smtSolver - > add ( storm : : expressions : : implies ( switchVarExpressions [ obs ] , storm : : expressions : : disjunction ( reachVarExpressionsPerObservation [ obs ] ) ) ) ;
smtSolver - > add ( storm : : expressions : : implies ( switchVarExpressions [ obs ] , storm : : expressions : : disjunction ( reachVarExpressionsPerObservation [ obs ] ) ) ) ;
}
}
// PAPER COMMENT 10
// PAPER COMMENT 10
if ( ! lookaheadConstraintsRequired ) {
uint64_t rowIndex = 0 ;
for ( uint64_t state = 0 ; state < pomdp . getNumberOfStates ( ) ; + + state ) {
uint64_t enabledActions = pomdp . getNumberOfChoices ( state ) ;
if ( ! surelyReachSinkStates . get ( state ) ) {
std : : vector < storm : : expressions : : Expression > successorVars ;
for ( uint64_t act = 0 ; act < enabledActions ; + + act ) {
for ( auto const & entries : pomdp . getTransitionMatrix ( ) . getRow ( rowIndex ) ) {
successorVars . push_back ( reachVarExpressions [ entries . getColumn ( ) ] ) ;
}
rowIndex + + ;
// if (!lookaheadConstraintsRequired) {
// uint64_t rowIndex = 0;
// for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
// uint64_t enabledActions = pomdp.getNumberOfChoices(state);
// if (!surelyReachSinkStates.get(state)) {
// std::vector<storm::expressions::Expression> successorVars;
// for (uint64_t act = 0; act < enabledActions; ++act) {
// for (auto const &entries : pomdp.getTransitionMatrix().getRow(rowIndex)) {
// successorVars.push_back(reachVarExpressions[entries.getColumn()]);
// }
// rowIndex++;
// }
// successorVars.push_back(!switchVars[pomdp.getObservation(state)]);
// smtSolver->add(storm::expressions::implies(storm::expressions::conjunction(successorVars), reachVarExpressions[state]));
// } else {
// rowIndex += enabledActions;
// }
// }
// } else {
// STORM_LOG_WARN("Some optimization not implemented yet.");
// }
// TODO: Update found schedulers if k is increased.
}
}
successorVars . push_back ( ! switchVars [ pomdp . getObservation ( state ) ] ) ;
smtSolver - > add ( storm : : expressions : : implies ( storm : : expressions : : conjunction ( successorVars ) , reachVarExpressions [ state ] ) ) ;
} else {
rowIndex + = enabledActions ;
template < typename ValueType >
uint64_t MemlessStrategySearchQualitative < ValueType > : : getOffsetFromObservation ( uint64_t state , uint64_t observation ) const {
if ( ! useFindOffset ) {
STORM_LOG_WARN ( " This code is slow and should only be used for debugging. " ) ;
useFindOffset = true ;
}
}
uint64_t offset = 0 ;
for ( uint64_t s : statesPerObservation [ observation ] ) {
if ( s = = state ) {
return offset ;
}
}
} else {
STORM_LOG_WARN ( " Some optimization not implemented yet. " ) ;
+ + offset ;
}
}
// TODO: Update found schedulers if k is increased.
assert ( false ) ; // State should have occured.
return 0 ;
}
}
template < typename ValueType >
template < typename ValueType >
bool MemlessStrategySearchQualitative < ValueType > : : analyze ( uint64_t k , storm : : storage : : BitVector const & oneOfTheseStates , storm : : storage : : BitVector const & allOfTheseStates ) {
bool MemlessStrategySearchQualitative < ValueType > : : analyze ( uint64_t k , storm : : storage : : BitVector const & oneOfTheseStates , storm : : storage : : BitVector const & allOfTheseStates ) {
std : : cout < < " Surely reach sink states: " < < surelyReachSinkStates < < std : : endl ;
std : : cout < < " Target states " < < targetStates < < std : : endl ;
std : : cout < < ( ~ surelyReachSinkStates & ~ targetStates ) < < std : : endl ;
stats . initializeSolverTimer . start ( ) ;
stats . initializeSolverTimer . start ( ) ;
// TODO: When do we need to reinitialize? When the solver has been reset.
// TODO: When do we need to reinitialize? When the solver has been reset.
initialize ( k ) ;
initialize ( k ) ;
maxK = k ;
maxK = k ;
stats . winningRegionUpdatesTimer . start ( ) ;
storm : : storage : : BitVector updated ( pomdp . getNrObservations ( ) ) ;
// TODO CODE DUPLICATION WITH UPDATE, PUT IN PROCEDURE
storm : : storage : : BitVector potentialWinner ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector observationsWithPartialWinners ( pomdp . getNrObservations ( ) ) ;
for ( uint64_t observation = 0 ; observation < pomdp . getNrObservations ( ) ; + + observation ) {
if ( winningRegion . observationIsWinning ( observation ) ) {
continue ;
}
bool observationIsWinning = true ;
for ( uint64_t state : statesPerObservation [ observation ] ) {
if ( ! targetStates . get ( state ) ) {
observationIsWinning = false ;
observationsWithPartialWinners . set ( observation ) ;
} else {
potentialWinner . set ( observation ) ;
}
}
if ( observationIsWinning ) {
STORM_LOG_TRACE ( " Observation " < < observation < < " is winning. " ) ;
stats . incrementGraphBasedWinningObservations ( ) ;
winningRegion . setObservationIsWinning ( observation ) ;
updated . set ( observation ) ;
}
}
STORM_LOG_INFO ( " Graph based winning obs: " < < stats . getGraphBasedwinningObservations ( ) ) ;
observationsWithPartialWinners & = potentialWinner ;
for ( auto const & observation : observationsWithPartialWinners ) {
uint64_t nrStatesForObs = statesPerObservation [ observation ] . size ( ) ;
storm : : storage : : BitVector update ( nrStatesForObs ) ;
for ( uint64_t i = 0 ; i < nrStatesForObs ; + + i ) {
uint64_t state = statesPerObservation [ observation ] [ i ] ;
if ( targetStates . get ( state ) ) {
update . set ( i ) ;
}
}
assert ( ! update . empty ( ) ) ;
STORM_LOG_TRACE ( " Extend winning region for observation " < < observation < < " with target states/offsets " < < update ) ;
winningRegion . addTargetStates ( observation , update ) ;
assert ( winningRegion . query ( observation , update ) ) ; // "Cannot continue: No scheduler known for state " << i << " (observation " << obs << ").");
updated . set ( observation ) ;
}
for ( auto const & state : targetStates ) {
STORM_LOG_ASSERT ( winningRegion . isWinning ( pomdp . getObservation ( state ) , getOffsetFromObservation ( state , pomdp . getObservation ( state ) ) ) , " Target state " < < state < < " , observation " < < pomdp . getObservation ( state ) < < " is not reflected as winning. " ) ;
}
stats . winningRegionUpdatesTimer . stop ( ) ;
uint64_t maximalNrActions = 8 ;
uint64_t maximalNrActions = 8 ;
STORM_LOG_WARN ( " We have hardcoded (an upper bound on) the number of actions " ) ;
STORM_LOG_WARN ( " We have hardcoded (an upper bound on) the number of actions " ) ;
@ -415,6 +492,7 @@ namespace storm {
storm : : storage : : BitVector observations ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector observations ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector observationsAfterSwitch ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector observationsAfterSwitch ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector observationUpdated ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector observationUpdated ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector uncoveredStates ( pomdp . getNumberOfStates ( ) ) ;
storm : : storage : : BitVector coveredStates ( pomdp . getNumberOfStates ( ) ) ;
storm : : storage : : BitVector coveredStates ( pomdp . getNumberOfStates ( ) ) ;
storm : : storage : : BitVector coveredStatesAfterSwitch ( pomdp . getNumberOfStates ( ) ) ;
storm : : storage : : BitVector coveredStatesAfterSwitch ( pomdp . getNumberOfStates ( ) ) ;
@ -449,15 +527,14 @@ namespace storm {
break ;
break ;
}
}
newSchedulerDiscovered = true ;
newSchedulerDiscovered = true ;
stats . upd ateExtensionSolverTime . start ( ) ;
auto model = smtSolver - > getModel ( ) ;
stats . eval uateExtensionSolverTime. start ( ) ;
auto const & model = smtSolver - > getModel ( ) ;
newObservationsAfterSwitch . clear ( ) ;
newObservationsAfterSwitch . clear ( ) ;
newObservations . clear ( ) ;
newObservations . clear ( ) ;
uint64_t obs = 0 ;
uint64_t obs = 0 ;
for ( auto const & ov : observationUpdatedVariables ) {
for ( auto const & ov : observationUpdatedVariables ) {
if ( ! observationUpdated . get ( obs ) & & model - > getBooleanValue ( ov ) ) {
if ( ! observationUpdated . get ( obs ) & & model - > getBooleanValue ( ov ) ) {
STORM_LOG_TRACE ( " New observation updated: " < < obs ) ;
STORM_LOG_TRACE ( " New observation updated: " < < obs ) ;
observationUpdated . set ( obs ) ;
observationUpdated . set ( obs ) ;
@ -465,32 +542,43 @@ namespace storm {
obs + + ;
obs + + ;
}
}
uint64_t i = 0 ;
for ( auto const & rv : reachVars ) {
if ( ! coveredStates . get ( i ) & & model - > getBooleanValue ( rv ) ) {
// for(uint64_t i : targetStates) {
// assert(model->getBooleanValue(reachVars[i]));
// }
uncoveredStates = ~ coveredStates ;
for ( uint64_t i : uncoveredStates ) {
auto const & rv = reachVars [ i ] ;
auto const & rvExpr = reachVarExpressions [ i ] ;
if ( model - > getBooleanValue ( rv ) ) {
STORM_LOG_TRACE ( " New state: " < < i ) ;
STORM_LOG_TRACE ( " New state: " < < i ) ;
smtSolver - > add ( rv . getExpression ( ) ) ;
smtSolver - > add ( rvExpr ) ;
assert ( ! surelyReachSinkStates . get ( i ) ) ;
assert ( ! surelyReachSinkStates . get ( i ) ) ;
newObservations . set ( pomdp . getObservation ( i ) ) ;
newObservations . set ( pomdp . getObservation ( i ) ) ;
coveredStates . set ( i ) ;
coveredStates . set ( i ) ;
}
}
+ + i ;
}
}
i = 0 ;
for ( auto const & rv : continuationVars ) {
if ( ! coveredStatesAfterSwitch . get ( i ) & & model - > getBooleanValue ( rv ) ) {
smtSolver - > add ( rv . getExpression ( ) ) ;
if ( ! observationsAfterSwitch . get ( pomdp . getObservation ( i ) ) ) {
newObservationsAfterSwitch . set ( pomdp . getObservation ( i ) ) ;
storm : : storage : : BitVector uncoveredStatesAfterSwitch ( ~ coveredStatesAfterSwitch ) ;
for ( uint64_t i : uncoveredStatesAfterSwitch ) {
auto const & cv = continuationVars [ i ] ;
if ( model - > getBooleanValue ( cv ) ) {
uint64_t obs = pomdp . getObservation ( i ) ;
STORM_LOG_ASSERT ( winningRegion . isWinning ( obs , getOffsetFromObservation ( i , obs ) ) , " Cannot continue: No scheduler known for state " < < i < < " (observation " < < obs < < " ). " ) ;
auto const & cvExpr = continuationVarExpressions [ i ] ;
smtSolver - > add ( cvExpr ) ;
if ( ! observationsAfterSwitch . get ( obs ) ) {
newObservationsAfterSwitch . set ( obs ) ;
}
}
+ + i ;
}
}
}
}
stats . evaluateExtensionSolverTime . stop ( ) ;
if ( options . computeTraceOutput ( ) ) {
if ( options . computeTraceOutput ( ) ) {
detail : : printRelevantInfoFromModel ( model , reachVars , continuationVars ) ;
detail : : printRelevantInfoFromModel ( model , reachVars , continuationVars ) ;
}
}
stats . encodeExtensionSolverTime . start ( ) ;
for ( auto obs : newObservations ) {
for ( auto obs : newObservations ) {
auto const & actionSelectionVarsForObs = actionSelectionVars [ obs ] ;
auto const & actionSelectionVarsForObs = actionSelectionVars [ obs ] ;
observations . set ( obs ) ;
observations . set ( obs ) ;
@ -534,16 +622,11 @@ namespace storm {
if ( remainingExpressions . empty ( ) ) {
if ( remainingExpressions . empty ( ) ) {
stats . updat eExtensionSolverTime. stop ( ) ;
stats . encod eExtensionSolverTime. stop ( ) ;
break ;
break ;
}
}
// Add scheduler
//std::cout << storm::expressions::disjunction(remainingExpressions) << std::endl;
smtSolver - > add ( storm : : expressions : : disjunction ( remainingExpressions ) ) ;
smtSolver - > add ( storm : : expressions : : disjunction ( remainingExpressions ) ) ;
stats . updateExtensionSolverTime . stop ( ) ;
stats . encodeExtensionSolverTime . stop ( ) ;
}
}
if ( ! newSchedulerDiscovered ) {
if ( ! newSchedulerDiscovered ) {
break ;
break ;
@ -591,45 +674,58 @@ namespace storm {
}
}
stats . winningRegionUpdatesTimer . stop ( ) ;
stats . winningRegionUpdatesTimer . stop ( ) ;
if ( newTargetObservations > 0 ) {
if ( newTargetObservations > 0 ) {
stats . graphSearchTime . start ( ) ;
storm : : analysis : : QualitativeAnalysisOnGraphs < ValueType > graphanalysis ( pomdp ) ;
storm : : analysis : : QualitativeAnalysisOnGraphs < ValueType > graphanalysis ( pomdp ) ;
uint64_t targetStatesBefore = targetStates . getNumberOfSetBits ( ) ;
uint64_t targetStatesBefore = targetStates . getNumberOfSetBits ( ) ;
STORM_LOG_INFO ( " Target states before graph based analysis " < < targetStates . getNumberOfSetBits ( ) ) ;
STORM_LOG_INFO ( " Target states before graph based analysis " < < targetStates . getNumberOfSetBits ( ) ) ;
storm : : storage : : BitVector newtargetStates = graphanalysis . analyseProb1Max ( ~ surelyReachSinkStates , targetStates ) ;
uint64_t targetStatesAfter = newtargetStates . getNumberOfSetBits ( ) ;
STORM_LOG_INFO ( " Target states after graph based analysis " < < newtargetStates . getNumberOfSetBits ( ) ) ;
targetStates = graphanalysis . analyseProb1Max ( ~ surelyReachSinkStates , targetStates ) ;
uint64_t targetStatesAfter = targetStates . getNumberOfSetBits ( ) ;
STORM_LOG_INFO ( " Target states after graph based analysis " < < targetStates . getNumberOfSetBits ( ) ) ;
stats . graphSearchTime . stop ( ) ;
if ( targetStatesAfter - targetStatesBefore > 0 ) {
if ( targetStatesAfter - targetStatesBefore > 0 ) {
stats . winningRegionUpdatesTimer . start ( ) ;
stats . winningRegionUpdatesTimer . start ( ) ;
// TODO CODE DUPLICATION WITH INIT, PUT IN PROCEDURE
storm : : storage : : BitVector potentialWinner ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector observationsWithPartialWinners ( pomdp . getNrObservations ( ) ) ;
for ( uint64_t observation = 0 ; observation < pomdp . getNrObservations ( ) ; + + observation ) {
for ( uint64_t observation = 0 ; observation < pomdp . getNrObservations ( ) ; + + observation ) {
if ( winningRegion . observationIsWinning ( observation ) ) {
if ( winningRegion . observationIsWinning ( observation ) ) {
continue ;
continue ;
}
}
bool observationIsWinning = true ;
bool observationIsWinning = true ;
for ( uint64_t state : statesPerObservation [ observation ] ) {
for ( uint64_t state : statesPerObservation [ observation ] ) {
if ( ! new targetStates. get ( state ) ) {
if ( ! targetStates . get ( state ) ) {
observationIsWinning = false ;
observationIsWinning = false ;
break ;
observationsWithPartialWinners . set ( observation ) ;
} else {
potentialWinner . set ( observation ) ;
}
}
}
}
if ( observationIsWinning ) {
if ( observationIsWinning ) {
stats . incrementGraphBasedWinningObservations ( ) ;
stats . incrementGraphBasedWinningObservations ( ) ;
winningRegion . setObservationIsWinning ( observation ) ;
winningRegion . setObservationIsWinning ( observation ) ;
for ( auto const & state : statesPerObservation [ observation ] ) {
targetStates . set ( state ) ;
}
updated . set ( observation ) ;
updated . set ( observation ) ;
}
}
}
}
STORM_LOG_INFO ( " Graph based winning obs: " < < stats . getGraphBasedwinningObservations ( ) ) ;
STORM_LOG_INFO ( " Graph based winning obs: " < < stats . getGraphBasedwinningObservations ( ) ) ;
uint64_t nonWinObTargetStates = 0 ;
for ( uint64_t state : targetStates ) {
if ( ! winningRegion . observationIsWinning ( pomdp . getObservation ( state ) ) ) {
nonWinObTargetStates + + ;
observationsWithPartialWinners & = potentialWinner ;
for ( auto const & observation : observationsWithPartialWinners ) {
uint64_t nrStatesForObs = statesPerObservation [ observation ] . size ( ) ;
storm : : storage : : BitVector update ( nrStatesForObs ) ;
for ( uint64_t i = 0 ; i < nrStatesForObs ; + + i ) {
uint64_t state = statesPerObservation [ observation ] [ i ] ;
if ( targetStates . get ( state ) ) {
update . set ( i ) ;
}
}
}
assert ( ! update . empty ( ) ) ;
STORM_LOG_TRACE ( " Extend winning region for observation " < < observation < < " with target states/offsets " < < update ) ;
winningRegion . addTargetStates ( observation , update ) ;
assert ( winningRegion . query ( observation , update ) ) ; //
updated . set ( observation ) ;
}
}
stats . winningRegionUpdatesTimer . stop ( ) ;
stats . winningRegionUpdatesTimer . stop ( ) ;
if ( nonWinObTargetStates > 0 ) {
std : : cout < < " Non winning target states " < < nonWinObTargetStates < < std : : endl ;
if ( observationsWithPartialWinners . getNumberOfSetBits ( ) > 0 ) {
STORM_LOG_WARN ( " This case has been barely tested and likely contains bug " ) ;
STORM_LOG_WARN ( " This case has been barely tested and likely contains bug " ) ;
reset ( ) ;
reset ( ) ;
return analyze ( k , ~ targetStates & ~ surelyReachSinkStates ) ;
return analyze ( k , ~ targetStates & ~ surelyReachSinkStates ) ;