@ -19,18 +19,30 @@ namespace storm {
STORM_LOG_TRACE ( ss . str ( ) ) ;
STORM_LOG_TRACE ( ss . str ( ) ) ;
i = 0 ;
i = 0 ;
STORM_LOG_TRACE ( " states from which we continue: " ) ;
STORM_LOG_TRACE ( " states from which we continue: " ) ;
ss . clear ( ) ;
std : : stringstream ss2 ;
for ( auto rv : continuationVars ) {
for ( auto rv : continuationVars ) {
if ( model - > getBooleanValue ( rv ) ) {
if ( model - > getBooleanValue ( rv ) ) {
ss < < " " < < i ;
ss2 < < " " < < i ;
}
}
+ + i ;
+ + i ;
}
}
STORM_LOG_TRACE ( ss . str ( ) ) ;
STORM_LOG_TRACE ( ss2 . str ( ) ) ;
}
}
}
}
template < typename ValueType >
void MemlessStrategySearchQualitative < ValueType > : : Statistics : : print ( ) const {
STORM_PRINT_AND_LOG ( " Total time: " < < totalTimer < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " SAT Calls " < < satCalls < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " SAT Calls time: " < < smtCheckTimer < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Outer iterations: " < < outerIterations < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Solver initialization time: " < < initializeSolverTimer < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Extend partial scheduler time: " < < updateExtensionSolverTime < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Update solver with new scheduler time: " < < updateNewStrategySolverTime < < std : : endl ) ;
STORM_PRINT_AND_LOG ( " Winning regions update time: " < < winningRegionUpdatesTimer < < std : : endl ) ;
}
template < typename ValueType >
template < typename ValueType >
MemlessStrategySearchQualitative < ValueType > : : MemlessStrategySearchQualitative ( storm : : models : : sparse : : Pomdp < ValueType > const & pomdp ,
MemlessStrategySearchQualitative < ValueType > : : MemlessStrategySearchQualitative ( storm : : models : : sparse : : Pomdp < ValueType > const & pomdp ,
std : : set < uint32_t > const & targetObservationSet ,
std : : set < uint32_t > const & targetObservationSet ,
@ -39,9 +51,9 @@ namespace storm {
std : : shared_ptr < storm : : utility : : solver : : SmtSolverFactory > & smtSolverFactory ,
std : : shared_ptr < storm : : utility : : solver : : SmtSolverFactory > & smtSolverFactory ,
MemlessSearchOptions const & options ) :
MemlessSearchOptions const & options ) :
pomdp ( pomdp ) ,
pomdp ( pomdp ) ,
targetStates ( targetStates ) ,
surelyReachSinkStates ( surelyReachSinkStates ) ,
surelyReachSinkStates ( surelyReachSinkStates ) ,
targetObservations ( targetObservationSet ) ,
targetObservations ( targetObservationSet ) ,
targetStates ( targetStates ) ,
options ( options )
options ( options )
{
{
this - > expressionManager = std : : make_shared < storm : : expressions : : ExpressionManager > ( ) ;
this - > expressionManager = std : : make_shared < storm : : expressions : : ExpressionManager > ( ) ;
@ -49,6 +61,7 @@ namespace storm {
// Initialize states per observation.
// Initialize states per observation.
for ( uint64_t obs = 0 ; obs < pomdp . getNrObservations ( ) ; + + obs ) {
for ( uint64_t obs = 0 ; obs < pomdp . getNrObservations ( ) ; + + obs ) {
statesPerObservation . push_back ( std : : vector < uint64_t > ( ) ) ; // Consider using bitvectors instead.
statesPerObservation . push_back ( std : : vector < uint64_t > ( ) ) ; // Consider using bitvectors instead.
reachVarExpressionsPerObservation . push_back ( std : : vector < storm : : expressions : : Expression > ( ) ) ;
}
}
uint64_t state = 0 ;
uint64_t state = 0 ;
for ( auto obs : pomdp . getObservations ( ) ) {
for ( auto obs : pomdp . getObservations ( ) ) {
@ -60,10 +73,15 @@ namespace storm {
nrStatesPerObservation . push_back ( states . size ( ) ) ;
nrStatesPerObservation . push_back ( states . size ( ) ) ;
}
}
winningRegion = WinningRegion ( nrStatesPerObservation ) ;
winningRegion = WinningRegion ( nrStatesPerObservation ) ;
}
}
template < typename ValueType >
template < typename ValueType >
void MemlessStrategySearchQualitative < ValueType > : : initialize ( uint64_t k ) {
void MemlessStrategySearchQualitative < ValueType > : : initialize ( uint64_t k ) {
STORM_LOG_INFO ( " Start intializing solver... " ) ;
// TODO fix this
bool lookaheadConstraintsRequired = options . lookaheadRequired ;
STORM_LOG_WARN ( " We have hardcoded that we do not need lookahead " ) ;
if ( maxK = = std : : numeric_limits < uint64_t > : : max ( ) ) {
if ( maxK = = std : : numeric_limits < uint64_t > : : max ( ) ) {
// not initialized at all.
// not initialized at all.
// Create some data structures.
// Create some data structures.
@ -76,16 +94,19 @@ namespace storm {
// declare the reachability variables,
// declare the reachability variables,
// declare the path variables.
// declare the path variables.
for ( uint64_t stateId = 0 ; stateId < pomdp . getNumberOfStates ( ) ; + + stateId ) {
for ( uint64_t stateId = 0 ; stateId < pomdp . getNumberOfStates ( ) ; + + stateId ) {
pathVars . push_back ( std : : vector < storm : : expressions : : Expression > ( ) ) ;
for ( uint64_t i = 0 ; i < k ; + + i ) {
pathVars . back ( ) . push_back ( expressionManager - > declareBooleanVariable ( " P- " + std : : to_string ( stateId ) + " - " + std : : to_string ( i ) ) . getExpression ( ) ) ;
if ( lookaheadConstraintsRequired ) {
pathVars . push_back ( std : : vector < storm : : expressions : : Expression > ( ) ) ;
for ( uint64_t i = 0 ; i < k ; + + i ) {
pathVars . back ( ) . push_back ( expressionManager - > declareBooleanVariable ( " P- " + std : : to_string ( stateId ) + " - " + std : : to_string ( i ) ) . getExpression ( ) ) ;
}
}
}
reachVars . push_back ( expressionManager - > declareBooleanVariable ( " C- " + std : : to_string ( stateId ) ) ) ;
reachVars . push_back ( expressionManager - > declareBooleanVariable ( " C- " + std : : to_string ( stateId ) ) ) ;
reachVarExpressions . push_back ( reachVars . back ( ) . getExpression ( ) ) ;
reachVarExpressions . push_back ( reachVars . back ( ) . getExpression ( ) ) ;
reachVarExpressionsPerObservation [ pomdp . getObservation ( stateId ) ] . push_back ( reachVarExpressions . back ( ) ) ;
continuationVars . push_back ( expressionManager - > declareBooleanVariable ( " D- " + std : : to_string ( stateId ) ) ) ;
continuationVars . push_back ( expressionManager - > declareBooleanVariable ( " D- " + std : : to_string ( stateId ) ) ) ;
continuationVarExpressions . push_back ( continuationVars . back ( ) . getExpression ( ) ) ;
continuationVarExpressions . push_back ( continuationVars . back ( ) . getExpression ( ) ) ;
}
}
assert ( pathVars . size ( ) = = pomdp . getNumberOfStates ( ) ) ;
assert ( ! lookaheadConstraintsRequired | | pathVars . size ( ) = = pomdp . getNumberOfStates ( ) ) ;
assert ( reachVars . size ( ) = = pomdp . getNumberOfStates ( ) ) ;
assert ( reachVars . size ( ) = = pomdp . getNumberOfStates ( ) ) ;
assert ( reachVarExpressions . size ( ) = = pomdp . getNumberOfStates ( ) ) ;
assert ( reachVarExpressions . size ( ) = = pomdp . getNumberOfStates ( ) ) ;
@ -101,7 +122,15 @@ namespace storm {
schedulerVariableExpressions . push_back ( schedulerVariables . back ( ) ) ;
schedulerVariableExpressions . push_back ( schedulerVariables . back ( ) ) ;
switchVars . push_back ( expressionManager - > declareBooleanVariable ( " S- " + std : : to_string ( obs ) ) ) ;
switchVars . push_back ( expressionManager - > declareBooleanVariable ( " S- " + std : : to_string ( obs ) ) ) ;
switchVarExpressions . push_back ( switchVars . back ( ) . getExpression ( ) ) ;
switchVarExpressions . push_back ( switchVars . back ( ) . getExpression ( ) ) ;
observationUpdatedVariables . push_back ( expressionManager - > declareBooleanVariable ( " U- " + std : : to_string ( obs ) ) ) ;
observationUpdatedExpressions . push_back ( observationUpdatedVariables . back ( ) . getExpression ( ) ) ;
if ( options . onlyDeterministicStrategies ) {
for ( uint64_t a = 0 ; a < pomdp . getNumberOfChoices ( statesForObservation . front ( ) ) - 1 ; + + a ) {
for ( uint64_t b = a + 1 ; b < pomdp . getNumberOfChoices ( statesForObservation . front ( ) ) ; + + b ) {
smtSolver - > add ( ! actionSelectionVarExpressions [ obs ] [ a ] | | ! actionSelectionVarExpressions [ obs ] [ b ] ) ;
}
}
}
+ + obs ;
+ + obs ;
}
}
@ -109,11 +138,15 @@ namespace storm {
smtSolver - > add ( storm : : expressions : : disjunction ( actionVars ) ) ;
smtSolver - > add ( storm : : expressions : : disjunction ( actionVars ) ) ;
}
}
for ( uint64_t state = 0 ; state < pomdp . getNumberOfStates ( ) ; + + state ) {
if ( targetStates . get ( state ) ) {
smtSolver - > add ( pathVars [ state ] [ 0 ] ) ;
} else {
smtSolver - > add ( ! pathVars [ state ] [ 0 ] ) ;
smtSolver - > add ( storm : : expressions : : disjunction ( observationUpdatedExpressions ) ) ;
if ( lookaheadConstraintsRequired ) {
for ( uint64_t state = 0 ; state < pomdp . getNumberOfStates ( ) ; + + state ) {
if ( targetStates . get ( state ) ) {
smtSolver - > add ( pathVars [ state ] [ 0 ] ) ;
} else {
smtSolver - > add ( ! pathVars [ state ] [ 0 ] ) ;
}
}
}
}
}
@ -152,42 +185,49 @@ namespace storm {
uint64_t rowindex = 0 ;
uint64_t rowindex = 0 ;
for ( uint64_t state = 0 ; state < pomdp . getNumberOfStates ( ) ; + + state ) {
for ( uint64_t state = 0 ; state < pomdp . getNumberOfStates ( ) ; + + state ) {
if ( surelyReachSinkStates . get ( state ) ) {
if ( surelyReachSinkStates . get ( state ) ) {
smtSolver - > add ( ! reachVarExpressions [ state ] ) ;
smtSolver - > add ( ! reachVarExpressions [ state ] ) ;
for ( uint64_t j = 1 ; j < k ; + + j ) {
smtSolver - > add ( ! pathVars [ state ] [ j ] ) ;
}
smtSolver - > add ( ! continuationVarExpressions [ state ] ) ;
smtSolver - > add ( ! continuationVarExpressions [ state ] ) ;
} else if ( ! targetStates . get ( state ) ) {
std : : vector < std : : vector < std : : vector < storm : : expressions : : Expression > > > pathsubsubexprs ;
for ( uint64_t j = 1 ; j < k ; + + j ) {
pathsubsubexprs . push_back ( std : : vector < std : : vector < storm : : expressions : : Expression > > ( ) ) ;
for ( uint64_t action = 0 ; action < pomdp . getNumberOfChoices ( state ) ; + + action ) {
pathsubsubexprs . back ( ) . push_back ( std : : vector < storm : : expressions : : Expression > ( ) ) ;
if ( lookaheadConstraintsRequired ) {
for ( uint64_t j = 1 ; j < k ; + + j ) {
smtSolver - > add ( ! pathVars [ state ] [ j ] ) ;
}
}
}
}
rowindex + = pomdp . getNumberOfChoices ( state ) ;
} else if ( ! targetStates . get ( state ) ) {
if ( lookaheadConstraintsRequired ) {
smtSolver - > add ( storm : : expressions : : implies ( reachVarExpressions . at ( state ) , pathVars . at ( state ) . back ( ) ) ) ;
std : : vector < std : : vector < std : : vector < storm : : expressions : : Expression > > > pathsubsubexprs ;
for ( uint64_t j = 1 ; j < k ; + + j ) {
pathsubsubexprs . push_back ( std : : vector < std : : vector < storm : : expressions : : Expression > > ( ) ) ;
for ( uint64_t action = 0 ; action < pomdp . getNumberOfChoices ( state ) ; + + action ) {
pathsubsubexprs . back ( ) . push_back ( std : : vector < storm : : expressions : : Expression > ( ) ) ;
}
}
for ( uint64_t action = 0 ; action < pomdp . getNumberOfChoices ( state ) ; + + action ) {
std : : vector < storm : : expressions : : Expression > subexprreach ;
for ( auto const & entries : pomdp . getTransitionMatrix ( ) . getRow ( rowindex ) ) {
for ( uint64_t j = 1 ; j < k ; + + j ) {
pathsubsubexprs [ j - 1 ] [ action ] . push_back ( pathVars [ entries . getColumn ( ) ] [ j - 1 ] ) ;
for ( uint64_t action = 0 ; action < pomdp . getNumberOfChoices ( state ) ; + + action ) {
std : : vector < storm : : expressions : : Expression > subexprreach ;
for ( auto const & entries : pomdp . getTransitionMatrix ( ) . getRow ( rowindex ) ) {
for ( uint64_t j = 1 ; j < k ; + + j ) {
pathsubsubexprs [ j - 1 ] [ action ] . push_back ( pathVars [ entries . getColumn ( ) ] [ j - 1 ] ) ;
}
}
}
rowindex + + ;
}
}
rowindex + + ;
}
smtSolver - > add ( storm : : expressions : : implies ( reachVarExpressions . at ( state ) , pathVars . at ( state ) . back ( ) ) ) ;
for ( uint64_t j = 1 ; j < k ; + + j ) {
std : : vector < storm : : expressions : : Expression > pathsubexprs ;
for ( uint64_t j = 1 ; j < k ; + + j ) {
std : : vector < storm : : expressions : : Expression > pathsubexprs ;
for ( uint64_t action = 0 ; action < pomdp . getNumberOfChoices ( state ) ; + + action ) {
pathsubexprs . push_back ( actionSelectionVarExpressions . at ( pomdp . getObservation ( state ) ) . at ( action ) & & storm : : expressions : : disjunction ( pathsubsubexprs [ j - 1 ] [ action ] ) ) ;
for ( uint64_t action = 0 ; action < pomdp . getNumberOfChoices ( state ) ; + + action ) {
pathsubexprs . push_back ( actionSelectionVarExpressions . at ( pomdp . getObservation ( state ) ) . at ( action ) & & storm : : expressions : : disjunction ( pathsubsubexprs [ j - 1 ] [ action ] ) ) ;
}
pathsubexprs . push_back ( switchVarExpressions . at ( pomdp . getObservation ( state ) ) ) ;
smtSolver - > add ( storm : : expressions : : iff ( pathVars [ state ] [ j ] , storm : : expressions : : disjunction ( pathsubexprs ) ) ) ;
}
}
pathsubexprs . push_back ( switchVarExpressions . at ( pomdp . getObservation ( state ) ) ) ;
smtSolver - > add ( storm : : expressions : : iff ( pathVars [ state ] [ j ] , storm : : expressions : : disjunction ( pathsubexprs ) ) ) ;
}
}
} else {
rowindex + = pomdp . getNumberOfChoices ( state ) ;
}
}
}
}
@ -199,17 +239,47 @@ namespace storm {
+ + obs ;
+ + obs ;
}
}
for ( uint64_t obs = 0 ; obs < pomdp . getNrObservations ( ) ; + + obs ) {
smtSolver - > add ( storm : : expressions : : implies ( switchVarExpressions [ obs ] , storm : : expressions : : disjunction ( reachVarExpressionsPerObservation [ obs ] ) ) ) ;
}
if ( ! lookaheadConstraintsRequired ) {
uint64_t rowIndex = 0 ;
for ( uint64_t state = 0 ; state < pomdp . getNumberOfStates ( ) ; + + state ) {
uint64_t enabledActions = pomdp . getNumberOfChoices ( state ) ;
if ( ! surelyReachSinkStates . get ( state ) ) {
std : : vector < storm : : expressions : : Expression > successorVars ;
for ( uint64_t act = 0 ; act < enabledActions ; + + act ) {
for ( auto const & entries : pomdp . getTransitionMatrix ( ) . getRow ( rowIndex ) ) {
successorVars . push_back ( reachVarExpressions [ entries . getColumn ( ) ] ) ;
}
rowIndex + + ;
}
successorVars . push_back ( ! switchVars [ pomdp . getObservation ( state ) ] ) ;
smtSolver - > add ( storm : : expressions : : implies ( storm : : expressions : : conjunction ( successorVars ) , reachVarExpressions [ state ] ) ) ;
} else {
rowIndex + = enabledActions ;
}
}
} else {
STORM_LOG_WARN ( " Some optimization not implemented yet. " ) ;
}
// TODO: Update found schedulers if k is increased.
// TODO: Update found schedulers if k is increased.
}
}
template < typename ValueType >
template < typename ValueType >
bool MemlessStrategySearchQualitative < ValueType > : : analyze ( uint64_t k , storm : : storage : : BitVector const & oneOfTheseStates , storm : : storage : : BitVector const & allOfTheseStates ) {
bool MemlessStrategySearchQualitative < ValueType > : : analyze ( uint64_t k , storm : : storage : : BitVector const & oneOfTheseStates , storm : : storage : : BitVector const & allOfTheseStates ) {
stats . initializeSolverTimer . start ( ) ;
if ( k < maxK ) {
if ( k < maxK ) {
initialize ( k ) ;
initialize ( k ) ;
maxK = k ;
maxK = k ;
}
}
uint64_t maximalNrActions = 8 ;
STORM_LOG_WARN ( " We have hardcoded (an upper bound on) the number of actions " ) ;
std : : vector < storm : : expressions : : Expression > atLeastOneOfStates ;
std : : vector < storm : : expressions : : Expression > atLeastOneOfStates ;
for ( uint64_t state : oneOfTheseStates ) {
for ( uint64_t state : oneOfTheseStates ) {
@ -225,112 +295,117 @@ namespace storm {
}
}
smtSolver - > push ( ) ;
smtSolver - > push ( ) ;
uint64_t obs = 0 ;
for ( auto const & statesForObservation : statesPerObservation ) {
smtSolver - > add ( schedulerVariableExpressions [ obs ] < = schedulerForObs . size ( ) ) ;
+ + obs ;
}
std : : vector < storm : : expressions : : Expression > updateForObservationExpressions ;
for ( uint64_t ob = 0 ; ob < pomdp . getNrObservations ( ) ; + + ob ) {
for ( uint64_t ob = 0 ; ob < pomdp . getNrObservations ( ) ; + + ob ) {
updateForObservationExpressions . push_back ( storm : : expressions : : disjunction ( reachVarExpressionsPerObservation [ ob ] ) ) ;
schedulerForObs . push_back ( std : : vector < uint64_t > ( ) ) ;
schedulerForObs . push_back ( std : : vector < uint64_t > ( ) ) ;
}
}
for ( uint64_t obs = 0 ; obs < pomdp . getNrObservations ( ) ; + + obs ) {
auto constant = expressionManager - > integer ( schedulerForObs [ obs ] . size ( ) ) ;
smtSolver - > add ( schedulerVariableExpressions [ obs ] < = constant ) ;
smtSolver - > add ( storm : : expressions : : iff ( observationUpdatedExpressions [ obs ] , updateForObservationExpressions [ obs ] ) ) ;
}
InternalObservationScheduler scheduler ;
InternalObservationScheduler scheduler ;
scheduler . switchObservations = storm : : storage : : BitVector ( pomdp . getNrObservations ( ) ) ;
scheduler . switchObservations = storm : : storage : : BitVector ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector newObservations ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector newObservationsAfterSwitch ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector observations ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector observations ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector observationsAfterSwitch ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector observationsAfterSwitch ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector remainingstates ( pomdp . getNumberOfStates ( ) ) ;
storm : : storage : : BitVector observationUpdated ( pomdp . getNrObservations ( ) ) ;
storm : : storage : : BitVector coveredStates ( pomdp . getNumberOfStates ( ) ) ;
storm : : storage : : BitVector coveredStatesAfterSwitch ( pomdp . getNumberOfStates ( ) ) ;
stats . initializeSolverTimer . stop ( ) ;
STORM_LOG_INFO ( " Start iterative solver... " ) ;
uint64_t iterations = 0 ;
uint64_t iterations = 0 ;
while ( true ) {
while ( true ) {
scheduler . clear ( ) ;
stats . incrementOuterIterations ( ) ;
scheduler . reset ( pomdp . getNrObservations ( ) , maximalNrActions ) ;
observations . clear ( ) ;
observations . clear ( ) ;
observationsAfterSwitch . clear ( ) ;
observationsAfterSwitch . clear ( ) ;
remainingstates . clear ( ) ;
coveredStates . clear ( ) ;
coveredStatesAfterSwitch . clear ( ) ;
observationUpdated . clear ( ) ;
bool newSchedulerDiscovered = false ;
while ( true ) {
while ( true ) {
+ + iterations ;
+ + iterations ;
if ( options . isExportSATSet ( ) ) {
STORM_LOG_DEBUG ( " Export SMT Solver Call ( " < < iterations < < " ) " ) ;
std : : string filepath = options . getExportSATCallsPath ( ) + " call_ " + std : : to_string ( iterations ) + " .smt2 " ;
std : : ofstream filestream ;
storm : : utility : : openFile ( filepath , filestream ) ;
filestream < < smtSolver - > getSmtLibString ( ) < < std : : endl ;
storm : : utility : : closeFile ( filestream ) ;
}
STORM_LOG_DEBUG ( " Call to SMT Solver ( " < < iterations < < " ) " ) ;
auto result = smtSolver - > check ( ) ;
uint64_t i = 0 ;
if ( result = = storm : : solver : : SmtSolver : : CheckResult : : Unknown ) {
STORM_LOG_THROW ( false , storm : : exceptions : : UnexpectedException , " SMT solver yielded an unexpected result " ) ;
} else if ( result = = storm : : solver : : SmtSolver : : CheckResult : : Unsat ) {
STORM_LOG_DEBUG ( " Unsatisfiable! " ) ;
bool foundScheduler = this - > smtCheck ( iterations ) ;
if ( ! foundScheduler ) {
break ;
break ;
}
}
STORM_LOG_DEBUG ( " Satisfying assignment: " ) ;
STORM_LOG_TRACE ( smtSolver - > getModelAsValuation ( ) . toString ( true ) ) ;
newSchedulerDiscovered = true ;
stats . updateExtensionSolverTime . start ( ) ;
auto model = smtSolver - > getModel ( ) ;
auto model = smtSolver - > getModel ( ) ;
newObservationsAfterSwitch . clear ( ) ;
newObservations . clear ( ) ;
observations . clear ( ) ;
observationsAfterSwitch . clear ( ) ;
remainingstates . clear ( ) ;
scheduler . clear ( ) ;
uint64_t obs = 0 ;
for ( auto ov : observationUpdatedVariables ) {
if ( ! observationUpdated . get ( obs ) & & model - > getBooleanValue ( ov ) ) {
STORM_LOG_TRACE ( " New observation updated: " < < obs ) ;
observationUpdated . set ( obs ) ;
}
obs + + ;
}
uint64_t i = 0 ;
for ( auto rv : reachVars ) {
for ( auto rv : reachVars ) {
if ( model - > getBooleanValue ( rv ) ) {
if ( ! coveredStates . get ( i ) & & model - > getBooleanValue ( rv ) ) {
STORM_LOG_TRACE ( " New state: " < < i ) ;
smtSolver - > add ( rv . getExpression ( ) ) ;
smtSolver - > add ( rv . getExpression ( ) ) ;
observations . set ( pomdp . getObservation ( i ) ) ;
} else {
remainingstates . set ( i ) ;
newObservations . set ( pomdp . getObservation ( i ) ) ;
coveredStates . set ( i ) ;
}
}
+ + i ;
+ + i ;
}
}
i = 0 ;
i = 0 ;
for ( auto rv : continuationVars ) {
for ( auto rv : continuationVars ) {
if ( model - > getBooleanValue ( rv ) ) {
if ( ! coveredStatesAfterSwitch . get ( i ) & & model - > getBooleanValue ( rv ) ) {
smtSolver - > add ( rv . getExpression ( ) ) ;
smtSolver - > add ( rv . getExpression ( ) ) ;
observationsAfterSwitch . set ( pomdp . getObservation ( i ) ) ;
if ( ! observationsAfterSwitch . get ( pomdp . getObservation ( i ) ) ) {
newObservationsAfterSwitch . set ( pomdp . getObservation ( i ) ) ;
}
+ + i ;
}
}
+ + i ;
}
}
if ( options . computeTraceOutput ( ) ) {
if ( options . computeTraceOutput ( ) ) {
detail : : printRelevantInfoFromModel ( model , reachVars , continuationVars ) ;
detail : : printRelevantInfoFromModel ( model , reachVars , continuationVars ) ;
}
}
// TODO do not repush everyting to the solver.
std : : vector < storm : : expressions : : Expression > schedulerSoFar ;
uint64_t obs = 0 ;
for ( auto const & actionSelectionVarsForObs : actionSelectionVars ) {
scheduler . actions . push_back ( std : : set < uint64_t > ( ) ) ;
if ( observations . get ( obs ) ) {
for ( uint64_t act = 0 ; act < actionSelectionVarsForObs . size ( ) ; + + act ) {
auto const & asv = actionSelectionVarsForObs [ act ] ;
if ( model - > getBooleanValue ( asv ) ) {
scheduler . actions . back ( ) . insert ( act ) ;
schedulerSoFar . push_back ( actionSelectionVarExpressions [ obs ] [ act ] ) ;
}
}
if ( model - > getBooleanValue ( switchVars [ obs ] ) ) {
scheduler . switchObservations . set ( obs ) ;
schedulerSoFar . push_back ( switchVarExpressions [ obs ] ) ;
for ( auto obs : newObservations ) {
auto const & actionSelectionVarsForObs = actionSelectionVars [ obs ] ;
observations . set ( obs ) ;
for ( uint64_t act = 0 ; act < actionSelectionVarsForObs . size ( ) ; + + act ) {
if ( model - > getBooleanValue ( actionSelectionVarsForObs [ act ] ) ) {
scheduler . actions [ obs ] . set ( act ) ;
smtSolver - > add ( actionSelectionVarExpressions [ obs ] [ act ] ) ;
} else {
} else {
schedulerSoFar . push_back ( ! switch VarExpressions [ obs ] ) ;
smtSolver - > add ( ! actionSelectionVarExpressions [ obs ] [ act ] ) ;
}
}
}
}
if ( observationsAfterSwitch . get ( obs ) ) {
scheduler . schedulerRef . push_back ( model - > getIntegerValue ( schedulerVariables [ obs ] ) ) ;
schedulerSoFar . push_back ( schedulerVariableExpressions [ obs ] = = expressionManager - > integer ( scheduler . schedulerRef . back ( ) ) ) ;
if ( model - > getBooleanValue ( switchVars [ obs ] ) ) {
scheduler . switchObservations . set ( obs ) ;
smtSolver - > add ( switchVarExpressions [ obs ] ) ;
} else {
} else {
scheduler . schedulerRef . push_back ( 0 ) ;
smtSolver - > add ( ! switchVarExpressions [ obs ] ) ;
}
}
obs + + ;
}
for ( auto obs : newObservationsAfterSwitch ) {
observationsAfterSwitch . set ( obs ) ;
scheduler . schedulerRef [ obs ] = model - > getIntegerValue ( schedulerVariables [ obs ] ) ;
smtSolver - > add ( schedulerVariableExpressions [ obs ] = = expressionManager - > integer ( scheduler . schedulerRef . back ( ) ) ) ;
}
}
if ( options . computeTraceOutput ( ) ) {
if ( options . computeTraceOutput ( ) ) {
@ -341,56 +416,82 @@ namespace storm {
}
}
std : : vector < storm : : expressions : : Expression > remainingExpressions ;
std : : vector < storm : : expressions : : Expression > remainingExpressions ;
for ( auto index : remainingstates ) {
remainingExpressions . push_back ( reachVarExpressions [ index ] ) ;
for ( auto index : ~ coveredStates ) {
if ( observationUpdated . get ( pomdp . getObservation ( index ) ) ) {
remainingExpressions . push_back ( reachVarExpressions [ index ] ) ;
}
}
for ( auto index : ~ observationUpdated ) {
remainingExpressions . push_back ( observationUpdatedExpressions [ index ] ) ;
}
if ( remainingExpressions . empty ( ) ) {
stats . updateExtensionSolverTime . stop ( ) ;
break ;
}
}
// Add scheduler
// Add scheduler
smtSolver - > add ( storm : : expressions : : conjunction ( schedulerSoFar ) ) ;
//std::cout << storm::expressions::disjunction(remainingExpressions) << std::endl;
smtSolver - > add ( storm : : expressions : : disjunction ( remainingExpressions ) ) ;
smtSolver - > add ( storm : : expressions : : disjunction ( remainingExpressions ) ) ;
stats . updateExtensionSolverTime . stop ( ) ;
}
}
if ( scheduler . empty ( ) ) {
if ( ! newSchedulerDiscovered ) {
break ;
break ;
}
}
smtSolver - > pop ( ) ;
smtSolver - > pop ( ) ;
if ( options . computeDebugOutput ( ) ) {
if ( options . computeDebugOutput ( ) ) {
printCoveredStates ( remainings tates) ;
printCoveredStates ( ~ coveredS tates) ;
// generates info output, but here we only want it for debug level.
// generates info output, but here we only want it for debug level.
// For consistency, all output on info level.
// For consistency, all output on info level.
STORM_LOG_DEBUG ( " the scheduler: " ) ;
STORM_LOG_DEBUG ( " the scheduler: " ) ;
scheduler . printForObservations ( observations , observationsAfterSwitch ) ;
scheduler . printForObservations ( observations , observationsAfterSwitch ) ;
}
}
std : : vector < storm : : expressions : : Expression > remainingExpressions ;
for ( auto index : remainingstates ) {
remainingExpressions . push_back ( reachVarExpressions [ index ] ) ;
}
stats . winningRegionUpdatesTimer . start ( ) ;
storm : : storage : : BitVector updated ( observations . size ( ) ) ;
for ( uint64_t observation = 0 ; observation < pomdp . getNrObservations ( ) ; + + observation ) {
for ( uint64_t observation = 0 ; observation < pomdp . getNrObservations ( ) ; + + observation ) {
storm : : storage : : BitVector update = storm : : storage : : BitVector ( statesPerObservation [ observation ] . size ( ) ) ;
STORM_LOG_TRACE ( " consider observation " < < observation ) ;
storm : : storage : : BitVector update ( statesPerObservation [ observation ] . size ( ) ) ;
uint64_t i = 0 ;
uint64_t i = 0 ;
for ( uint64_t state : statesPerObservation [ observation ] ) {
for ( uint64_t state : statesPerObservation [ observation ] ) {
if ( ! remainings tates. get ( state ) ) {
if ( coveredS tates. get ( state ) ) {
update . set ( i ) ;
update . set ( i ) ;
}
}
+ + i ;
}
if ( ! update . empty ( ) ) {
STORM_LOG_TRACE ( " Update Winning Region: Observation " < < observation < < " with update " < < update ) ;
bool updateResult = winningRegion . update ( observation , update ) ;
STORM_LOG_TRACE ( " Region changed: " < < updateResult ) ;
if ( updateResult ) {
updated . set ( observation ) ;
updateForObservationExpressions [ observation ] = winningRegion . extensionExpression ( observation , reachVarExpressionsPerObservation [ observation ] ) ;
}
}
}
winningRegion . update ( observation , update ) ;
+ + i ;
}
}
STORM_LOG_ASSERT ( ! updated . empty ( ) , " The strategy should be new in at least one place " ) ;
stats . winningRegionUpdatesTimer . stop ( ) ;
smtSolver - > add ( storm : : expressions : : disjunction ( remainingExpressions ) ) ;
if ( options . computeDebugOutput ( ) ) {
winningRegion . print ( ) ;
}
stats . updateNewStrategySolverTime . start ( ) ;
uint64_t obs = 0 ;
uint64_t obs = 0 ;
for ( auto const & statesForObservation : statesPerObservation ) {
for ( auto const & statesForObservation : statesPerObservation ) {
if ( observations . get ( obs ) ) {
if ( observations . get ( obs ) & & updated . get ( obs ) ) {
STORM_LOG_DEBUG ( " We have a new policy ( " < < finalSchedulers . size ( ) < < " ) for states with observation " < < obs < < " . " ) ;
STORM_LOG_DEBUG ( " We have a new policy ( " < < finalSchedulers . size ( ) < < " ) for states with observation " < < obs < < " . " ) ;
assert ( schedulerForObs . size ( ) > obs ) ;
assert ( schedulerForObs . size ( ) > obs ) ;
schedulerForObs [ obs ] . push_back ( finalSchedulers . size ( ) ) ;
schedulerForObs [ obs ] . push_back ( finalSchedulers . size ( ) ) ;
STORM_LOG_DEBUG ( " We now have " < < schedulerForObs [ obs ] . size ( ) < < " policies for states with observation " < < obs ) ;
STORM_LOG_DEBUG ( " We now have " < < schedulerForObs [ obs ] . size ( ) < < " policies for states with observation " < < obs ) ;
for ( auto const & state : statesForObservation ) {
for ( auto const & state : statesForObservation ) {
if ( remainings tates. get ( state ) ) {
if ( ! coveredS tates. get ( state ) ) {
auto constant = expressionManager - > integer ( schedulerForObs [ obs ] . size ( ) ) ;
auto constant = expressionManager - > integer ( schedulerForObs [ obs ] . size ( ) ) ;
smtSolver - > add ( ! ( continuationVarExpressions [ state ] & & ( schedulerVariableExpressions [ obs ] = = constant ) ) ) ;
smtSolver - > add ( ! ( continuationVarExpressions [ state ] & & ( schedulerVariableExpressions [ obs ] = = constant ) ) ) ;
}
}
@ -399,14 +500,20 @@ namespace storm {
+ + obs ;
+ + obs ;
}
}
finalSchedulers . push_back ( scheduler ) ;
finalSchedulers . push_back ( scheduler ) ;
smtSolver - > push ( ) ;
smtSolver - > push ( ) ;
for ( uint64_t obs = 0 ; obs < pomdp . getNrObservations ( ) ; + + obs ) {
for ( uint64_t obs = 0 ; obs < pomdp . getNrObservations ( ) ; + + obs ) {
auto constant = expressionManager - > integer ( schedulerForObs [ obs ] . size ( ) ) ;
auto constant = expressionManager - > integer ( schedulerForObs [ obs ] . size ( ) ) ;
smtSolver - > add ( schedulerVariableExpressions [ obs ] < = constant ) ;
smtSolver - > add ( schedulerVariableExpressions [ obs ] < = constant ) ;
smtSolver - > add ( storm : : expressions : : iff ( observationUpdatedExpressions [ obs ] , updateForObservationExpressions [ obs ] ) ) ;
}
}
stats . updateNewStrategySolverTime . stop ( ) ;
}
}
winningRegion . print ( ) ;
return true ;
return true ;
}
}
@ -423,21 +530,53 @@ namespace storm {
}
}
template < typename ValueType >
void MemlessStrategySearchQualitative < ValueType > : : printScheduler ( std : : vector < InternalObservationScheduler > const & ) {
}
template < typename ValueType >
template < typename ValueType >
void MemlessStrategySearchQualitative < ValueType > : : printScheduler ( std : : vector < InternalObservationScheduler > const & ) {
void MemlessStrategySearchQualitative < ValueType > : : finalizeStatistics ( ) {
}
}
template < typename ValueType >
typename MemlessStrategySearchQualitative < ValueType > : : Statistics const & MemlessStrategySearchQualitative < ValueType > : : getStatistics ( ) const {
return stats ;
}
template < typename ValueType >
template < typename ValueType >
storm : : expressions : : Expression const & MemlessStrategySearchQualitative < ValueType > : : getDoneActionExpression ( uint64_t obs ) const {
return actionSelectionVarExpressions [ obs ] . back ( ) ;
}
bool MemlessStrategySearchQualitative < ValueType > : : smtCheck ( uint64_t iteration ) {
if ( options . isExportSATSet ( ) ) {
STORM_LOG_DEBUG ( " Export SMT Solver Call ( " < < iteration < < " ) " ) ;
std : : string filepath = options . getExportSATCallsPath ( ) + " call_ " + std : : to_string ( iteration ) + " .smt2 " ;
std : : ofstream filestream ;
storm : : utility : : openFile ( filepath , filestream ) ;
filestream < < smtSolver - > getSmtLibString ( ) < < std : : endl ;
storm : : utility : : closeFile ( filestream ) ;
}
STORM_LOG_DEBUG ( " Call to SMT Solver ( " < < iteration < < " ) " ) ;
stats . smtCheckTimer . start ( ) ;
auto result = smtSolver - > check ( ) ;
stats . smtCheckTimer . stop ( ) ;
stats . incrementSmtChecks ( ) ;
if ( result = = storm : : solver : : SmtSolver : : CheckResult : : Unknown ) {
STORM_LOG_THROW ( false , storm : : exceptions : : UnexpectedException , " SMT solver yielded an unexpected result " ) ;
} else if ( result = = storm : : solver : : SmtSolver : : CheckResult : : Unsat ) {
STORM_LOG_DEBUG ( " Unsatisfiable! " ) ;
return false ;
}
STORM_LOG_DEBUG ( " Satisfying assignment: " ) ;
STORM_LOG_TRACE ( smtSolver - > getModelAsValuation ( ) . toString ( true ) ) ;
return true ;
}
template class MemlessStrategySearchQualitative < double > ;
template class MemlessStrategySearchQualitative < double > ;
template class MemlessStrategySearchQualitative < storm : : RationalNumber > ;
template class MemlessStrategySearchQualitative < storm : : RationalNumber > ;
}
}
}
}