Browse Source

get submatrix now can set some columns to zero but retain them

tempestpy_adaptions
Sebastian Junges 4 years ago
parent
commit
406ffbdc7f
  1. 12
      src/storm/storage/SparseMatrix.cpp
  2. 13
      src/storm/storage/SparseMatrix.h

12
src/storm/storage/SparseMatrix.cpp

@ -974,9 +974,9 @@ namespace storm {
} }
template<typename ValueType> template<typename ValueType>
SparseMatrix<ValueType> SparseMatrix<ValueType>::getSubmatrix(bool useGroups, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalElements) const {
SparseMatrix<ValueType> SparseMatrix<ValueType>::getSubmatrix(bool useGroups, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalElements, storm::storage::BitVector const& makeZeroColumns) const {
if (useGroups) { if (useGroups) {
return getSubmatrix(rowConstraint, columnConstraint, this->getRowGroupIndices(), insertDiagonalElements);
return getSubmatrix(rowConstraint, columnConstraint, this->getRowGroupIndices(), insertDiagonalElements, makeZeroColumns);
} else { } else {
// Create a fake row grouping to reduce this to a call to a more general method. // Create a fake row grouping to reduce this to a call to a more general method.
std::vector<index_type> fakeRowGroupIndices(rowCount + 1); std::vector<index_type> fakeRowGroupIndices(rowCount + 1);
@ -984,7 +984,7 @@ namespace storm {
for (std::vector<index_type>::iterator it = fakeRowGroupIndices.begin(); it != fakeRowGroupIndices.end(); ++it, ++i) { for (std::vector<index_type>::iterator it = fakeRowGroupIndices.begin(); it != fakeRowGroupIndices.end(); ++it, ++i) {
*it = i; *it = i;
} }
auto res = getSubmatrix(rowConstraint, columnConstraint, fakeRowGroupIndices, insertDiagonalElements);
auto res = getSubmatrix(rowConstraint, columnConstraint, fakeRowGroupIndices, insertDiagonalElements, makeZeroColumns);
// Create a new row grouping that reflects the new sizes of the row groups if the current matrix has a // Create a new row grouping that reflects the new sizes of the row groups if the current matrix has a
// non trivial row-grouping. // non trivial row-grouping.
@ -1014,7 +1014,7 @@ namespace storm {
} }
template<typename ValueType> template<typename ValueType>
SparseMatrix<ValueType> SparseMatrix<ValueType>::getSubmatrix(storm::storage::BitVector const& rowGroupConstraint, storm::storage::BitVector const& columnConstraint, std::vector<index_type> const& rowGroupIndices, bool insertDiagonalEntries) const {
SparseMatrix<ValueType> SparseMatrix<ValueType>::getSubmatrix(storm::storage::BitVector const& rowGroupConstraint, storm::storage::BitVector const& columnConstraint, std::vector<index_type> const& rowGroupIndices, bool insertDiagonalEntries, storm::storage::BitVector const& makeZeroColumns) const {
STORM_LOG_THROW(!rowGroupConstraint.empty() && !columnConstraint.empty(), storm::exceptions::InvalidArgumentException, "Cannot build empty submatrix."); STORM_LOG_THROW(!rowGroupConstraint.empty() && !columnConstraint.empty(), storm::exceptions::InvalidArgumentException, "Cannot build empty submatrix.");
uint_fast64_t submatrixColumnCount = columnConstraint.getNumberOfSetBits(); uint_fast64_t submatrixColumnCount = columnConstraint.getNumberOfSetBits();
@ -1037,7 +1037,7 @@ namespace storm {
bool foundDiagonalElement = false; bool foundDiagonalElement = false;
for (const_iterator it = this->begin(i), ite = this->end(i); it != ite; ++it) { for (const_iterator it = this->begin(i), ite = this->end(i); it != ite; ++it) {
if (columnConstraint.get(it->getColumn())) {
if (columnConstraint.get(it->getColumn()) && (makeZeroColumns.size() == 0 || !makeZeroColumns.get(it->getColumn())) ) {
++subEntries; ++subEntries;
if (columnBitsSetBeforeIndex[it->getColumn()] == rowBitsSetBeforeIndex[index]) { if (columnBitsSetBeforeIndex[it->getColumn()] == rowBitsSetBeforeIndex[index]) {
@ -1069,7 +1069,7 @@ namespace storm {
bool insertedDiagonalElement = false; bool insertedDiagonalElement = false;
for (const_iterator it = this->begin(i), ite = this->end(i); it != ite; ++it) { for (const_iterator it = this->begin(i), ite = this->end(i); it != ite; ++it) {
if (columnConstraint.get(it->getColumn())) {
if (columnConstraint.get(it->getColumn()) && (makeZeroColumns.size() == 0 || !makeZeroColumns.get(it->getColumn()))) {
if (columnBitsSetBeforeIndex[it->getColumn()] == rowBitsSetBeforeIndex[index]) { if (columnBitsSetBeforeIndex[it->getColumn()] == rowBitsSetBeforeIndex[index]) {
insertedDiagonalElement = true; insertedDiagonalElement = true;
} else if (insertDiagonalEntries && !insertedDiagonalElement && columnBitsSetBeforeIndex[it->getColumn()] > rowBitsSetBeforeIndex[index]) { } else if (insertDiagonalEntries && !insertedDiagonalElement && columnBitsSetBeforeIndex[it->getColumn()] > rowBitsSetBeforeIndex[index]) {

13
src/storm/storage/SparseMatrix.h

@ -1,5 +1,4 @@
#ifndef STORM_STORAGE_SPARSEMATRIX_H_
#define STORM_STORAGE_SPARSEMATRIX_H_
#pragma once
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
@ -10,6 +9,7 @@
#include <boost/functional/hash.hpp> #include <boost/functional/hash.hpp>
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include "storm/storage/BitVector.h"
#include "storm/solver/OptimizationDirection.h" #include "storm/solver/OptimizationDirection.h"
#include "storm/utility/OsDetection.h" #include "storm/utility/OsDetection.h"
@ -34,8 +34,7 @@ namespace storm {
namespace storm { namespace storm {
namespace storage { namespace storage {
class BitVector;
// Forward declare matrix class. // Forward declare matrix class.
template<typename T> template<typename T>
@ -719,7 +718,7 @@ namespace storm {
* @return A matrix corresponding to a submatrix of the current matrix in which only rows and columns given * @return A matrix corresponding to a submatrix of the current matrix in which only rows and columns given
* by the constraints are kept and all others are dropped. * by the constraints are kept and all others are dropped.
*/ */
SparseMatrix getSubmatrix(bool useGroups, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalEntries = false) const;
SparseMatrix getSubmatrix(bool useGroups, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalEntries = false, storm::storage::BitVector const& makeZeroColumns = storm::storage::BitVector()) const;
/*! /*!
* Restrict rows in grouped rows matrix. Ensures that the number of groups stays the same. * Restrict rows in grouped rows matrix. Ensures that the number of groups stays the same.
@ -1169,7 +1168,7 @@ namespace storm {
* @return A matrix corresponding to a submatrix of the current matrix in which only row groups and columns * @return A matrix corresponding to a submatrix of the current matrix in which only row groups and columns
* given by the row group constraint are kept and all others are dropped. * given by the row group constraint are kept and all others are dropped.
*/ */
SparseMatrix getSubmatrix(storm::storage::BitVector const& rowGroupConstraint, storm::storage::BitVector const& columnConstraint, std::vector<index_type> const& rowGroupIndices, bool insertDiagonalEntries = false) const;
SparseMatrix getSubmatrix(storm::storage::BitVector const& rowGroupConstraint, storm::storage::BitVector const& columnConstraint, std::vector<index_type> const& rowGroupIndices, bool insertDiagonalEntries = false, storm::storage::BitVector const& makeZeroColumns = storm::storage::BitVector()) const;
// The number of rows of the matrix. // The number of rows of the matrix.
index_type rowCount; index_type rowCount;
@ -1206,5 +1205,3 @@ namespace storm {
} // namespace storage } // namespace storage
} // namespace storm } // namespace storm
#endif // STORM_STORAGE_SPARSEMATRIX_H_
Loading…
Cancel
Save