@ -136,6 +136,64 @@ namespace storm {
template < typename ValueType , typename StateType >
template < typename ValueType , typename StateType >
std : : vector < StateType > PrismNextStateGenerator < ValueType , StateType > : : getInitialStates ( StateToIdCallback const & stateToIdCallback ) {
std : : vector < StateType > PrismNextStateGenerator < ValueType , StateType > : : getInitialStates ( StateToIdCallback const & stateToIdCallback ) {
std : : vector < StateType > initialStateIndices ;
// If all states are initial, we can simplify the enumeration substantially.
if ( program . hasInitialConstruct ( ) & & program . getInitialConstruct ( ) . getInitialStatesExpression ( ) . isTrue ( ) ) {
CompressedState initialState ( this - > variableInformation . getTotalBitOffset ( true ) ) ;
std : : vector < int_fast64_t > currentIntegerValues ;
currentIntegerValues . reserve ( this - > variableInformation . integerVariables . size ( ) ) ;
for ( auto const & variable : this - > variableInformation . integerVariables ) {
STORM_LOG_THROW ( variable . lowerBound < = variable . upperBound , storm : : exceptions : : InvalidArgumentException , " Expecting variable with non-empty set of possible values. " ) ;
currentIntegerValues . emplace_back ( 0 ) ;
initialState . setFromInt ( variable . bitOffset , variable . bitWidth , 0 ) ;
}
initialStateIndices . emplace_back ( stateToIdCallback ( initialState ) ) ;
bool done = false ;
while ( ! done ) {
bool changedBooleanVariable = false ;
for ( auto const & booleanVariable : this - > variableInformation . booleanVariables ) {
if ( initialState . get ( booleanVariable . bitOffset ) ) {
initialState . set ( booleanVariable . bitOffset ) ;
changedBooleanVariable = true ;
break ;
} else {
initialState . set ( booleanVariable . bitOffset , false ) ;
}
}
bool changedIntegerVariable = false ;
if ( changedBooleanVariable ) {
initialStateIndices . emplace_back ( stateToIdCallback ( initialState ) ) ;
} else {
for ( uint64_t integerVariableIndex = 0 ; integerVariableIndex < this - > variableInformation . integerVariables . size ( ) ; + + integerVariableIndex ) {
auto const & integerVariable = this - > variableInformation . integerVariables [ integerVariableIndex ] ;
if ( currentIntegerValues [ integerVariableIndex ] < integerVariable . upperBound - integerVariable . lowerBound ) {
+ + currentIntegerValues [ integerVariableIndex ] ;
changedIntegerVariable = true ;
} else {
currentIntegerValues [ integerVariableIndex ] = integerVariable . lowerBound ;
}
initialState . setFromInt ( integerVariable . bitOffset , integerVariable . bitWidth , currentIntegerValues [ integerVariableIndex ] ) ;
if ( changedIntegerVariable ) {
break ;
}
}
}
if ( changedIntegerVariable ) {
initialStateIndices . emplace_back ( stateToIdCallback ( initialState ) ) ;
}
done = ! changedBooleanVariable & & ! changedIntegerVariable ;
}
STORM_LOG_DEBUG ( " Enumerated " < < initialStateIndices . size ( ) < < " initial states using brute force enumeration. " ) ;
} else {
// Prepare an SMT solver to enumerate all initial states.
// Prepare an SMT solver to enumerate all initial states.
storm : : utility : : solver : : SmtSolverFactory factory ;
storm : : utility : : solver : : SmtSolverFactory factory ;
std : : unique_ptr < storm : : solver : : SmtSolver > solver = factory . create ( program . getManager ( ) ) ;
std : : unique_ptr < storm : : solver : : SmtSolver > solver = factory . create ( program . getManager ( ) ) ;
@ -147,7 +205,6 @@ namespace storm {
solver - > add ( program . getInitialStatesExpression ( ) ) ;
solver - > add ( program . getInitialStatesExpression ( ) ) ;
// Proceed ss long as the solver can still enumerate initial states.
// Proceed ss long as the solver can still enumerate initial states.
std : : vector < StateType > initialStateIndices ;
while ( solver - > check ( ) = = storm : : solver : : SmtSolver : : CheckResult : : Sat ) {
while ( solver - > check ( ) = = storm : : solver : : SmtSolver : : CheckResult : : Sat ) {
// Create fresh state.
// Create fresh state.
CompressedState initialState ( this - > variableInformation . getTotalBitOffset ( true ) ) ;
CompressedState initialState ( this - > variableInformation . getTotalBitOffset ( true ) ) ;
@ -180,6 +237,9 @@ namespace storm {
solver - > add ( blockingExpression ) ;
solver - > add ( blockingExpression ) ;
}
}
STORM_LOG_DEBUG ( " Enumerated " < < initialStateIndices . size ( ) < < " initial states using SMT solving. " ) ;
}
return initialStateIndices ;
return initialStateIndices ;
}
}
@ -455,13 +515,29 @@ namespace storm {
return result ;
return result ;
}
}
template < typename ValueType , typename StateType >
void PrismNextStateGenerator < ValueType , StateType > : : generateSynchronizedDistribution ( storm : : storage : : BitVector const & state , ValueType const & probability , uint64_t position , std : : vector < std : : vector < std : : reference_wrapper < storm : : prism : : Command const > > : : const_iterator > const & iteratorList , storm : : builder : : jit : : Distribution < StateType , ValueType > & distribution , StateToIdCallback stateToIdCallback ) {
if ( storm : : utility : : isZero < ValueType > ( probability ) ) {
return ;
}
if ( position > = iteratorList . size ( ) ) {
StateType id = stateToIdCallback ( state ) ;
distribution . add ( id , probability ) ;
} else {
storm : : prism : : Command const & command = * iteratorList [ position ] ;
for ( uint_fast64_t j = 0 ; j < command . getNumberOfUpdates ( ) ; + + j ) {
storm : : prism : : Update const & update = command . getUpdate ( j ) ;
generateSynchronizedDistribution ( applyUpdate ( state , update ) , probability * this - > evaluator - > asRational ( update . getLikelihoodExpression ( ) ) , position + 1 , iteratorList , distribution , stateToIdCallback ) ;
}
}
}
template < typename ValueType , typename StateType >
template < typename ValueType , typename StateType >
std : : vector < Choice < ValueType > > PrismNextStateGenerator < ValueType , StateType > : : getLabeledChoices ( CompressedState const & state , StateToIdCallback stateToIdCallback ) {
std : : vector < Choice < ValueType > > PrismNextStateGenerator < ValueType , StateType > : : getLabeledChoices ( CompressedState const & state , StateToIdCallback stateToIdCallback ) {
std : : vector < Choice < ValueType > > result ;
std : : vector < Choice < ValueType > > result ;
storm : : builder : : jit : : Distribution < CompressedState , ValueType > currentDistribution ;
storm : : builder : : jit : : Distribution < CompressedState , ValueType > nextDistribution ;
for ( uint_fast64_t actionIndex : program . getSynchronizingActionIndices ( ) ) {
for ( uint_fast64_t actionIndex : program . getSynchronizingActionIndices ( ) ) {
boost : : optional < std : : vector < std : : vector < std : : reference_wrapper < storm : : prism : : Command const > > > > optionalActiveCommandLists = getActiveCommandsByActionIndex ( actionIndex ) ;
boost : : optional < std : : vector < std : : vector < std : : reference_wrapper < storm : : prism : : Command const > > > > optionalActiveCommandLists = getActiveCommandsByActionIndex ( actionIndex ) ;
@ -475,37 +551,14 @@ namespace storm {
iteratorList [ i ] = activeCommandList [ i ] . cbegin ( ) ;
iteratorList [ i ] = activeCommandList [ i ] . cbegin ( ) ;
}
}
storm : : builder : : jit : : Distribution < StateType , ValueType > distribution ;
// As long as there is one feasible combination of commands, keep on expanding it.
// As long as there is one feasible combination of commands, keep on expanding it.
bool done = false ;
bool done = false ;
while ( ! done ) {
while ( ! done ) {
currentDistribution . clear ( ) ;
nextDistribution . clear ( ) ;
currentDistribution . add ( state , storm : : utility : : one < ValueType > ( ) ) ;
for ( uint_fast64_t i = 0 ; i < iteratorList . size ( ) ; + + i ) {
storm : : prism : : Command const & command = * iteratorList [ i ] ;
for ( uint_fast64_t j = 0 ; j < command . getNumberOfUpdates ( ) ; + + j ) {
storm : : prism : : Update const & update = command . getUpdate ( j ) ;
for ( auto const & stateProbability : currentDistribution ) {
ValueType probability = stateProbability . getValue ( ) * this - > evaluator - > asRational ( update . getLikelihoodExpression ( ) ) ;
if ( ! storm : : utility : : isZero < ValueType > ( probability ) ) {
// Compute the new state under the current update and add it to the set of new target states.
CompressedState newTargetState = applyUpdate ( stateProbability . getState ( ) , update ) ;
nextDistribution . add ( newTargetState , probability ) ;
}
}
}
nextDistribution . compress ( ) ;
// If there is one more command to come, shift the target states one time step back.
if ( i < iteratorList . size ( ) - 1 ) {
currentDistribution = std : : move ( nextDistribution ) ;
}
}
distribution . clear ( ) ;
generateSynchronizedDistribution ( state , storm : : utility : : one < ValueType > ( ) , 0 , iteratorList , distribution , stateToIdCallback ) ;
distribution . compress ( ) ;
// At this point, we applied all commands of the current command combination and newTargetStates
// At this point, we applied all commands of the current command combination and newTargetStates
// contains all target states and their respective probabilities. That means we are now ready to
// contains all target states and their respective probabilities. That means we are now ready to
@ -529,9 +582,8 @@ namespace storm {
// Add the probabilities/rates to the newly created choice.
// Add the probabilities/rates to the newly created choice.
ValueType probabilitySum = storm : : utility : : zero < ValueType > ( ) ;
ValueType probabilitySum = storm : : utility : : zero < ValueType > ( ) ;
for ( auto const & stateProbability : nextDistribution ) {
StateType actualIndex = stateToIdCallback ( stateProbability . getState ( ) ) ;
choice . addProbability ( actualIndex , stateProbability . getValue ( ) ) ;
for ( auto const & stateProbability : distribution ) {
choice . addProbability ( stateProbability . getState ( ) , stateProbability . getValue ( ) ) ;
if ( this - > options . isExplorationChecksSet ( ) ) {
if ( this - > options . isExplorationChecksSet ( ) ) {
probabilitySum + = stateProbability . getValue ( ) ;
probabilitySum + = stateProbability . getValue ( ) ;
}
}