You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

128 lines
4.3 KiB

#include <iostream>
#include <Eigen/Core>
#include <Eigen/Dense>
#include <Eigen/IterativeLinearSolvers>
#include <unsupported/Eigen/IterativeSolvers>
class MatrixReplacement;
using StormEigen::SparseMatrix;
namespace StormEigen {
namespace internal {
// MatrixReplacement looks-like a SparseMatrix, so let's inherits its traits:
template<>
struct traits<MatrixReplacement> : public StormEigen::internal::traits<StormEigen::SparseMatrix<double> >
{};
}
}
// Example of a matrix-free wrapper from a user type to Eigen's compatible type
// For the sake of simplicity, this example simply wrap a StormEigen::SparseMatrix.
class MatrixReplacement : public StormEigen::EigenBase<MatrixReplacement> {
public:
// Required typedefs, constants, and method:
typedef double Scalar;
typedef double RealScalar;
typedef int StorageIndex;
enum {
ColsAtCompileTime = StormEigen::Dynamic,
MaxColsAtCompileTime = StormEigen::Dynamic,
IsRowMajor = false
};
Index rows() const { return mp_mat->rows(); }
Index cols() const { return mp_mat->cols(); }
template<typename Rhs>
StormEigen::Product<MatrixReplacement,Rhs,StormEigen::AliasFreeProduct> operator*(const Eigen::MatrixBase<Rhs>& x) const {
return StormEigen::Product<MatrixReplacement,Rhs,StormEigen::AliasFreeProduct>(*this, x.derived());
}
// Custom API:
MatrixReplacement() : mp_mat(0) {}
void attachMyMatrix(const SparseMatrix<double> &mat) {
mp_mat = &mat;
}
const SparseMatrix<double> my_matrix() const { return *mp_mat; }
private:
const SparseMatrix<double> *mp_mat;
};
// Implementation of MatrixReplacement * StormEigen::DenseVector though a specialization of internal::generic_product_impl:
namespace StormEigen {
namespace internal {
template<typename Rhs>
struct generic_product_impl<MatrixReplacement, Rhs, SparseShape, DenseShape, GemvProduct> // GEMV stands for matrix-vector
: generic_product_impl_base<MatrixReplacement,Rhs,generic_product_impl<MatrixReplacement,Rhs> >
{
typedef typename Product<MatrixReplacement,Rhs>::Scalar Scalar;
template<typename Dest>
static void scaleAndAddTo(Dest& dst, const MatrixReplacement& lhs, const Rhs& rhs, const Scalar& alpha)
{
// This method should implement "dst += alpha * lhs * rhs" inplace,
// however, for iterative solvers, alpha is always equal to 1, so let's not bother about it.
assert(alpha==Scalar(1) && "scaling is not implemented");
// Here we could simply call dst.noalias() += lhs.my_matrix() * rhs,
// but let's do something fancier (and less efficient):
for(Index i=0; i<lhs.cols(); ++i)
dst += rhs(i) * lhs.my_matrix().col(i);
}
};
}
}
int main()
{
int n = 10;
StormEigen::SparseMatrix<double> S = StormEigen::MatrixXd::Random(n,n).sparseView(0.5,1);
S = S.transpose()*S;
MatrixReplacement A;
A.attachMyMatrix(S);
StormEigen::VectorXd b(n), x;
b.setRandom();
// Solve Ax = b using various iterative solver with matrix-free version:
{
StormEigen::ConjugateGradient<MatrixReplacement, StormEigen::Lower|Eigen::Upper, Eigen::IdentityPreconditioner> cg;
cg.compute(A);
x = cg.solve(b);
std::cout << "CG: #iterations: " << cg.iterations() << ", estimated error: " << cg.error() << std::endl;
}
{
StormEigen::BiCGSTAB<MatrixReplacement, StormEigen::IdentityPreconditioner> bicg;
bicg.compute(A);
x = bicg.solve(b);
std::cout << "BiCGSTAB: #iterations: " << bicg.iterations() << ", estimated error: " << bicg.error() << std::endl;
}
{
StormEigen::GMRES<MatrixReplacement, StormEigen::IdentityPreconditioner> gmres;
gmres.compute(A);
x = gmres.solve(b);
std::cout << "GMRES: #iterations: " << gmres.iterations() << ", estimated error: " << gmres.error() << std::endl;
}
{
StormEigen::DGMRES<MatrixReplacement, StormEigen::IdentityPreconditioner> gmres;
gmres.compute(A);
x = gmres.solve(b);
std::cout << "DGMRES: #iterations: " << gmres.iterations() << ", estimated error: " << gmres.error() << std::endl;
}
{
StormEigen::MINRES<MatrixReplacement, StormEigen::Lower|Eigen::Upper, Eigen::IdentityPreconditioner> minres;
minres.compute(A);
x = minres.solve(b);
std::cout << "MINRES: #iterations: " << minres.iterations() << ", estimated error: " << minres.error() << std::endl;
}
}