diff --git a/src/storm/storage/SparseMatrix.cpp b/src/storm/storage/SparseMatrix.cpp index d37177b6c..360e68ee8 100644 --- a/src/storm/storage/SparseMatrix.cpp +++ b/src/storm/storage/SparseMatrix.cpp @@ -1144,7 +1144,7 @@ namespace storm { // Finalize created matrix and return result. return matrixBuilder.build(); } - + template SparseMatrix SparseMatrix::selectRowsFromRowIndexSequence(std::vector const& rowIndexSequence, bool insertDiagonalEntries) const{ // First, we need to count how many non-zero entries the resulting matrix will have and reserve space for @@ -1162,10 +1162,10 @@ namespace storm { ++newEntries; } } - + // Now create the matrix to be returned with the appropriate size. SparseMatrixBuilder matrixBuilder(rowIndexSequence.size(), columnCount, newEntries); - + // Copy over the selected rows from the source matrix. for(index_type row = 0, rowEnd = rowIndexSequence.size(); row < rowEnd; ++row) { bool insertedDiagonalElement = false; @@ -1182,10 +1182,32 @@ namespace storm { matrixBuilder.addNextValue(row, row, storm::utility::zero()); } } - + // Finally create matrix and return result. return matrixBuilder.build(); } + + template + SparseMatrix SparseMatrix::permuteRows(std::vector const& inversePermutation) const { + + // Now create the matrix to be returned with the appropriate size. + // The entry size is only adequate if this is indeed a permutation. + SparseMatrixBuilder matrixBuilder(inversePermutation.size(), columnCount, entryCount); + + // Copy over the selected rows from the source matrix. + + for (index_type writeTo = 0; writeTo < inversePermutation.size(); ++writeTo) { + index_type const &readFrom = inversePermutation[writeTo]; + auto row = this->getRow(readFrom); + for (auto const& entry : row) { + matrixBuilder.addNextValue(writeTo, entry.getColumn(), entry.getValue()); + } + } + // Finally create matrix and return result. + auto result = matrixBuilder.build(); + result.setRowGroupIndices(this->rowGroupIndices.get()); + return result; + } template SparseMatrix SparseMatrix::transpose(bool joinGroups, bool keepZeros) const { diff --git a/src/storm/storage/SparseMatrix.h b/src/storm/storage/SparseMatrix.h index 880c1ee36..d3bbb5930 100644 --- a/src/storm/storage/SparseMatrix.h +++ b/src/storm/storage/SparseMatrix.h @@ -716,7 +716,15 @@ namespace storm { * */ SparseMatrix restrictRows(storm::storage::BitVector const& rowsToKeep, bool allowEmptyRowGroups = false) const; - + + /* + * Permute rows of the matrix according to the vector. + * That is, in row i, write the entry of row inversePermutation[i]. + * Consequently, a single row might actually be written into multiple other rows, and the function application is not necessarily a permutation. + * Notice that this method does *not* touch column entries, nor the row grouping. + */ + SparseMatrix permuteRows(std::vector const& inversePermutation) const; + /*! * Returns a copy of this matrix that only considers entries in the selected rows. * Non-selected rows will not have any entries diff --git a/src/test/storm/storage/SparseMatrixTest.cpp b/src/test/storm/storage/SparseMatrixTest.cpp index 0c1c208d2..6d0ea4480 100644 --- a/src/test/storm/storage/SparseMatrixTest.cpp +++ b/src/test/storm/storage/SparseMatrixTest.cpp @@ -694,3 +694,28 @@ TEST(SparseMatrix, IsSubmatrix) { ASSERT_FALSE(matrix3.isSubmatrixOf(matrix)); ASSERT_FALSE(matrix3.isSubmatrixOf(matrix2)); } + +TEST(SparseMatrix, Permute) { + storm::storage::SparseMatrixBuilder matrixBuilder(5, 4, 8); + ASSERT_NO_THROW(matrixBuilder.addNextValue(0, 1, 1.0)); + ASSERT_NO_THROW(matrixBuilder.addNextValue(0, 2, 1.2)); + ASSERT_NO_THROW(matrixBuilder.addNextValue(1, 0, 0.5)); + ASSERT_NO_THROW(matrixBuilder.addNextValue(1, 1, 0.7)); + ASSERT_NO_THROW(matrixBuilder.addNextValue(3, 2, 1.1)); + ASSERT_NO_THROW(matrixBuilder.addNextValue(4, 0, 0.1)); + ASSERT_NO_THROW(matrixBuilder.addNextValue(4, 1, 0.2)); + ASSERT_NO_THROW(matrixBuilder.addNextValue(4, 3, 0.3)); + storm::storage::SparseMatrix matrix; + ASSERT_NO_THROW(matrix = matrixBuilder.build()); + + std::vector inversePermutation = {1,4,0,3,2}; + storm::storage::SparseMatrix matrixperm = matrix.permuteRows(inversePermutation); + EXPECT_EQ(5, matrixperm.getRowCount()); + EXPECT_EQ(4, matrixperm.getColumnCount()); + EXPECT_EQ(8, matrixperm.getEntryCount()); + EXPECT_EQ(matrix.getRowSum(1), matrixperm.getRowSum(0)); + EXPECT_EQ(matrix.getRowSum(4), matrixperm.getRowSum(1)); + EXPECT_EQ(matrix.getRowSum(0), matrixperm.getRowSum(2)); + EXPECT_EQ(matrix.getRowSum(3), matrixperm.getRowSum(3)); + EXPECT_EQ(matrix.getRowSum(2), matrixperm.getRowSum(4)); +} \ No newline at end of file