diff --git a/src/storm/storage/SparseMatrix.cpp b/src/storm/storage/SparseMatrix.cpp index 26e1c0dc7..65ab0f101 100644 --- a/src/storm/storage/SparseMatrix.cpp +++ b/src/storm/storage/SparseMatrix.cpp @@ -123,22 +123,40 @@ namespace storm { void SparseMatrixBuilder::addNextValue(index_type row, index_type column, ValueType const& value) { // Check that we did not move backwards wrt. the row. STORM_LOG_THROW(row >= lastRow, storm::exceptions::InvalidArgumentException, "Adding an element in row " << row << ", but an element in row " << lastRow << " has already been added."); + STORM_LOG_ASSERT(columnsAndValues.size() == currentEntryCount, "Unexpected size of columnsAndValues vector."); + + // Check if a diagonal entry shall be inserted before + if (pendingDiagonalEntry) { + index_type diagColumn = hasCustomRowGrouping ? currentRowGroupCount - 1 : lastRow; + if (row > lastRow || column >= diagColumn) { + ValueType diagValue = std::move(pendingDiagonalEntry.get()); + pendingDiagonalEntry = boost::none; + // Add the pending diagonal value now + if (row == lastRow && column == diagColumn) { + // The currently added value coincides with the diagonal entry! + // We add up the values and repeat this call. + addNextValue(row, column, diagValue + value); + // We return here because the above call already did all the work. + return; + } else { + addNextValue(lastRow, diagColumn, diagValue); + } + } + } // If the element is in the same row, but was not inserted in the correct order, we need to fix the row after // the insertion. bool fixCurrentRow = row == lastRow && column < lastColumn; - - // If the element is in the same row and column as the previous entry, we add them up. - if (row == lastRow && column == lastColumn && !columnsAndValues.empty()) { + // If the element is in the same row and column as the previous entry, we add them up... + // unless there is no entry in this row yet, which might happen either for the very first entry or when only a diagonal value has been added + if (row == lastRow && column == lastColumn && rowIndications.back() < currentEntryCount) { columnsAndValues.back().setValue(columnsAndValues.back().getValue() + value); } else { // If we switched to another row, we have to adjust the missing entries in the row indices vector. if (row != lastRow) { // Otherwise, we need to push the correct values to the vectors, which might trigger reallocations. - for (index_type i = lastRow + 1; i <= row; ++i) { - rowIndications.push_back(currentEntryCount); - } - + assert(rowIndications.size() == lastRow + 1); + rowIndications.resize(row + 1, currentEntryCount); lastRow = row; } @@ -183,15 +201,24 @@ namespace storm { void SparseMatrixBuilder::newRowGroup(index_type startingRow) { STORM_LOG_THROW(hasCustomRowGrouping, storm::exceptions::InvalidStateException, "Matrix was not created to have a custom row grouping."); STORM_LOG_THROW(startingRow >= lastRow, storm::exceptions::InvalidStateException, "Illegal row group with negative size."); - rowGroupIndices.get().push_back(startingRow); - ++currentRowGroupCount; - // Close all rows from the most recent one to the starting row. - for (index_type i = lastRow + 1; i < startingRow; ++i) { - rowIndications.push_back(currentEntryCount); + // If there still is a pending diagonal entry, we need to add it now (otherwise, the correct diagonal column will be unclear) + if (pendingDiagonalEntry) { + STORM_LOG_ASSERT(currentRowGroupCount > 0, "Diagonal entry was set before opening the first row group."); + index_type diagColumn = currentRowGroupCount - 1; + ValueType diagValue = std::move(pendingDiagonalEntry.get()); + pendingDiagonalEntry = boost::none; // clear now, so addNextValue works properly + addNextValue(lastRow, diagColumn, diagValue); } + rowGroupIndices.get().push_back(startingRow); + ++currentRowGroupCount; + + // Handle the case where the previous row group ends with one or more empty rows if (lastRow + 1 < startingRow) { + // Close all rows from the most recent one to the starting row. + assert(rowIndications.size() == lastRow + 1); + rowIndications.resize(startingRow, currentEntryCount); // Reset the most recently seen row/column to allow for proper insertion of the following elements. lastRow = startingRow - 1; lastColumn = 0; @@ -201,6 +228,14 @@ namespace storm { template SparseMatrix SparseMatrixBuilder::build(index_type overriddenRowCount, index_type overriddenColumnCount, index_type overriddenRowGroupCount) { + // If there still is a pending diagonal entry, we need to add it now + if (pendingDiagonalEntry) { + index_type diagColumn = hasCustomRowGrouping ? currentRowGroupCount - 1 : lastRow; + ValueType diagValue = std::move(pendingDiagonalEntry.get()); + pendingDiagonalEntry = boost::none; // clear now, so addNextValue works properly + addNextValue(lastRow, diagColumn, diagValue); + } + bool hasEntries = currentEntryCount != 0; uint_fast64_t rowCount = hasEntries ? lastRow + 1 : 0; @@ -332,9 +367,34 @@ namespace storm { } highestColumn = maxColumn; - lastColumn = columnsAndValues.empty() ? 0 : columnsAndValues[columnsAndValues.size() - 1].getColumn(); + lastColumn = columnsAndValues.empty() ? 0 : columnsAndValues.back().getColumn(); } + template + void SparseMatrixBuilder::addDiagonalEntry(index_type row, ValueType const& value) { + STORM_LOG_THROW(row >= lastRow, storm::exceptions::InvalidArgumentException, "Adding a diagonal element in row " << row << ", but an element in row " << lastRow << " has already been added."); + if (pendingDiagonalEntry) { + if (row == lastRow) { + // Add the two diagonal entries, nothing else to be done. + pendingDiagonalEntry.get() += value; + return; + } else { + // add the pending entry + index_type column = hasCustomRowGrouping ? currentRowGroupCount - 1 : lastRow; + ValueType diagValue = std::move(pendingDiagonalEntry.get()); + pendingDiagonalEntry = boost::none; // clear now, so addNextValue works properly + addNextValue(lastRow, column, diagValue); + } + } + pendingDiagonalEntry = value; + if (lastRow != row) { + assert(rowIndications.size() == lastRow + 1); + rowIndications.resize(row + 1, currentEntryCount); + lastRow = row; + lastColumn = 0; + } + } + template SparseMatrix::rows::rows(iterator begin, index_type entryCount) : beginIterator(begin), entryCount(entryCount) { // Intentionally left empty. diff --git a/src/storm/storage/SparseMatrix.h b/src/storm/storage/SparseMatrix.h index ad9c00ecd..dc175c0b8 100644 --- a/src/storm/storage/SparseMatrix.h +++ b/src/storm/storage/SparseMatrix.h @@ -243,7 +243,16 @@ namespace storm { * @param offset Offset to add to each id in vector index. */ void replaceColumns(std::vector const& replacements, index_type offset); - + + /*! + * Makes sure that a diagonal entry will be inserted at the given row. + * All other entries of this row must be set immediately after calling this (without setting values at other rows in between) + * The provided row must not be smaller than the row of the most recent insertion. + * If there is a row grouping, the column of the diagonal entry will correspond to the current row group. + * If addNextValue is called on the given row and the diagonal column, we take the sum of the two values provided to addDiagonalEntry and addNextValue + */ + void addDiagonalEntry(index_type row, ValueType const& value); + private: // A flag indicating whether a row count was set upon construction. bool initialRowCountSet; @@ -305,6 +314,8 @@ namespace storm { // Stores the currently active row group. This is used for correctly constructing the row grouping of the // matrix. index_type currentRowGroupCount; + + boost::optional pendingDiagonalEntry; }; /*! diff --git a/src/test/storm/storage/SparseMatrixTest.cpp b/src/test/storm/storage/SparseMatrixTest.cpp index 38566112d..4984e6607 100644 --- a/src/test/storm/storage/SparseMatrixTest.cpp +++ b/src/test/storm/storage/SparseMatrixTest.cpp @@ -148,6 +148,79 @@ TEST(SparseMatrix, Build) { ASSERT_EQ(5ul, matrix5.getEntryCount()); } +TEST(SparseMatrix, DiagonalEntries) { + { + // No row groupings + storm::storage::SparseMatrixBuilder builder(4, 4, 7); + storm::storage::SparseMatrixBuilder builderCmp(4, 4, 7); + for (uint64_t i = 0; i < 4; ++i) { + ASSERT_NO_THROW(builder.addDiagonalEntry(i, i)); + ASSERT_NO_THROW(builder.addNextValue(i, 2, 100.0 + i)); + if (i < 2) { + ASSERT_NO_THROW(builderCmp.addNextValue(i, i, i)); + ASSERT_NO_THROW(builderCmp.addNextValue(i, 2, 100.0 + i)); + } else { + ASSERT_NO_THROW(builderCmp.addNextValue(i, 2, 100.0 + i)); + ASSERT_NO_THROW(builderCmp.addNextValue(i, i, i)); + } + } + auto matrix = builder.build(); + auto matrixCmp = builderCmp.build(); + EXPECT_EQ(matrix, matrixCmp); + } + { + // With row groupings (each row group has 3 rows) + storm::storage::SparseMatrixBuilder builder(12, 4, 21, true, true, 4); + storm::storage::SparseMatrixBuilder builderCmp(12, 4, 21, true, true, 4); + for (uint64_t i = 0; i < 4; ++i) { + uint64_t row = 3*i; + builder.newRowGroup(row); + builderCmp.newRowGroup(row); + for (; row < 3*(i+1); ++row) { + ASSERT_NO_THROW(builder.addDiagonalEntry(row, row)); + ASSERT_NO_THROW(builder.addNextValue(row, 2, 100 + row)); + if (i < 2) { + ASSERT_NO_THROW(builderCmp.addNextValue(row, i, row)); + ASSERT_NO_THROW(builderCmp.addNextValue(row, 2, 100.0 + row)); + } else { + ASSERT_NO_THROW(builderCmp.addNextValue(row, 2, 100.0 + row)); + ASSERT_NO_THROW(builderCmp.addNextValue(row, i, row)); + } + } + } + auto matrix = builder.build(); + auto matrixCmp = builderCmp.build(); + EXPECT_EQ(matrix, matrixCmp); + } + { + // With row groupings (every second row is empty) + storm::storage::SparseMatrixBuilder builder(12, 4, 10, true, true, 4); + storm::storage::SparseMatrixBuilder builderCmp(12, 4, 10, true, true, 4); + for (uint64_t i = 0; i < 4; ++i) { + uint64_t row = 3*i; + builder.newRowGroup(row); + builderCmp.newRowGroup(row); + for (; row < 3*(i+1); ++row) { + if (row % 2 == 1) { + continue; + } + ASSERT_NO_THROW(builder.addDiagonalEntry(row, row)); + ASSERT_NO_THROW(builder.addNextValue(row, 2, 100 + row)); + if (i < 2) { + ASSERT_NO_THROW(builderCmp.addNextValue(row, i, row)); + ASSERT_NO_THROW(builderCmp.addNextValue(row, 2, 100.0 + row)); + } else { + ASSERT_NO_THROW(builderCmp.addNextValue(row, i, row)); + ASSERT_NO_THROW(builderCmp.addNextValue(row, 2, 100.0 + row)); + } + } + } + auto matrix = builder.build(); + auto matrixCmp = builderCmp.build(); + EXPECT_EQ(matrix, matrixCmp); + } +} + TEST(SparseMatrix, CreationWithMovingContents) { std::vector> columnsAndValues; columnsAndValues.emplace_back(1, 1.0);