Browse Source

started making row grouping optional

Former-commit-id: b90ae91e75
tempestpy_adaptions
dehnert 9 years ago
parent
commit
f81ce1cac1
  1. 152
      src/storage/SparseMatrix.cpp
  2. 46
      src/storage/SparseMatrix.h

152
src/storage/SparseMatrix.cpp

@ -80,8 +80,9 @@ namespace storm {
if (initialEntryCountSet) {
columnsAndValues.reserve(initialEntryCount);
}
if (initialRowGroupCountSet) {
rowGroupIndices.reserve(initialRowGroupCount + 1);
if (initialRowGroupCountSet && hasCustomRowGrouping) {
rowGroupIndices = std::vector<index_type>();
rowGroupIndices.get().reserve(initialRowGroupCount + 1);
}
rowIndications.push_back(0);
}
@ -92,8 +93,8 @@ namespace storm {
// If the matrix has a custom row grouping, we move it and remove the last element to make it 'open' again.
if (hasCustomRowGrouping) {
rowGroupIndices = std::move(matrix.rowGroupIndices);
rowGroupIndices.pop_back();
currentRowGroup = rowGroupIndices.size() - 1;
rowGroupIndices.get().pop_back();
currentRowGroup = rowGroupIndices.get().size() - 1;
}
// Likewise, we need to 'open' the row indications again.
@ -141,7 +142,7 @@ namespace storm {
void SparseMatrixBuilder<ValueType>::newRowGroup(index_type startingRow) {
STORM_LOG_THROW(hasCustomRowGrouping, storm::exceptions::InvalidStateException, "Matrix was not created to have a custom row grouping.");
STORM_LOG_THROW(startingRow >= lastRow, storm::exceptions::InvalidStateException, "Illegal row group with negative size.");
rowGroupIndices.push_back(startingRow);
rowGroupIndices.get().push_back(startingRow);
++currentRowGroup;
// Close all rows from the most recent one to the starting row.
@ -186,11 +187,7 @@ namespace storm {
}
// Check whether row groups are missing some entries.
if (!hasCustomRowGrouping) {
for (index_type i = 0; i <= rowCount; ++i) {
rowGroupIndices.push_back(i);
}
} else {
if (hasCustomRowGrouping) {
uint_fast64_t rowGroupCount = currentRowGroup;
if (initialRowGroupCountSet && forceInitialDimensions) {
STORM_LOG_THROW(rowGroupCount <= initialRowGroupCount, storm::exceptions::InvalidStateException, "Expected not more than " << initialRowGroupCount << " row groups, but got " << rowGroupCount << ".");
@ -199,7 +196,7 @@ namespace storm {
rowGroupCount = std::max(rowGroupCount, overriddenRowGroupCount);
for (index_type i = currentRowGroup; i <= rowGroupCount; ++i) {
rowGroupIndices.push_back(rowCount);
rowGroupIndices.get().push_back(rowCount);
}
}
@ -295,12 +292,12 @@ namespace storm {
}
template<typename ValueType>
SparseMatrix<ValueType>::SparseMatrix() : rowCount(0), columnCount(0), entryCount(0), nonzeroEntryCount(0), columnsAndValues(), rowIndications(), nontrivialRowGrouping(false), rowGroupIndices() {
SparseMatrix<ValueType>::SparseMatrix() : rowCount(0), columnCount(0), entryCount(0), nonzeroEntryCount(0), columnsAndValues(), rowIndications(), rowGroupIndices() {
// Intentionally left empty.
}
template<typename ValueType>
SparseMatrix<ValueType>::SparseMatrix(SparseMatrix<ValueType> const& other) : rowCount(other.rowCount), columnCount(other.columnCount), entryCount(other.entryCount), nonzeroEntryCount(other.nonzeroEntryCount), columnsAndValues(other.columnsAndValues), rowIndications(other.rowIndications), nontrivialRowGrouping(other.nontrivialRowGrouping), rowGroupIndices(other.rowGroupIndices) {
SparseMatrix<ValueType>::SparseMatrix(SparseMatrix<ValueType> const& other) : rowCount(other.rowCount), columnCount(other.columnCount), entryCount(other.entryCount), nonzeroEntryCount(other.nonzeroEntryCount), columnsAndValues(other.columnsAndValues), rowIndications(other.rowIndications), rowGroupIndices(other.rowGroupIndices) {
// Intentionally left empty.
}
@ -312,7 +309,7 @@ namespace storm {
}
template<typename ValueType>
SparseMatrix<ValueType>::SparseMatrix(SparseMatrix<ValueType>&& other) : rowCount(other.rowCount), columnCount(other.columnCount), entryCount(other.entryCount), nonzeroEntryCount(other.nonzeroEntryCount), columnsAndValues(std::move(other.columnsAndValues)), rowIndications(std::move(other.rowIndications)), nontrivialRowGrouping(other.nontrivialRowGrouping), rowGroupIndices(std::move(other.rowGroupIndices)) {
SparseMatrix<ValueType>::SparseMatrix(SparseMatrix<ValueType>&& other) : rowCount(other.rowCount), columnCount(other.columnCount), entryCount(other.entryCount), nonzeroEntryCount(other.nonzeroEntryCount), columnsAndValues(std::move(other.columnsAndValues)), rowIndications(std::move(other.rowIndications)), rowGroupIndices(std::move(other.rowGroupIndices)) {
// Now update the source matrix
other.rowCount = 0;
other.columnCount = 0;
@ -320,12 +317,12 @@ namespace storm {
}
template<typename ValueType>
SparseMatrix<ValueType>::SparseMatrix(index_type columnCount, std::vector<index_type> const& rowIndications, std::vector<MatrixEntry<index_type, ValueType>> const& columnsAndValues, std::vector<index_type> const& rowGroupIndices, bool nontrivialRowGrouping) : rowCount(rowIndications.size() - 1), columnCount(columnCount), entryCount(columnsAndValues.size()), nonzeroEntryCount(0), columnsAndValues(columnsAndValues), rowIndications(rowIndications), nontrivialRowGrouping(nontrivialRowGrouping), rowGroupIndices(rowGroupIndices) {
SparseMatrix<ValueType>::SparseMatrix(index_type columnCount, std::vector<index_type> const& rowIndications, std::vector<MatrixEntry<index_type, ValueType>> const& columnsAndValues, boost::optional<std::vector<index_type>> const& rowGroupIndices) : rowCount(rowIndications.size() - 1), columnCount(columnCount), entryCount(columnsAndValues.size()), nonzeroEntryCount(0), columnsAndValues(columnsAndValues), rowIndications(rowIndications), trivialRowGrouping(!rowGroupIndices), rowGroupIndices(rowGroupIndices) {
this->updateNonzeroEntryCount();
}
template<typename ValueType>
SparseMatrix<ValueType>::SparseMatrix(index_type columnCount, std::vector<index_type>&& rowIndications, std::vector<MatrixEntry<index_type, ValueType>>&& columnsAndValues, std::vector<index_type>&& rowGroupIndices, bool nontrivialRowGrouping) : rowCount(rowIndications.size() - 1), columnCount(columnCount), entryCount(columnsAndValues.size()), nonzeroEntryCount(0), columnsAndValues(std::move(columnsAndValues)), rowIndications(std::move(rowIndications)), nontrivialRowGrouping(nontrivialRowGrouping), rowGroupIndices(std::move(rowGroupIndices)) {
SparseMatrix<ValueType>::SparseMatrix(index_type columnCount, std::vector<index_type>&& rowIndications, std::vector<MatrixEntry<index_type, ValueType>>&& columnsAndValues, boost::optional<std::vector<index_type>>&& rowGroupIndices, bool nontrivialRowGrouping) : rowCount(rowIndications.size() - 1), columnCount(columnCount), entryCount(columnsAndValues.size()), nonzeroEntryCount(0), columnsAndValues(std::move(columnsAndValues)), rowIndications(std::move(rowIndications)), trivialRowGrouping(!rowGroupIndices), rowGroupIndices(std::move(rowGroupIndices)) {
this->updateNonzeroEntryCount();
}
@ -341,7 +338,7 @@ namespace storm {
columnsAndValues = other.columnsAndValues;
rowIndications = other.rowIndications;
rowGroupIndices = other.rowGroupIndices;
nontrivialRowGrouping = other.nontrivialRowGrouping;
trivialRowGrouping = other.trivialRowGrouping;
}
return *this;
@ -359,7 +356,7 @@ namespace storm {
columnsAndValues = std::move(other.columnsAndValues);
rowIndications = std::move(other.rowIndications);
rowGroupIndices = std::move(other.rowGroupIndices);
nontrivialRowGrouping = other.nontrivialRowGrouping;
trivialRowGrouping = other.trivialRowGrouping;
}
return *this;
@ -381,7 +378,11 @@ namespace storm {
if (!equalityResult) {
return false;
}
if (!this->hasTrivialRowGrouping() && !other.hasTrivialRowGrouping()) {
equalityResult &= this->getRowGroupIndices() == other.getRowGroupIndices();
} else {
equalityResult &= this->hasTrivialRowGrouping() && other.hasTrivialRowGrouping();
}
if (!equalityResult) {
return false;
}
@ -433,9 +434,13 @@ namespace storm {
template<typename T>
uint_fast64_t SparseMatrix<T>::getRowGroupEntryCount(uint_fast64_t const group) const {
uint_fast64_t result = 0;
if (!this->hasTrivialRowGrouping()) {
for (uint_fast64_t row = this->getRowGroupIndices()[group]; row < this->getRowGroupIndices()[group + 1]; ++row) {
result += (this->rowIndications[row + 1] - this->rowIndications[row]);
}
} else {
result += (this->rowIndications[group + 1] - this->rowIndications[group]);
}
return result;
}
@ -473,7 +478,11 @@ namespace storm {
template<typename ValueType>
typename SparseMatrix<ValueType>::index_type SparseMatrix<ValueType>::getRowGroupCount() const {
return rowGroupIndices.size() - 1;
if (!this->hasTrivialRowGrouping()) {
return rowGroupIndices.get().size() - 1;
} else {
return rowCount;
}
}
template<typename ValueType>
@ -483,7 +492,15 @@ namespace storm {
template<typename ValueType>
std::vector<typename SparseMatrix<ValueType>::index_type> const& SparseMatrix<ValueType>::getRowGroupIndices() const {
return rowGroupIndices;
// If there is no current row grouping, we need to create it.
if (!this->rowGroupIndices) {
STORM_LOG_ASSERT(trivialRowGrouping, "Only trivial row-groupings can be constructed on-the-fly.");
this->rowGroupIndices = std::vector<index_type>(this->getRowCount() + 1);
for (uint_fast64_t group = 0; group < this->getRowCount(); ++group) {
this->rowGroupIndices.get()[group] = group;
}
}
return rowGroupIndices.get();
}
template<typename ValueType>
@ -495,11 +512,17 @@ namespace storm {
template<typename ValueType>
void SparseMatrix<ValueType>::makeRowGroupsAbsorbing(storm::storage::BitVector const& rowGroupConstraint) {
if (!this->hasTrivialRowGrouping()) {
for (auto rowGroup : rowGroupConstraint) {
for (index_type row = this->getRowGroupIndices()[rowGroup]; row < this->getRowGroupIndices()[rowGroup + 1]; ++row) {
makeRowDirac(row, rowGroup);
}
}
} else {
for (auto rowGroup : rowGroupConstraint) {
makeRowDirac(rowGroup, rowGroup);
}
}
}
template<typename ValueType>
@ -550,11 +573,17 @@ namespace storm {
std::vector<ValueType> SparseMatrix<ValueType>::getConstrainedRowGroupSumVector(storm::storage::BitVector const& rowGroupConstraint, storm::storage::BitVector const& columnConstraint) const {
std::vector<ValueType> result;
result.reserve(rowGroupConstraint.getNumberOfSetBits());
if (!this->hasTrivialRowGrouping()) {
for (auto rowGroup : rowGroupConstraint) {
for (index_type row = this->getRowGroupIndices()[rowGroup]; row < this->getRowGroupIndices()[rowGroup + 1]; ++row) {
result.push_back(getConstrainedRowSum(row, columnConstraint));
}
}
} else {
for (auto rowGroup : rowGroupConstraint) {
result.push_back(getConstrainedRowSum(rowGroup, columnConstraint));
}
}
return result;
}
@ -571,7 +600,9 @@ namespace storm {
}
auto res = getSubmatrix(rowConstraint, columnConstraint, fakeRowGroupIndices, insertDiagonalElements);
// Create a new row grouping that reflects the new sizes of the row groups.
// Create a new row grouping that reflects the new sizes of the row groups if the current matrix has a
// non trivial row-grouping.
if (!this->hasTrivialRowGrouping()) {
std::vector<uint_fast64_t> newRowGroupIndices;
newRowGroupIndices.push_back(0);
auto selectedRowIt = rowConstraint.begin();
@ -588,7 +619,10 @@ namespace storm {
}
}
res.trivialRowGrouping = false;
res.rowGroupIndices = newRowGroupIndices;
}
return res;
}
}
@ -635,13 +669,13 @@ namespace storm {
}
// Create and initialize resulting matrix.
SparseMatrixBuilder<ValueType> matrixBuilder(subRows, submatrixColumnCount, subEntries, true, this->hasNontrivialRowGrouping());
SparseMatrixBuilder<ValueType> matrixBuilder(subRows, submatrixColumnCount, subEntries, true, !this->hasTrivialRowGrouping());
// Copy over selected entries.
rowGroupCount = 0;
index_type rowCount = 0;
for (auto index : rowGroupConstraint) {
if (this->hasNontrivialRowGrouping()) {
if (!this->hasTrivialRowGrouping()) {
matrixBuilder.newRowGroup(rowCount);
}
for (index_type i = rowGroupIndices[index]; i < rowGroupIndices[index + 1]; ++i) {
@ -677,12 +711,12 @@ namespace storm {
template<typename ValueType>
SparseMatrix<ValueType> SparseMatrix<ValueType>::restrictRows(storm::storage::BitVector const& rowsToKeep) const {
// For now, we use the expensive call to submatrix.
assert(rowsToKeep.size() == getRowCount());
assert(rowsToKeep.getNumberOfSetBits() >= getRowGroupCount());
STORM_LOG_ASSERT(rowsToKeep.size() == this->getRowCount(), "Dimensions mismatch.");
STORM_LOG_ASSERT(rowsToKeep.getNumberOfSetBits() >= this->getRowGroupCount(), "Invalid dimensions.");
SparseMatrix<ValueType> res(getSubmatrix(false, rowsToKeep, storm::storage::BitVector(getColumnCount(), true), false));
assert(res.getRowCount() == rowsToKeep.getNumberOfSetBits());
assert(res.getColumnCount() == getColumnCount());
assert(getRowGroupCount() == res.getRowGroupCount());
STORM_LOG_ASSERT(res.getRowCount() == rowsToKeep.getNumberOfSetBits(), "Invalid dimensions");
STORM_LOG_ASSERT(res.getColumnCount() == this->getColumnCount(), "Invalid dimensions");
STORM_LOG_ASSERT(this->getRowGroupCount() == res.getRowGroupCount(), "Invalid dimensions");
return res;
}
@ -693,7 +727,7 @@ namespace storm {
index_type subEntries = 0;
for (index_type rowGroupIndex = 0, rowGroupIndexEnd = rowGroupToRowIndexMapping.size(); rowGroupIndex < rowGroupIndexEnd; ++rowGroupIndex) {
// Determine which row we need to select from the current row group.
index_type rowToCopy = rowGroupIndices[rowGroupIndex] + rowGroupToRowIndexMapping[rowGroupIndex];
index_type rowToCopy = this->getRowGroupIndices()[rowGroupIndex] + rowGroupToRowIndexMapping[rowGroupIndex];
// Iterate through that row and count the number of slots we have to reserve for copying.
bool foundDiagonalElement = false;
@ -709,12 +743,12 @@ namespace storm {
}
// Now create the matrix to be returned with the appropriate size.
SparseMatrixBuilder<ValueType> matrixBuilder(rowGroupIndices.size() - 1, columnCount, subEntries);
SparseMatrixBuilder<ValueType> matrixBuilder(rowGroupIndices.get().size() - 1, columnCount, subEntries);
// Copy over the selected lines from the source matrix.
for (index_type rowGroupIndex = 0, rowGroupIndexEnd = rowGroupToRowIndexMapping.size(); rowGroupIndex < rowGroupIndexEnd; ++rowGroupIndex) {
// Determine which row we need to select from the current row group.
index_type rowToCopy = rowGroupIndices[rowGroupIndex] + rowGroupToRowIndexMapping[rowGroupIndex];
index_type rowToCopy = this->getRowGroupIndices()[rowGroupIndex] + rowGroupToRowIndexMapping[rowGroupIndex];
// Iterate through that row and copy the entries. This also inserts a zero element on the diagonal if
// there is no entry yet.
@ -781,12 +815,7 @@ namespace storm {
}
}
std::vector<index_type> rowGroupIndices(rowCount + 1);
for (index_type i = 0; i <= rowCount; ++i) {
rowGroupIndices[i] = i;
}
storm::storage::SparseMatrix<ValueType> transposedMatrix(columnCount, std::move(rowIndications), std::move(columnsAndValues), std::move(rowGroupIndices), false);
storm::storage::SparseMatrix<ValueType> transposedMatrix(columnCount, std::move(rowIndications), std::move(columnsAndValues), boost::none);
return transposedMatrix;
}
@ -1056,30 +1085,46 @@ namespace storm {
template<typename ValueType>
typename SparseMatrix<ValueType>::const_rows SparseMatrix<ValueType>::getRow(index_type rowGroup, index_type offset) const {
assert(rowGroup < this->getRowGroupCount());
assert(offset < this->getRowGroupEntryCount(rowGroup));
return getRow(rowGroupIndices[rowGroup] + offset);
STORM_LOG_ASSERT(rowGroup < this->getRowGroupCount(), "Row group is out-of-bounds.");
STORM_LOG_ASSERT(offset < this->getRowGroupEntryCount(rowGroup), "Row offset in row-group is out-of-bounds.");
if (!this->hasTrivialRowGrouping()) {
return getRow(this->getRowGroupIndices()[rowGroup] + offset);
} else {
return getRow(this->getRowGroupIndices()[rowGroup] + offset);
}
}
template<typename ValueType>
typename SparseMatrix<ValueType>::rows SparseMatrix<ValueType>::getRow(index_type rowGroup, index_type offset) {
assert(rowGroup < this->getRowGroupCount());
assert(offset < this->getRowGroupEntryCount(rowGroup));
return getRow(rowGroupIndices[rowGroup] + offset);
STORM_LOG_ASSERT(rowGroup < this->getRowGroupCount(), "Row group is out-of-bounds.");
STORM_LOG_ASSERT(offset < this->getRowGroupEntryCount(rowGroup), "Row offset in row-group is out-of-bounds.");
if (!this->hasTrivialRowGrouping()) {
return getRow(this->getRowGroupIndices()[rowGroup] + offset);
} else {
return getRow(rowGroup + offset);
}
}
template<typename ValueType>
typename SparseMatrix<ValueType>::const_rows SparseMatrix<ValueType>::getRowGroup(index_type rowGroup) const {
assert(rowGroup < this->getRowGroupCount());
return getRows(rowGroupIndices[rowGroup], rowGroupIndices[rowGroup + 1] - 1);
STORM_LOG_ASSERT(rowGroup < this->getRowGroupCount(), "Row group is out-of-bounds.");
if (!this->hasTrivialRowGrouping()) {
return getRows(this->getRowGroupIndices()[rowGroup], this->getRowGroupIndices()[rowGroup + 1] - 1);
} else {
return getRows(rowGroup, rowGroup + 1);
}
}
template<typename ValueType>
typename SparseMatrix<ValueType>::rows SparseMatrix<ValueType>::getRowGroup(index_type rowGroup) {
assert(rowGroup < this->getRowGroupCount());
return getRows(rowGroupIndices[rowGroup], rowGroupIndices[rowGroup + 1] - 1);
STORM_LOG_ASSERT(rowGroup < this->getRowGroupCount(), "Row group is out-of-bounds.");
if (!this->hasTrivialRowGrouping()) {
return getRows(this->getRowGroupIndices()[rowGroup], this->getRowGroupIndices()[rowGroup + 1] - 1);
} else {
return getRows(rowGroup, rowGroup + 1);
}
}
template<typename ValueType>
@ -1113,8 +1158,8 @@ namespace storm {
}
template<typename ValueType>
bool SparseMatrix<ValueType>::hasNontrivialRowGrouping() const {
return nontrivialRowGrouping;
bool SparseMatrix<ValueType>::hasTrivialRowGrouping() const {
return trivialRowGrouping;
}
template<typename ValueType>
@ -1143,6 +1188,9 @@ namespace storm {
// Check for matching sizes.
if (this->getRowCount() != matrix.getRowCount()) return false;
if (this->getColumnCount() != matrix.getColumnCount()) return false;
if (this->hasTrivialRowGrouping() && !matrix.hasTrivialRowGrouping()) return false;
if (!this->hasTrivialRowGrouping() && matrix.hasTrivialRowGrouping()) return false;
if (!this->hasTrivialRowGrouping() && !matrix.hasTrivialRowGrouping() && this->getRowGroupIndices() != matrix.getRowGroupIndices()) return false;
if (this->getRowGroupIndices() != matrix.getRowGroupIndices()) return false;
// Check the subset property for all rows individually.
@ -1174,7 +1222,7 @@ namespace storm {
// Iterate over all row groups.
for (typename SparseMatrix<ValueType>::index_type group = 0; group < matrix.getRowGroupCount(); ++group) {
out << "\t---- group " << group << "/" << (matrix.getRowGroupCount() - 1) << " ---- " << std::endl;
for (typename SparseMatrix<ValueType>::index_type i = matrix.getRowGroupIndices()[group]; i < matrix.getRowGroupIndices()[group + 1]; ++i) {
for (typename SparseMatrix<ValueType>::index_type i = matrix.hasTrivialRowGrouping() ? group : matrix.getRowGroupIndices()[group]; i < matrix.hasTrivialRowGrouping() ? group + 1 : matrix.getRowGroupIndices()[group + 1]; ++i) {
typename SparseMatrix<ValueType>::index_type nextIndex = matrix.rowIndications[i];
// Print the actual row.
@ -1237,7 +1285,9 @@ namespace storm {
boost::hash_combine(result, this->getEntryCount());
boost::hash_combine(result, boost::hash_range(columnsAndValues.begin(), columnsAndValues.end()));
boost::hash_combine(result, boost::hash_range(rowIndications.begin(), rowIndications.end()));
boost::hash_combine(result, boost::hash_range(rowGroupIndices.begin(), rowGroupIndices.end()));
if (!this->hasTrivialRowGrouping()) {
boost::hash_combine(result, boost::hash_range(rowGroupIndices.get().begin(), rowGroupIndices.get().end()));
}
return result;
}

46
src/storage/SparseMatrix.h

@ -7,9 +7,11 @@
#include <vector>
#include <iterator>
#include <boost/functional/hash.hpp>
#include <boost/optional.hpp>
#include "src/utility/OsDetection.h"
#include "src/adapters/CarlAdapter.h"
#include <boost/functional/hash.hpp>
// Forward declaration for adapter classes.
namespace storm {
@ -259,7 +261,8 @@ namespace storm {
// The number of row groups in the matrix.
index_type initialRowGroupCount;
std::vector<index_type> rowGroupIndices;
// The vector that stores the row-group indices (if they are non-trivial).
boost::optional<std::vector<index_type>> rowGroupIndices;
// The storage for the columns and values of all entries in the matrix.
std::vector<MatrixEntry<index_type, value_type>> columnsAndValues;
@ -447,9 +450,8 @@ namespace storm {
* @param rowIndications The row indications vector of the matrix to be constructed.
* @param columnsAndValues The vector containing the columns and values of the entries in the matrix.
* @param rowGroupIndices The vector representing the row groups in the matrix.
* @param hasNontrivialRowGrouping If set to true, this indicates that the row grouping is non-trivial.
*/
SparseMatrix(index_type columnCount, std::vector<index_type> const& rowIndications, std::vector<MatrixEntry<index_type, value_type>> const& columnsAndValues, std::vector<index_type> const& rowGroupIndices, bool hasNontrivialRowGrouping);
SparseMatrix(index_type columnCount, std::vector<index_type> const& rowIndications, std::vector<MatrixEntry<index_type, value_type>> const& columnsAndValues, boost::optional<std::vector<index_type>> const& rowGroupIndices);
/*!
* Constructs a sparse matrix by moving the given contents.
@ -458,9 +460,8 @@ namespace storm {
* @param rowIndications The row indications vector of the matrix to be constructed.
* @param columnsAndValues The vector containing the columns and values of the entries in the matrix.
* @param rowGroupIndices The vector representing the row groups in the matrix.
* @param hasNontrivialRowGrouping If set to true, this indicates that the row grouping is non-trivial.
*/
SparseMatrix(index_type columnCount, std::vector<index_type>&& rowIndications, std::vector<MatrixEntry<index_type, value_type>>&& columnsAndValues, std::vector<index_type>&& rowGroupIndices, bool hasNontrivialRowGrouping);
SparseMatrix(index_type columnCount, std::vector<index_type>&& rowIndications, std::vector<MatrixEntry<index_type, value_type>>&& columnsAndValues, boost::optional<std::vector<index_type>>&& rowGroupIndices);
/*!
* Assigns the contents of the given matrix to the current one by deep-copying its contents.
@ -913,27 +914,28 @@ namespace storm {
iterator end();
/*!
* Retrieves whether the matrix has a (possibly) non-trivial row grouping.
* Retrieves whether the matrix has a trivial row grouping.
*
* @return True iff the matrix has a (possibly) non-trivial row grouping.
* @return True iff the matrix has a trivial row grouping.
*/
bool hasNontrivialRowGrouping() const;
bool hasTrivialRowGrouping() const;
/*!
* Returns a copy of the matrix with the chosen internal data type
*/
template<typename NewValueType>
SparseMatrix<NewValueType> toValueType() const {
std::vector<MatrixEntry<SparseMatrix::index_type, NewValueType>> new_columnsAndValues;
std::vector<SparseMatrix::index_type> new_rowIndications(rowIndications);
std::vector<SparseMatrix::index_type> new_rowGroupIndices(rowGroupIndices);
std::vector<MatrixEntry<SparseMatrix::index_type, NewValueType>> newColumnsAndValues;
std::vector<SparseMatrix::index_type> newRowIndications(rowIndications);
boost::optional<std::vector<SparseMatrix::index_type>> newRowGroupIndices(rowGroupIndices);
new_columnsAndValues.resize(columnsAndValues.size());
for (size_t i = 0, size = columnsAndValues.size(); i < size; ++i) {
new_columnsAndValues.at(i) = MatrixEntry<SparseMatrix::index_type, NewValueType>(columnsAndValues.at(i).getColumn(), static_cast<NewValueType>(columnsAndValues.at(i).getValue()));
}
newColumnsAndValues.resize(columnsAndValues.size());
std::transform(columnsAndValues.begin(), columnsAndValues.end(), newColumnsAndValues.begin(),
[] (MatrixEntry<SparseMatrix::index_type, ValueType> const& a) {
return MatrixEntry<SparseMatrix::index_type, NewValueType>(a.getColumn(), static_cast<NewValueType>(a.getValue()));
});
return SparseMatrix<NewValueType>(columnCount, std::move(new_rowIndications), std::move(new_columnsAndValues), std::move(new_rowGroupIndices), nontrivialRowGrouping);
return SparseMatrix<NewValueType>(columnCount, std::move(newRowIndications), std::move(newColumnsAndValues), std::move(newRowGroupIndices));
}
private:
@ -972,12 +974,12 @@ namespace storm {
// entry is not included anymore.
std::vector<index_type> rowIndications;
// A flag that indicates whether the matrix has a non-trivial row-grouping, i.e. (possibly) more than one
// row per row group.
bool nontrivialRowGrouping;
// A flag indicating whether the matrix has a trivial row grouping. Note that this may be true and yet
// there may be row group indices, because they were requested from the outside.
bool trivialRowGrouping;
// A vector indicating the row groups of the matrix.
std::vector<index_type> rowGroupIndices;
// A vector indicating the row groups of the matrix. This needs to be mutible in case we create it on-the-fly.
mutable boost::optional<std::vector<index_type>> rowGroupIndices;
};

Loading…
Cancel
Save