Browse Source

Improved SparseMatrix::restrictRows so it can handle empty row groups

tempestpy_adaptions
TimQu 8 years ago
parent
commit
3fd72a11d8
  1. 41
      src/storm/storage/SparseMatrix.cpp
  2. 8
      src/storm/storage/SparseMatrix.h

41
src/storm/storage/SparseMatrix.cpp

@ -911,14 +911,41 @@ namespace storm {
} }
template<typename ValueType> template<typename ValueType>
SparseMatrix<ValueType> SparseMatrix<ValueType>::restrictRows(storm::storage::BitVector const& rowsToKeep) const {
// For now, we use the expensive call to submatrix.
SparseMatrix<ValueType> SparseMatrix<ValueType>::restrictRows(storm::storage::BitVector const& rowsToKeep, bool allowEmptyRowGroups) const {
STORM_LOG_ASSERT(rowsToKeep.size() == this->getRowCount(), "Dimensions mismatch."); STORM_LOG_ASSERT(rowsToKeep.size() == this->getRowCount(), "Dimensions mismatch.");
STORM_LOG_ASSERT(rowsToKeep.getNumberOfSetBits() >= this->getRowGroupCount(), "Invalid dimensions.");
SparseMatrix<ValueType> res(getSubmatrix(false, rowsToKeep, storm::storage::BitVector(getColumnCount(), true), false));
STORM_LOG_ASSERT(res.getRowCount() == rowsToKeep.getNumberOfSetBits(), "Invalid dimensions");
STORM_LOG_ASSERT(res.getColumnCount() == this->getColumnCount(), "Invalid dimensions");
STORM_LOG_ASSERT(this->getRowGroupCount() == res.getRowGroupCount(), "Invalid dimensions");
// Count the number of entries of the resulting matrix
uint_fast64_t entryCount = 0;
for (auto const& row : rowsToKeep) {
entryCount += this->getRow(row).getNumberOfEntries();
}
// build the matrix. The row grouping will always be considered as nontrivial.
SparseMatrixBuilder<ValueType> builder(rowsToKeep.getNumberOfSetBits(), this->getColumnCount(), entryCount, true, true, this->getRowGroupCount());
uint_fast64_t newRow = 0;
for (uint_fast64_t rowGroup = 0; rowGroup < this->getRowGroupCount(); ++rowGroup) {
builder.newRowGroup(newRow);
bool rowGroupEmpty = true;
if (this->hasTrivialRowGrouping()) {
if (rowsToKeep.get(rowGroup)) {
rowGroupEmpty = false;
for (auto const& entry : this->getRow(rowGroup)) {
builder.addNextValue(newRow, entry.getColumn(), entry.getValue());
}
++newRow;
}
} else {
for (uint_fast64_t row = rowsToKeep.getNextSetIndex(this->getRowGroupIndices()[rowGroup]); row < this->getRowGroupIndices()[rowGroup + 1]; row = rowsToKeep.getNextSetIndex(row + 1)) {
rowGroupEmpty = false;
for (auto const& entry: this->getRow(row)) {
builder.addNextValue(newRow, entry.getColumn(), entry.getValue());
}
++newRow;
}
}
STORM_LOG_THROW(allowEmptyRowGroups || !rowGroupEmpty, storm::exceptions::InvalidArgumentException, "Empty rows are not allowed, but row group " << rowGroup << " is empty.");
}
SparseMatrix<ValueType> res = builder.build();
return res; return res;
} }

8
src/storm/storage/SparseMatrix.h

@ -657,9 +657,13 @@ namespace storm {
* Restrict rows in grouped rows matrix. Ensures that the number of groups stays the same. * Restrict rows in grouped rows matrix. Ensures that the number of groups stays the same.
* *
* @param rowsToKeep A bit vector indicating which rows to keep. * @param rowsToKeep A bit vector indicating which rows to keep.
*
* @param allowEmptyRowGroups if set to true, the result can potentially have empty row groups.
* Otherwise, it is asserted that there are no empty row groups.
*
* @note The resulting matrix will always have a non-trivial row grouping even if the current one is trivial.
*
*/ */
SparseMatrix restrictRows(storm::storage::BitVector const& rowsToKeep) const;
SparseMatrix restrictRows(storm::storage::BitVector const& rowsToKeep, bool allowEmptyRowGroups = false) const;
/** /**
* Compares two rows. * Compares two rows.

Loading…
Cancel
Save