Browse Source

added a test case for SparseMatri::restrictRows and fixed it

tempestpy_adaptions
TimQu 8 years ago
parent
commit
35c9b58fda
  1. 37
      src/storm/storage/SparseMatrix.cpp
  2. 78
      src/test/storage/SparseMatrixTest.cpp

37
src/storm/storage/SparseMatrix.cpp

@ -919,32 +919,35 @@ namespace storm {
for (auto const& row : rowsToKeep) { for (auto const& row : rowsToKeep) {
entryCount += this->getRow(row).getNumberOfEntries(); entryCount += this->getRow(row).getNumberOfEntries();
} }
// Get the smallest row group index such that all row groups with at least this index are empty.
uint_fast64_t firstTrailingEmptyRowGroup = this->getRowGroupCount();
for (auto groupIndexIt = this->getRowGroupIndices().rbegin() + 1; groupIndexIt != this->getRowGroupIndices().rend(); ++groupIndexIt) {
if (rowsToKeep.getNextSetIndex(*groupIndexIt) != rowsToKeep.size()) {
break;
}
--firstTrailingEmptyRowGroup;
}
STORM_LOG_THROW(allowEmptyRowGroups || firstTrailingEmptyRowGroup == this->getRowGroupCount(), storm::exceptions::InvalidArgumentException, "Empty rows are not allowed, but row group " << firstTrailingEmptyRowGroup << " is empty.");
// build the matrix. The row grouping will always be considered as nontrivial. // build the matrix. The row grouping will always be considered as nontrivial.
SparseMatrixBuilder<ValueType> builder(rowsToKeep.getNumberOfSetBits(), this->getColumnCount(), entryCount, true, true, this->getRowGroupCount()); SparseMatrixBuilder<ValueType> builder(rowsToKeep.getNumberOfSetBits(), this->getColumnCount(), entryCount, true, true, this->getRowGroupCount());
uint_fast64_t newRow = 0; uint_fast64_t newRow = 0;
for (uint_fast64_t rowGroup = 0; rowGroup < this->getRowGroupCount(); ++rowGroup) {
for (uint_fast64_t rowGroup = 0; rowGroup < firstTrailingEmptyRowGroup; ++rowGroup) {
// Add a new row group
builder.newRowGroup(newRow); builder.newRowGroup(newRow);
bool rowGroupEmpty = true; bool rowGroupEmpty = true;
if (this->hasTrivialRowGrouping()) {
if (rowsToKeep.get(rowGroup)) {
rowGroupEmpty = false;
for (auto const& entry : this->getRow(rowGroup)) {
builder.addNextValue(newRow, entry.getColumn(), entry.getValue());
}
++newRow;
}
} else {
for (uint_fast64_t row = rowsToKeep.getNextSetIndex(this->getRowGroupIndices()[rowGroup]); row < this->getRowGroupIndices()[rowGroup + 1]; row = rowsToKeep.getNextSetIndex(row + 1)) {
rowGroupEmpty = false;
for (auto const& entry: this->getRow(row)) {
builder.addNextValue(newRow, entry.getColumn(), entry.getValue());
}
++newRow;
for (uint_fast64_t row = rowsToKeep.getNextSetIndex(this->getRowGroupIndices()[rowGroup]); row < this->getRowGroupIndices()[rowGroup + 1]; row = rowsToKeep.getNextSetIndex(row + 1)) {
rowGroupEmpty = false;
for (auto const& entry: this->getRow(row)) {
builder.addNextValue(newRow, entry.getColumn(), entry.getValue());
} }
++newRow;
} }
STORM_LOG_THROW(allowEmptyRowGroups || !rowGroupEmpty, storm::exceptions::InvalidArgumentException, "Empty rows are not allowed, but row group " << rowGroup << " is empty."); STORM_LOG_THROW(allowEmptyRowGroups || !rowGroupEmpty, storm::exceptions::InvalidArgumentException, "Empty rows are not allowed, but row group " << rowGroup << " is empty.");
} }
// The all remaining row groups will be empty. Note that it is not allowed to call builder.addNewGroup(...) if there are no more rows afterwards.
SparseMatrix<ValueType> res = builder.build(); SparseMatrix<ValueType> res = builder.build();
return res; return res;
} }

78
src/test/storage/SparseMatrixTest.cpp

