Browse Source

New methods for the SparseMatrix: SetRowGroupIndices and filterEntries

tempestpy_adaptions
TimQu 7 years ago
parent
commit
142d034765
  1. 62
      src/storm/storage/SparseMatrix.cpp
  2. 43
      src/storm/storage/SparseMatrix.h

62
src/storm/storage/SparseMatrix.cpp

@ -584,6 +584,27 @@ namespace storm {
return rowGroupIndices.get(); return rowGroupIndices.get();
} }
template<typename ValueType>
void SparseMatrix<ValueType>::setRowGroupIndices(std::vector<index_type> const& newRowGroupIndices) {
trivialRowGrouping = false;
rowGroupIndices = newRowGroupIndices;
}
template<typename ValueType>
bool SparseMatrix<ValueType>::hasTrivialRowGrouping() const {
return trivialRowGrouping;
}
template<typename ValueType>
void SparseMatrix<ValueType>::makeRowGroupingTrivial() {
if (trivialRowGrouping) {
STORM_LOG_ASSERT(!rowGroupIndices || rowGroupIndices.get() == storm::utility::vector::buildVectorForRange(0, this->getRowGroupCount() + 1), "Row grouping is supposed to be trivial but actually it is not.");
} else {
trivialRowGrouping = true;
rowGroupIndices = boost::none;
}
}
template<typename ValueType> template<typename ValueType>
storm::storage::BitVector SparseMatrix<ValueType>::getRowFilter(storm::storage::BitVector const& groupConstraint) const { storm::storage::BitVector SparseMatrix<ValueType>::getRowFilter(storm::storage::BitVector const& groupConstraint) const {
storm::storage::BitVector res(this->getRowCount(), false); storm::storage::BitVector res(this->getRowCount(), false);
@ -984,6 +1005,30 @@ namespace storm {
return res; return res;
} }
template<typename ValueType>
SparseMatrix<ValueType> SparseMatrix<ValueType>::filterEntries(storm::storage::BitVector const& rowFilter) const {
// Count the number of entries in the resulting matrix.
index_type entryCount = 0;
for (auto const& row : rowFilter) {
entryCount += getRow(row).getNumberOfEntries();
}
// Build the resulting matrix.
SparseMatrixBuilder<ValueType> builder(getRowCount(), getColumnCount(), entryCount);
for (auto const& row : rowFilter) {
for (auto const& entry : getRow(row)) {
builder.addNextValue(row, entry.getColumn(), entry.getValue());
}
}
SparseMatrix<ValueType> result = builder.build();
// Add a row grouping if necessary.
if (!hasTrivialRowGrouping()) {
result.setRowGroupIndices(getRowGroupIndices());
}
return result;
}
template<typename ValueType> template<typename ValueType>
SparseMatrix<ValueType> SparseMatrix<ValueType>::selectRowsFromRowGroups(std::vector<index_type> const& rowGroupToRowIndexMapping, bool insertDiagonalEntries) const { SparseMatrix<ValueType> SparseMatrix<ValueType>::selectRowsFromRowGroups(std::vector<index_type> const& rowGroupToRowIndexMapping, bool insertDiagonalEntries) const {
// First, we need to count how many non-zero entries the resulting matrix will have and reserve space for // First, we need to count how many non-zero entries the resulting matrix will have and reserve space for
@ -1547,22 +1592,7 @@ namespace storm {
typename SparseMatrix<ValueType>::iterator SparseMatrix<ValueType>::end() { typename SparseMatrix<ValueType>::iterator SparseMatrix<ValueType>::end() {
return this->columnsAndValues.begin() + this->rowIndications[rowCount]; return this->columnsAndValues.begin() + this->rowIndications[rowCount];
} }
template<typename ValueType>
bool SparseMatrix<ValueType>::hasTrivialRowGrouping() const {
return trivialRowGrouping;
}
template<typename ValueType>
void SparseMatrix<ValueType>::makeRowGroupingTrivial() {
if (trivialRowGrouping) {
STORM_LOG_ASSERT(!rowGroupIndices || rowGroupIndices.get() == storm::utility::vector::buildVectorForRange(0, this->getRowGroupCount() + 1), "Row grouping is supposed to be trivial but actually it is not.");
} else {
trivialRowGrouping = true;
rowGroupIndices = boost::none;
}
}
template<typename ValueType> template<typename ValueType>
ValueType SparseMatrix<ValueType>::getRowSum(index_type row) const { ValueType SparseMatrix<ValueType>::getRowSum(index_type row) const {
ValueType sum = storm::utility::zero<ValueType>(); ValueType sum = storm::utility::zero<ValueType>();

43
src/storm/storage/SparseMatrix.h

@ -563,6 +563,27 @@ namespace storm {
*/ */
std::vector<index_type> const& getRowGroupIndices() const; std::vector<index_type> const& getRowGroupIndices() const;
/*!
* Sets the row grouping to the given one.
* @note It is assumed that the new row grouping is non-trivial.
*
* @param newRowGroupIndices The new row group indices.
*/
void setRowGroupIndices(std::vector<index_type> const& newRowGroupIndices);
/*!
* Retrieves whether the matrix has a trivial row grouping.
*
* @return True iff the matrix has a trivial row grouping.
*/
bool hasTrivialRowGrouping() const;
/*!
* Makes the row grouping of this matrix trivial.
* Has no effect when the row grouping is already trivial.
*/
void makeRowGroupingTrivial();
/*! /*!
* Returns the indices of the rows that belong to one of the selected row groups. * Returns the indices of the rows that belong to one of the selected row groups.
* *
@ -665,6 +686,15 @@ namespace storm {
*/ */
SparseMatrix restrictRows(storm::storage::BitVector const& rowsToKeep, bool allowEmptyRowGroups = false) const; SparseMatrix restrictRows(storm::storage::BitVector const& rowsToKeep, bool allowEmptyRowGroups = false) const;
/*!
* Returns a copy of this matrix that only considers entries in the selected rows.
* Non-selected rows will not have any entries
*
* @note does not change the dimensions (row-, column-, and rowGroup count) of this matrix
* @param rowFilter the selected rows
*/
SparseMatrix filterEntries(storm::storage::BitVector const& rowFilter) const;
/** /**
* Compares two rows. * Compares two rows.
* @param i1 Index of first row * @param i1 Index of first row
@ -1004,19 +1034,6 @@ namespace storm {
*/ */
iterator end(); iterator end();
/*!
* Retrieves whether the matrix has a trivial row grouping.
*
* @return True iff the matrix has a trivial row grouping.
*/
bool hasTrivialRowGrouping() const;
/*!
* Makes the row grouping of this matrix trivial.
* Has no effect when the row grouping is already trivial.
*/
void makeRowGroupingTrivial();
/*! /*!
* Returns a copy of the matrix with the chosen internal data type * Returns a copy of the matrix with the chosen internal data type
*/ */

Loading…
Cancel
Save