From da6333cead5eb0c2d1e6d001ffcc19427b847e65 Mon Sep 17 00:00:00 2001 From: Tim Quatmann Date: Thu, 1 Oct 2020 15:50:38 +0200 Subject: [PATCH] Fix in scheduler export for acyclic Min Max solver --- .../AcyclicMinMaxLinearEquationSolver.cpp | 41 +++++++++++++------ .../AcyclicMinMaxLinearEquationSolver.h | 2 + 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/storm/solver/AcyclicMinMaxLinearEquationSolver.cpp b/src/storm/solver/AcyclicMinMaxLinearEquationSolver.cpp index a11fa137e..d17e46528 100644 --- a/src/storm/solver/AcyclicMinMaxLinearEquationSolver.cpp +++ b/src/storm/solver/AcyclicMinMaxLinearEquationSolver.cpp @@ -27,15 +27,7 @@ namespace storm { bool AcyclicMinMaxLinearEquationSolver::internalSolveEquations(Environment const& env, OptimizationDirection dir, std::vector& x, std::vector const& b) const { STORM_LOG_ASSERT(x.size() == this->A->getRowGroupCount(), "Provided x-vector has invalid size."); STORM_LOG_ASSERT(b.size() == this->A->getRowCount(), "Provided b-vector has invalid size."); - // Allocate memory for the scheduler (if required) - if (this->isTrackSchedulerSet()) { - if (this->schedulerChoices) { - this->schedulerChoices->resize(this->A->getRowGroupCount()); - } else { - this->schedulerChoices = std::vector(this->A->getRowGroupCount()); - } - } - + if (!multiplier) { // We have not allocated cache memory, yet rowGroupOrdering = helper::computeTopologicalGroupOrdering(*this->A); @@ -72,16 +64,40 @@ namespace storm { bPtr = &auxiliaryRowVector.get(); } + // Allocate memory for the scheduler (if required) + std::vector* choicesPtr = nullptr; if (this->isTrackSchedulerSet()) { - this->multiplier->multiplyAndReduceGaussSeidel(env, dir, *xPtr, bPtr, &this->schedulerChoices.get(), true); - } else { - this->multiplier->multiplyAndReduceGaussSeidel(env, dir, *xPtr, bPtr, nullptr, true); + if (this->schedulerChoices) { + this->schedulerChoices->resize(this->A->getRowGroupCount()); + } else { + this->schedulerChoices = std::vector(this->A->getRowGroupCount()); + } + if (rowGroupOrdering) { + if (auxiliaryRowGroupIndexVector) { + auxiliaryRowGroupIndexVector->resize(this->A->getRowGroupCount()); + } else { + auxiliaryRowGroupIndexVector = std::vector(this->A->getRowGroupCount()); + } + choicesPtr = &(auxiliaryRowGroupIndexVector.get()); + } else { + choicesPtr = &(this->schedulerChoices.get()); + } } + // Since a topological ordering is guaranteed, we can solve the equations with a single matrix-vector Multiplication step. + this->multiplier->multiplyAndReduceGaussSeidel(env, dir, *xPtr, bPtr, choicesPtr, true); + if (rowGroupOrdering) { + // Restore the correct input-order for the output vector for (uint64_t newGroupIndex = 0; newGroupIndex < x.size(); ++newGroupIndex) { x[(*rowGroupOrdering)[newGroupIndex]] = (*xPtr)[newGroupIndex]; } + if (this->isTrackSchedulerSet()) { + // Do the same for the scheduler choices + for (uint64_t newGroupIndex = 0; newGroupIndex < x.size(); ++newGroupIndex) { + this->schedulerChoices.get()[(*rowGroupOrdering)[newGroupIndex]] = (*choicesPtr)[newGroupIndex]; + } + } } if (!this->isCachingEnabled()) { @@ -105,6 +121,7 @@ namespace storm { rowGroupOrdering = boost::none; auxiliaryRowVector = boost::none; auxiliaryRowGroupVector = boost::none; + auxiliaryRowGroupIndexVector = boost::none; bFactors.clear(); } diff --git a/src/storm/solver/AcyclicMinMaxLinearEquationSolver.h b/src/storm/solver/AcyclicMinMaxLinearEquationSolver.h index 682f14432..41c94c918 100644 --- a/src/storm/solver/AcyclicMinMaxLinearEquationSolver.h +++ b/src/storm/solver/AcyclicMinMaxLinearEquationSolver.h @@ -44,6 +44,8 @@ namespace storm { mutable boost::optional> auxiliaryRowVector; // A.rowCount() entries // can be used if the entries in 'x' need to be reordered mutable boost::optional> auxiliaryRowGroupVector; // A.rowGroupCount() entries + // can be used if the performed scheduler choices need to be reordered + mutable boost::optional> auxiliaryRowGroupIndexVector; // A.rowGroupCount() entries // contains factors applied to scale the entries of the 'b' vector mutable std::vector> bFactors;