From 3fd72a11d8fcf0ca81c5efba31e1a45158666166 Mon Sep 17 00:00:00 2001 From: TimQu Date: Mon, 12 Jun 2017 12:07:27 +0200 Subject: [PATCH] Improved SparseMatrix::restrictRows so it can handle empty row groups --- src/storm/storage/SparseMatrix.cpp | 41 +++++++++++++++++++++++++----- src/storm/storage/SparseMatrix.h | 8 ++++-- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/src/storm/storage/SparseMatrix.cpp b/src/storm/storage/SparseMatrix.cpp index ed775b2b6..b8f694939 100644 --- a/src/storm/storage/SparseMatrix.cpp +++ b/src/storm/storage/SparseMatrix.cpp @@ -911,14 +911,41 @@ namespace storm { } template - SparseMatrix SparseMatrix::restrictRows(storm::storage::BitVector const& rowsToKeep) const { - // For now, we use the expensive call to submatrix. + SparseMatrix SparseMatrix::restrictRows(storm::storage::BitVector const& rowsToKeep, bool allowEmptyRowGroups) const { STORM_LOG_ASSERT(rowsToKeep.size() == this->getRowCount(), "Dimensions mismatch."); - STORM_LOG_ASSERT(rowsToKeep.getNumberOfSetBits() >= this->getRowGroupCount(), "Invalid dimensions."); - SparseMatrix res(getSubmatrix(false, rowsToKeep, storm::storage::BitVector(getColumnCount(), true), false)); - STORM_LOG_ASSERT(res.getRowCount() == rowsToKeep.getNumberOfSetBits(), "Invalid dimensions"); - STORM_LOG_ASSERT(res.getColumnCount() == this->getColumnCount(), "Invalid dimensions"); - STORM_LOG_ASSERT(this->getRowGroupCount() == res.getRowGroupCount(), "Invalid dimensions"); + + // Count the number of entries of the resulting matrix + uint_fast64_t entryCount = 0; + for (auto const& row : rowsToKeep) { + entryCount += this->getRow(row).getNumberOfEntries(); + } + + // build the matrix. The row grouping will always be considered as nontrivial. + SparseMatrixBuilder builder(rowsToKeep.getNumberOfSetBits(), this->getColumnCount(), entryCount, true, true, this->getRowGroupCount()); + uint_fast64_t newRow = 0; + for (uint_fast64_t rowGroup = 0; rowGroup < this->getRowGroupCount(); ++rowGroup) { + builder.newRowGroup(newRow); + bool rowGroupEmpty = true; + if (this->hasTrivialRowGrouping()) { + if (rowsToKeep.get(rowGroup)) { + rowGroupEmpty = false; + for (auto const& entry : this->getRow(rowGroup)) { + builder.addNextValue(newRow, entry.getColumn(), entry.getValue()); + } + ++newRow; + } + } else { + for (uint_fast64_t row = rowsToKeep.getNextSetIndex(this->getRowGroupIndices()[rowGroup]); row < this->getRowGroupIndices()[rowGroup + 1]; row = rowsToKeep.getNextSetIndex(row + 1)) { + rowGroupEmpty = false; + for (auto const& entry: this->getRow(row)) { + builder.addNextValue(newRow, entry.getColumn(), entry.getValue()); + } + ++newRow; + } + } + STORM_LOG_THROW(allowEmptyRowGroups || !rowGroupEmpty, storm::exceptions::InvalidArgumentException, "Empty rows are not allowed, but row group " << rowGroup << " is empty."); + } + SparseMatrix res = builder.build(); return res; } diff --git a/src/storm/storage/SparseMatrix.h b/src/storm/storage/SparseMatrix.h index 19a63efe5..bd8c1bad1 100644 --- a/src/storm/storage/SparseMatrix.h +++ b/src/storm/storage/SparseMatrix.h @@ -657,9 +657,13 @@ namespace storm { * Restrict rows in grouped rows matrix. Ensures that the number of groups stays the same. * * @param rowsToKeep A bit vector indicating which rows to keep. - * + * @param allowEmptyRowGroups if set to true, the result can potentially have empty row groups. + * Otherwise, it is asserted that there are no empty row groups. + * + * @note The resulting matrix will always have a non-trivial row grouping even if the current one is trivial. + * */ - SparseMatrix restrictRows(storm::storage::BitVector const& rowsToKeep) const; + SparseMatrix restrictRows(storm::storage::BitVector const& rowsToKeep, bool allowEmptyRowGroups = false) const; /** * Compares two rows.