IT++ Logo
ls_solve.cpp
Go to the documentation of this file.
1 
29 #ifndef _MSC_VER
30 # include <itpp/config.h>
31 #else
32 # include <itpp/config_msvc.h>
33 #endif
34 
35 #if defined(HAVE_LAPACK)
36 # include <itpp/base/algebra/lapack.h>
37 #endif
38 
40 
41 
42 namespace itpp
43 {
44 
45 // ----------- ls_solve_chol -----------------------------------------------------------
46 
47 #if defined(HAVE_LAPACK)
48 
49 bool ls_solve_chol(const mat &A, const vec &b, vec &x)
50 {
51  int n, lda, ldb, nrhs, info;
52  n = lda = ldb = A.rows();
53  nrhs = 1;
54  char uplo = 'U';
55 
56  it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
57  it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
58 
59  ivec ipiv(n);
60  x = b;
61  mat Chol = A;
62 
63  dposv_(&uplo, &n, &nrhs, Chol._data(), &lda, x._data(), &ldb, &info);
64 
65  return (info == 0);
66 }
67 
68 
69 bool ls_solve_chol(const mat &A, const mat &B, mat &X)
70 {
71  int n, lda, ldb, nrhs, info;
72  n = lda = ldb = A.rows();
73  nrhs = B.cols();
74  char uplo = 'U';
75 
76  it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
77  it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
78 
79  ivec ipiv(n);
80  X = B;
81  mat Chol = A;
82 
83  dposv_(&uplo, &n, &nrhs, Chol._data(), &lda, X._data(), &ldb, &info);
84 
85  return (info == 0);
86 }
87 
88 bool ls_solve_chol(const cmat &A, const cvec &b, cvec &x)
89 {
90  int n, lda, ldb, nrhs, info;
91  n = lda = ldb = A.rows();
92  nrhs = 1;
93  char uplo = 'U';
94 
95  it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
96  it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
97 
98  ivec ipiv(n);
99  x = b;
100  cmat Chol = A;
101 
102  zposv_(&uplo, &n, &nrhs, Chol._data(), &lda, x._data(), &ldb, &info);
103 
104  return (info == 0);
105 }
106 
107 bool ls_solve_chol(const cmat &A, const cmat &B, cmat &X)
108 {
109  int n, lda, ldb, nrhs, info;
110  n = lda = ldb = A.rows();
111  nrhs = B.cols();
112  char uplo = 'U';
113 
114  it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
115  it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
116 
117  ivec ipiv(n);
118  X = B;
119  cmat Chol = A;
120 
121  zposv_(&uplo, &n, &nrhs, Chol._data(), &lda, X._data(), &ldb, &info);
122 
123  return (info == 0);
124 }
125 
126 #else
127 
128 bool ls_solve_chol(const mat &A, const vec &b, vec &x)
129 {
130  it_error("LAPACK library is needed to use ls_solve_chol() function");
131  return false;
132 }
133 
134 bool ls_solve_chol(const mat &A, const mat &B, mat &X)
135 {
136  it_error("LAPACK library is needed to use ls_solve_chol() function");
137  return false;
138 }
139 
140 bool ls_solve_chol(const cmat &A, const cvec &b, cvec &x)
141 {
142  it_error("LAPACK library is needed to use ls_solve_chol() function");
143  return false;
144 }
145 
146 bool ls_solve_chol(const cmat &A, const cmat &B, cmat &X)
147 {
148  it_error("LAPACK library is needed to use ls_solve_chol() function");
149  return false;
150 }
151 
152 #endif // HAVE_LAPACK
153 
154 vec ls_solve_chol(const mat &A, const vec &b)
155 {
156  vec x;
157  bool info;
158  info = ls_solve_chol(A, b, x);
159  it_assert_debug(info, "ls_solve_chol: Failed solving the system");
160  return x;
161 }
162 
163 mat ls_solve_chol(const mat &A, const mat &B)
164 {
165  mat X;
166  bool info;
167  info = ls_solve_chol(A, B, X);
168  it_assert_debug(info, "ls_solve_chol: Failed solving the system");
169  return X;
170 }
171 
172 cvec ls_solve_chol(const cmat &A, const cvec &b)
173 {
174  cvec x;
175  bool info;
176  info = ls_solve_chol(A, b, x);
177  it_assert_debug(info, "ls_solve_chol: Failed solving the system");
178  return x;
179 }
180 
181 cmat ls_solve_chol(const cmat &A, const cmat &B)
182 {
183  cmat X;
184  bool info;
185  info = ls_solve_chol(A, B, X);
186  it_assert_debug(info, "ls_solve_chol: Failed solving the system");
187  return X;
188 }
189 
190 
191 // --------- ls_solve ---------------------------------------------------------------
192 #if defined(HAVE_LAPACK)
193 
194 bool ls_solve(const mat &A, const vec &b, vec &x)
195 {
196  int n, lda, ldb, nrhs, info;
197  n = lda = ldb = A.rows();
198  nrhs = 1;
199 
200  it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
201  it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
202 
203  ivec ipiv(n);
204  x = b;
205  mat LU = A;
206 
207  dgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), x._data(), &ldb, &info);
208 
209  return (info == 0);
210 }
211 
212 bool ls_solve(const mat &A, const mat &B, mat &X)
213 {
214  int n, lda, ldb, nrhs, info;
215  n = lda = ldb = A.rows();
216  nrhs = B.cols();
217 
218  it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
219  it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
220 
221  ivec ipiv(n);
222  X = B;
223  mat LU = A;
224 
225  dgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), X._data(), &ldb, &info);
226 
227  return (info == 0);
228 }
229 
230 bool ls_solve(const cmat &A, const cvec &b, cvec &x)
231 {
232  int n, lda, ldb, nrhs, info;
233  n = lda = ldb = A.rows();
234  nrhs = 1;
235 
236  it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
237  it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
238 
239  ivec ipiv(n);
240  x = b;
241  cmat LU = A;
242 
243  zgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), x._data(), &ldb, &info);
244 
245  return (info == 0);
246 }
247 
248 bool ls_solve(const cmat &A, const cmat &B, cmat &X)
249 {
250  int n, lda, ldb, nrhs, info;
251  n = lda = ldb = A.rows();
252  nrhs = B.cols();
253 
254  it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
255  it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
256 
257  ivec ipiv(n);
258  X = B;
259  cmat LU = A;
260 
261  zgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), X._data(), &ldb, &info);
262 
263  return (info == 0);
264 }
265 
266 #else
267 
268 bool ls_solve(const mat &A, const vec &b, vec &x)
269 {
270  it_error("LAPACK library is needed to use ls_solve() function");
271  return false;
272 }
273 
274 bool ls_solve(const mat &A, const mat &B, mat &X)
275 {
276  it_error("LAPACK library is needed to use ls_solve() function");
277  return false;
278 }
279 
280 bool ls_solve(const cmat &A, const cvec &b, cvec &x)
281 {
282  it_error("LAPACK library is needed to use ls_solve() function");
283  return false;
284 }
285 
286 bool ls_solve(const cmat &A, const cmat &B, cmat &X)
287 {
288  it_error("LAPACK library is needed to use ls_solve() function");
289  return false;
290 }
291 
292 #endif // HAVE_LAPACK
293 
294 vec ls_solve(const mat &A, const vec &b)
295 {
296  vec x;
297  bool info;
298  info = ls_solve(A, b, x);
299  it_assert_debug(info, "ls_solve: Failed solving the system");
300  return x;
301 }
302 
303 mat ls_solve(const mat &A, const mat &B)
304 {
305  mat X;
306  bool info;
307  info = ls_solve(A, B, X);
308  it_assert_debug(info, "ls_solve: Failed solving the system");
309  return X;
310 }
311 
312 cvec ls_solve(const cmat &A, const cvec &b)
313 {
314  cvec x;
315  bool info;
316  info = ls_solve(A, b, x);
317  it_assert_debug(info, "ls_solve: Failed solving the system");
318  return x;
319 }
320 
321 cmat ls_solve(const cmat &A, const cmat &B)
322 {
323  cmat X;
324  bool info;
325  info = ls_solve(A, B, X);
326  it_assert_debug(info, "ls_solve: Failed solving the system");
327  return X;
328 }
329 
330 
331 // ----------------- ls_solve_od ------------------------------------------------------------------
332 #if defined(HAVE_LAPACK)
333 
334 bool ls_solve_od(const mat &A, const vec &b, vec &x)
335 {
336  int m, n, lda, ldb, nrhs, lwork, info;
337  char trans = 'N';
338  m = lda = ldb = A.rows();
339  n = A.cols();
340  nrhs = 1;
341  lwork = n + std::max(m, nrhs);
342 
343  it_assert_debug(m >= n, "The system is under-determined!");
344  it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
345 
346  vec work(lwork);
347  x = b;
348  mat QR = A;
349 
350  dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
351  x.set_size(n, true);
352 
353  return (info == 0);
354 }
355 
356 bool ls_solve_od(const mat &A, const mat &B, mat &X)
357 {
358  int m, n, lda, ldb, nrhs, lwork, info;
359  char trans = 'N';
360  m = lda = ldb = A.rows();
361  n = A.cols();
362  nrhs = B.cols();
363  lwork = n + std::max(m, nrhs);
364 
365  it_assert_debug(m >= n, "The system is under-determined!");
366  it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
367 
368  vec work(lwork);
369  X = B;
370  mat QR = A;
371 
372  dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
373  X.set_size(n, nrhs, true);
374 
375  return (info == 0);
376 }
377 
378 bool ls_solve_od(const cmat &A, const cvec &b, cvec &x)
379 {
380  int m, n, lda, ldb, nrhs, lwork, info;
381  char trans = 'N';
382  m = lda = ldb = A.rows();
383  n = A.cols();
384  nrhs = 1;
385  lwork = n + std::max(m, nrhs);
386 
387  it_assert_debug(m >= n, "The system is under-determined!");
388  it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
389 
390  cvec work(lwork);
391  x = b;
392  cmat QR = A;
393 
394  zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
395  x.set_size(n, true);
396 
397  return (info == 0);
398 }
399 
400 bool ls_solve_od(const cmat &A, const cmat &B, cmat &X)
401 {
402  int m, n, lda, ldb, nrhs, lwork, info;
403  char trans = 'N';
404  m = lda = ldb = A.rows();
405  n = A.cols();
406  nrhs = B.cols();
407  lwork = n + std::max(m, nrhs);
408 
409  it_assert_debug(m >= n, "The system is under-determined!");
410  it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
411 
412  cvec work(lwork);
413  X = B;
414  cmat QR = A;
415 
416  zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
417  X.set_size(n, nrhs, true);
418 
419  return (info == 0);
420 }
421 
422 #else
423 
424 bool ls_solve_od(const mat &A, const vec &b, vec &x)
425 {
426  it_error("LAPACK library is needed to use ls_solve_od() function");
427  return false;
428 }
429 
430 bool ls_solve_od(const mat &A, const mat &B, mat &X)
431 {
432  it_error("LAPACK library is needed to use ls_solve_od() function");
433  return false;
434 }
435 
436 bool ls_solve_od(const cmat &A, const cvec &b, cvec &x)
437 {
438  it_error("LAPACK library is needed to use ls_solve_od() function");
439  return false;
440 }
441 
442 bool ls_solve_od(const cmat &A, const cmat &B, cmat &X)
443 {
444  it_error("LAPACK library is needed to use ls_solve_od() function");
445  return false;
446 }
447 
448 #endif // HAVE_LAPACK
449 
450 vec ls_solve_od(const mat &A, const vec &b)
451 {
452  vec x;
453  bool info;
454  info = ls_solve_od(A, b, x);
455  it_assert_debug(info, "ls_solve_od: Failed solving the system");
456  return x;
457 }
458 
459 mat ls_solve_od(const mat &A, const mat &B)
460 {
461  mat X;
462  bool info;
463  info = ls_solve_od(A, B, X);
464  it_assert_debug(info, "ls_solve_od: Failed solving the system");
465  return X;
466 }
467 
468 cvec ls_solve_od(const cmat &A, const cvec &b)
469 {
470  cvec x;
471  bool info;
472  info = ls_solve_od(A, b, x);
473  it_assert_debug(info, "ls_solve_od: Failed solving the system");
474  return x;
475 }
476 
477 cmat ls_solve_od(const cmat &A, const cmat &B)
478 {
479  cmat X;
480  bool info;
481  info = ls_solve_od(A, B, X);
482  it_assert_debug(info, "ls_solve_od: Failed solving the system");
483  return X;
484 }
485 
486 // ------------------- ls_solve_ud -----------------------------------------------------------
487 #if defined(HAVE_LAPACK)
488 
489 bool ls_solve_ud(const mat &A, const vec &b, vec &x)
490 {
491  int m, n, lda, ldb, nrhs, lwork, info;
492  char trans = 'N';
493  m = lda = A.rows();
494  n = A.cols();
495  ldb = n;
496  nrhs = 1;
497  lwork = m + std::max(n, nrhs);
498 
499  it_assert_debug(m < n, "The system is over-determined!");
500  it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
501 
502  vec work(lwork);
503  x = b;
504  x.set_size(n, true);
505  mat QR = A;
506 
507  dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
508 
509  return (info == 0);
510 }
511 
512 bool ls_solve_ud(const mat &A, const mat &B, mat &X)
513 {
514  int m, n, lda, ldb, nrhs, lwork, info;
515  char trans = 'N';
516  m = lda = A.rows();
517  n = A.cols();
518  ldb = n;
519  nrhs = B.cols();
520  lwork = m + std::max(n, nrhs);
521 
522  it_assert_debug(m < n, "The system is over-determined!");
523  it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
524 
525  vec work(lwork);
526  X = B;
527  X.set_size(n, std::max(m, nrhs), true);
528  mat QR = A;
529 
530  dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
531  X.set_size(n, nrhs, true);
532 
533  return (info == 0);
534 }
535 
536 bool ls_solve_ud(const cmat &A, const cvec &b, cvec &x)
537 {
538  int m, n, lda, ldb, nrhs, lwork, info;
539  char trans = 'N';
540  m = lda = A.rows();
541  n = A.cols();
542  ldb = n;
543  nrhs = 1;
544  lwork = m + std::max(n, nrhs);
545 
546  it_assert_debug(m < n, "The system is over-determined!");
547  it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
548 
549  cvec work(lwork);
550  x = b;
551  x.set_size(n, true);
552  cmat QR = A;
553 
554  zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
555 
556  return (info == 0);
557 }
558 
559 bool ls_solve_ud(const cmat &A, const cmat &B, cmat &X)
560 {
561  int m, n, lda, ldb, nrhs, lwork, info;
562  char trans = 'N';
563  m = lda = A.rows();
564  n = A.cols();
565  ldb = n;
566  nrhs = B.cols();
567  lwork = m + std::max(n, nrhs);
568 
569  it_assert_debug(m < n, "The system is over-determined!");
570  it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
571 
572  cvec work(lwork);
573  X = B;
574  X.set_size(n, std::max(m, nrhs), true);
575  cmat QR = A;
576 
577  zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
578  X.set_size(n, nrhs, true);
579 
580  return (info == 0);
581 }
582 
583 #else
584 
585 bool ls_solve_ud(const mat &A, const vec &b, vec &x)
586 {
587  it_error("LAPACK library is needed to use ls_solve_ud() function");
588  return false;
589 }
590 
591 bool ls_solve_ud(const mat &A, const mat &B, mat &X)
592 {
593  it_error("LAPACK library is needed to use ls_solve_ud() function");
594  return false;
595 }
596 
597 bool ls_solve_ud(const cmat &A, const cvec &b, cvec &x)
598 {
599  it_error("LAPACK library is needed to use ls_solve_ud() function");
600  return false;
601 }
602 
603 bool ls_solve_ud(const cmat &A, const cmat &B, cmat &X)
604 {
605  it_error("LAPACK library is needed to use ls_solve_ud() function");
606  return false;
607 }
608 
609 #endif // HAVE_LAPACK
610 
611 
612 vec ls_solve_ud(const mat &A, const vec &b)
613 {
614  vec x;
615  bool info;
616  info = ls_solve_ud(A, b, x);
617  it_assert_debug(info, "ls_solve_ud: Failed solving the system");
618  return x;
619 }
620 
621 mat ls_solve_ud(const mat &A, const mat &B)
622 {
623  mat X;
624  bool info;
625  info = ls_solve_ud(A, B, X);
626  it_assert_debug(info, "ls_solve_ud: Failed solving the system");
627  return X;
628 }
629 
630 cvec ls_solve_ud(const cmat &A, const cvec &b)
631 {
632  cvec x;
633  bool info;
634  info = ls_solve_ud(A, b, x);
635  it_assert_debug(info, "ls_solve_ud: Failed solving the system");
636  return x;
637 }
638 
639 cmat ls_solve_ud(const cmat &A, const cmat &B)
640 {
641  cmat X;
642  bool info;
643  info = ls_solve_ud(A, B, X);
644  it_assert_debug(info, "ls_solve_ud: Failed solving the system");
645  return X;
646 }
647 
648 
649 // ---------------------- backslash -----------------------------------------
650 
651 bool backslash(const mat &A, const vec &b, vec &x)
652 {
653  int m = A.rows(), n = A.cols();
654  bool info;
655 
656  if (m == n)
657  info = ls_solve(A, b, x);
658  else if (m > n)
659  info = ls_solve_od(A, b, x);
660  else
661  info = ls_solve_ud(A, b, x);
662 
663  return info;
664 }
665 
666 
667 vec backslash(const mat &A, const vec &b)
668 {
669  vec x;
670  bool info;
671  info = backslash(A, b, x);
672  it_assert_debug(info, "backslash(): solution was not found");
673  return x;
674 }
675 
676 
677 bool backslash(const mat &A, const mat &B, mat &X)
678 {
679  int m = A.rows(), n = A.cols();
680  bool info;
681 
682  if (m == n)
683  info = ls_solve(A, B, X);
684  else if (m > n)
685  info = ls_solve_od(A, B, X);
686  else
687  info = ls_solve_ud(A, B, X);
688 
689  return info;
690 }
691 
692 
693 mat backslash(const mat &A, const mat &B)
694 {
695  mat X;
696  bool info;
697  info = backslash(A, B, X);
698  it_assert_debug(info, "backslash(): solution was not found");
699  return X;
700 }
701 
702 
703 bool backslash(const cmat &A, const cvec &b, cvec &x)
704 {
705  int m = A.rows(), n = A.cols();
706  bool info;
707 
708  if (m == n)
709  info = ls_solve(A, b, x);
710  else if (m > n)
711  info = ls_solve_od(A, b, x);
712  else
713  info = ls_solve_ud(A, b, x);
714 
715  return info;
716 }
717 
718 
719 cvec backslash(const cmat &A, const cvec &b)
720 {
721  cvec x;
722  bool info;
723  info = backslash(A, b, x);
724  it_assert_debug(info, "backslash(): solution was not found");
725  return x;
726 }
727 
728 
729 bool backslash(const cmat &A, const cmat &B, cmat &X)
730 {
731  int m = A.rows(), n = A.cols();
732  bool info;
733 
734  if (m == n)
735  info = ls_solve(A, B, X);
736  else if (m > n)
737  info = ls_solve_od(A, B, X);
738  else
739  info = ls_solve_ud(A, B, X);
740 
741  return info;
742 }
743 
744 cmat backslash(const cmat &A, const cmat &B)
745 {
746  cmat X;
747  bool info;
748  info = backslash(A, B, X);
749  it_assert_debug(info, "backslash(): solution was not found");
750  return X;
751 }
752 
753 
754 // --------------------------------------------------------------------------
755 
756 vec forward_substitution(const mat &L, const vec &b)
757 {
758  int n = L.rows();
759  vec x(n);
760 
761  forward_substitution(L, b, x);
762 
763  return x;
764 }
765 
766 void forward_substitution(const mat &L, const vec &b, vec &x)
767 {
768  it_assert(L.rows() == L.cols() && L.cols() == b.size() && b.size() == x.size(),
769  "forward_substitution: dimension mismatch");
770  int n = L.rows(), i, j;
771  double temp;
772 
773  x(0) = b(0) / L(0, 0);
774  for (i = 1;i < n;i++) {
775  // Should be: x(i)=((b(i)-L(i,i,0,i-1)*x(0,i-1))/L(i,i))(0); but this is to slow.
776  //i_pos=i*L._row_offset();
777  temp = 0;
778  for (j = 0; j < i; j++) {
779  temp += L._elem(i, j) * x(j);
780  //temp+=L._data()[i_pos+j]*x(j);
781  }
782  x(i) = (b(i) - temp) / L._elem(i, i);
783  //x(i)=(b(i)-temp)/L._data()[i_pos+i];
784  }
785 }
786 
787 vec forward_substitution(const mat &L, int p, const vec &b)
788 {
789  int n = L.rows();
790  vec x(n);
791 
792  forward_substitution(L, p, b, x);
793 
794  return x;
795 }
796 
797 void forward_substitution(const mat &L, int p, const vec &b, vec &x)
798 {
799  it_assert(L.rows() == L.cols() && L.cols() == b.size() && b.size() == x.size() && p <= L.rows() / 2,
800  "forward_substitution: dimension mismatch");
801  int n = L.rows(), i, j;
802 
803  x = b;
804 
805  for (j = 0;j < n;j++) {
806  x(j) /= L(j, j);
807  for (i = j + 1;i < std::min(j + p + 1, n);i++) {
808  x(i) -= L(i, j) * x(j);
809  }
810  }
811 }
812 
813 vec backward_substitution(const mat &U, const vec &b)
814 {
815  vec x(U.rows());
816  backward_substitution(U, b, x);
817 
818  return x;
819 }
820 
821 void backward_substitution(const mat &U, const vec &b, vec &x)
822 {
823  it_assert(U.rows() == U.cols() && U.cols() == b.size() && b.size() == x.size(),
824  "backward_substitution: dimension mismatch");
825  int n = U.rows(), i, j;
826  double temp;
827 
828  x(n - 1) = b(n - 1) / U(n - 1, n - 1);
829  for (i = n - 2; i >= 0; i--) {
830  // Should be: x(i)=((b(i)-U(i,i,i+1,n-1)*x(i+1,n-1))/U(i,i))(0); but this is too slow.
831  temp = 0;
832  //i_pos=i*U._row_offset();
833  for (j = i + 1; j < n; j++) {
834  temp += U._elem(i, j) * x(j);
835  //temp+=U._data()[i_pos+j]*x(j);
836  }
837  x(i) = (b(i) - temp) / U._elem(i, i);
838  //x(i)=(b(i)-temp)/U._data()[i_pos+i];
839  }
840 }
841 
842 vec backward_substitution(const mat &U, int q, const vec &b)
843 {
844  vec x(U.rows());
845  backward_substitution(U, q, b, x);
846 
847  return x;
848 }
849 
850 void backward_substitution(const mat &U, int q, const vec &b, vec &x)
851 {
852  it_assert(U.rows() == U.cols() && U.cols() == b.size() && b.size() == x.size() && q <= U.rows() / 2,
853  "backward_substitution: dimension mismatch");
854  int n = U.rows(), i, j;
855 
856  x = b;
857 
858  for (j = n - 1; j >= 0; j--) {
859  x(j) /= U(j, j);
860  for (i = std::max(0, j - q); i < j; i++) {
861  x(i) -= U(i, j) * x(j);
862  }
863  }
864 }
865 
866 } // namespace itpp
SourceForge Logo

Generated on Sat May 25 2013 16:32:18 for IT++ by Doxygen 1.8.2