From d27954622afc869a9e9142e8ed1176449926ccba Mon Sep 17 00:00:00 2001 From: dehnert Date: Sun, 10 Sep 2017 09:30:23 +0200 Subject: [PATCH] slightly changed handling of gauss-seidel invocations in linear equation solver --- .../modules/MinMaxEquationSolverSettings.cpp | 2 +- .../IterativeMinMaxLinearEquationSolver.cpp | 11 ++--- src/storm/solver/LinearEquationSolver.cpp | 27 +++++++----- src/storm/solver/LinearEquationSolver.h | 44 +++++++++++++------ src/storm/solver/MultiplicationStyle.cpp | 2 +- src/storm/solver/MultiplicationStyle.h | 2 +- .../solver/NativeLinearEquationSolver.cpp | 24 ++++++++-- src/storm/solver/NativeLinearEquationSolver.h | 3 ++ 8 files changed, 76 insertions(+), 39 deletions(-) diff --git a/src/storm/settings/modules/MinMaxEquationSolverSettings.cpp b/src/storm/settings/modules/MinMaxEquationSolverSettings.cpp index 2a5b812c7..6c6b43be2 100644 --- a/src/storm/settings/modules/MinMaxEquationSolverSettings.cpp +++ b/src/storm/settings/modules/MinMaxEquationSolverSettings.cpp @@ -99,7 +99,7 @@ namespace storm { storm::solver::MultiplicationStyle MinMaxEquationSolverSettings::getValueIterationMultiplicationStyle() const { std::string multiplicationStyleString = this->getOption(valueIterationMultiplicationStyleOptionName).getArgumentByName("name").getValueAsString(); if (multiplicationStyleString == "gaussseidel" || multiplicationStyleString == "gs") { - return storm::solver::MultiplicationStyle::AllowGaussSeidel; + return storm::solver::MultiplicationStyle::GaussSeidel; } else if (multiplicationStyleString == "regular" || multiplicationStyleString == "r") { return storm::solver::MultiplicationStyle::Regular; } diff --git a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp index 87cba84fd..606946402 100644 --- a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp +++ b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp @@ -273,9 +273,7 @@ namespace storm { } // Allow aliased multiplications. - MultiplicationStyle multiplicationStyle = settings.getValueIterationMultiplicationStyle(); - MultiplicationStyle oldMultiplicationStyle = this->linEqSolverA->getMultiplicationStyle(); - this->linEqSolverA->setMultiplicationStyle(multiplicationStyle); + bool useGaussSeidelMultiplication = this->linEqSolverA->supportsGaussSeidelMultiplication() && settings.getValueIterationMultiplicationStyle() == storm::solver::MultiplicationStyle::GaussSeidel; std::vector* newX = auxiliaryRowGroupVector.get(); std::vector* currentX = &x; @@ -286,10 +284,10 @@ namespace storm { Status status = Status::InProgress; while (status == Status::InProgress) { // Compute x' = min/max(A*x + b). - if (multiplicationStyle == MultiplicationStyle::AllowGaussSeidel) { + if (useGaussSeidelMultiplication) { // Copy over the current vector so we can modify it in-place. *newX = *currentX; - this->linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), *newX, &b, *newX); + this->linEqSolverA->multiplyAndReduceGaussSeidel(dir, this->A->getRowGroupIndices(), *newX, &b); } else { this->linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), *currentX, &b, *newX); } @@ -319,9 +317,6 @@ namespace storm { this->linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), x, &b, *currentX, &this->schedulerChoices.get()); } - // Restore whether aliased multiplications were allowed before. - this->linEqSolverA->setMultiplicationStyle(oldMultiplicationStyle); - if (!this->isCachingEnabled()) { clearCache(); } diff --git a/src/storm/solver/LinearEquationSolver.cpp b/src/storm/solver/LinearEquationSolver.cpp index 34b7a8734..2f7d60287 100644 --- a/src/storm/solver/LinearEquationSolver.cpp +++ b/src/storm/solver/LinearEquationSolver.cpp @@ -19,7 +19,7 @@ namespace storm { namespace solver { template - LinearEquationSolver::LinearEquationSolver() : cachingEnabled(false), multiplicationStyle(MultiplicationStyle::Regular) { + LinearEquationSolver::LinearEquationSolver() : cachingEnabled(false) { // Intentionally left empty. } @@ -86,6 +86,21 @@ namespace storm { } #endif + template + bool LinearEquationSolver::supportsGaussSeidelMultiplication() const { + return false; + } + + template + void LinearEquationSolver::multiplyGaussSeidel(std::vector& x, std::vector const* b) const { + STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "This solver does not support the function 'multiplyGaussSeidel'."); + } + + template + void LinearEquationSolver::multiplyAndReduceGaussSeidel(OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, std::vector* choices) const { + STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "This solver does not support the function 'multiplyAndReduceGaussSeidel'."); + } + template void LinearEquationSolver::setCachingEnabled(bool value) const { if(cachingEnabled && !value) { @@ -121,16 +136,6 @@ namespace storm { setUpperBound(upper); } - template - void LinearEquationSolver::setMultiplicationStyle(MultiplicationStyle multiplicationStyle) { - this->multiplicationStyle = multiplicationStyle; - } - - template - MultiplicationStyle LinearEquationSolver::getMultiplicationStyle() const { - return multiplicationStyle; - } - template std::unique_ptr> LinearEquationSolverFactory::create(storm::storage::SparseMatrix&& matrix) const { return create(matrix); diff --git a/src/storm/solver/LinearEquationSolver.h b/src/storm/solver/LinearEquationSolver.h index f32f0b3d7..47e82204a 100644 --- a/src/storm/solver/LinearEquationSolver.h +++ b/src/storm/solver/LinearEquationSolver.h @@ -73,6 +73,37 @@ namespace storm { */ virtual void multiplyAndReduce(OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, std::vector& result, std::vector* choices = nullptr) const; + /*! + * Retrieves whether this solver offers the gauss-seidel style multiplications. + */ + virtual bool supportsGaussSeidelMultiplication() const; + + /*! + * Performs on matrix-vector multiplication x' = A*x + b. It does so in a gauss-seidel style, i.e. reusing + * the new x' components in the further multiplication. + * + * @param x The input vector with which to multiply the matrix. Its length must be equal + * to the number of columns of A. + * @param b If non-null, this vector is added after the multiplication. If given, its length must be equal + * to the number of rows of A. + */ + virtual void multiplyGaussSeidel(std::vector& x, std::vector const* b) const; + + /*! + * Performs on matrix-vector multiplication x' = A*x + b and then minimizes/maximizes over the row groups + * so that the resulting vector has the size of number of row groups of A. It does so in a gauss-seidel + * style, i.e. reusing the new x' components in the further multiplication. + * + * @param dir The direction for the reduction step. + * @param rowGroupIndices A vector storing the row groups over which to reduce. + * @param x The input vector with which to multiply the matrix. Its length must be equal + * to the number of columns of A. + * @param b If non-null, this vector is added after the multiplication. If given, its length must be equal + * to the number of rows of A. + * @param choices If given, the choices made in the reduction process are written to this vector. + */ + virtual void multiplyAndReduceGaussSeidel(OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, std::vector* choices = nullptr) const; + /*! * Performs repeated matrix-vector multiplication, using x[0] = x and x[i + 1] = A*x[i] + b. After * performing the necessary multiplications, the result is written to the input vector x. Note that the @@ -117,16 +148,6 @@ namespace storm { */ void setBounds(ValueType const& lower, ValueType const& upper); - /*! - * Sets the multiplication style. - */ - void setMultiplicationStyle(MultiplicationStyle multiplicationStyle); - - /*! - * Retrieves whether vector aliasing in multiplication is allowed. - */ - MultiplicationStyle getMultiplicationStyle() const; - protected: // auxiliary storage. If set, this vector has getMatrixRowCount() entries. mutable std::unique_ptr> cachedRowVector; @@ -150,9 +171,6 @@ namespace storm { /// Whether some of the generated data during solver calls should be cached. mutable bool cachingEnabled; - - /// The multiplication style. - MultiplicationStyle multiplicationStyle; }; template diff --git a/src/storm/solver/MultiplicationStyle.cpp b/src/storm/solver/MultiplicationStyle.cpp index 68dfd6381..a6a447679 100644 --- a/src/storm/solver/MultiplicationStyle.cpp +++ b/src/storm/solver/MultiplicationStyle.cpp @@ -5,7 +5,7 @@ namespace storm { std::ostream& operator<<(std::ostream& out, MultiplicationStyle const& style) { switch (style) { - case MultiplicationStyle::AllowGaussSeidel: out << "Allow-Gauss-Seidel"; break; + case MultiplicationStyle::GaussSeidel: out << "Gauss-Seidel"; break; case MultiplicationStyle::Regular: out << "Regular"; break; } return out; diff --git a/src/storm/solver/MultiplicationStyle.h b/src/storm/solver/MultiplicationStyle.h index 950643f4a..db974d17b 100644 --- a/src/storm/solver/MultiplicationStyle.h +++ b/src/storm/solver/MultiplicationStyle.h @@ -5,7 +5,7 @@ namespace storm { namespace solver { - enum class MultiplicationStyle { AllowGaussSeidel, Regular }; + enum class MultiplicationStyle { GaussSeidel, Regular }; std::ostream& operator<<(std::ostream& out, MultiplicationStyle const& style); diff --git a/src/storm/solver/NativeLinearEquationSolver.cpp b/src/storm/solver/NativeLinearEquationSolver.cpp index ba7399bbe..170edc03c 100644 --- a/src/storm/solver/NativeLinearEquationSolver.cpp +++ b/src/storm/solver/NativeLinearEquationSolver.cpp @@ -206,7 +206,7 @@ namespace storm { template void NativeLinearEquationSolver::multiply(std::vector& x, std::vector const* b, std::vector& result) const { - if (&x != &result || this->getMultiplicationStyle() == MultiplicationStyle::AllowGaussSeidel) { + if (&x != &result) { A->multiplyWithVector(x, result, b); } else { // If the two vectors are aliases, we need to create a temporary. @@ -225,15 +225,15 @@ namespace storm { template void NativeLinearEquationSolver::multiplyAndReduce(OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, std::vector& result, std::vector* choices) const { - if (&x != &result || this->getMultiplicationStyle() == MultiplicationStyle::AllowGaussSeidel) { - A->multiplyAndReduce(dir, rowGroupIndices, x, b, result, choices, true); + if (&x != &result) { + A->multiplyAndReduce(dir, rowGroupIndices, x, b, result, choices); } else { // If the two vectors are aliases, we need to create a temporary. if (!this->cachedRowVector) { this->cachedRowVector = std::make_unique>(getMatrixRowCount()); } - this->A->multiplyAndReduce(dir, rowGroupIndices, x, b, *this->cachedRowVector, choices, false); + this->A->multiplyAndReduce(dir, rowGroupIndices, x, b, *this->cachedRowVector, choices); result.swap(*this->cachedRowVector); if (!this->isCachingEnabled()) { @@ -242,6 +242,22 @@ namespace storm { } } + template + bool NativeLinearEquationSolver::supportsGaussSeidelMultiplication() const { + return true; + } + + template + void NativeLinearEquationSolver::multiplyGaussSeidel(std::vector& x, std::vector const* b) const { + STORM_LOG_ASSERT(this->A->getRowCount() == this->A->getColumnCount(), "This function is only applicable for square matrices."); + A->multiplyWithVector(x, x, b, true, storm::storage::SparseMatrix::MultiplicationDirection::Backward); + } + + template + void NativeLinearEquationSolver::multiplyAndReduceGaussSeidel(OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, std::vector* choices) const { + A->multiplyAndReduce(dir, rowGroupIndices, x, b, x, choices, true, storm::storage::SparseMatrix::MultiplicationDirection::Backward); + } + template void NativeLinearEquationSolver::setSettings(NativeLinearEquationSolverSettings const& newSettings) { settings = newSettings; diff --git a/src/storm/solver/NativeLinearEquationSolver.h b/src/storm/solver/NativeLinearEquationSolver.h index b5de30ffc..11cb39bd1 100644 --- a/src/storm/solver/NativeLinearEquationSolver.h +++ b/src/storm/solver/NativeLinearEquationSolver.h @@ -52,6 +52,9 @@ namespace storm { virtual bool solveEquations(std::vector& x, std::vector const& b) const override; virtual void multiply(std::vector& x, std::vector const* b, std::vector& result) const override; virtual void multiplyAndReduce(OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, std::vector& result, std::vector* choices = nullptr) const override; + virtual bool supportsGaussSeidelMultiplication() const override; + virtual void multiplyGaussSeidel(std::vector& x, std::vector const* b) const override; + virtual void multiplyAndReduceGaussSeidel(OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, std::vector* choices = nullptr) const override; void setSettings(NativeLinearEquationSolverSettings const& newSettings); NativeLinearEquationSolverSettings const& getSettings() const;