Browse Source

permute for matrices

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
8a7e40558d
  1. 22
      src/storm/storage/SparseMatrix.cpp
  2. 8
      src/storm/storage/SparseMatrix.h
  3. 25
      src/test/storm/storage/SparseMatrixTest.cpp

22
src/storm/storage/SparseMatrix.cpp

@ -1187,6 +1187,28 @@ namespace storm {
return matrixBuilder.build(); return matrixBuilder.build();
} }
template<typename ValueType>
SparseMatrix<ValueType> SparseMatrix<ValueType>::permuteRows(std::vector<index_type> const& inversePermutation) const {
// Now create the matrix to be returned with the appropriate size.
// The entry size is only adequate if this is indeed a permutation.
SparseMatrixBuilder<ValueType> matrixBuilder(inversePermutation.size(), columnCount, entryCount);
// Copy over the selected rows from the source matrix.
for (index_type writeTo = 0; writeTo < inversePermutation.size(); ++writeTo) {
index_type const &readFrom = inversePermutation[writeTo];
auto row = this->getRow(readFrom);
for (auto const& entry : row) {
matrixBuilder.addNextValue(writeTo, entry.getColumn(), entry.getValue());
}
}
// Finally create matrix and return result.
auto result = matrixBuilder.build();
result.setRowGroupIndices(this->rowGroupIndices.get());
return result;
}
template <typename ValueType> template <typename ValueType>
SparseMatrix<ValueType> SparseMatrix<ValueType>::transpose(bool joinGroups, bool keepZeros) const { SparseMatrix<ValueType> SparseMatrix<ValueType>::transpose(bool joinGroups, bool keepZeros) const {
index_type rowCount = this->getColumnCount(); index_type rowCount = this->getColumnCount();

8
src/storm/storage/SparseMatrix.h

@ -717,6 +717,14 @@ namespace storm {
*/ */
SparseMatrix restrictRows(storm::storage::BitVector const& rowsToKeep, bool allowEmptyRowGroups = false) const; SparseMatrix restrictRows(storm::storage::BitVector const& rowsToKeep, bool allowEmptyRowGroups = false) const;
/*
* Permute rows of the matrix according to the vector.
* That is, in row i, write the entry of row inversePermutation[i].
* Consequently, a single row might actually be written into multiple other rows, and the function application is not necessarily a permutation.
* Notice that this method does *not* touch column entries, nor the row grouping.
*/
SparseMatrix permuteRows(std::vector<index_type> const& inversePermutation) const;
/*! /*!
* Returns a copy of this matrix that only considers entries in the selected rows. * Returns a copy of this matrix that only considers entries in the selected rows.
* Non-selected rows will not have any entries * Non-selected rows will not have any entries

25
src/test/storm/storage/SparseMatrixTest.cpp

@ -694,3 +694,28 @@ TEST(SparseMatrix, IsSubmatrix) {
ASSERT_FALSE(matrix3.isSubmatrixOf(matrix)); ASSERT_FALSE(matrix3.isSubmatrixOf(matrix));
ASSERT_FALSE(matrix3.isSubmatrixOf(matrix2)); ASSERT_FALSE(matrix3.isSubmatrixOf(matrix2));
} }
TEST(SparseMatrix, Permute) {
storm::storage::SparseMatrixBuilder<double> matrixBuilder(5, 4, 8);
ASSERT_NO_THROW(matrixBuilder.addNextValue(0, 1, 1.0));
ASSERT_NO_THROW(matrixBuilder.addNextValue(0, 2, 1.2));
ASSERT_NO_THROW(matrixBuilder.addNextValue(1, 0, 0.5));
ASSERT_NO_THROW(matrixBuilder.addNextValue(1, 1, 0.7));
ASSERT_NO_THROW(matrixBuilder.addNextValue(3, 2, 1.1));
ASSERT_NO_THROW(matrixBuilder.addNextValue(4, 0, 0.1));
ASSERT_NO_THROW(matrixBuilder.addNextValue(4, 1, 0.2));
ASSERT_NO_THROW(matrixBuilder.addNextValue(4, 3, 0.3));
storm::storage::SparseMatrix<double> matrix;
ASSERT_NO_THROW(matrix = matrixBuilder.build());
std::vector<uint64_t> inversePermutation = {1,4,0,3,2};
storm::storage::SparseMatrix<double> matrixperm = matrix.permuteRows(inversePermutation);
EXPECT_EQ(5, matrixperm.getRowCount());
EXPECT_EQ(4, matrixperm.getColumnCount());
EXPECT_EQ(8, matrixperm.getEntryCount());
EXPECT_EQ(matrix.getRowSum(1), matrixperm.getRowSum(0));
EXPECT_EQ(matrix.getRowSum(4), matrixperm.getRowSum(1));
EXPECT_EQ(matrix.getRowSum(0), matrixperm.getRowSum(2));
EXPECT_EQ(matrix.getRowSum(3), matrixperm.getRowSum(3));
EXPECT_EQ(matrix.getRowSum(2), matrixperm.getRowSum(4));
}
Loading…
Cancel
Save