#include "storm/storage/FlexibleSparseMatrix.h" #include "storm/storage/SparseMatrix.h" #include "storm/storage/BitVector.h" #include "storm/adapters/RationalFunctionAdapter.h" #include "storm/utility/macros.h" #include "storm/utility/constants.h" #include "storm/exceptions/InvalidArgumentException.h" namespace storm { namespace storage { template FlexibleSparseMatrix::FlexibleSparseMatrix(index_type rows) : data(rows), columnCount(0), nonzeroEntryCount(0) { // Intentionally left empty. } template FlexibleSparseMatrix::FlexibleSparseMatrix(storm::storage::SparseMatrix const& matrix, bool setAllValuesToOne, bool revertEquationSystem) : data(matrix.getRowCount()), columnCount(matrix.getColumnCount()), nonzeroEntryCount(matrix.getNonzeroEntryCount()), trivialRowGrouping(matrix.hasTrivialRowGrouping()) { STORM_LOG_THROW(!revertEquationSystem || trivialRowGrouping, storm::exceptions::InvalidArgumentException, "Illegal option for creating flexible matrix."); if (!trivialRowGrouping) { rowGroupIndices = matrix.getRowGroupIndices(); } for (index_type rowIndex = 0; rowIndex < matrix.getRowCount(); ++rowIndex) { typename storm::storage::SparseMatrix::const_rows row = matrix.getRow(rowIndex); reserveInRow(rowIndex, row.getNumberOfEntries()); for (auto const& element : row) { // If the probability is zero, we skip this entry. if (storm::utility::isZero(element.getValue())) { if (revertEquationSystem && rowIndex == element.getColumn()) { getRow(rowIndex).emplace_back(element.getColumn(), storm::utility::one()); } else { continue; } } if (setAllValuesToOne) { if (revertEquationSystem && element.getColumn() == rowIndex && storm::utility::isOne(element.getValue())) { continue; } else { getRow(rowIndex).emplace_back(element.getColumn(), storm::utility::one()); } } else { if (revertEquationSystem) { if (element.getColumn() == rowIndex) { if (storm::utility::isOne(element.getValue())) { continue; } getRow(rowIndex).emplace_back(element.getColumn(), storm::utility::one() - element.getValue()); } else { getRow(rowIndex).emplace_back(element.getColumn(), -element.getValue()); } } else { getRow(rowIndex).emplace_back(element); } } } } } template void FlexibleSparseMatrix::reserveInRow(index_type row, index_type numberOfElements) { this->data[row].reserve(numberOfElements); } template typename FlexibleSparseMatrix::row_type& FlexibleSparseMatrix::getRow(index_type index) { return this->data[index]; } template typename FlexibleSparseMatrix::row_type const& FlexibleSparseMatrix::getRow(index_type index) const { return this->data[index]; } template typename FlexibleSparseMatrix::row_type& FlexibleSparseMatrix::getRow(index_type rowGroup, index_type offset) { STORM_LOG_ASSERT(rowGroup < this->getRowGroupCount(), "Invalid rowGroup."); STORM_LOG_ASSERT(offset < this->getRowGroupSize(rowGroup), "Invalid offset."); return getRow(rowGroupIndices[rowGroup] + offset); } template typename FlexibleSparseMatrix::row_type const& FlexibleSparseMatrix::getRow(index_type rowGroup, index_type offset) const { STORM_LOG_ASSERT(rowGroup < this->getRowGroupCount(), "Invalid rowGroup."); STORM_LOG_ASSERT(offset < this->getRowGroupSize(rowGroup), "Invalid offset."); return getRow(rowGroupIndices[rowGroup] + offset); } template std::vector::index_type> const& FlexibleSparseMatrix::getRowGroupIndices() const { return rowGroupIndices; } template typename FlexibleSparseMatrix::index_type FlexibleSparseMatrix::getRowCount() const { return this->data.size(); } template typename FlexibleSparseMatrix::index_type FlexibleSparseMatrix::getColumnCount() const { return columnCount; } template typename FlexibleSparseMatrix::index_type FlexibleSparseMatrix::getNonzeroEntryCount() const { return nonzeroEntryCount; } template typename FlexibleSparseMatrix::index_type FlexibleSparseMatrix::getRowGroupCount() const { return rowGroupIndices.size() - 1; } template typename FlexibleSparseMatrix::index_type FlexibleSparseMatrix::getRowGroupSize(index_type group) const { return rowGroupIndices[group + 1] - rowGroupIndices[group]; } template ValueType FlexibleSparseMatrix::getRowSum(index_type row) const { ValueType sum = storm::utility::zero(); for (auto const& element : getRow(row)) { sum += element.getValue(); } return sum; } template void FlexibleSparseMatrix::updateDimensions() { this->nonzeroEntryCount = 0; this->columnCount = 0; for (auto const& row : this->data) { for (auto const& element : row) { STORM_LOG_ASSERT(!storm::utility::isZero(element.getValue()), "Entry is 0."); ++this->nonzeroEntryCount; this->columnCount = std::max(element.getColumn() + 1, this->columnCount); } } } template bool FlexibleSparseMatrix::empty() const { for (auto const& row : this->data) { if (!row.empty()) { return false; } } return true; } template bool FlexibleSparseMatrix::hasTrivialRowGrouping() const { return trivialRowGrouping; } template void FlexibleSparseMatrix::filterEntries(storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint) { for (uint_fast64_t rowIndex = 0; rowIndex < this->data.size(); ++rowIndex) { auto& row = this->data[rowIndex]; if (!rowConstraint.get(rowIndex)) { row.clear(); row.shrink_to_fit(); continue; } row_type newRow; for (auto const& element : row) { if (columnConstraint.get(element.getColumn())) { newRow.push_back(element); } } row = std::move(newRow); } } template storm::storage::SparseMatrix FlexibleSparseMatrix::createSparseMatrix() { uint_fast64_t numEntries = 0; for (auto const& row : this->data) { numEntries += row.size(); } storm::storage::SparseMatrixBuilder matrixBuilder(getRowCount(), getColumnCount(), numEntries, hasTrivialRowGrouping(), hasTrivialRowGrouping() ? 0 : getRowGroupCount()); uint_fast64_t currRowIndex = 0; auto rowGroupIndexIt = getRowGroupIndices().begin(); for (auto const& row : this->data) { if(!hasTrivialRowGrouping()) { while (currRowIndex >= *rowGroupIndexIt) { matrixBuilder.newRowGroup(currRowIndex); ++rowGroupIndexIt; } } for (auto const& entry : row) { matrixBuilder.addNextValue(currRowIndex, entry.getColumn(), entry.getValue()); } ++currRowIndex; } // The matrix might end with one or more empty row groups if(!hasTrivialRowGrouping()) { while (currRowIndex >= *rowGroupIndexIt) { matrixBuilder.newRowGroup(currRowIndex); ++rowGroupIndexIt; } } return matrixBuilder.build(); } template storm::storage::SparseMatrix FlexibleSparseMatrix::createSparseMatrix(storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint) { uint_fast64_t numEntries = 0; for (auto const& rowIndex : rowConstraint) { auto const& row = data[rowIndex]; for(auto const& entry : row) { if (columnConstraint.get(entry.getColumn())) { ++numEntries; } } } uint_fast64_t numRowGroups = 0; if (!hasTrivialRowGrouping()) { auto lastRowGroupIndexIt = getRowGroupIndices().end() - 1; auto rowGroupIndexIt = getRowGroupIndices().begin(); while (rowGroupIndexIt != lastRowGroupIndexIt) { // Check whether the rowGroup will be nonempty if(rowConstraint.getNextSetIndex(*rowGroupIndexIt) < *(++rowGroupIndexIt)) { ++numRowGroups; } } } std::vector oldToNewColumnIndexMapping(getColumnCount(), getColumnCount()); uint_fast64_t newColumnIndex = 0; for (auto const& oldColumnIndex : columnConstraint) { oldToNewColumnIndexMapping[oldColumnIndex] = newColumnIndex++; } storm::storage::SparseMatrixBuilder matrixBuilder(rowConstraint.getNumberOfSetBits(), newColumnIndex, numEntries, true, !hasTrivialRowGrouping(), numRowGroups); uint_fast64_t currRowIndex = 0; auto rowGroupIndexIt = getRowGroupIndices().begin(); for (auto const& oldRowIndex : rowConstraint) { if(!hasTrivialRowGrouping() && oldRowIndex >= *rowGroupIndexIt) { matrixBuilder.newRowGroup(currRowIndex); // Skip empty row groups do { ++rowGroupIndexIt; } while (oldRowIndex >= *rowGroupIndexIt); } auto const& row = data[oldRowIndex]; for (auto const& entry : row) { if(columnConstraint.get(entry.getColumn())) { matrixBuilder.addNextValue(currRowIndex, oldToNewColumnIndexMapping[entry.getColumn()], entry.getValue()); } } ++currRowIndex; } return matrixBuilder.build(); } template bool FlexibleSparseMatrix::rowHasDiagonalElement(storm::storage::sparse::state_type state) { for (auto const& entry : this->getRow(state)) { if (entry.getColumn() < state) { continue; } else if (entry.getColumn() > state) { return false; } else if (entry.getColumn() == state) { return true; } } return false; } template std::ostream& FlexibleSparseMatrix::printRow(std::ostream& out, index_type const& rowIndex) const { index_type columnIndex = 0; row_type row = this->getRow(rowIndex); for (index_type column = 0; column < this->getColumnCount(); ++column) { if (columnIndex < row.size() && row[columnIndex].getColumn() == column) { // Insert entry out << row[columnIndex].getValue() << "\t"; ++columnIndex; } else { // Insert zero out << "0\t"; } } return out; } template std::ostream& operator<<(std::ostream& out, FlexibleSparseMatrix const& matrix) { typedef typename FlexibleSparseMatrix::index_type FlexibleIndex; // Print column numbers in header. out << "\t\t"; for (FlexibleIndex i = 0; i < matrix.getColumnCount(); ++i) { out << i << "\t"; } out << std::endl; if (!matrix.hasTrivialRowGrouping()) { // Iterate over all row groups FlexibleIndex rowGroupCount = matrix.getRowGroupCount(); for (FlexibleIndex rowGroup = 0; rowGroup < rowGroupCount; ++rowGroup) { out << "\t---- group " << rowGroup << "/" << (rowGroupCount - 1) << " ---- " << std::endl; FlexibleIndex endRow = matrix.rowGroupIndices[rowGroup + 1]; // Iterate over all rows. for (FlexibleIndex row = matrix.rowGroupIndices[rowGroup]; row < endRow; ++row) { // Print the actual row. out << rowGroup << "\t(\t"; matrix.printRow(out, row); out << "\t)\t" << rowGroup << std::endl; } } } else { // Iterate over all rows for (FlexibleIndex row = 0; row < matrix.getRowCount(); ++row) { // Print the actual row. out << row << "\t(\t"; matrix.printRow(out, row); out << "\t)\t" << row << std::endl; } } // Print column numbers in footer. out << "\t\t"; for (FlexibleIndex i = 0; i < matrix.getColumnCount(); ++i) { out << i << "\t"; } out << std::endl; return out; } // Explicitly instantiate the matrix. template class FlexibleSparseMatrix; template std::ostream& operator<<(std::ostream& out, FlexibleSparseMatrix const& matrix); #ifdef STORM_HAVE_CARL template class FlexibleSparseMatrix; template std::ostream& operator<<(std::ostream& out, FlexibleSparseMatrix const& matrix); template class FlexibleSparseMatrix; template std::ostream& operator<<(std::ostream& out, FlexibleSparseMatrix const& matrix); #endif } // namespace storage } // namespace storm