@ -404,48 +404,10 @@ namespace storm {
}
template < typename ValueType >
std : : tuple < typename SparseMdpLearningModelChecker < ValueType > : : StateType , ValueType , ValueType > SparseMdpLearningModelChecker < ValueType > : : performLearningProcedure ( storm : : expressions : : Expression const & targetStateExpression , storm : : storage : : sparse : : StateStorage < StateType > & stateStorage , storm : : generator : : PrismNextStateGenerator < ValueType , StateType > & generator , std : : function < StateType ( storm : : generator : : CompressedState const & ) > const & stateToIdCallback , std : : vector < std : : vector < storm : : storage : : MatrixEntry < StateType , ValueType > > > & matrix , std : : vector < StateType > & rowGroupIndices , std : : vector < StateType > & stateToRowGroupMapping , std : : unordered_map < StateType , storm : : generator : : CompressedState > & unexploredStates , StateType const & unexploredMarker ) {
// Generate the initial state so we know where to start the simulation.
stateStorage . initialStateIndices = generator . getInitialStates ( stateToIdCallback ) ;
STORM_LOG_THROW ( stateStorage . initialStateIndices . size ( ) = = 1 , storm : : exceptions : : NotSupportedException , " Currently only models with one initial state are supported by the learning engine. " ) ;
StateType initialStateIndex = stateStorage . initialStateIndices . front ( ) ;
// A set storing all states in which to terminate the search.
boost : : container : : flat_set < StateType > terminalStates ;
// Vectors to store the lower/upper bounds for each action (in each state).
std : : vector < ValueType > lowerBoundsPerAction ;
std : : vector < ValueType > upperBoundsPerAction ;
std : : vector < ValueType > lowerBoundsPerState ;
std : : vector < ValueType > upperBoundsPerState ;
// Since we might run into end-components, we track a mapping from states in ECs to all leaving choices of
// that EC.
std : : unordered_map < StateType , ChoiceSetPointer > stateToLeavingChoicesOfEndComponent ;
// Now perform the actual sampling.
std : : vector < std : : pair < StateType , uint32_t > > stateActionStack ;
bool SparseMdpLearningModelChecker < ValueType > : : exploreState ( storm : : generator : : PrismNextStateGenerator < ValueType , StateType > & generator , std : : function < StateType ( storm : : generator : : CompressedState const & ) > const & stateToIdCallback , StateType const & currentStateId , storm : : generator : : CompressedState const & compressedState , StateType const & unexploredMarker , boost : : container : : flat_set < StateType > & terminalStates , std : : vector < std : : vector < storm : : storage : : MatrixEntry < StateType , ValueType > > > & matrix , std : : vector < StateType > & rowGroupIndices , std : : vector < StateType > & stateToRowGroupMapping , storm : : expressions : : Expression const & targetStateExpression , std : : vector < ValueType > & lowerBoundsPerAction , std : : vector < ValueType > & upperBoundsPerAction , std : : vector < ValueType > & lowerBoundsPerState , std : : vector < ValueType > & upperBoundsPerState , Statistics & stats ) {
std : : size_t iterations = 0 ;
std : : size_t maxPathLength = 0 ;
std : : size_t numberOfTargetStates = 0 ;
std : : size_t pathLengthUntilEndComponentDetection = 27 ;
bool convergenceCriterionMet = false ;
while ( ! convergenceCriterionMet ) {
// Start the search from the initial state.
stateActionStack . push_back ( std : : make_pair ( initialStateIndex , 0 ) ) ;
bool foundTerminalState = false ;
bool foundTargetState = false ;
while ( ! foundTerminalState ) {
StateType const & currentStateId = stateActionStack . back ( ) . first ;
STORM_LOG_TRACE ( " State on top of stack is: " < < currentStateId < < " . " ) ;
// If the state is not yet expanded, we need to retrieve its behaviors.
auto unexploredIt = unexploredStates . find ( currentStateId ) ;
if ( unexploredIt ! = unexploredStates . end ( ) ) {
STORM_LOG_TRACE ( " State was not yet explored. " ) ;
bool isTerminalState = false ;
bool isTargetState = false ;
// Map the unexplored state to a row group.
stateToRowGroupMapping [ currentStateId ] = rowGroupIndices . size ( ) - 1 ;
@ -453,18 +415,13 @@ namespace storm {
lowerBoundsPerState . push_back ( storm : : utility : : zero < ValueType > ( ) ) ;
upperBoundsPerState . push_back ( storm : : utility : : one < ValueType > ( ) ) ;
// We need to get the compressed state back from the id to explore it.
STORM_LOG_ASSERT ( unexploredIt ! = unexploredStates . end ( ) , " Unable to find unexplored state " < < currentStateId < < " . " ) ;
storm : : storage : : BitVector const & currentState = unexploredIt - > second ;
// Before generating the behavior of the state, we need to determine whether it's a target state that
// does not need to be expanded.
generator . load ( current State ) ;
generator . load ( compressedState ) ;
if ( generator . satisfies ( targetStateExpression ) ) {
STORM_LOG_TRACE ( " State does not need to be expanded, because it is a target state. +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ " ) ;
+ + numberOfTargetStates ;
foundTargetState = true ;
foundTerminalState = true ;
+ + stats . numberOfTargetStates ;
isTargetState = true ;
isTerminalState = true ;
} else {
STORM_LOG_TRACE ( " Exploring state. " ) ;
@ -474,20 +431,20 @@ namespace storm {
// Clumsily check whether we have found a state that forms a trivial BMEC.
if ( behavior . getNumberOfChoices ( ) = = 0 ) {
found TerminalState = true ;
is TerminalState = true ;
} else if ( behavior . getNumberOfChoices ( ) = = 1 ) {
auto const & onlyChoice = * behavior . begin ( ) ;
if ( onlyChoice . size ( ) = = 1 ) {
auto const & onlyEntry = * onlyChoice . begin ( ) ;
if ( onlyEntry . first = = currentStateId ) {
found TerminalState = true ;
is TerminalState = true ;
}
}
}
// If the state was neither a trivial (non-accepting) terminal state nor a target state, we
// need to store its behavior.
if ( ! found TerminalState) {
if ( ! is TerminalState) {
// Next, we insert the behavior into our matrix structure.
StateType startRow = matrix . size ( ) ;
matrix . resize ( startRow + behavior . getNumberOfChoices ( ) ) ;
@ -522,11 +479,11 @@ namespace storm {
}
}
if ( found TerminalState) {
STORM_LOG_TRACE ( " State does not need to be explored, because it is " < < ( found TargetState ? " a target state " : " a rejecting terminal state " ) < < " . " ) ;
if ( is TerminalState) {
STORM_LOG_TRACE ( " State does not need to be explored, because it is " < < ( is TargetState ? " a target state " : " a rejecting terminal state " ) < < " . " ) ;
terminalStates . insert ( currentStateId ) ;
if ( found TargetState) {
if ( is TargetState) {
lowerBoundsPerState . back ( ) = storm : : utility : : one < ValueType > ( ) ;
lowerBoundsPerAction . push_back ( storm : : utility : : one < ValueType > ( ) ) ;
upperBoundsPerAction . push_back ( storm : : utility : : one < ValueType > ( ) ) ;
@ -543,20 +500,38 @@ namespace storm {
rowGroupIndices . push_back ( matrix . size ( ) ) ;
}
// Now that we have explored the state, we can dispose of it.
return isTerminalState ;
}
template < typename ValueType >
bool SparseMdpLearningModelChecker < ValueType > : : samplePathFromState ( storm : : generator : : PrismNextStateGenerator < ValueType , StateType > & generator , std : : function < StateType ( storm : : generator : : CompressedState const & ) > const & stateToIdCallback , StateType initialStateIndex , std : : vector < std : : pair < StateType , uint32_t > > & stateActionStack , std : : unordered_map < StateType , storm : : generator : : CompressedState > & unexploredStates , StateType const & unexploredMarker , boost : : container : : flat_set < StateType > & terminalStates , std : : vector < std : : vector < storm : : storage : : MatrixEntry < StateType , ValueType > > > & matrix , std : : vector < StateType > & rowGroupIndices , std : : vector < StateType > & stateToRowGroupMapping , std : : unordered_map < StateType , ChoiceSetPointer > & stateToLeavingChoicesOfEndComponent , storm : : expressions : : Expression const & targetStateExpression , std : : vector < ValueType > & lowerBoundsPerAction , std : : vector < ValueType > & upperBoundsPerAction , std : : vector < ValueType > & lowerBoundsPerState , std : : vector < ValueType > & upperBoundsPerState , Statistics & stats ) {
// Start the search from the initial state.
stateActionStack . push_back ( std : : make_pair ( initialStateIndex , 0 ) ) ;
bool foundTerminalState = false ;
while ( ! foundTerminalState ) {
StateType const & currentStateId = stateActionStack . back ( ) . first ;
STORM_LOG_TRACE ( " State on top of stack is: " < < currentStateId < < " . " ) ;
// If the state is not yet explored, we need to retrieve its behaviors.
auto unexploredIt = unexploredStates . find ( currentStateId ) ;
if ( unexploredIt ! = unexploredStates . end ( ) ) {
STORM_LOG_TRACE ( " State was not yet explored. " ) ;
// Explore the previously unexplored state.
foundTerminalState = exploreState ( generator , stateToIdCallback , currentStateId , unexploredIt - > second ) ;
unexploredStates . erase ( unexploredIt ) ;
} else {
// If the state was already explored, we check whether it is a terminal state or not.
if ( terminalStates . find ( currentStateId ) ! = terminalStates . end ( ) ) {
STORM_LOG_TRACE ( " Found already explored terminal state: " < < currentStateId < < " . " ) ;
foundTerminalState = true ;
}
}
if ( foundTerminalState ) {
// Update the bounds along the path to the terminal state.
STORM_LOG_TRACE ( " Found terminal state, updating probabilities along path. " ) ;
updateProbabilitiesUsingStack ( stateActionStack , matrix , rowGroupIndices , stateToRowGroupMapping , lowerBoundsPerAction , upperBoundsPerAction , lowerBoundsPerState , upperBoundsPerState , unexploredMarker ) ;
} else {
// If the state was not a terminal state, we continue the path search and sample the next state.
if ( ! foundTerminalState ) {
std : : cout < < " (2) stack is: " < < std : : endl ;
for ( auto const & el : stateActionStack ) {
std : : cout < < el . first < < " -[ " < < el . second < < " ]-> " ;
@ -586,11 +561,11 @@ namespace storm {
// Put the successor state and a dummy action on top of the stack.
stateActionStack . emplace_back ( successor , 0 ) ;
maxPathLength = std : : max ( maxPathLength , stateActionStack . size ( ) ) ;
stats . maxPathLength = std : : max ( stats . maxPathLength , stateActionStack . size ( ) ) ;
// If the current path length exceeds the threshold and the model is a nondeterministic one, we
// perform an EC detection.
if ( stateActionStack . size ( ) > pathLengthUntilEndComponentDetection & & ! program . isDeterministicModel ( ) ) {
if ( stateActionStack . size ( ) > stats . pathLengthUntilEndComponentDetection & & ! program . isDeterministicModel ( ) ) {
detectEndComponents ( stateActionStack , terminalStates , matrix , rowGroupIndices , stateToRowGroupMapping , lowerBoundsPerAction , upperBoundsPerAction , lowerBoundsPerState , upperBoundsPerState , stateToLeavingChoicesOfEndComponent , unexploredMarker ) ;
// Abort the current search.
@ -601,40 +576,83 @@ namespace storm {
}
}
// Sanity check of results.
for ( StateType state = 0 ; state < stateToRowGroupMapping . size ( ) ; + + state ) {
if ( stateToRowGroupMapping [ state ] ! = unexploredMarker ) {
STORM_LOG_ASSERT ( lowerBoundsPerState [ stateToRowGroupMapping [ state ] ] < = upperBoundsPerState [ stateToRowGroupMapping [ state ] ] , " The bounds for state " < < state < < " are not in a sane relation: " < < lowerBoundsPerState [ stateToRowGroupMapping [ state ] ] < < " > " < < upperBoundsPerState [ stateToRowGroupMapping [ state ] ] < < " . " ) ;
return foundTerminalState ;
}
template < typename ValueType >
std : : tuple < typename SparseMdpLearningModelChecker < ValueType > : : StateType , ValueType , ValueType > SparseMdpLearningModelChecker < ValueType > : : performLearningProcedure ( storm : : expressions : : Expression const & targetStateExpression , storm : : storage : : sparse : : StateStorage < StateType > & stateStorage , storm : : generator : : PrismNextStateGenerator < ValueType , StateType > & generator , std : : function < StateType ( storm : : generator : : CompressedState const & ) > const & stateToIdCallback , std : : vector < std : : vector < storm : : storage : : MatrixEntry < StateType , ValueType > > > & matrix , std : : vector < StateType > & rowGroupIndices , std : : vector < StateType > & stateToRowGroupMapping , std : : unordered_map < StateType , storm : : generator : : CompressedState > & unexploredStates , StateType const & unexploredMarker ) {
// Generate the initial state so we know where to start the simulation.
stateStorage . initialStateIndices = generator . getInitialStates ( stateToIdCallback ) ;
STORM_LOG_THROW ( stateStorage . initialStateIndices . size ( ) = = 1 , storm : : exceptions : : NotSupportedException , " Currently only models with one initial state are supported by the learning engine. " ) ;
StateType initialStateIndex = stateStorage . initialStateIndices . front ( ) ;
// A set storing all states in which to terminate the search.
boost : : container : : flat_set < StateType > terminalStates ;
// Vectors to store the lower/upper bounds for each action (in each state).
std : : vector < ValueType > lowerBoundsPerAction ;
std : : vector < ValueType > upperBoundsPerAction ;
std : : vector < ValueType > lowerBoundsPerState ;
std : : vector < ValueType > upperBoundsPerState ;
// Since we might run into end-components, we track a mapping from states in ECs to all leaving choices of
// that EC.
std : : unordered_map < StateType , ChoiceSetPointer > stateToLeavingChoicesOfEndComponent ;
// Now perform the actual sampling.
std : : vector < std : : pair < StateType , uint32_t > > stateActionStack ;
Statistics stats ;
bool convergenceCriterionMet = false ;
while ( ! convergenceCriterionMet ) {
bool result = samplePathFromState ( generator , stateToIdCallback , initialStateIndex , stateActionStack , unexploredStates , unexploredMarker , terminalStates , matrix , rowGroupIndices , stateToRowGroupMapping , stateToLeavingChoicesOfEndComponent , targetStateExpression , lowerBoundsPerAction , upperBoundsPerAction , lowerBoundsPerState , upperBoundsPerState , stats ) ;
// If a terminal state was found, we update the probabilities along the path contained in the stack.
if ( result ) {
// Update the bounds along the path to the terminal state.
STORM_LOG_TRACE ( " Found terminal state, updating probabilities along path. " ) ;
updateProbabilitiesUsingStack ( stateActionStack , matrix , rowGroupIndices , stateToRowGroupMapping , lowerBoundsPerAction , upperBoundsPerAction , lowerBoundsPerState , upperBoundsPerState , unexploredMarker ) ;
} else {
// If not terminal state was found, the search aborted, possibly because of an EC-detection. In this
// case, we cannot update the probabilities.
STORM_LOG_TRACE ( " Did not find terminal state. " ) ;
}
// Sanity check of results.
for ( StateType state = 0 ; state < stateToRowGroupMapping . size ( ) ; + + state ) {
if ( stateToRowGroupMapping [ state ] ! = unexploredMarker ) {
std : : cout < < " state " < < state < < " (grp " < < stateToRowGroupMapping [ state ] < < " ) has bounds [ " < < lowerBoundsPerState [ stateToRowGroupMapping [ state ] ] < < " , " < < upperBoundsPerState [ stateToRowGroupMapping [ state ] ] < < " ], actions: " ;
for ( auto choice = rowGroupIndices [ stateToRowGroupMapping [ state ] ] ; choice < rowGroupIndices [ stateToRowGroupMapping [ state ] + 1 ] ; + + choice ) {
std : : cout < < choice < < " = [ " < < lowerBoundsPerAction [ choice ] < < " , " < < upperBoundsPerAction [ choice ] < < " ], " ;
}
std : : cout < < std : : endl ;
} else {
std : : cout < < " state " < < state < < " is unexplored " < < std : : endl ;
STORM_LOG_ASSERT ( lowerBoundsPerState [ stateToRowGroupMapping [ state ] ] < = upperBoundsPerState [ stateToRowGroupMapping [ state ] ] , " The bounds for state " < < state < < " are not in a sane relation: " < < lowerBoundsPerState [ stateToRowGroupMapping [ state ] ] < < " > " < < upperBoundsPerState [ stateToRowGroupMapping [ state ] ] < < " . " ) ;
}
}
// TODO: remove debug output when superfluous
// for (StateType state = 0; state < stateToRowGroupMapping.size(); ++state) {
// if (stateToRowGroupMapping[state] != unexploredMarker) {
// std::cout << "state " << state << " (grp " << stateToRowGroupMapping[state] << ") has bounds [" << lowerBoundsPerState[stateToRowGroupMapping[state]] << ", " << upperBoundsPerState[stateToRowGroupMapping[state]] << "], actions: ";
// for (auto choice = rowGroupIndices[stateToRowGroupMapping[state]]; choice < rowGroupIndices[stateToRowGroupMapping[state] + 1]; ++choice) {
// std::cout << choice << " = [" << lowerBoundsPerAction[choice] << ", " << upperBoundsPerAction[choice] << "], ";
// }
// std::cout << std::endl;
// } else {
// std::cout << "state " << state << " is unexplored" << std::endl;
// }
// }
STORM_LOG_DEBUG ( " Discovered states: " < < stateStorage . numberOfStates < < " ( " < < unexploredStates . size ( ) < < " unexplored). " ) ;
STORM_LOG_DEBUG ( " Value of initial state is in [ " < < lowerBoundsPerState [ initialStateIndex ] < < " , " < < upperBoundsPerState [ initialStateIndex ] < < " ]. " ) ;
ValueType difference = upperBoundsPerState [ initialStateIndex ] - lowerBoundsPerState [ initialStateIndex ] ;
STORM_LOG_DEBUG ( " Difference after iteration " < < iterations < < " is " < < difference < < " . " ) ;
STORM_LOG_DEBUG ( " Difference after iteration " < < stats . iterations < < " is " < < difference < < " . " ) ;
convergenceCriterionMet = difference < 1e-6 ;
+ + iterations ;
+ + stats . iterations ;
}
if ( storm : : settings : : generalSettings ( ) . isShowStatisticsSet ( ) ) {
std : : cout < < std : : endl < < " Learning summary ------------------------- " < < std : : endl ;
std : : cout < < " Discovered states: " < < stateStorage . numberOfStates < < " ( " < < unexploredStates . size ( ) < < " unexplored, " < < numberOfTargetStates < < " target states) " < < std : : endl ;
std : : cout < < " Sampling iterations: " < < iterations < < std : : endl ;
std : : cout < < " Maximal path length: " < < maxPathLength < < std : : endl ;
std : : cout < < " Discovered states: " < < stateStorage . numberOfStates < < " ( " < < unexploredStates . size ( ) < < " unexplored, " < < stats . numberOfTargetStates < < " target states) " < < std : : endl ;
std : : cout < < " Sampling iterations: " < < stats . iterations < < std : : endl ;
std : : cout < < " Maximal path length: " < < stats . maxPathLength < < std : : endl ;
}
return std : : make_tuple ( initialStateIndex , lowerBoundsPerState [ initialStateIndex ] , upperBoundsPerState [ initialStateIndex ] ) ;
xxxxxxxxxx