diff --git a/src/storm/solver/AcyclicLinearEquationSolver.cpp b/src/storm/solver/AcyclicLinearEquationSolver.cpp index d5a2940dd..7d5396506 100644 --- a/src/storm/solver/AcyclicLinearEquationSolver.cpp +++ b/src/storm/solver/AcyclicLinearEquationSolver.cpp @@ -62,9 +62,11 @@ namespace storm { orderedMatrix = helper::createReorderedMatrix(*this->A, *rowOrdering, bFactors); this->multiplier = storm::solver::MultiplierFactory().create(env, *orderedMatrix); } - auxiliaryRowVector = std::vector(); + auxiliaryRowVector = std::vector(this->A->getRowCount()); + auxiliaryRowVector2 = std::vector(this->A->getRowCount()); } + std::vector* xPtr = &x; std::vector const* bPtr = &b; if (rowOrdering) { STORM_LOG_ASSERT(rowOrdering->size() == b.size(), "b-vector has unexpected size."); @@ -74,9 +76,16 @@ namespace storm { (*auxiliaryRowVector)[bFactor.first] *= bFactor.second; } bPtr = &auxiliaryRowVector.get(); + xPtr = &auxiliaryRowVector2.get(); } - this->multiplier->multiplyGaussSeidel(env, x, bPtr, true); + this->multiplier->multiplyGaussSeidel(env, *xPtr, bPtr, true); + + if (rowOrdering) { + for (uint64_t newRow = 0; newRow < x.size(); ++newRow) { + x[(*rowOrdering)[newRow]] = (*xPtr)[newRow]; + } + } if (!this->isCachingEnabled()) { this->clearCache(); @@ -103,6 +112,7 @@ namespace storm { orderedMatrix = boost::none; rowOrdering = boost::none; auxiliaryRowVector = boost::none; + auxiliaryRowVector2 = boost::none; bFactors.clear(); } diff --git a/src/storm/solver/AcyclicLinearEquationSolver.h b/src/storm/solver/AcyclicLinearEquationSolver.h index e16931d09..6a151936d 100644 --- a/src/storm/solver/AcyclicLinearEquationSolver.h +++ b/src/storm/solver/AcyclicLinearEquationSolver.h @@ -55,6 +55,8 @@ namespace storm { mutable boost::optional> rowOrdering; // A.rowGroupCount() entries // can be used if the entries in 'b' need to be reordered mutable boost::optional> auxiliaryRowVector; // A.rowCount() entries + // can be used if the entries in 'x' need to be reordered + mutable boost::optional> auxiliaryRowVector2; // A.rowCount() entries // contains factors applied to scale the entries of the 'b' vector mutable std::vector> bFactors; diff --git a/src/storm/solver/AcyclicMinMaxLinearEquationSolver.cpp b/src/storm/solver/AcyclicMinMaxLinearEquationSolver.cpp index 3bee1aa3a..a11fa137e 100644 --- a/src/storm/solver/AcyclicMinMaxLinearEquationSolver.cpp +++ b/src/storm/solver/AcyclicMinMaxLinearEquationSolver.cpp @@ -47,24 +47,41 @@ namespace storm { orderedMatrix = helper::createReorderedMatrix(*this->A, *rowGroupOrdering, bFactors); this->multiplier = storm::solver::MultiplierFactory().create(env, *orderedMatrix); } - auxiliaryRowVector = std::vector(); + auxiliaryRowVector = std::vector(this->A->getRowCount()); + auxiliaryRowGroupVector = std::vector(this->A->getRowGroupCount()); } + std::vector* xPtr = &x; std::vector const* bPtr = &b; if (rowGroupOrdering) { - STORM_LOG_ASSERT(rowGroupOrdering->size() == b.size(), "b-vector has unexpected size."); - auxiliaryRowVector->resize(b.size()); - storm::utility::vector::selectVectorValues(*auxiliaryRowVector, *rowGroupOrdering, b); + STORM_LOG_ASSERT(rowGroupOrdering->size() == x.size(), "x-vector has unexpected size."); + STORM_LOG_ASSERT(auxiliaryRowGroupVector->size() == x.size(), "x-vector has unexpected size."); + STORM_LOG_ASSERT(auxiliaryRowVector->size() == b.size(), "b-vector has unexpected size."); + for (uint64_t newGroupIndex = 0; newGroupIndex < x.size(); ++newGroupIndex) { + uint64_t newRow = orderedMatrix->getRowGroupIndices()[newGroupIndex]; + uint64_t newRowGroupEnd = orderedMatrix->getRowGroupIndices()[newGroupIndex + 1]; + uint64_t oldRow = this->A->getRowGroupIndices()[(*rowGroupOrdering)[newGroupIndex]]; + for (; newRow < newRowGroupEnd; ++newRow, ++oldRow) { + (*auxiliaryRowVector)[newRow] = b[oldRow]; + } + } for (auto const& bFactor : bFactors) { (*auxiliaryRowVector)[bFactor.first] *= bFactor.second; } + xPtr = &auxiliaryRowGroupVector.get(); bPtr = &auxiliaryRowVector.get(); } if (this->isTrackSchedulerSet()) { - this->multiplier->multiplyAndReduceGaussSeidel(env, dir, x, bPtr, &this->schedulerChoices.get(), true); + this->multiplier->multiplyAndReduceGaussSeidel(env, dir, *xPtr, bPtr, &this->schedulerChoices.get(), true); } else { - this->multiplier->multiplyAndReduceGaussSeidel(env, dir, x, bPtr, nullptr, true); + this->multiplier->multiplyAndReduceGaussSeidel(env, dir, *xPtr, bPtr, nullptr, true); + } + + if (rowGroupOrdering) { + for (uint64_t newGroupIndex = 0; newGroupIndex < x.size(); ++newGroupIndex) { + x[(*rowGroupOrdering)[newGroupIndex]] = (*xPtr)[newGroupIndex]; + } } if (!this->isCachingEnabled()) { @@ -87,6 +104,7 @@ namespace storm { orderedMatrix = boost::none; rowGroupOrdering = boost::none; auxiliaryRowVector = boost::none; + auxiliaryRowGroupVector = boost::none; bFactors.clear(); } diff --git a/src/storm/solver/AcyclicMinMaxLinearEquationSolver.h b/src/storm/solver/AcyclicMinMaxLinearEquationSolver.h index f199c98ab..682f14432 100644 --- a/src/storm/solver/AcyclicMinMaxLinearEquationSolver.h +++ b/src/storm/solver/AcyclicMinMaxLinearEquationSolver.h @@ -42,6 +42,8 @@ namespace storm { mutable boost::optional> rowGroupOrdering; // A.rowGroupCount() entries // can be used if the entries in 'b' need to be reordered mutable boost::optional> auxiliaryRowVector; // A.rowCount() entries + // can be used if the entries in 'x' need to be reordered + mutable boost::optional> auxiliaryRowGroupVector; // A.rowGroupCount() entries // contains factors applied to scale the entries of the 'b' vector mutable std::vector> bFactors; diff --git a/src/storm/solver/helper/AcyclicSolverHelper.cpp b/src/storm/solver/helper/AcyclicSolverHelper.cpp index 9bbdf6fbb..9fb0b944a 100644 --- a/src/storm/solver/helper/AcyclicSolverHelper.cpp +++ b/src/storm/solver/helper/AcyclicSolverHelper.cpp @@ -80,7 +80,7 @@ namespace storm { for (uint64_t newRowGroup = 0; newRowGroup < newToOrigIndexMap.size(); ++newRowGroup) { auto const& origRowGroup = newToOrigIndexMap[newRowGroup]; if (hasRowGrouping) { - builder.newRowGroup(newRowGroup); + builder.newRowGroup(newRow); } for (uint64_t origRow = matrix.getRowGroupIndices()[origRowGroup]; origRow < matrix.getRowGroupIndices()[origRowGroup + 1]; ++origRow) { for (auto const& entry : matrix.getRow(origRow)) {