diff --git a/src/storm/storage/SparseMatrix.cpp b/src/storm/storage/SparseMatrix.cpp index ac40bfb00..05a526702 100644 --- a/src/storm/storage/SparseMatrix.cpp +++ b/src/storm/storage/SparseMatrix.cpp @@ -744,6 +744,29 @@ namespace storm { return result; } + template + storm::storage::BitVector SparseMatrix::getRowGroupFilter(storm::storage::BitVector const& rowConstraint, bool setIfForAllRowsInGroup) const { + STORM_LOG_ASSERT(!this->hasTrivialRowGrouping(), "Tried to get a row group filter but this matrix does not have row groups"); + storm::storage::BitVector result(this->getRowGroupCount(), false); + auto const& groupIndices = this->getRowGroupIndices(); + if (setIfForAllRowsInGroup) { + for (uint64_t group = 0; group < this->getRowGroupCount(); ++group) { + if (rowConstraint.getNextUnsetIndex(groupIndices[group]) >= groupIndices[group + 1]) { + // All rows within this group are set + result.set(group, true); + } + } + } else { + for (uint64_t group = 0; group < this->getRowGroupCount(); ++group) { + if (rowConstraint.getNextSetIndex(groupIndices[group]) < groupIndices[group + 1]) { + // Some row is set + result.set(group, true); + } + } + } + return result; + } + template void SparseMatrix::makeRowsAbsorbing(storm::storage::BitVector const& rows) { for (auto row : rows) { diff --git a/src/storm/storage/SparseMatrix.h b/src/storm/storage/SparseMatrix.h index b16a1ea99..ba25a2819 100644 --- a/src/storm/storage/SparseMatrix.h +++ b/src/storm/storage/SparseMatrix.h @@ -642,6 +642,15 @@ namespace storm { */ storm::storage::BitVector getRowFilter(storm::storage::BitVector const& groupConstraint, storm::storage::BitVector const& columnConstraints) const; + /*! + * Returns the indices of all row groups selected by the row constraints + * + * @param rowConstraint the selected rows + * @param setIfForAllRowsInGroup if true, a group is selected if the rowConstraint is true for *all* rows within that group. If false, a group is selected if the rowConstraint is true for *some* row within that group + * @return a bit vector that is true at position i iff row i satisfies the constraints. + */ + storm::storage::BitVector getRowGroupFilter(storm::storage::BitVector const& rowConstraint, bool setIfForAllRowsInGroup) const; + /*! * This function makes the given rows absorbing. *