diff --git a/src/storm/storage/SparseMatrix.cpp b/src/storm/storage/SparseMatrix.cpp index 65ab0f101..ac40bfb00 100644 --- a/src/storm/storage/SparseMatrix.cpp +++ b/src/storm/storage/SparseMatrix.cpp @@ -974,9 +974,9 @@ namespace storm { } template - SparseMatrix SparseMatrix::getSubmatrix(bool useGroups, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalElements) const { + SparseMatrix SparseMatrix::getSubmatrix(bool useGroups, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalElements, storm::storage::BitVector const& makeZeroColumns) const { if (useGroups) { - return getSubmatrix(rowConstraint, columnConstraint, this->getRowGroupIndices(), insertDiagonalElements); + return getSubmatrix(rowConstraint, columnConstraint, this->getRowGroupIndices(), insertDiagonalElements, makeZeroColumns); } else { // Create a fake row grouping to reduce this to a call to a more general method. std::vector fakeRowGroupIndices(rowCount + 1); @@ -984,7 +984,7 @@ namespace storm { for (std::vector::iterator it = fakeRowGroupIndices.begin(); it != fakeRowGroupIndices.end(); ++it, ++i) { *it = i; } - auto res = getSubmatrix(rowConstraint, columnConstraint, fakeRowGroupIndices, insertDiagonalElements); + auto res = getSubmatrix(rowConstraint, columnConstraint, fakeRowGroupIndices, insertDiagonalElements, makeZeroColumns); // Create a new row grouping that reflects the new sizes of the row groups if the current matrix has a // non trivial row-grouping. @@ -1014,7 +1014,7 @@ namespace storm { } template - SparseMatrix SparseMatrix::getSubmatrix(storm::storage::BitVector const& rowGroupConstraint, storm::storage::BitVector const& columnConstraint, std::vector const& rowGroupIndices, bool insertDiagonalEntries) const { + SparseMatrix SparseMatrix::getSubmatrix(storm::storage::BitVector const& rowGroupConstraint, storm::storage::BitVector const& columnConstraint, std::vector const& rowGroupIndices, bool insertDiagonalEntries, storm::storage::BitVector const& makeZeroColumns) const { STORM_LOG_THROW(!rowGroupConstraint.empty() && !columnConstraint.empty(), storm::exceptions::InvalidArgumentException, "Cannot build empty submatrix."); uint_fast64_t submatrixColumnCount = columnConstraint.getNumberOfSetBits(); @@ -1037,7 +1037,7 @@ namespace storm { bool foundDiagonalElement = false; for (const_iterator it = this->begin(i), ite = this->end(i); it != ite; ++it) { - if (columnConstraint.get(it->getColumn())) { + if (columnConstraint.get(it->getColumn()) && (makeZeroColumns.size() == 0 || !makeZeroColumns.get(it->getColumn())) ) { ++subEntries; if (columnBitsSetBeforeIndex[it->getColumn()] == rowBitsSetBeforeIndex[index]) { @@ -1069,7 +1069,7 @@ namespace storm { bool insertedDiagonalElement = false; for (const_iterator it = this->begin(i), ite = this->end(i); it != ite; ++it) { - if (columnConstraint.get(it->getColumn())) { + if (columnConstraint.get(it->getColumn()) && (makeZeroColumns.size() == 0 || !makeZeroColumns.get(it->getColumn()))) { if (columnBitsSetBeforeIndex[it->getColumn()] == rowBitsSetBeforeIndex[index]) { insertedDiagonalElement = true; } else if (insertDiagonalEntries && !insertedDiagonalElement && columnBitsSetBeforeIndex[it->getColumn()] > rowBitsSetBeforeIndex[index]) { diff --git a/src/storm/storage/SparseMatrix.h b/src/storm/storage/SparseMatrix.h index dc175c0b8..b16a1ea99 100644 --- a/src/storm/storage/SparseMatrix.h +++ b/src/storm/storage/SparseMatrix.h @@ -1,5 +1,4 @@ -#ifndef STORM_STORAGE_SPARSEMATRIX_H_ -#define STORM_STORAGE_SPARSEMATRIX_H_ +#pragma once #include #include @@ -10,6 +9,7 @@ #include #include +#include "storm/storage/BitVector.h" #include "storm/solver/OptimizationDirection.h" #include "storm/utility/OsDetection.h" @@ -34,8 +34,7 @@ namespace storm { namespace storm { namespace storage { - - class BitVector; + // Forward declare matrix class. template @@ -719,7 +718,7 @@ namespace storm { * @return A matrix corresponding to a submatrix of the current matrix in which only rows and columns given * by the constraints are kept and all others are dropped. */ - SparseMatrix getSubmatrix(bool useGroups, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalEntries = false) const; + SparseMatrix getSubmatrix(bool useGroups, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalEntries = false, storm::storage::BitVector const& makeZeroColumns = storm::storage::BitVector()) const; /*! * Restrict rows in grouped rows matrix. Ensures that the number of groups stays the same. @@ -1169,7 +1168,7 @@ namespace storm { * @return A matrix corresponding to a submatrix of the current matrix in which only row groups and columns * given by the row group constraint are kept and all others are dropped. */ - SparseMatrix getSubmatrix(storm::storage::BitVector const& rowGroupConstraint, storm::storage::BitVector const& columnConstraint, std::vector const& rowGroupIndices, bool insertDiagonalEntries = false) const; + SparseMatrix getSubmatrix(storm::storage::BitVector const& rowGroupConstraint, storm::storage::BitVector const& columnConstraint, std::vector const& rowGroupIndices, bool insertDiagonalEntries = false, storm::storage::BitVector const& makeZeroColumns = storm::storage::BitVector()) const; // The number of rows of the matrix. index_type rowCount; @@ -1206,5 +1205,3 @@ namespace storm { } // namespace storage } // namespace storm - -#endif // STORM_STORAGE_SPARSEMATRIX_H_