@ -584,6 +584,27 @@ namespace storm {
return rowGroupIndices . get ( ) ;
return rowGroupIndices . get ( ) ;
}
}
template < typename ValueType >
void SparseMatrix < ValueType > : : setRowGroupIndices ( std : : vector < index_type > const & newRowGroupIndices ) {
trivialRowGrouping = false ;
rowGroupIndices = newRowGroupIndices ;
}
template < typename ValueType >
bool SparseMatrix < ValueType > : : hasTrivialRowGrouping ( ) const {
return trivialRowGrouping ;
}
template < typename ValueType >
void SparseMatrix < ValueType > : : makeRowGroupingTrivial ( ) {
if ( trivialRowGrouping ) {
STORM_LOG_ASSERT ( ! rowGroupIndices | | rowGroupIndices . get ( ) = = storm : : utility : : vector : : buildVectorForRange ( 0 , this - > getRowGroupCount ( ) + 1 ) , " Row grouping is supposed to be trivial but actually it is not. " ) ;
} else {
trivialRowGrouping = true ;
rowGroupIndices = boost : : none ;
}
}
template < typename ValueType >
template < typename ValueType >
storm : : storage : : BitVector SparseMatrix < ValueType > : : getRowFilter ( storm : : storage : : BitVector const & groupConstraint ) const {
storm : : storage : : BitVector SparseMatrix < ValueType > : : getRowFilter ( storm : : storage : : BitVector const & groupConstraint ) const {
storm : : storage : : BitVector res ( this - > getRowCount ( ) , false ) ;
storm : : storage : : BitVector res ( this - > getRowCount ( ) , false ) ;
@ -984,6 +1005,30 @@ namespace storm {
return res ;
return res ;
}
}
template < typename ValueType >
SparseMatrix < ValueType > SparseMatrix < ValueType > : : filterEntries ( storm : : storage : : BitVector const & rowFilter ) const {
// Count the number of entries in the resulting matrix.
index_type entryCount = 0 ;
for ( auto const & row : rowFilter ) {
entryCount + = getRow ( row ) . getNumberOfEntries ( ) ;
}
// Build the resulting matrix.
SparseMatrixBuilder < ValueType > builder ( getRowCount ( ) , getColumnCount ( ) , entryCount ) ;
for ( auto const & row : rowFilter ) {
for ( auto const & entry : getRow ( row ) ) {
builder . addNextValue ( row , entry . getColumn ( ) , entry . getValue ( ) ) ;
}
}
SparseMatrix < ValueType > result = builder . build ( ) ;
// Add a row grouping if necessary.
if ( ! hasTrivialRowGrouping ( ) ) {
result . setRowGroupIndices ( getRowGroupIndices ( ) ) ;
}
return result ;
}
template < typename ValueType >
template < typename ValueType >
SparseMatrix < ValueType > SparseMatrix < ValueType > : : selectRowsFromRowGroups ( std : : vector < index_type > const & rowGroupToRowIndexMapping , bool insertDiagonalEntries ) const {
SparseMatrix < ValueType > SparseMatrix < ValueType > : : selectRowsFromRowGroups ( std : : vector < index_type > const & rowGroupToRowIndexMapping , bool insertDiagonalEntries ) const {
// First, we need to count how many non-zero entries the resulting matrix will have and reserve space for
// First, we need to count how many non-zero entries the resulting matrix will have and reserve space for
@ -1548,21 +1593,6 @@ namespace storm {
return this - > columnsAndValues . begin ( ) + this - > rowIndications [ rowCount ] ;
return this - > columnsAndValues . begin ( ) + this - > rowIndications [ rowCount ] ;
}
}
template < typename ValueType >
bool SparseMatrix < ValueType > : : hasTrivialRowGrouping ( ) const {
return trivialRowGrouping ;
}
template < typename ValueType >
void SparseMatrix < ValueType > : : makeRowGroupingTrivial ( ) {
if ( trivialRowGrouping ) {
STORM_LOG_ASSERT ( ! rowGroupIndices | | rowGroupIndices . get ( ) = = storm : : utility : : vector : : buildVectorForRange ( 0 , this - > getRowGroupCount ( ) + 1 ) , " Row grouping is supposed to be trivial but actually it is not. " ) ;
} else {
trivialRowGrouping = true ;
rowGroupIndices = boost : : none ;
}
}
template < typename ValueType >
template < typename ValueType >
ValueType SparseMatrix < ValueType > : : getRowSum ( index_type row ) const {
ValueType SparseMatrix < ValueType > : : getRowSum ( index_type row ) const {
ValueType sum = storm : : utility : : zero < ValueType > ( ) ;
ValueType sum = storm : : utility : : zero < ValueType > ( ) ;