diff --git a/src/storm/storage/SparseMatrix.cpp b/src/storm/storage/SparseMatrix.cpp index b8f694939..99a716415 100644 --- a/src/storm/storage/SparseMatrix.cpp +++ b/src/storm/storage/SparseMatrix.cpp @@ -919,32 +919,35 @@ namespace storm { for (auto const& row : rowsToKeep) { entryCount += this->getRow(row).getNumberOfEntries(); } - + + // Get the smallest row group index such that all row groups with at least this index are empty. + uint_fast64_t firstTrailingEmptyRowGroup = this->getRowGroupCount(); + for (auto groupIndexIt = this->getRowGroupIndices().rbegin() + 1; groupIndexIt != this->getRowGroupIndices().rend(); ++groupIndexIt) { + if (rowsToKeep.getNextSetIndex(*groupIndexIt) != rowsToKeep.size()) { + break; + } + --firstTrailingEmptyRowGroup; + } + STORM_LOG_THROW(allowEmptyRowGroups || firstTrailingEmptyRowGroup == this->getRowGroupCount(), storm::exceptions::InvalidArgumentException, "Empty rows are not allowed, but row group " << firstTrailingEmptyRowGroup << " is empty."); + // build the matrix. The row grouping will always be considered as nontrivial. SparseMatrixBuilder builder(rowsToKeep.getNumberOfSetBits(), this->getColumnCount(), entryCount, true, true, this->getRowGroupCount()); uint_fast64_t newRow = 0; - for (uint_fast64_t rowGroup = 0; rowGroup < this->getRowGroupCount(); ++rowGroup) { + for (uint_fast64_t rowGroup = 0; rowGroup < firstTrailingEmptyRowGroup; ++rowGroup) { + // Add a new row group builder.newRowGroup(newRow); bool rowGroupEmpty = true; - if (this->hasTrivialRowGrouping()) { - if (rowsToKeep.get(rowGroup)) { - rowGroupEmpty = false; - for (auto const& entry : this->getRow(rowGroup)) { - builder.addNextValue(newRow, entry.getColumn(), entry.getValue()); - } - ++newRow; - } - } else { - for (uint_fast64_t row = rowsToKeep.getNextSetIndex(this->getRowGroupIndices()[rowGroup]); row < this->getRowGroupIndices()[rowGroup + 1]; row = rowsToKeep.getNextSetIndex(row + 1)) { - rowGroupEmpty = false; - for (auto const& entry: this->getRow(row)) { - builder.addNextValue(newRow, entry.getColumn(), entry.getValue()); - } - ++newRow; + for (uint_fast64_t row = rowsToKeep.getNextSetIndex(this->getRowGroupIndices()[rowGroup]); row < this->getRowGroupIndices()[rowGroup + 1]; row = rowsToKeep.getNextSetIndex(row + 1)) { + rowGroupEmpty = false; + for (auto const& entry: this->getRow(row)) { + builder.addNextValue(newRow, entry.getColumn(), entry.getValue()); } + ++newRow; } STORM_LOG_THROW(allowEmptyRowGroups || !rowGroupEmpty, storm::exceptions::InvalidArgumentException, "Empty rows are not allowed, but row group " << rowGroup << " is empty."); } + + // The all remaining row groups will be empty. Note that it is not allowed to call builder.addNewGroup(...) if there are no more rows afterwards. SparseMatrix res = builder.build(); return res; } diff --git a/src/test/storage/SparseMatrixTest.cpp b/src/test/storage/SparseMatrixTest.cpp index b331fefe1..1b468c318 100644 --- a/src/test/storage/SparseMatrixTest.cpp +++ b/src/test/storage/SparseMatrixTest.cpp @@ -3,6 +3,7 @@ #include "storm/storage/BitVector.h" #include "storm/exceptions/InvalidStateException.h" #include "storm/exceptions/OutOfRangeException.h" +#include "storm/exceptions/InvalidArgumentException.h" TEST(SparseMatrixBuilder, CreationWithDimensions) { storm::storage::SparseMatrixBuilder matrixBuilder(3, 4, 5); @@ -374,6 +375,83 @@ TEST(SparseMatrix, Submatrix) { ASSERT_TRUE(matrix4 == matrix5); } +TEST(SparseMatrix, RestrictRows) { + storm::storage::SparseMatrixBuilder matrixBuilder1(7, 4, 9, true, true, 3); + ASSERT_NO_THROW(matrixBuilder1.newRowGroup(0)); + ASSERT_NO_THROW(matrixBuilder1.addNextValue(0, 1, 1.0)); + ASSERT_NO_THROW(matrixBuilder1.addNextValue(0, 2, 1.2)); + ASSERT_NO_THROW(matrixBuilder1.addNextValue(1, 0, 0.5)); + ASSERT_NO_THROW(matrixBuilder1.addNextValue(1, 1, 0.7)); + ASSERT_NO_THROW(matrixBuilder1.newRowGroup(2)); + ASSERT_NO_THROW(matrixBuilder1.addNextValue(2, 0, 0.5)); + ASSERT_NO_THROW(matrixBuilder1.addNextValue(3, 2, 1.1)); + ASSERT_NO_THROW(matrixBuilder1.newRowGroup(4)); + ASSERT_NO_THROW(matrixBuilder1.addNextValue(4, 0, 0.1)); + ASSERT_NO_THROW(matrixBuilder1.addNextValue(4, 1, 0.2)); + ASSERT_NO_THROW(matrixBuilder1.addNextValue(6, 3, 0.3)); + storm::storage::SparseMatrix matrix1; + ASSERT_NO_THROW(matrix1 = matrixBuilder1.build()); + + storm::storage::BitVector constraint1(7); + constraint1.set(0); + constraint1.set(1); + constraint1.set(2); + constraint1.set(5); + + storm::storage::SparseMatrix matrix1Prime; + ASSERT_NO_THROW(matrix1Prime = matrix1.restrictRows(constraint1)); + + storm::storage::SparseMatrixBuilder matrixBuilder2(4, 4, 5, true, true, 3); + ASSERT_NO_THROW(matrixBuilder2.newRowGroup(0)); + ASSERT_NO_THROW(matrixBuilder2.addNextValue(0, 1, 1.0)); + ASSERT_NO_THROW(matrixBuilder2.addNextValue(0, 2, 1.2)); + ASSERT_NO_THROW(matrixBuilder2.addNextValue(1, 0, 0.5)); + ASSERT_NO_THROW(matrixBuilder2.addNextValue(1, 1, 0.7)); + ASSERT_NO_THROW(matrixBuilder2.newRowGroup(2)); + ASSERT_NO_THROW(matrixBuilder2.addNextValue(2, 0, 0.5)); + ASSERT_NO_THROW(matrixBuilder2.newRowGroup(3)); + storm::storage::SparseMatrix matrix2; + ASSERT_NO_THROW(matrix2 = matrixBuilder2.build()); + + ASSERT_EQ(matrix2, matrix1Prime); + + storm::storage::BitVector constraint2(4); + constraint2.set(1); + constraint2.set(2); + + storm::storage::SparseMatrix matrix2Prime; + ASSERT_THROW(matrix2Prime = matrix2.restrictRows(constraint2), storm::exceptions::InvalidArgumentException); + ASSERT_NO_THROW(matrix2Prime = matrix2.restrictRows(constraint2, true)); + + storm::storage::SparseMatrixBuilder matrixBuilder3(2, 4, 3, true, true, 3); + ASSERT_NO_THROW(matrixBuilder3.newRowGroup(0)); + ASSERT_NO_THROW(matrixBuilder3.addNextValue(0, 0, 0.5)); + ASSERT_NO_THROW(matrixBuilder3.addNextValue(0, 1, 0.7)); + ASSERT_NO_THROW(matrixBuilder3.newRowGroup(1)); + ASSERT_NO_THROW(matrixBuilder3.addNextValue(1, 0, 0.5)); + storm::storage::SparseMatrix matrix3; + ASSERT_NO_THROW(matrix3 = matrixBuilder3.build()); + + ASSERT_EQ(matrix3, matrix2Prime); + + matrix3.makeRowGroupingTrivial(); + storm::storage::BitVector constraint3(2); + constraint3.set(1); + + storm::storage::SparseMatrix matrix3Prime; + ASSERT_THROW(matrix3Prime = matrix3.restrictRows(constraint3), storm::exceptions::InvalidArgumentException); + ASSERT_NO_THROW(matrix3Prime = matrix3.restrictRows(constraint3, true)); + + storm::storage::SparseMatrixBuilder matrixBuilder4(1, 4, 1, true, true, 2); + ASSERT_NO_THROW(matrixBuilder4.newRowGroup(0)); + ASSERT_NO_THROW(matrixBuilder4.newRowGroup(0)); + ASSERT_NO_THROW(matrixBuilder4.addNextValue(0, 0, 0.5)); + storm::storage::SparseMatrix matrix4; + ASSERT_NO_THROW(matrix4 = matrixBuilder4.build()); + + ASSERT_EQ(matrix4, matrix3Prime); +} + TEST(SparseMatrix, Transpose) { storm::storage::SparseMatrixBuilder matrixBuilder(5, 4, 9); ASSERT_NO_THROW(matrixBuilder.addNextValue(0, 1, 1.0));