Browse Source

permute for matrices

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
8a7e40558d
  1. 30
      src/storm/storage/SparseMatrix.cpp
  2. 10
      src/storm/storage/SparseMatrix.h
  3. 25
      src/test/storm/storage/SparseMatrixTest.cpp

30
src/storm/storage/SparseMatrix.cpp

@ -1144,7 +1144,7 @@ namespace storm {
// Finalize created matrix and return result.
return matrixBuilder.build();
}
template<typename ValueType>
SparseMatrix<ValueType> SparseMatrix<ValueType>::selectRowsFromRowIndexSequence(std::vector<index_type> const& rowIndexSequence, bool insertDiagonalEntries) const{
// First, we need to count how many non-zero entries the resulting matrix will have and reserve space for
@ -1162,10 +1162,10 @@ namespace storm {
++newEntries;
}
}
// Now create the matrix to be returned with the appropriate size.
SparseMatrixBuilder<ValueType> matrixBuilder(rowIndexSequence.size(), columnCount, newEntries);
// Copy over the selected rows from the source matrix.
for(index_type row = 0, rowEnd = rowIndexSequence.size(); row < rowEnd; ++row) {
bool insertedDiagonalElement = false;
@ -1182,10 +1182,32 @@ namespace storm {
matrixBuilder.addNextValue(row, row, storm::utility::zero<ValueType>());
}
}
// Finally create matrix and return result.
return matrixBuilder.build();
}
template<typename ValueType>
SparseMatrix<ValueType> SparseMatrix<ValueType>::permuteRows(std::vector<index_type> const& inversePermutation) const {
// Now create the matrix to be returned with the appropriate size.
// The entry size is only adequate if this is indeed a permutation.
SparseMatrixBuilder<ValueType> matrixBuilder(inversePermutation.size(), columnCount, entryCount);
// Copy over the selected rows from the source matrix.
for (index_type writeTo = 0; writeTo < inversePermutation.size(); ++writeTo) {
index_type const &readFrom = inversePermutation[writeTo];
auto row = this->getRow(readFrom);
for (auto const& entry : row) {
matrixBuilder.addNextValue(writeTo, entry.getColumn(), entry.getValue());
}
}
// Finally create matrix and return result.
auto result = matrixBuilder.build();
result.setRowGroupIndices(this->rowGroupIndices.get());
return result;
}
template <typename ValueType>
SparseMatrix<ValueType> SparseMatrix<ValueType>::transpose(bool joinGroups, bool keepZeros) const {

10
src/storm/storage/SparseMatrix.h

@ -716,7 +716,15 @@ namespace storm {
*
*/
SparseMatrix restrictRows(storm::storage::BitVector const& rowsToKeep, bool allowEmptyRowGroups = false) const;
/*
* Permute rows of the matrix according to the vector.
* That is, in row i, write the entry of row inversePermutation[i].
* Consequently, a single row might actually be written into multiple other rows, and the function application is not necessarily a permutation.
* Notice that this method does *not* touch column entries, nor the row grouping.
*/
SparseMatrix permuteRows(std::vector<index_type> const& inversePermutation) const;
/*!
* Returns a copy of this matrix that only considers entries in the selected rows.
* Non-selected rows will not have any entries

25
src/test/storm/storage/SparseMatrixTest.cpp

@ -694,3 +694,28 @@ TEST(SparseMatrix, IsSubmatrix) {
ASSERT_FALSE(matrix3.isSubmatrixOf(matrix));
ASSERT_FALSE(matrix3.isSubmatrixOf(matrix2));
}
TEST(SparseMatrix, Permute) {
storm::storage::SparseMatrixBuilder<double> matrixBuilder(5, 4, 8);
ASSERT_NO_THROW(matrixBuilder.addNextValue(0, 1, 1.0));
ASSERT_NO_THROW(matrixBuilder.addNextValue(0, 2, 1.2));
ASSERT_NO_THROW(matrixBuilder.addNextValue(1, 0, 0.5));
ASSERT_NO_THROW(matrixBuilder.addNextValue(1, 1, 0.7));
ASSERT_NO_THROW(matrixBuilder.addNextValue(3, 2, 1.1));
ASSERT_NO_THROW(matrixBuilder.addNextValue(4, 0, 0.1));
ASSERT_NO_THROW(matrixBuilder.addNextValue(4, 1, 0.2));
ASSERT_NO_THROW(matrixBuilder.addNextValue(4, 3, 0.3));
storm::storage::SparseMatrix<double> matrix;
ASSERT_NO_THROW(matrix = matrixBuilder.build());
std::vector<uint64_t> inversePermutation = {1,4,0,3,2};
storm::storage::SparseMatrix<double> matrixperm = matrix.permuteRows(inversePermutation);
EXPECT_EQ(5, matrixperm.getRowCount());
EXPECT_EQ(4, matrixperm.getColumnCount());
EXPECT_EQ(8, matrixperm.getEntryCount());
EXPECT_EQ(matrix.getRowSum(1), matrixperm.getRowSum(0));
EXPECT_EQ(matrix.getRowSum(4), matrixperm.getRowSum(1));
EXPECT_EQ(matrix.getRowSum(0), matrixperm.getRowSum(2));
EXPECT_EQ(matrix.getRowSum(3), matrixperm.getRowSum(3));
EXPECT_EQ(matrix.getRowSum(2), matrixperm.getRowSum(4));
}
Loading…
Cancel
Save