NUMCXX  0.13.20181108
Numerical library for small projects and teaching purposes
tsolver-lapacklu.ixx
Go to the documentation of this file.
1 namespace numcxx
7 {
8  template<typename T>
9  inline TSolverLapackLU<T>::TSolverLapackLU(const std::shared_ptr<TMatrix<T>> pMatrix):
10  TLinSolver<T>(),
11  pMatrix(pMatrix),
12  pLU(pMatrix->clone()),
13  pIPiv(TArray1<int>::create(pMatrix->shape(0)))
14  {
15  update();;
16  }
17 
18  template<typename T>
20  TLinSolver<T>(),
21  pMatrix(0),
22  pLU(TMatrix<T>::create(Matrix.shape(0),Matrix.shape(1))),
23  pIPiv(TArray1<int>::create(Matrix.shape(0)))
24  {
25  update(Matrix);
26  }
27 
28  template<typename T>
29  inline std::shared_ptr<TSolverLapackLU<T>> TSolverLapackLU<T>::create(const std::shared_ptr<TMatrix<T>> a)
30  {
31  return std::make_shared<TSolverLapackLU<T>>(a);
32  }
33 
34  // Declarations for Fortran methos from LAPACK
35  extern "C"
36  {
37  void sgetrf_(int *n, int *m, float *a, int *lda, int* ipiv, int *info);
38  void sgetrs_(char *trans,int *n, const int *nrhs, float*a, int* lda, int *ipiv , float *b, int *ldb, int *info );
39  void sgetri_(int *n, float*a, int* lda, int *ipiv , float *work, int *lwork, int *info );
40  void dgetrf_(int *n, int *m, double *a, int *lda, int* ipiv, int *info);
41  void dgetrs_(char *trans,int *n, const int *nrhs, double*a, int* lda, int *ipiv , double *b, int *ldb, int *info );
42  void dgetri_(int *n, double*a, int* lda, int *ipiv , double *work, int *lwork, int *info );
43  }
44 
45  template<>
47  {
48  int n=pLU->shape(0);
49  int info=0;
50  *pLU=Matrix;
51  dgetrf_(&n,&n,pLU->data(),&n,pIPiv->data(),&info);
52  if (info!=0)
53  {
54  char errormsg[80];
55  snprintf(errormsg,80,"numcxx::TSolverLapackLU::update: dgetrf error %d\n",info);
56  throw std::runtime_error(errormsg);
57  }
58  }
59 
60  template<>
61  inline void TSolverLapackLU<double>::solve( TArray<double> & sol, const TArray<double> & rhs) const
62  {
63  assign(sol,rhs);
64  char trans[2]={'T','\0'};
65  int n=pLU->shape(0);
66  int one=1;
67  int info=0;
68  dgetrs_(trans,&n,&one,pLU->data(),&n,pIPiv->data(),sol.data(),&n,&info);
69  if (info!=0)
70  {
71  char errormsg[80];
72  snprintf(errormsg,80,"numcxx::TSolverLapackLU::update: dgetrs error %d\n",info);
73  throw std::runtime_error(errormsg);
74  }
75  }
76 
77 
78  template<>
80  {
81  int n=pLU->shape(0);
82  int info;
83  *pLU=Matrix;
84  sgetrf_(&n,&n,pLU->data(),&n,pIPiv->data(),&info);
85  if (info!=0)
86  {
87  char errormsg[80];
88  snprintf(errormsg,80,"numcxx::TSolverLapackLU::update: sgetrf error %d\n",info);
89  throw std::runtime_error(errormsg);
90  }
91 
92  }
93 
94  template<typename T>
96  {
97  if (pMatrix==nullptr)
98  throw std::runtime_error("numcxx: TSolverLapackLU created without smartpointer");
99  update(*pMatrix);
100  }
101 
102 
103  template<>
104  inline void TSolverLapackLU<float>::solve( TArray<float> & sol, const TArray<float> & rhs) const
105  {
106  assign(sol,rhs);
107  char trans[2]={'T','\0'};
108  int n=pLU->shape(0);
109  int one=1;
110  int info;
111  sgetrs_(trans,&n,&one,pLU->data(),&n,pIPiv->data(),sol.data(),&n,&info);
112  if (info!=0)
113  {
114  char errormsg[80];
115  snprintf(errormsg,80,"numcxx::TSolverLapackLU::update: sgetrs error %d\n",info);
116  throw std::runtime_error(errormsg);
117  }
118 
119  }
120 
121 
122 
123  template<>
124  inline std::shared_ptr<TMatrix<double>> TSolverLapackLU<double>::calculate_inverse()
125  {
126  int n=pLU->shape(0);
127  auto pInverse=std::make_shared<TMatrix<double>>(*pLU);
128  int info;
129  TArray1<double> Work(n);
130  dgetri_(&n,
131  pInverse->data(),
132  &n,
133  pIPiv->data(),
134  Work.data(),
135  &n,
136  &info);
137  if (info!=0)
138  {
139  char errormsg[80];
140  snprintf(errormsg,80,"numcxx::TSolverLapackLU::calculate_inverse: dgetri error %d\n",info);
141  throw std::runtime_error(errormsg);
142  }
143  return pInverse;
144  }
145 
146  template<>
147  inline std::shared_ptr<TMatrix<float>> TSolverLapackLU<float>::calculate_inverse()
148  {
149  int n=pLU->shape(0);
150  auto pInverse=std::make_shared<TMatrix<float>>(*pLU);
151  int info;
152  TArray1<float> Work(n);
153  sgetri_(&n,
154  pInverse->data(),
155  &n,
156  pIPiv->data(),
157  Work.data(),
158  &n,
159  &info);
160  if (info!=0)
161  {
162  char errormsg[80];
163  snprintf(errormsg,80,"numcxx::TSolverLapackLU::calculate_inverse: sgetri error %d\n",info);
164  throw std::runtime_error(errormsg);
165  }
166  return pInverse;
167  }
168 
169 
170 }
171 
const std::shared_ptr< TMatrix< T > > pMatrix
std::shared_ptr< TMatrix< T > > calculate_inverse()
Calculate inverse of matrix A from its LU factors.
static std::shared_ptr< TSolverLapackLU< T > > create(const std::shared_ptr< TMatrix< T >> pMatrix)
Static wrapper around constructor.
void dgetrs_(char *trans, int *n, const int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info)
void solve(TArray< T > &Sol, const TArray< T > &Rhs) const
Solve LU factorized system.
TArray is the common template base class for arrays and dense matrices of the numcxx project...
Definition: tarray.hxx:17
void update()
Perform computation of LU factorization using actual state of matrix.
Base class for linear solvers and preconditioners.
Definition: tarray.hxx:275
TSolverLapackLU()
Default constructor for swig.
One dimensional array class.
Definition: tarray1.hxx:31
Dense matrix class.
Definition: tmatrix.hxx:38
void dgetri_(int *n, double *a, int *lda, int *ipiv, double *work, int *lwork, int *info)
void dgetrf_(int *n, int *m, double *a, int *lda, int *ipiv, int *info)
Numcxx template library.
Definition: expression.ixx:41
const std::shared_ptr< TArray1< int > > pIPiv
TArray< T > & assign(TArray< T > &A, const EXPR &expr, const EXPR *x=0)
Definition: tarray.ixx:35
const std::shared_ptr< TMatrix< T > > pLU
void sgetrs_(char *trans, int *n, const int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info)
T * data() const
Obtain C-pointer of data array.
Definition: tarray.hxx:128
void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info)
void sgetrf_(int *n, int *m, float *a, int *lda, int *ipiv, int *info)