@ -136,48 +136,108 @@ namespace storm {
template < typename ValueType , typename StateType >
std : : vector < StateType > PrismNextStateGenerator < ValueType , StateType > : : getInitialStates ( StateToIdCallback const & stateToIdCallback ) {
// Prepare an SMT solver to enumerate all initial states.
storm : : utility : : solver : : SmtSolverFactory factory ;
std : : unique_ptr < storm : : solver : : SmtSolver > solver = factory . create ( program . getManager ( ) ) ;
std : : vector < storm : : expressions : : Expression > rangeExpressions = program . getAllRangeExpressions ( ) ;
for ( auto const & expression : rangeExpressions ) {
solver - > add ( expression ) ;
}
solver - > add ( program . getInitialStatesExpression ( ) ) ;
// Proceed ss long as the solver can still enumerate initial states.
std : : vector < StateType > initialStateIndices ;
while ( solver - > check ( ) = = storm : : solver : : SmtSolver : : CheckResult : : Sat ) {
// Create fresh state.
// If all states are initial, we can simplify the enumeration substantially.
if ( program . hasInitialConstruct ( ) & & program . getInitialConstruct ( ) . getInitialStatesExpression ( ) . isTrue ( ) ) {
CompressedState initialState ( this - > variableInformation . getTotalBitOffset ( true ) ) ;
// Read variable assignment from the solution of the solver. Also, create an expression we can use to
// prevent the variable assignment from being enumerated again.
storm : : expressions : : Expression blockingExpression ;
std : : shared_ptr < storm : : solver : : SmtSolver : : ModelReference > model = solver - > getModel ( ) ;
for ( auto const & booleanVariable : this - > variableInformation . booleanVariables ) {
bool variableValue = model - > getBooleanValue ( booleanVariable . variable ) ;
storm : : expressions : : Expression localBlockingExpression = variableValue ? ! booleanVariable . variable : booleanVariable . variable ;
blockingExpression = blockingExpression . isInitialized ( ) ? blockingExpression | | localBlockingExpression : localBlockingExpression ;
initialState . set ( booleanVariable . bitOffset , variableValue ) ;
std : : vector < int_fast64_t > currentIntegerValues ;
currentIntegerValues . reserve ( this - > variableInformation . integerVariables . size ( ) ) ;
for ( auto const & variable : this - > variableInformation . integerVariables ) {
STORM_LOG_THROW ( variable . lowerBound < = variable . upperBound , storm : : exceptions : : InvalidArgumentException , " Expecting variable with non-empty set of possible values. " ) ;
currentIntegerValues . emplace_back ( 0 ) ;
initialState . setFromInt ( variable . bitOffset , variable . bitWidth , 0 ) ;
}
for ( auto const & integerVariable : this - > variableInformation . integerVariables ) {
int_fast64_t variableValue = model - > getIntegerValue ( integerVariable . variable ) ;
storm : : expressions : : Expression localBlockingExpression = integerVariable . variable ! = model - > getManager ( ) . integer ( variableValue ) ;
blockingExpression = blockingExpression . isInitialized ( ) ? blockingExpression | | localBlockingExpression : localBlockingExpression ;
initialState . setFromInt ( integerVariable . bitOffset , integerVariable . bitWidth , static_cast < uint_fast64_t > ( variableValue - integerVariable . lowerBound ) ) ;
initialStateIndices . emplace_back ( stateToIdCallback ( initialState ) ) ;
bool done = false ;
while ( ! done ) {
bool changedBooleanVariable = false ;
for ( auto const & booleanVariable : this - > variableInformation . booleanVariables ) {
if ( initialState . get ( booleanVariable . bitOffset ) ) {
initialState . set ( booleanVariable . bitOffset ) ;
changedBooleanVariable = true ;
break ;
} else {
initialState . set ( booleanVariable . bitOffset , false ) ;
}
}
bool changedIntegerVariable = false ;
if ( changedBooleanVariable ) {
initialStateIndices . emplace_back ( stateToIdCallback ( initialState ) ) ;
} else {
for ( uint64_t integerVariableIndex = 0 ; integerVariableIndex < this - > variableInformation . integerVariables . size ( ) ; + + integerVariableIndex ) {
auto const & integerVariable = this - > variableInformation . integerVariables [ integerVariableIndex ] ;
if ( currentIntegerValues [ integerVariableIndex ] < integerVariable . upperBound - integerVariable . lowerBound ) {
+ + currentIntegerValues [ integerVariableIndex ] ;
changedIntegerVariable = true ;
} else {
currentIntegerValues [ integerVariableIndex ] = integerVariable . lowerBound ;
}
initialState . setFromInt ( integerVariable . bitOffset , integerVariable . bitWidth , currentIntegerValues [ integerVariableIndex ] ) ;
if ( changedIntegerVariable ) {
break ;
}
}
}
if ( changedIntegerVariable ) {
initialStateIndices . emplace_back ( stateToIdCallback ( initialState ) ) ;
}
done = ! changedBooleanVariable & & ! changedIntegerVariable ;
}
// Register initial state and return it.
StateType id = stateToIdCallback ( initialState ) ;
initialStateIndices . push_back ( id ) ;
STORM_LOG_DEBUG ( " Enumerated " < < initialStateIndices . size ( ) < < " initial states using brute force enumeration. " ) ;
} else {
// Prepare an SMT solver to enumerate all initial states.
storm : : utility : : solver : : SmtSolverFactory factory ;
std : : unique_ptr < storm : : solver : : SmtSolver > solver = factory . create ( program . getManager ( ) ) ;
// Block the current initial state to search for the next one.
if ( ! blockingExpression . isInitialized ( ) ) {
break ;
std : : vector < storm : : expressions : : Expression > rangeExpressions = program . getAllRangeExpressions ( ) ;
for ( auto const & expression : rangeExpressions ) {
solver - > add ( expression ) ;
}
solver - > add ( blockingExpression ) ;
solver - > add ( program . getInitialStatesExpression ( ) ) ;
// Proceed ss long as the solver can still enumerate initial states.
while ( solver - > check ( ) = = storm : : solver : : SmtSolver : : CheckResult : : Sat ) {
// Create fresh state.
CompressedState initialState ( this - > variableInformation . getTotalBitOffset ( true ) ) ;
// Read variable assignment from the solution of the solver. Also, create an expression we can use to
// prevent the variable assignment from being enumerated again.
storm : : expressions : : Expression blockingExpression ;
std : : shared_ptr < storm : : solver : : SmtSolver : : ModelReference > model = solver - > getModel ( ) ;
for ( auto const & booleanVariable : this - > variableInformation . booleanVariables ) {
bool variableValue = model - > getBooleanValue ( booleanVariable . variable ) ;
storm : : expressions : : Expression localBlockingExpression = variableValue ? ! booleanVariable . variable : booleanVariable . variable ;
blockingExpression = blockingExpression . isInitialized ( ) ? blockingExpression | | localBlockingExpression : localBlockingExpression ;
initialState . set ( booleanVariable . bitOffset , variableValue ) ;
}
for ( auto const & integerVariable : this - > variableInformation . integerVariables ) {
int_fast64_t variableValue = model - > getIntegerValue ( integerVariable . variable ) ;
storm : : expressions : : Expression localBlockingExpression = integerVariable . variable ! = model - > getManager ( ) . integer ( variableValue ) ;
blockingExpression = blockingExpression . isInitialized ( ) ? blockingExpression | | localBlockingExpression : localBlockingExpression ;
initialState . setFromInt ( integerVariable . bitOffset , integerVariable . bitWidth , static_cast < uint_fast64_t > ( variableValue - integerVariable . lowerBound ) ) ;
}
// Register initial state and return it.
StateType id = stateToIdCallback ( initialState ) ;
initialStateIndices . push_back ( id ) ;
// Block the current initial state to search for the next one.
if ( ! blockingExpression . isInitialized ( ) ) {
break ;
}
solver - > add ( blockingExpression ) ;
}
STORM_LOG_DEBUG ( " Enumerated " < < initialStateIndices . size ( ) < < " initial states using SMT solving. " ) ;
}
return initialStateIndices ;
@ -454,67 +514,60 @@ namespace storm {
return result ;
}
template < typename ValueType , typename StateType >
void PrismNextStateGenerator < ValueType , StateType > : : generateSynchronizedDistribution ( storm : : storage : : BitVector const & state , ValueType const & probability , uint64_t position , std : : vector < std : : vector < std : : reference_wrapper < storm : : prism : : Command const > > : : const_iterator > const & iteratorList , storm : : builder : : jit : : Distribution < StateType , ValueType > & distribution , StateToIdCallback stateToIdCallback ) {
if ( storm : : utility : : isZero < ValueType > ( probability ) ) {
return ;
}
if ( position > = iteratorList . size ( ) ) {
StateType id = stateToIdCallback ( state ) ;
distribution . add ( id , probability ) ;
} else {
storm : : prism : : Command const & command = * iteratorList [ position ] ;
for ( uint_fast64_t j = 0 ; j < command . getNumberOfUpdates ( ) ; + + j ) {
storm : : prism : : Update const & update = command . getUpdate ( j ) ;
generateSynchronizedDistribution ( applyUpdate ( state , update ) , probability * this - > evaluator - > asRational ( update . getLikelihoodExpression ( ) ) , position + 1 , iteratorList , distribution , stateToIdCallback ) ;
}
}
}
template < typename ValueType , typename StateType >
std : : vector < Choice < ValueType > > PrismNextStateGenerator < ValueType , StateType > : : getLabeledChoices ( CompressedState const & state , StateToIdCallback stateToIdCallback ) {
std : : vector < Choice < ValueType > > result ;
storm : : builder : : jit : : Distribution < CompressedState , ValueType > currentDistribution ;
storm : : builder : : jit : : Distribution < CompressedState , ValueType > nextDistribution ;
for ( uint_fast64_t actionIndex : program . getSynchronizingActionIndices ( ) ) {
boost : : optional < std : : vector < std : : vector < std : : reference_wrapper < storm : : prism : : Command const > > > > optionalActiveCommandLists = getActiveCommandsByActionIndex ( actionIndex ) ;
// Only process this action label, if there is at least one feasible solution.
if ( optionalActiveCommandLists ) {
std : : vector < std : : vector < std : : reference_wrapper < storm : : prism : : Command const > > > const & activeCommandList = optionalActiveCommandLists . get ( ) ;
std : : vector < std : : vector < std : : reference_wrapper < storm : : prism : : Command const > > : : const_iterator > iteratorList ( activeCommandList . size ( ) ) ;
// Initialize the list of iterators.
for ( size_t i = 0 ; i < activeCommandList . size ( ) ; + + i ) {
iteratorList [ i ] = activeCommandList [ i ] . cbegin ( ) ;
}
storm : : builder : : jit : : Distribution < StateType , ValueType > distribution ;
// As long as there is one feasible combination of commands, keep on expanding it.
bool done = false ;
while ( ! done ) {
currentDistribution . clear ( ) ;
nextDistribution . clear ( ) ;
distribution . clear ( ) ;
generateSynchronizedDistribution ( state , storm : : utility : : one < ValueType > ( ) , 0 , iteratorList , distribution , stateToIdCallback ) ;
distribution . compress ( ) ;
currentDistribution . add ( state , storm : : utility : : one < ValueType > ( ) ) ;
for ( uint_fast64_t i = 0 ; i < iteratorList . size ( ) ; + + i ) {
storm : : prism : : Command const & command = * iteratorList [ i ] ;
for ( uint_fast64_t j = 0 ; j < command . getNumberOfUpdates ( ) ; + + j ) {
storm : : prism : : Update const & update = command . getUpdate ( j ) ;
for ( auto const & stateProbability : currentDistribution ) {
ValueType probability = stateProbability . getValue ( ) * this - > evaluator - > asRational ( update . getLikelihoodExpression ( ) ) ;
if ( ! storm : : utility : : isZero < ValueType > ( probability ) ) {
// Compute the new state under the current update and add it to the set of new target states.
CompressedState newTargetState = applyUpdate ( stateProbability . getState ( ) , update ) ;
nextDistribution . add ( newTargetState , probability ) ;
}
}
}
nextDistribution . compress ( ) ;
// If there is one more command to come, shift the target states one time step back.
if ( i < iteratorList . size ( ) - 1 ) {
currentDistribution = std : : move ( nextDistribution ) ;
}
}
// At this point, we applied all commands of the current command combination and newTargetStates
// contains all target states and their respective probabilities. That means we are now ready to
// add the choice to the list of transitions.
result . push_back ( Choice < ValueType > ( actionIndex ) ) ;
// Now create the actual distribution.
Choice < ValueType > & choice = result . back ( ) ;
// Remember the choice label and origins only if we were asked to.
if ( this - > options . isBuildChoiceLabelsSet ( ) ) {
choice . addLabel ( program . getActionName ( actionIndex ) ) ;
@ -526,22 +579,21 @@ namespace storm {
}
choice . addOriginData ( boost : : any ( std : : move ( commandIndices ) ) ) ;
}
// Add the probabilities/rates to the newly created choice.
ValueType probabilitySum = storm : : utility : : zero < ValueType > ( ) ;
for ( auto const & stateProbability : nextDistribution ) {
StateType actualIndex = stateToIdCallback ( stateProbability . getState ( ) ) ;
choice . addProbability ( actualIndex , stateProbability . getValue ( ) ) ;
for ( auto const & stateProbability : distribution ) {
choice . addProbability ( stateProbability . getState ( ) , stateProbability . getValue ( ) ) ;
if ( this - > options . isExplorationChecksSet ( ) ) {
probabilitySum + = stateProbability . getValue ( ) ;
}
}
if ( this - > options . isExplorationChecksSet ( ) ) {
// Check that the resulting distribution is in fact a distribution.
STORM_LOG_THROW ( ! program . isDiscreteTimeModel ( ) | | ! this - > comparator . isConstant ( probabilitySum ) | | this - > comparator . isOne ( probabilitySum ) , storm : : exceptions : : WrongFormatException , " Sum of update probabilities do not some to one for some command (actually sum to " < < probabilitySum < < " ). " ) ;
}
// Create the state-action reward for the newly created choice.
for ( auto const & rewardModel : rewardModels ) {
ValueType stateActionRewardValue = storm : : utility : : zero < ValueType > ( ) ;
@ -554,7 +606,7 @@ namespace storm {
}
choice . addReward ( stateActionRewardValue ) ;
}
// Now, check whether there is one more command combination to consider.
bool movedIterator = false ;
for ( int_fast64_t j = iteratorList . size ( ) - 1 ; ! movedIterator & & j > = 0 ; - - j ) {
@ -566,12 +618,12 @@ namespace storm {
iteratorList [ j ] = activeCommandList [ j ] . begin ( ) ;
}
}
done = ! movedIterator ;
}
}
}
return result ;
}