|
@ -974,9 +974,9 @@ namespace storm { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
template<typename ValueType> |
|
|
SparseMatrix<ValueType> SparseMatrix<ValueType>::getSubmatrix(bool useGroups, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalElements) const { |
|
|
|
|
|
|
|
|
SparseMatrix<ValueType> SparseMatrix<ValueType>::getSubmatrix(bool useGroups, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalElements, storm::storage::BitVector const& makeZeroColumns) const { |
|
|
if (useGroups) { |
|
|
if (useGroups) { |
|
|
return getSubmatrix(rowConstraint, columnConstraint, this->getRowGroupIndices(), insertDiagonalElements); |
|
|
|
|
|
|
|
|
return getSubmatrix(rowConstraint, columnConstraint, this->getRowGroupIndices(), insertDiagonalElements, makeZeroColumns); |
|
|
} else { |
|
|
} else { |
|
|
// Create a fake row grouping to reduce this to a call to a more general method.
|
|
|
// Create a fake row grouping to reduce this to a call to a more general method.
|
|
|
std::vector<index_type> fakeRowGroupIndices(rowCount + 1); |
|
|
std::vector<index_type> fakeRowGroupIndices(rowCount + 1); |
|
@ -984,7 +984,7 @@ namespace storm { |
|
|
for (std::vector<index_type>::iterator it = fakeRowGroupIndices.begin(); it != fakeRowGroupIndices.end(); ++it, ++i) { |
|
|
for (std::vector<index_type>::iterator it = fakeRowGroupIndices.begin(); it != fakeRowGroupIndices.end(); ++it, ++i) { |
|
|
*it = i; |
|
|
*it = i; |
|
|
} |
|
|
} |
|
|
auto res = getSubmatrix(rowConstraint, columnConstraint, fakeRowGroupIndices, insertDiagonalElements); |
|
|
|
|
|
|
|
|
auto res = getSubmatrix(rowConstraint, columnConstraint, fakeRowGroupIndices, insertDiagonalElements, makeZeroColumns); |
|
|
|
|
|
|
|
|
// Create a new row grouping that reflects the new sizes of the row groups if the current matrix has a
|
|
|
// Create a new row grouping that reflects the new sizes of the row groups if the current matrix has a
|
|
|
// non trivial row-grouping.
|
|
|
// non trivial row-grouping.
|
|
@ -1014,7 +1014,7 @@ namespace storm { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
template<typename ValueType> |
|
|
SparseMatrix<ValueType> SparseMatrix<ValueType>::getSubmatrix(storm::storage::BitVector const& rowGroupConstraint, storm::storage::BitVector const& columnConstraint, std::vector<index_type> const& rowGroupIndices, bool insertDiagonalEntries) const { |
|
|
|
|
|
|
|
|
SparseMatrix<ValueType> SparseMatrix<ValueType>::getSubmatrix(storm::storage::BitVector const& rowGroupConstraint, storm::storage::BitVector const& columnConstraint, std::vector<index_type> const& rowGroupIndices, bool insertDiagonalEntries, storm::storage::BitVector const& makeZeroColumns) const { |
|
|
STORM_LOG_THROW(!rowGroupConstraint.empty() && !columnConstraint.empty(), storm::exceptions::InvalidArgumentException, "Cannot build empty submatrix."); |
|
|
STORM_LOG_THROW(!rowGroupConstraint.empty() && !columnConstraint.empty(), storm::exceptions::InvalidArgumentException, "Cannot build empty submatrix."); |
|
|
uint_fast64_t submatrixColumnCount = columnConstraint.getNumberOfSetBits(); |
|
|
uint_fast64_t submatrixColumnCount = columnConstraint.getNumberOfSetBits(); |
|
|
|
|
|
|
|
@ -1037,7 +1037,7 @@ namespace storm { |
|
|
bool foundDiagonalElement = false; |
|
|
bool foundDiagonalElement = false; |
|
|
|
|
|
|
|
|
for (const_iterator it = this->begin(i), ite = this->end(i); it != ite; ++it) { |
|
|
for (const_iterator it = this->begin(i), ite = this->end(i); it != ite; ++it) { |
|
|
if (columnConstraint.get(it->getColumn())) { |
|
|
|
|
|
|
|
|
if (columnConstraint.get(it->getColumn()) && (makeZeroColumns.size() == 0 || !makeZeroColumns.get(it->getColumn())) ) { |
|
|
++subEntries; |
|
|
++subEntries; |
|
|
|
|
|
|
|
|
if (columnBitsSetBeforeIndex[it->getColumn()] == rowBitsSetBeforeIndex[index]) { |
|
|
if (columnBitsSetBeforeIndex[it->getColumn()] == rowBitsSetBeforeIndex[index]) { |
|
@ -1069,7 +1069,7 @@ namespace storm { |
|
|
bool insertedDiagonalElement = false; |
|
|
bool insertedDiagonalElement = false; |
|
|
|
|
|
|
|
|
for (const_iterator it = this->begin(i), ite = this->end(i); it != ite; ++it) { |
|
|
for (const_iterator it = this->begin(i), ite = this->end(i); it != ite; ++it) { |
|
|
if (columnConstraint.get(it->getColumn())) { |
|
|
|
|
|
|
|
|
if (columnConstraint.get(it->getColumn()) && (makeZeroColumns.size() == 0 || !makeZeroColumns.get(it->getColumn()))) { |
|
|
if (columnBitsSetBeforeIndex[it->getColumn()] == rowBitsSetBeforeIndex[index]) { |
|
|
if (columnBitsSetBeforeIndex[it->getColumn()] == rowBitsSetBeforeIndex[index]) { |
|
|
insertedDiagonalElement = true; |
|
|
insertedDiagonalElement = true; |
|
|
} else if (insertDiagonalEntries && !insertedDiagonalElement && columnBitsSetBeforeIndex[it->getColumn()] > rowBitsSetBeforeIndex[index]) { |
|
|
} else if (insertDiagonalEntries && !insertedDiagonalElement && columnBitsSetBeforeIndex[it->getColumn()] > rowBitsSetBeforeIndex[index]) { |
|
|