@ -3,6 +3,7 @@
#include "storm/storage/BitVector.h" #include "storm/storage/BitVector.h"
#include "storm/exceptions/InvalidStateException.h" #include "storm/exceptions/InvalidStateException.h"
#include "storm/exceptions/OutOfRangeException.h" #include "storm/exceptions/OutOfRangeException.h"
#include "storm/exceptions/InvalidArgumentException.h"
TEST(SparseMatrixBuilder, CreationWithDimensions) { TEST(SparseMatrixBuilder, CreationWithDimensions) {
storm::storage::SparseMatrixBuilder<double> matrixBuilder(3, 4, 5); storm::storage::SparseMatrixBuilder<double> matrixBuilder(3, 4, 5);
@ -374,6 +375,83 @@ TEST(SparseMatrix, Submatrix) {
ASSERT_TRUE(matrix4 == matrix5); ASSERT_TRUE(matrix4 == matrix5);
} }
TEST(SparseMatrix, RestrictRows) {
storm::storage::SparseMatrixBuilder<double> matrixBuilder1(7, 4, 9, true, true, 3);
ASSERT_NO_THROW(matrixBuilder1.newRowGroup(0));
ASSERT_NO_THROW(matrixBuilder1.addNextValue(0, 1, 1.0));
ASSERT_NO_THROW(matrixBuilder1.addNextValue(0, 2, 1.2));
ASSERT_NO_THROW(matrixBuilder1.addNextValue(1, 0, 0.5));
ASSERT_NO_THROW(matrixBuilder1.addNextValue(1, 1, 0.7));
ASSERT_NO_THROW(matrixBuilder1.newRowGroup(2));
ASSERT_NO_THROW(matrixBuilder1.addNextValue(2, 0, 0.5));
ASSERT_NO_THROW(matrixBuilder1.addNextValue(3, 2, 1.1));
ASSERT_NO_THROW(matrixBuilder1.newRowGroup(4));
ASSERT_NO_THROW(matrixBuilder1.addNextValue(4, 0, 0.1));
ASSERT_NO_THROW(matrixBuilder1.addNextValue(4, 1, 0.2));
ASSERT_NO_THROW(matrixBuilder1.addNextValue(6, 3, 0.3));
storm::storage::SparseMatrix<double> matrix1;
ASSERT_NO_THROW(matrix1 = matrixBuilder1.build());
storm::storage::BitVector constraint1(7);
constraint1.set(0);
constraint1.set(1);
constraint1.set(2);
constraint1.set(5);
storm::storage::SparseMatrix<double> matrix1Prime;
ASSERT_NO_THROW(matrix1Prime = matrix1.restrictRows(constraint1));
storm::storage::SparseMatrixBuilder<double> matrixBuilder2(4, 4, 5, true, true, 3);
ASSERT_NO_THROW(matrixBuilder2.newRowGroup(0));
ASSERT_NO_THROW(matrixBuilder2.addNextValue(0, 1, 1.0));
ASSERT_NO_THROW(matrixBuilder2.addNextValue(0, 2, 1.2));
ASSERT_NO_THROW(matrixBuilder2.addNextValue(1, 0, 0.5));
ASSERT_NO_THROW(matrixBuilder2.addNextValue(1, 1, 0.7));
ASSERT_NO_THROW(matrixBuilder2.newRowGroup(2));
ASSERT_NO_THROW(matrixBuilder2.addNextValue(2, 0, 0.5));
ASSERT_NO_THROW(matrixBuilder2.newRowGroup(3));
storm::storage::SparseMatrix<double> matrix2;
ASSERT_NO_THROW(matrix2 = matrixBuilder2.build());
ASSERT_EQ(matrix2, matrix1Prime);
storm::storage::BitVector constraint2(4);
constraint2.set(1);
constraint2.set(2);
storm::storage::SparseMatrix<double> matrix2Prime;
ASSERT_THROW(matrix2Prime = matrix2.restrictRows(constraint2), storm::exceptions::InvalidArgumentException);
ASSERT_NO_THROW(matrix2Prime = matrix2.restrictRows(constraint2, true));
storm::storage::SparseMatrixBuilder<double> matrixBuilder3(2, 4, 3, true, true, 3);
ASSERT_NO_THROW(matrixBuilder3.newRowGroup(0));
ASSERT_NO_THROW(matrixBuilder3.addNextValue(0, 0, 0.5));
ASSERT_NO_THROW(matrixBuilder3.addNextValue(0, 1, 0.7));
ASSERT_NO_THROW(matrixBuilder3.newRowGroup(1));
ASSERT_NO_THROW(matrixBuilder3.addNextValue(1, 0, 0.5));
storm::storage::SparseMatrix<double> matrix3;
ASSERT_NO_THROW(matrix3 = matrixBuilder3.build());
ASSERT_EQ(matrix3, matrix2Prime);
matrix3.makeRowGroupingTrivial();
storm::storage::BitVector constraint3(2);
constraint3.set(1);
storm::storage::SparseMatrix<double> matrix3Prime;
ASSERT_THROW(matrix3Prime = matrix3.restrictRows(constraint3), storm::exceptions::InvalidArgumentException);
ASSERT_NO_THROW(matrix3Prime = matrix3.restrictRows(constraint3, true));
storm::storage::SparseMatrixBuilder<double> matrixBuilder4(1, 4, 1, true, true, 2);
ASSERT_NO_THROW(matrixBuilder4.newRowGroup(0));
ASSERT_NO_THROW(matrixBuilder4.newRowGroup(0));
ASSERT_NO_THROW(matrixBuilder4.addNextValue(0, 0, 0.5));
storm::storage::SparseMatrix<double> matrix4;
ASSERT_NO_THROW(matrix4 = matrixBuilder4.build());
ASSERT_EQ(matrix4, matrix3Prime);
}
TEST(SparseMatrix, Transpose) { TEST(SparseMatrix, Transpose) {
storm::storage::SparseMatrixBuilder<double> matrixBuilder(5, 4, 9); storm::storage::SparseMatrixBuilder<double> matrixBuilder(5, 4, 9);
ASSERT_NO_THROW(matrixBuilder.addNextValue(0, 1, 1.0)); ASSERT_NO_THROW(matrixBuilder.addNextValue(0, 1, 1.0));

Loading…
Cancel
Save