@ -911,14 +911,41 @@ namespace storm {
}
template < typename ValueType >
SparseMatrix < ValueType > SparseMatrix < ValueType > : : restrictRows ( storm : : storage : : BitVector const & rowsToKeep ) const {
// For now, we use the expensive call to submatrix.
SparseMatrix < ValueType > SparseMatrix < ValueType > : : restrictRows ( storm : : storage : : BitVector const & rowsToKeep , bool allowEmptyRowGroups ) const {
STORM_LOG_ASSERT ( rowsToKeep . size ( ) = = this - > getRowCount ( ) , " Dimensions mismatch. " ) ;
STORM_LOG_ASSERT ( rowsToKeep . getNumberOfSetBits ( ) > = this - > getRowGroupCount ( ) , " Invalid dimensions. " ) ;
SparseMatrix < ValueType > res ( getSubmatrix ( false , rowsToKeep , storm : : storage : : BitVector ( getColumnCount ( ) , true ) , false ) ) ;
STORM_LOG_ASSERT ( res . getRowCount ( ) = = rowsToKeep . getNumberOfSetBits ( ) , " Invalid dimensions " ) ;
STORM_LOG_ASSERT ( res . getColumnCount ( ) = = this - > getColumnCount ( ) , " Invalid dimensions " ) ;
STORM_LOG_ASSERT ( this - > getRowGroupCount ( ) = = res . getRowGroupCount ( ) , " Invalid dimensions " ) ;
// Count the number of entries of the resulting matrix
uint_fast64_t entryCount = 0 ;
for ( auto const & row : rowsToKeep ) {
entryCount + = this - > getRow ( row ) . getNumberOfEntries ( ) ;
}
// build the matrix. The row grouping will always be considered as nontrivial.
SparseMatrixBuilder < ValueType > builder ( rowsToKeep . getNumberOfSetBits ( ) , this - > getColumnCount ( ) , entryCount , true , true , this - > getRowGroupCount ( ) ) ;
uint_fast64_t newRow = 0 ;
for ( uint_fast64_t rowGroup = 0 ; rowGroup < this - > getRowGroupCount ( ) ; + + rowGroup ) {
builder . newRowGroup ( newRow ) ;
bool rowGroupEmpty = true ;
if ( this - > hasTrivialRowGrouping ( ) ) {
if ( rowsToKeep . get ( rowGroup ) ) {
rowGroupEmpty = false ;
for ( auto const & entry : this - > getRow ( rowGroup ) ) {
builder . addNextValue ( newRow , entry . getColumn ( ) , entry . getValue ( ) ) ;
}
+ + newRow ;
}
} else {
for ( uint_fast64_t row = rowsToKeep . getNextSetIndex ( this - > getRowGroupIndices ( ) [ rowGroup ] ) ; row < this - > getRowGroupIndices ( ) [ rowGroup + 1 ] ; row = rowsToKeep . getNextSetIndex ( row + 1 ) ) {
rowGroupEmpty = false ;
for ( auto const & entry : this - > getRow ( row ) ) {
builder . addNextValue ( newRow , entry . getColumn ( ) , entry . getValue ( ) ) ;
}
+ + newRow ;
}
}
STORM_LOG_THROW ( allowEmptyRowGroups | | ! rowGroupEmpty , storm : : exceptions : : InvalidArgumentException , " Empty rows are not allowed, but row group " < < rowGroup < < " is empty. " ) ;
}
SparseMatrix < ValueType > res = builder . build ( ) ;
return res ;
}