Browse Source

SparseMatrixBuilder: Added a function to easily add diagonal entries.

tempestpy_adaptions
Tim Quatmann 4 years ago
parent
commit
6f59c4f3eb
  1. 86
      src/storm/storage/SparseMatrix.cpp
  2. 11
      src/storm/storage/SparseMatrix.h
  3. 73
      src/test/storm/storage/SparseMatrixTest.cpp

86
src/storm/storage/SparseMatrix.cpp

@ -123,22 +123,40 @@ namespace storm {
void SparseMatrixBuilder<ValueType>::addNextValue(index_type row, index_type column, ValueType const& value) { void SparseMatrixBuilder<ValueType>::addNextValue(index_type row, index_type column, ValueType const& value) {
// Check that we did not move backwards wrt. the row. // Check that we did not move backwards wrt. the row.
STORM_LOG_THROW(row >= lastRow, storm::exceptions::InvalidArgumentException, "Adding an element in row " << row << ", but an element in row " << lastRow << " has already been added."); STORM_LOG_THROW(row >= lastRow, storm::exceptions::InvalidArgumentException, "Adding an element in row " << row << ", but an element in row " << lastRow << " has already been added.");
STORM_LOG_ASSERT(columnsAndValues.size() == currentEntryCount, "Unexpected size of columnsAndValues vector.");
// Check if a diagonal entry shall be inserted before
if (pendingDiagonalEntry) {
index_type diagColumn = hasCustomRowGrouping ? currentRowGroupCount - 1 : lastRow;
if (row > lastRow || column >= diagColumn) {
ValueType diagValue = std::move(pendingDiagonalEntry.get());
pendingDiagonalEntry = boost::none;
// Add the pending diagonal value now
if (row == lastRow && column == diagColumn) {
// The currently added value coincides with the diagonal entry!
// We add up the values and repeat this call.
addNextValue(row, column, diagValue + value);
// We return here because the above call already did all the work.
return;
} else {
addNextValue(lastRow, diagColumn, diagValue);
}
}
}
// If the element is in the same row, but was not inserted in the correct order, we need to fix the row after // If the element is in the same row, but was not inserted in the correct order, we need to fix the row after
// the insertion. // the insertion.
bool fixCurrentRow = row == lastRow && column < lastColumn; bool fixCurrentRow = row == lastRow && column < lastColumn;
// If the element is in the same row and column as the previous entry, we add them up.
if (row == lastRow && column == lastColumn && !columnsAndValues.empty()) {
// If the element is in the same row and column as the previous entry, we add them up...
// unless there is no entry in this row yet, which might happen either for the very first entry or when only a diagonal value has been added
if (row == lastRow && column == lastColumn && rowIndications.back() < currentEntryCount) {
columnsAndValues.back().setValue(columnsAndValues.back().getValue() + value); columnsAndValues.back().setValue(columnsAndValues.back().getValue() + value);
} else { } else {
// If we switched to another row, we have to adjust the missing entries in the row indices vector. // If we switched to another row, we have to adjust the missing entries in the row indices vector.
if (row != lastRow) { if (row != lastRow) {
// Otherwise, we need to push the correct values to the vectors, which might trigger reallocations. // Otherwise, we need to push the correct values to the vectors, which might trigger reallocations.
for (index_type i = lastRow + 1; i <= row; ++i) {
rowIndications.push_back(currentEntryCount);
}
assert(rowIndications.size() == lastRow + 1);
rowIndications.resize(row + 1, currentEntryCount);
lastRow = row; lastRow = row;
} }
@ -183,15 +201,24 @@ namespace storm {
void SparseMatrixBuilder<ValueType>::newRowGroup(index_type startingRow) { void SparseMatrixBuilder<ValueType>::newRowGroup(index_type startingRow) {
STORM_LOG_THROW(hasCustomRowGrouping, storm::exceptions::InvalidStateException, "Matrix was not created to have a custom row grouping."); STORM_LOG_THROW(hasCustomRowGrouping, storm::exceptions::InvalidStateException, "Matrix was not created to have a custom row grouping.");
STORM_LOG_THROW(startingRow >= lastRow, storm::exceptions::InvalidStateException, "Illegal row group with negative size."); STORM_LOG_THROW(startingRow >= lastRow, storm::exceptions::InvalidStateException, "Illegal row group with negative size.");
rowGroupIndices.get().push_back(startingRow);
++currentRowGroupCount;
// Close all rows from the most recent one to the starting row.
for (index_type i = lastRow + 1; i < startingRow; ++i) {
rowIndications.push_back(currentEntryCount);
// If there still is a pending diagonal entry, we need to add it now (otherwise, the correct diagonal column will be unclear)
if (pendingDiagonalEntry) {
STORM_LOG_ASSERT(currentRowGroupCount > 0, "Diagonal entry was set before opening the first row group.");
index_type diagColumn = currentRowGroupCount - 1;
ValueType diagValue = std::move(pendingDiagonalEntry.get());
pendingDiagonalEntry = boost::none; // clear now, so addNextValue works properly
addNextValue(lastRow, diagColumn, diagValue);
} }
rowGroupIndices.get().push_back(startingRow);
++currentRowGroupCount;
// Handle the case where the previous row group ends with one or more empty rows
if (lastRow + 1 < startingRow) { if (lastRow + 1 < startingRow) {
// Close all rows from the most recent one to the starting row.
assert(rowIndications.size() == lastRow + 1);
rowIndications.resize(startingRow, currentEntryCount);
// Reset the most recently seen row/column to allow for proper insertion of the following elements. // Reset the most recently seen row/column to allow for proper insertion of the following elements.
lastRow = startingRow - 1; lastRow = startingRow - 1;
lastColumn = 0; lastColumn = 0;
@ -201,6 +228,14 @@ namespace storm {
template<typename ValueType> template<typename ValueType>
SparseMatrix<ValueType> SparseMatrixBuilder<ValueType>::build(index_type overriddenRowCount, index_type overriddenColumnCount, index_type overriddenRowGroupCount) { SparseMatrix<ValueType> SparseMatrixBuilder<ValueType>::build(index_type overriddenRowCount, index_type overriddenColumnCount, index_type overriddenRowGroupCount) {
// If there still is a pending diagonal entry, we need to add it now
if (pendingDiagonalEntry) {
index_type diagColumn = hasCustomRowGrouping ? currentRowGroupCount - 1 : lastRow;
ValueType diagValue = std::move(pendingDiagonalEntry.get());
pendingDiagonalEntry = boost::none; // clear now, so addNextValue works properly
addNextValue(lastRow, diagColumn, diagValue);
}
bool hasEntries = currentEntryCount != 0; bool hasEntries = currentEntryCount != 0;
uint_fast64_t rowCount = hasEntries ? lastRow + 1 : 0; uint_fast64_t rowCount = hasEntries ? lastRow + 1 : 0;
@ -332,7 +367,32 @@ namespace storm {
} }
highestColumn = maxColumn; highestColumn = maxColumn;
lastColumn = columnsAndValues.empty() ? 0 : columnsAndValues[columnsAndValues.size() - 1].getColumn();
lastColumn = columnsAndValues.empty() ? 0 : columnsAndValues.back().getColumn();
}
template<typename ValueType>
void SparseMatrixBuilder<ValueType>::addDiagonalEntry(index_type row, ValueType const& value) {
STORM_LOG_THROW(row >= lastRow, storm::exceptions::InvalidArgumentException, "Adding a diagonal element in row " << row << ", but an element in row " << lastRow << " has already been added.");
if (pendingDiagonalEntry) {
if (row == lastRow) {
// Add the two diagonal entries, nothing else to be done.
pendingDiagonalEntry.get() += value;
return;
} else {
// add the pending entry
index_type column = hasCustomRowGrouping ? currentRowGroupCount - 1 : lastRow;
ValueType diagValue = std::move(pendingDiagonalEntry.get());
pendingDiagonalEntry = boost::none; // clear now, so addNextValue works properly
addNextValue(lastRow, column, diagValue);
}
}
pendingDiagonalEntry = value;
if (lastRow != row) {
assert(rowIndications.size() == lastRow + 1);
rowIndications.resize(row + 1, currentEntryCount);
lastRow = row;
lastColumn = 0;
}
} }
template<typename ValueType> template<typename ValueType>

11
src/storm/storage/SparseMatrix.h

@ -244,6 +244,15 @@ namespace storm {
*/ */
void replaceColumns(std::vector<index_type> const& replacements, index_type offset); void replaceColumns(std::vector<index_type> const& replacements, index_type offset);
/*!
* Makes sure that a diagonal entry will be inserted at the given row.
* All other entries of this row must be set immediately after calling this (without setting values at other rows in between)
* The provided row must not be smaller than the row of the most recent insertion.
* If there is a row grouping, the column of the diagonal entry will correspond to the current row group.
* If addNextValue is called on the given row and the diagonal column, we take the sum of the two values provided to addDiagonalEntry and addNextValue
*/
void addDiagonalEntry(index_type row, ValueType const& value);
private: private:
// A flag indicating whether a row count was set upon construction. // A flag indicating whether a row count was set upon construction.
bool initialRowCountSet; bool initialRowCountSet;
@ -305,6 +314,8 @@ namespace storm {
// Stores the currently active row group. This is used for correctly constructing the row grouping of the // Stores the currently active row group. This is used for correctly constructing the row grouping of the
// matrix. // matrix.
index_type currentRowGroupCount; index_type currentRowGroupCount;
boost::optional<ValueType> pendingDiagonalEntry;
}; };
/*! /*!

73
src/test/storm/storage/SparseMatrixTest.cpp

@ -148,6 +148,79 @@ TEST(SparseMatrix, Build) {
ASSERT_EQ(5ul, matrix5.getEntryCount()); ASSERT_EQ(5ul, matrix5.getEntryCount());
} }
TEST(SparseMatrix, DiagonalEntries) {
{
// No row groupings
storm::storage::SparseMatrixBuilder<double> builder(4, 4, 7);
storm::storage::SparseMatrixBuilder<double> builderCmp(4, 4, 7);
for (uint64_t i = 0; i < 4; ++i) {
ASSERT_NO_THROW(builder.addDiagonalEntry(i, i));
ASSERT_NO_THROW(builder.addNextValue(i, 2, 100.0 + i));
if (i < 2) {
ASSERT_NO_THROW(builderCmp.addNextValue(i, i, i));
ASSERT_NO_THROW(builderCmp.addNextValue(i, 2, 100.0 + i));
} else {
ASSERT_NO_THROW(builderCmp.addNextValue(i, 2, 100.0 + i));
ASSERT_NO_THROW(builderCmp.addNextValue(i, i, i));
}
}
auto matrix = builder.build();
auto matrixCmp = builderCmp.build();
EXPECT_EQ(matrix, matrixCmp);
}
{
// With row groupings (each row group has 3 rows)
storm::storage::SparseMatrixBuilder<double> builder(12, 4, 21, true, true, 4);
storm::storage::SparseMatrixBuilder<double> builderCmp(12, 4, 21, true, true, 4);
for (uint64_t i = 0; i < 4; ++i) {
uint64_t row = 3*i;
builder.newRowGroup(row);
builderCmp.newRowGroup(row);
for (; row < 3*(i+1); ++row) {
ASSERT_NO_THROW(builder.addDiagonalEntry(row, row));
ASSERT_NO_THROW(builder.addNextValue(row, 2, 100 + row));
if (i < 2) {
ASSERT_NO_THROW(builderCmp.addNextValue(row, i, row));
ASSERT_NO_THROW(builderCmp.addNextValue(row, 2, 100.0 + row));
} else {
ASSERT_NO_THROW(builderCmp.addNextValue(row, 2, 100.0 + row));
ASSERT_NO_THROW(builderCmp.addNextValue(row, i, row));
}
}
}
auto matrix = builder.build();
auto matrixCmp = builderCmp.build();
EXPECT_EQ(matrix, matrixCmp);
}
{
// With row groupings (every second row is empty)
storm::storage::SparseMatrixBuilder<double> builder(12, 4, 10, true, true, 4);
storm::storage::SparseMatrixBuilder<double> builderCmp(12, 4, 10, true, true, 4);
for (uint64_t i = 0; i < 4; ++i) {
uint64_t row = 3*i;
builder.newRowGroup(row);
builderCmp.newRowGroup(row);
for (; row < 3*(i+1); ++row) {
if (row % 2 == 1) {
continue;
}
ASSERT_NO_THROW(builder.addDiagonalEntry(row, row));
ASSERT_NO_THROW(builder.addNextValue(row, 2, 100 + row));
if (i < 2) {
ASSERT_NO_THROW(builderCmp.addNextValue(row, i, row));
ASSERT_NO_THROW(builderCmp.addNextValue(row, 2, 100.0 + row));
} else {
ASSERT_NO_THROW(builderCmp.addNextValue(row, i, row));
ASSERT_NO_THROW(builderCmp.addNextValue(row, 2, 100.0 + row));
}
}
}
auto matrix = builder.build();
auto matrixCmp = builderCmp.build();
EXPECT_EQ(matrix, matrixCmp);
}
}
TEST(SparseMatrix, CreationWithMovingContents) { TEST(SparseMatrix, CreationWithMovingContents) {
std::vector<storm::storage::MatrixEntry<uint_fast64_t, double>> columnsAndValues; std::vector<storm::storage::MatrixEntry<uint_fast64_t, double>> columnsAndValues;
columnsAndValues.emplace_back(1, 1.0); columnsAndValues.emplace_back(1, 1.0);

Loading…
Cancel
Save