You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

259 lines
13 KiB

#ifndef STORM_SOLVER_LINEAREQUATIONSOLVER_H_
#define STORM_SOLVER_LINEAREQUATIONSOLVER_H_
#include <vector>
#include <memory>
#include "storm/solver/AbstractEquationSolver.h"
#include "storm/solver/MultiplicationStyle.h"
#include "storm/solver/LinearEquationSolverProblemFormat.h"
#include "storm/solver/OptimizationDirection.h"
#include "storm/utility/VectorHelper.h"
#include "storm/storage/SparseMatrix.h"
namespace storm {
namespace solver {
enum class LinearEquationSolverOperation {
SolveEquations, MultiplyRepeatedly
};
/*!
* An interface that represents an abstract linear equation solver. In addition to solving a system of linear
* equations, the functionality to repeatedly multiply a matrix with a given vector is provided.
*/
template<class ValueType>
class LinearEquationSolver : public AbstractEquationSolver<ValueType> {
public:
LinearEquationSolver();
virtual ~LinearEquationSolver() {
// Intentionally left empty.
}
virtual void setMatrix(storm::storage::SparseMatrix<ValueType> const& A) = 0;
virtual void setMatrix(storm::storage::SparseMatrix<ValueType>&& A) = 0;
/*!
* If the solver expects the equation system format, it solves Ax = b. If it it expects a fixed point
* format, it solves Ax + b = x. In both versions, the matrix A is required to be square and the problem
* is required to have a unique solution. The solution will be written to the vector x. Note that the matrix
* A has to be given upon construction time of the solver object.
*
* @param x The solution vector that has to be computed. Its length must be equal to the number of rows of A.
* @param b The vector b. Its length must be equal to the number of rows of A.
*
* @return true
*/
virtual bool solveEquations(std::vector<ValueType>& x, std::vector<ValueType> const& b) const = 0;
/*!
* Performs on matrix-vector multiplication x' = A*x + b.
*
* @param x The input vector with which to multiply the matrix. Its length must be equal
* to the number of columns of A.
* @param b If non-null, this vector is added after the multiplication. If given, its length must be equal
* to the number of rows of A.
* @param result The target vector into which to write the multiplication result. Its length must be equal
* to the number of rows of A.
*/
virtual void multiply(std::vector<ValueType>& x, std::vector<ValueType> const* b, std::vector<ValueType>& result) const = 0;
/*!
* Performs on matrix-vector multiplication x' = A*x + b and then minimizes/maximizes over the row groups
* so that the resulting vector has the size of number of row groups of A.
*
* @param dir The direction for the reduction step.
* @param rowGroupIndices A vector storing the row groups over which to reduce.
* @param x The input vector with which to multiply the matrix. Its length must be equal
* to the number of columns of A.
* @param b If non-null, this vector is added after the multiplication. If given, its length must be equal
* to the number of rows of A.
* @param result The target vector into which to write the multiplication result. Its length must be equal
* to the number of rows of A.
* @param choices If given, the choices made in the reduction process are written to this vector.
*/
virtual void multiplyAndReduce(OptimizationDirection const& dir, std::vector<uint64_t> const& rowGroupIndices, std::vector<ValueType>& x, std::vector<ValueType> const* b, std::vector<ValueType>& result, std::vector<uint_fast64_t>* choices = nullptr) const;
/*!
* Retrieves whether this solver offers the gauss-seidel style multiplications.
*/
virtual bool supportsGaussSeidelMultiplication() const;
/*!
* Performs on matrix-vector multiplication x' = A*x + b. It does so in a gauss-seidel style, i.e. reusing
* the new x' components in the further multiplication.
*
* @param x The input vector with which to multiply the matrix. Its length must be equal
* to the number of columns of A.
* @param b If non-null, this vector is added after the multiplication. If given, its length must be equal
* to the number of rows of A.
*/
virtual void multiplyGaussSeidel(std::vector<ValueType>& x, std::vector<ValueType> const* b) const;
/*!
* Performs on matrix-vector multiplication x' = A*x + b and then minimizes/maximizes over the row groups
* so that the resulting vector has the size of number of row groups of A. It does so in a gauss-seidel
* style, i.e. reusing the new x' components in the further multiplication.
*
* @param dir The direction for the reduction step.
* @param rowGroupIndices A vector storing the row groups over which to reduce.
* @param x The input vector with which to multiply the matrix. Its length must be equal
* to the number of columns of A.
* @param b If non-null, this vector is added after the multiplication. If given, its length must be equal
* to the number of rows of A.
* @param choices If given, the choices made in the reduction process are written to this vector.
*/
virtual void multiplyAndReduceGaussSeidel(OptimizationDirection const& dir, std::vector<uint64_t> const& rowGroupIndices, std::vector<ValueType>& x, std::vector<ValueType> const* b, std::vector<uint_fast64_t>* choices = nullptr) const;
/*!
* Performs repeated matrix-vector multiplication, using x[0] = x and x[i + 1] = A*x[i] + b. After
* performing the necessary multiplications, the result is written to the input vector x. Note that the
* matrix A has to be given upon construction time of the solver object.
*
* @param x The initial vector with which to perform matrix-vector multiplication. Its length must be equal
* to the number of columns of A.
* @param b If non-null, this vector is added after each multiplication. If given, its length must be equal
* to the number of rows of A.
* @param n The number of times to perform the multiplication.
*/
void repeatedMultiply(std::vector<ValueType>& x, std::vector<ValueType> const* b, uint_fast64_t n) const;
/*!
* Retrieves the format in which this solver expects to solve equations. If the solver expects the equation
* system format, it solves Ax = b. If it it expects a fixed point format, it solves Ax + b = x.
*/
virtual LinearEquationSolverProblemFormat getEquationProblemFormat() const = 0;
/*!
* Sets whether some of the generated data during solver calls should be cached.
* This possibly increases the runtime of subsequent calls but also increases memory consumption.
*/
void setCachingEnabled(bool value) const;
/*!
* Retrieves whether some of the generated data during solver calls should be cached.
*/
bool isCachingEnabled() const;
/*
* Clears the currently cached data that has been stored during previous calls of the solver.
*/
virtual void clearCache() const;
/*!
* Sets a lower bound for the solution that can potentially be used by the solver.
*/
void setLowerBound(ValueType const& value);
/*!
* Sets an upper bound for the solution that can potentially be used by the solver.
*/
void setUpperBound(ValueType const& value);
/*!
* Sets bounds for the solution that can potentially be used by the solver.
*/
void setBounds(ValueType const& lower, ValueType const& upper);
protected:
// auxiliary storage. If set, this vector has getMatrixRowCount() entries.
mutable std::unique_ptr<std::vector<ValueType>> cachedRowVector;
// A lower bound if one was set.
boost::optional<ValueType> lowerBound;
// An upper bound if one was set.
boost::optional<ValueType> upperBound;
private:
/*!
* Retrieves the row count of the matrix associated with this solver.
*/
virtual uint64_t getMatrixRowCount() const = 0;
/*!
* Retrieves the column count of the matrix associated with this solver.
*/
virtual uint64_t getMatrixColumnCount() const = 0;
/// Whether some of the generated data during solver calls should be cached.
mutable bool cachingEnabled;
/// An object that can be used to reduce vectors.
storm::utility::VectorHelper<ValueType> vectorHelper;
};
template<typename ValueType>
class LinearEquationSolverFactory {
public:
/*!
* Creates a new linear equation solver instance with the given matrix.
*
* @param matrix The matrix that defines the equation system.
* @return A pointer to the newly created solver.
*/
std::unique_ptr<LinearEquationSolver<ValueType>> create(storm::storage::SparseMatrix<ValueType> const& matrix) const;
/*!
* Creates a new linear equation solver instance with the given matrix. The caller gives up posession of the
* matrix by calling this function.
*
* @param matrix The matrix that defines the equation system.
* @return A pointer to the newly created solver.
*/
std::unique_ptr<LinearEquationSolver<ValueType>> create(storm::storage::SparseMatrix<ValueType>&& matrix) const;
/*!
* Creates an equation solver with the current settings, but without a matrix.
*/
virtual std::unique_ptr<LinearEquationSolver<ValueType>> create() const = 0;
/*!
* Creates a copy of this factory.
*/
virtual std::unique_ptr<LinearEquationSolverFactory<ValueType>> clone() const = 0;
/*!
* Retrieves the problem format that the solver expects if it was created with the current settings.
*/
virtual LinearEquationSolverProblemFormat getEquationProblemFormat() const;
};
template<typename ValueType>
class GeneralLinearEquationSolverFactory : public LinearEquationSolverFactory<ValueType> {
public:
using LinearEquationSolverFactory<ValueType>::create;
virtual std::unique_ptr<LinearEquationSolver<ValueType>> create() const override;
virtual std::unique_ptr<LinearEquationSolverFactory<ValueType>> clone() const override;
};
#ifdef STORM_HAVE_CARL
template<>
class GeneralLinearEquationSolverFactory<storm::RationalNumber> : public LinearEquationSolverFactory<storm::RationalNumber> {
public:
using LinearEquationSolverFactory<storm::RationalNumber>::create;
virtual std::unique_ptr<LinearEquationSolver<storm::RationalNumber>> create() const override;
virtual std::unique_ptr<LinearEquationSolverFactory<storm::RationalNumber>> clone() const override;
};
template<>
class GeneralLinearEquationSolverFactory<storm::RationalFunction> : public LinearEquationSolverFactory<storm::RationalFunction> {
public:
using LinearEquationSolverFactory<storm::RationalFunction>::create;
virtual std::unique_ptr<LinearEquationSolver<storm::RationalFunction>> create() const override;
virtual std::unique_ptr<LinearEquationSolverFactory<storm::RationalFunction>> clone() const override;
};
#endif
} // namespace solver
} // namespace storm
#endif /* STORM_SOLVER_LINEAREQUATIONSOLVER_H_ */