CoDiPack  2.3.0
A Code Differentiation Package
SciComp TU Kaiserslautern
Loading...
Searching...
No Matches
linearSystemHandler.hpp
1/*
2 * CoDiPack, a Code Differentiation Package
3 *
4 * Copyright (C) 2015-2024 Chair for Scientific Computing (SciComp), University of Kaiserslautern-Landau
5 * Homepage: http://scicomp.rptu.de
6 * Contact: Prof. Nicolas R. Gauger (codi@scicomp.uni-kl.de)
7 *
8 * Lead developers: Max Sagebaum, Johannes Blühdorn (SciComp, University of Kaiserslautern-Landau)
9 *
10 * This file is part of CoDiPack (http://scicomp.rptu.de/software/codi).
11 *
12 * CoDiPack is free software: you can redistribute it and/or
13 * modify it under the terms of the GNU General Public License
14 * as published by the Free Software Foundation, either version 3 of the
15 * License, or (at your option) any later version.
16 *
17 * CoDiPack is distributed in the hope that it will be useful,
18 * but WITHOUT ANY WARRANTY; without even the implied warranty
19 * of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
20 *
21 * See the GNU General Public License for more details.
22 * You should have received a copy of the GNU
23 * General Public License along with CoDiPack.
24 * If not, see <http://www.gnu.org/licenses/>.
25 *
26 * For other licensing options please contact us.
27 *
28 * Authors:
29 * - SciComp, University of Kaiserslautern-Landau:
30 * - Max Sagebaum
31 * - Johannes Blühdorn
32 * - Former members:
33 * - Tim Albring
34 */
35
36#pragma once
37
38#include <vector>
39
40#include "../../../config.h"
41#include "../../../expressions/lhsExpressionInterface.hpp"
42#include "../../../misc/exceptions.hpp"
43#include "../../../traits/tapeTraits.hpp"
44#include "../../data/direction.hpp"
45#include "linearSystemFlags.hpp"
46#include "linearSystemInterface.hpp"
47#include "linearSystemSpecializationDetection.hpp"
48
50namespace codi {
51
68 template<typename T_LinearSystem, typename = void>
70 public:
71
75 using Type = CODI_DD(typename LinearSystem::Type, CODI_DEFAULT_LHS_EXPRESSION);
76
77 using Matrix = typename LinearSystem::Matrix;
78 using MatrixReal = typename LinearSystem::MatrixReal;
79 using MatrixIdentifier = typename LinearSystem::MatrixIdentifier;
80 using Vector = typename LinearSystem::Vector;
81 using VectorReal = typename LinearSystem::VectorReal;
82 using VectorIdentifier = typename LinearSystem::VectorIdentifier;
83
84 private:
85
86 /*******************************************************************************/
87 // Additional definitions
88
89 using Real = typename Type::Real;
90 using Identifier = typename Type::Identifier;
91 using Gradient = typename Type::Gradient;
92
94 using Tape = CODI_DD(typename Type::Tape, CODI_DEFAULT_TAPE);
95
98
101
102 /*******************************************************************************/
103 // Implementation of functors for iterators.
104
106 struct VectorAccessFunctor {
107 public:
108
109 size_t dim;
110 VectorAccess* adjointInterface;
111
113 VectorAccessFunctor(size_t dim, VectorAccess* adjointInterface)
114 : dim(dim), adjointInterface(adjointInterface) {}
115 };
116
118 static void extract(Type const& value, Real& value_v, Identifier& value_id) {
119 value_v = value.getValue();
120 value_id = value.getIdentifier();
121 }
122
124 struct ExtractAdjoint : public VectorAccessFunctor {
125 public:
126 using VectorAccessFunctor::VectorAccessFunctor;
127
128 void operator()(Real& value_b, Identifier const& value_id) {
129 value_b = this->adjointInterface->getAdjoint(value_id, this->dim);
130 this->adjointInterface->resetAdjoint(value_id, this->dim);
131 }
132 };
133
135 static void getOutput(Type const& value, Real& value_v) {
136 value_v = value.getValue();
137 }
138
140 struct GetAdjoint : public VectorAccessFunctor {
141 public:
142 using VectorAccessFunctor::VectorAccessFunctor;
143
144 void operator()(Real& value_b, Identifier const& value_id) {
145 value_b = this->adjointInterface->getAdjoint(value_id, this->dim);
146 }
147 };
148
150 struct GetPrimal : public VectorAccessFunctor {
151 public:
152 using VectorAccessFunctor::VectorAccessFunctor;
153
154 void operator()(Real& value_v, Identifier const& value_id) {
155 value_v = this->adjointInterface->getPrimal(value_id);
156 }
157 };
158
160 struct GetPrimalAndGetAdjoint : public VectorAccessFunctor {
161 public:
162 using VectorAccessFunctor::VectorAccessFunctor;
163
164 void operator()(Real& value_v, Real& value_b, Identifier const& value_id) {
165 value_v = this->adjointInterface->getPrimal(value_id);
166 value_b = this->adjointInterface->getAdjoint(value_id, this->dim);
167 }
168 };
169
171 using GetPrimalAndGetTangent = GetPrimalAndGetAdjoint;
172
174 using GetTangent = GetAdjoint;
175
177 static Real registerOutput(Type& value, Real& value_v, Identifier& value_id) {
178 value = value_v;
179 Real oldTemp = Type::getTape().registerExternalFunctionOutput(value);
180 value_id = value.getIdentifier();
181
182 return oldTemp;
183 }
184
186 static void registerOutputWithPrimal(Type& value, Real& value_v, Identifier& value_id, Real& oldValue) {
187 oldValue = registerOutput(value, value_v, value_id);
188 }
189
191 static void setOutput(Type& value, Real const& value_v) {
192 value = value_v;
193 }
194
196 struct SetTangent : public VectorAccessFunctor {
197 public:
198 using VectorAccessFunctor::VectorAccessFunctor;
199
200 void operator()(Real& value_d, Identifier const& value_id) {
201 this->adjointInterface->resetAdjoint(value_id, this->dim);
202 this->adjointInterface->updateAdjoint(value_id, this->dim, value_d);
203 }
204 };
205
207 struct SetPrimal : public VectorAccessFunctor {
208 public:
209 using VectorAccessFunctor::VectorAccessFunctor;
210
211 void operator()(Real& value_v, Identifier const& value_id) {
212 this->adjointInterface->setPrimal(value_id, value_v);
213 }
214 };
215
217 struct SetPrimalAndSetTangent : public VectorAccessFunctor {
218 public:
219 using VectorAccessFunctor::VectorAccessFunctor;
220
221 void operator()(Real& value_v, Real& value_d, Identifier const& value_id) {
222 this->adjointInterface->setPrimal(value_id, value_v);
223 this->adjointInterface->resetAdjoint(value_id, this->dim);
224 this->adjointInterface->updateAdjoint(value_id, this->dim, value_d);
225 }
226 };
227
229 struct SetPrimalAndSetTangentAndUpdateOldPrimal : public VectorAccessFunctor {
230 public:
231 using VectorAccessFunctor::VectorAccessFunctor;
232
233 void operator()(Real& value_v, Real& value_d, Identifier const& value_id, Real& oldValue) {
234 oldValue = this->adjointInterface->getPrimal(value_id);
235 this->adjointInterface->setPrimal(value_id, value_v);
236 this->adjointInterface->resetAdjoint(value_id, this->dim);
237 this->adjointInterface->updateAdjoint(value_id, this->dim, value_d);
238 }
239 };
240
242 struct SetPrimalAndUpdateOldPrimals : public VectorAccessFunctor {
243 public:
244 using VectorAccessFunctor::VectorAccessFunctor;
245
246 void operator()(Real& value_v, Identifier const& value_id, Real& oldValue) {
247 oldValue = this->adjointInterface->getPrimal(value_id);
248 this->adjointInterface->setPrimal(value_id, value_v);
249 }
250 };
251
253 struct UpdateAdjoint : public VectorAccessFunctor {
254 public:
255 using VectorAccessFunctor::VectorAccessFunctor;
256
257 void operator()(Real& value_b, Identifier const& value_id) {
258 this->adjointInterface->updateAdjoint(value_id, this->dim, value_b);
259 }
260 };
261
263 struct UpdateAdjointDyadic : public VectorAccessFunctor {
264 public:
265 using VectorAccessFunctor::VectorAccessFunctor;
266
267 void operator()(Identifier& mat_id, Real const& x_v, Real const& b_b) {
268 Real adjoint = -x_v * b_b;
269 this->adjointInterface->updateAdjoint(mat_id, this->dim, adjoint);
270 }
271 };
272
273 /*******************************************************************************/
274 // Detection of constant properties
275
277 static bool constexpr IsPrimalValueTape = TapeTraits::IsPrimalValueTape<Tape>::value;
278 static bool constexpr StoreOldPrimals = IsPrimalValueTape & !Tape::LinearIndexHandling;
279
280 /*******************************************************************************/
281 // External function handle implementations
282
284 struct ExtFuncData {
285 MatrixReal* A_v;
286 MatrixReal* A_v_trans;
287 MatrixIdentifier* A_id;
288
289 VectorIdentifier* b_id;
290
291 VectorReal* x_v;
292 VectorIdentifier* x_id;
293
294 VectorReal* oldPrimals;
295
296 LinearSystem lsi;
298
299 ExtFuncData(LinearSystem lsi, LinearSystemSolverHints hints)
300 : A_v(NULL),
301 A_v_trans(NULL),
302 A_id(NULL),
303 b_id(NULL),
304 x_v(NULL),
305 x_id(NULL),
306 oldPrimals(NULL),
307 lsi(lsi),
308 hints(hints) {}
309
310 ~ExtFuncData() {
311 if (NULL != A_v) {
312 lsi.deleteMatrixReal(A_v);
313 }
314 if (NULL != A_v_trans) {
315 lsi.deleteMatrixReal(A_v_trans);
316 }
317 if (NULL != A_id) {
318 lsi.deleteMatrixIdentifier(A_id);
319 }
320 if (NULL != b_id) {
321 lsi.deleteVectorIdentifier(b_id);
322 }
323 if (NULL != x_v) {
324 lsi.deleteVectorReal(x_v);
325 }
326 if (NULL != x_id) {
327 lsi.deleteVectorIdentifier(x_id);
328 }
329 if (NULL != oldPrimals) {
330 lsi.deleteVectorReal(oldPrimals);
331 }
332 }
333 };
334
342 static void solve_b(Tape* tape, void* d, VectorAccess* adjointInterface) {
343 CODI_UNUSED(tape);
344
346 CODI_EXCEPTION("Missing functionality for linear system reverse mode. iterateDyadic(%d), transposeMatrix(%d)",
348 }
349
350 ExtFuncData* data = (ExtFuncData*)d;
351
352 if (!data->hints.test(LinearSystemSolverFlags::ReverseEvaluation)) {
353 CODI_EXCEPTION(
354 "Linear system reverse mode called without hint 'LinearSystemSolverFlags::ReverseEvaluation'.");
355 }
356
357 VectorReal* x_b = data->lsi.createVectorReal(data->x_id);
358 VectorReal* s = data->lsi.createVectorReal(data->b_id);
359
360 if (NULL != data->oldPrimals) {
361 data->lsi.iterateVector(SetPrimal(0, adjointInterface), data->oldPrimals, data->x_id);
362 }
363
364 size_t maxDim = adjointInterface->getVectorSize();
365 for (size_t curDim = 0; curDim < maxDim; curDim += 1) {
366 data->lsi.iterateVector(ExtractAdjoint(curDim, adjointInterface), x_b, data->x_id);
367
368 data->lsi.solveSystem(data->A_v_trans, x_b, s);
369
370 data->lsi.iterateDyadic(UpdateAdjointDyadic(curDim, adjointInterface), data->A_id, data->x_v, s);
371 data->lsi.iterateVector(UpdateAdjoint(curDim, adjointInterface), s, data->b_id);
372 }
373
374 data->lsi.deleteVectorReal(x_b);
375 data->lsi.deleteVectorReal(s);
376 }
377
383 static void solve_d(Tape* tape, void* d, VectorAccess* adjointInterface) {
384 CODI_UNUSED(tape);
385
386 ExtFuncData* data = (ExtFuncData*)d;
387
389 CODI_EXCEPTION("Missing functionality for linear system forward mode. subtractMultiply(%d)",
391 }
392 if (!data->hints.test(LinearSystemSolverFlags::ForwardEvaluation)) {
393 CODI_EXCEPTION(
394 "Linear system forward mode called without hint 'LinearSystemSolverFlags::ForwardEvaluation'.");
395 }
396
397 bool const updatePrimals =
398 IsPrimalValueTape && data->hints.test(LinearSystemSolverFlags::RecomputePrimalInForwardEvaluation);
399
400 MatrixReal* A_d = data->lsi.createMatrixReal(data->A_id);
401 VectorReal* b_v = data->lsi.createVectorReal(data->b_id); // b_v is also used as a temporary.
402 VectorReal* b_d = data->lsi.createVectorReal(data->b_id);
403 VectorReal* x_d = data->lsi.createVectorReal(data->x_id);
404
405 size_t maxDim = adjointInterface->getVectorSize();
406 for (size_t curDim = 0; curDim < maxDim; curDim += 1) {
407 if (0 == curDim && updatePrimals) {
408 data->lsi.iterateMatrix(GetPrimalAndGetTangent(curDim, adjointInterface), data->A_v, A_d, data->A_id);
409 data->lsi.iterateVector(GetPrimalAndGetTangent(curDim, adjointInterface), b_v, b_d, data->b_id);
410 } else {
411 data->lsi.iterateMatrix(GetTangent(curDim, adjointInterface), A_d, data->A_id);
412 data->lsi.iterateVector(GetTangent(curDim, adjointInterface), b_d, data->b_id);
413 }
414
415 if (0 == curDim && updatePrimals) { // Solve primal system only once and transposed setup only once.
416
417 if (NULL != data->A_v_trans) {
418 // Only renew A_v_trans if it already exists.
419 data->lsi.deleteMatrixReal(data->A_v_trans);
420 data->A_v_trans = data->lsi.transposeMatrix(data->A_v);
421 }
422
423 data->lsi.solveSystem(data->A_v, b_v, data->x_v);
424 }
425
426 data->lsi.subtractMultiply(b_v, b_d, A_d, data->x_v); // Use of b_v as temporary.
427
428 data->lsi.solveSystem(data->A_v, b_v /* temporary */, x_d);
429
430 if (updatePrimals) {
431 if (NULL != data->oldPrimals) {
432 data->lsi.iterateVector(SetPrimalAndSetTangentAndUpdateOldPrimal(curDim, adjointInterface), data->x_v,
433 x_d, data->x_id, data->oldPrimals);
434 } else {
435 data->lsi.iterateVector(SetPrimalAndSetTangent(curDim, adjointInterface), data->x_v, x_d, data->x_id);
436 }
437 } else {
438 data->lsi.iterateVector(SetTangent(curDim, adjointInterface), x_d, data->x_id);
439 }
440 }
441
442 data->lsi.deleteMatrixReal(A_d);
443 data->lsi.deleteVectorReal(b_v);
444 data->lsi.deleteVectorReal(b_d);
445 data->lsi.deleteVectorReal(x_d);
446 }
447
452 static void solve_p(Tape* tape, void* d, VectorAccess* adjointInterface) {
453 CODI_UNUSED(tape);
454
455 ExtFuncData* data = (ExtFuncData*)d;
456
457 if (!data->hints.test(LinearSystemSolverFlags::PrimalEvaluation)) {
458 CODI_EXCEPTION("Linear system primal mode called without hint 'LinearSystemSolverFlags::PrimalEvaluation'.");
459 }
460
461 VectorReal* b_v = data->lsi.createVectorReal(data->b_id);
462
463 data->lsi.iterateMatrix(GetPrimal(0, adjointInterface), data->A_v, data->A_id);
464 data->lsi.iterateVector(GetPrimal(0, adjointInterface), b_v, data->b_id);
465
466 data->lsi.solveSystem(data->A_v, b_v, data->x_v);
467
468 if (NULL != data->A_v_trans) {
469 // Only renew trans if it already exists.
470
471 data->lsi.deleteMatrixReal(data->A_v_trans);
472 data->A_v_trans = data->lsi.transposeMatrix(data->A_v);
473 }
474
475 if (NULL != data->oldPrimals) {
476 data->lsi.iterateVector(SetPrimalAndUpdateOldPrimals(0, adjointInterface), data->x_v, data->x_id,
477 data->oldPrimals);
478 } else {
479 data->lsi.iterateVector(SetPrimal(0, adjointInterface), data->x_v, data->x_id);
480 }
481
482 data->lsi.deleteVectorReal(b_v);
483 }
484
485 static void deleteData(Tape* tape, void* d) {
486 CODI_UNUSED(tape);
487
488 ExtFuncData* data = (ExtFuncData*)d;
489 delete data;
490 }
491
492 public:
493
500 Tape& tape = Type::getTape();
501
502 MatrixReal* A_v = lsi.createMatrixReal(A);
503 MatrixIdentifier* A_id = lsi.createMatrixIdentifier(A);
504 VectorReal* b_v = lsi.createVectorReal(b);
505 VectorIdentifier* b_id = lsi.createVectorIdentifier(b);
506 VectorReal* x_v = lsi.createVectorReal(x);
507 VectorIdentifier* x_id = lsi.createVectorIdentifier(x);
508
509 lsi.iterateMatrix(extract, A, A_v, A_id);
510 lsi.iterateVector(extract, b, b_v, b_id);
511
512 if (hints.test(LinearSystemSolverFlags::ProvidePrimalSolution)) {
513 lsi.iterateVector(getOutput, x, x_v);
514 }
515
517 lsi.solveSystemPrimal(A_v, b_v, x_v);
518 } else {
519 lsi.solveSystem(A_v, b_v, x_v);
520 }
521
522 if (tape.isActive()) {
523 MatrixReal* A_v_trans = NULL;
524 if (hints.test(LinearSystemSolverFlags::ReverseEvaluation)) {
525 A_v_trans = lsi.transposeMatrix(A_v);
526 }
527
528 VectorReal* oldPrimals = NULL;
529 if (StoreOldPrimals && hints.test(LinearSystemSolverFlags::ReverseEvaluation)) {
530 oldPrimals = b_v; // Reuse b_v here for the primal value handling
531 lsi.iterateVector(registerOutputWithPrimal, x, x_v, x_id, oldPrimals);
532 b_v = NULL; // Do not delete b_v
533 } else {
534 lsi.iterateVector(registerOutput, x, x_v, x_id);
535 }
536
537 ExtFuncData* data = new ExtFuncData(lsi, hints);
538 if (hints.test(LinearSystemSolverFlags::ForwardEvaluation) ||
539 hints.test(LinearSystemSolverFlags::PrimalEvaluation)) {
540 data->A_v = A_v;
541 A_v = NULL; // Do not delete A_v
542 }
543 data->A_v_trans = A_v_trans;
544 data->A_id = A_id;
545 data->b_id = b_id;
546 data->x_v = x_v;
547 data->x_id = x_id;
548 data->oldPrimals = oldPrimals;
549
550 tape.pushExternalFunction(ExternalFunction<Tape>::create(solve_b, data, deleteData, solve_d, solve_p));
551
552 if (b_v != NULL) {
553 lsi.deleteVectorReal(b_v);
554 }
555 if (A_v != NULL) {
556 lsi.deleteMatrixReal(A_v);
557 }
558 } else {
559 lsi.iterateVector(setOutput, x, x_v);
560
561 lsi.deleteMatrixReal(A_v);
562 lsi.deleteMatrixIdentifier(A_id);
563 lsi.deleteVectorReal(b_v);
564 lsi.deleteVectorIdentifier(b_id);
565 lsi.deleteVectorReal(x_v);
566 lsi.deleteVectorIdentifier(x_id);
567 }
568 }
569 };
570
571#ifndef DOXYGEN_DISABLE
574 template<typename T_LinearSystem>
575 struct LinearSystemSolverHandler<T_LinearSystem, RealTraits::EnableIfPassiveReal<typename T_LinearSystem::Type>> {
576 public:
577
579 using LinearSystem = CODI_DD(T_LinearSystem, CODI_T(LinearSystemInterface<LinearSystemInterfaceTypes>));
580 using Matrix = typename LinearSystem::Matrix;
581 using Vector = typename LinearSystem::Vector;
582
583 private:
584
585 using Overloads = LinearSystemSpecializationDetection<LinearSystem>;
586
587 public:
588
593 void solve(LinearSystem lsi, Matrix* A, Vector* b, Vector* x, LinearSystemSolverHints hints) {
594 CODI_UNUSED(hints);
595
597 lsi.solveSystemPrimal(A, b, x);
598 } else {
599 lsi.solveSystem(A, b, x);
600 }
601 }
602 };
603
606 template<typename T_LinearSystem>
607 struct LinearSystemSolverHandler<T_LinearSystem,
608 TapeTraits::EnableIfForwardTape<typename T_LinearSystem::Type::Tape>> {
609 public:
610
611 using LinearSystem =
612 CODI_DD(T_LinearSystem,
613 CODI_T(LinearSystemInterface<LinearSystemInterfaceTypes>));
614
615 using Type = CODI_DD(typename LinearSystem::Type,
616 CODI_DEFAULT_LHS_EXPRESSION);
617
618 using Matrix = typename LinearSystem::Matrix;
619 using MatrixReal = typename LinearSystem::MatrixReal;
620 using MatrixIdentifier = typename LinearSystem::MatrixIdentifier;
621 using Vector = typename LinearSystem::Vector;
622 using VectorReal = typename LinearSystem::VectorReal;
623 using VectorIdentifier = typename LinearSystem::VectorIdentifier;
624
625 private:
626
627 /*******************************************************************************/
628 // Additional definitions
629
630 using Real = typename Type::Real;
631 using Identifier = typename Type::Identifier;
632 using Gradient = typename Type::Gradient;
633
635 using Overloads = LinearSystemSpecializationDetection<LinearSystem>;
636
637 /*******************************************************************************/
638 // Implementation of functors for iterators.
639
641 struct DimFunctor {
642 size_t dim;
643
644 DimFunctor(size_t dim) : dim(dim) {}
645 };
646
648 static void getOutput(Type& value, Real& value_v) {
649 value_v = value.getValue();
650 }
651
653 struct GetPrimalAndGetTangent : public DimFunctor {
654 public:
655 using DimFunctor::DimFunctor;
656
657 void operator()(Type const& value, Real& value_v, Real& value_d) {
658 value_v = value.getValue();
659 value_d = GradientTraits::at(value.getGradient(), this->dim);
660 }
661 };
662
664 struct GetTangent : public DimFunctor {
665 public:
666 using DimFunctor::DimFunctor;
667
668 void operator()(Type const& value, Real& value_d) {
669 value_d = GradientTraits::at(value.getGradient(), this->dim);
670 }
671 };
672
674 struct SetPrimalAndSetTangent : public DimFunctor {
675 public:
676 using DimFunctor::DimFunctor;
677
678 void operator()(Type& value, Real const& value_v, Real const& value_d) {
679 value.value() = value_v;
680 GradientTraits::at(value.gradient(), this->dim) = value_d;
681 }
682 };
683
685 struct SetTangent : public DimFunctor {
686 public:
687 using DimFunctor::DimFunctor;
688
689 void operator()(Type& value, Real const& value_d) {
690 GradientTraits::at(value.gradient(), this->dim) = value_d;
691 }
692 };
693
694 public:
695
701 void solve(LinearSystem lsi, Matrix* A, Vector* b, Vector* x, LinearSystemSolverHints hints) {
702 CODI_UNUSED(hints);
703
704 MatrixReal* A_v = lsi.createMatrixReal(A);
705 MatrixReal* A_d = lsi.createMatrixReal(A);
706 VectorReal* b_v = lsi.createVectorReal(b);
707 VectorReal* b_d = lsi.createVectorReal(b);
708 VectorReal* x_v = lsi.createVectorReal(x);
709 VectorReal* x_d = lsi.createVectorReal(x);
710
711 size_t maxDim = GradientTraits::dim<Gradient>();
712
713 if (hints.test(LinearSystemSolverFlags::ProvidePrimalSolution)) {
714 lsi.iterateVector(getOutput, x, x_v);
715 }
716
717 for (size_t curDim = 0; curDim < maxDim; curDim += 1) {
718 if (0 == curDim) {
719 lsi.iterateMatrix(GetPrimalAndGetTangent(curDim), A, A_v, A_d);
720 lsi.iterateVector(GetPrimalAndGetTangent(curDim), b, b_v, b_d);
721 } else {
722 lsi.iterateMatrix(GetTangent(curDim), A, A_d);
723 lsi.iterateVector(GetTangent(curDim), b, b_d);
724 }
725
726 if (0 == curDim) { // Solve primal system only once.
727 // Solve Ax = b
729 lsi.solveSystemPrimal(A_v, b_v, x_v);
730 } else {
731 lsi.solveSystem(A_v, b_v, x_v);
732 }
733 }
734
735 // temp(x_d) = b_d - A_d * x
736 lsi.subtractMultiply(x_d, b_d, A_d, x_v);
737
738 std::swap(b_d, x_d); // Move temporary to b_d.
739
740 // Solve A x_d = temp
741 lsi.solveSystem(A_v, b_d, x_d);
742
743 if (0 == curDim) {
744 lsi.iterateVector(SetPrimalAndSetTangent(curDim), x, x_v, x_d);
745 } else {
746 lsi.iterateVector(SetTangent(curDim), x, x_d);
747 }
748 }
749
750 lsi.deleteMatrixReal(A_v);
751 lsi.deleteMatrixReal(A_d);
752 lsi.deleteVectorReal(b_v);
753 lsi.deleteVectorReal(b_d);
754 lsi.deleteVectorReal(x_v);
755 lsi.deleteVectorReal(x_d);
756 }
757 };
758#endif
759
769 template<typename LSInterface>
770 void solveLinearSystem(LSInterface lsi, typename LSInterface::Matrix& A, typename LSInterface::Vector& b,
771 typename LSInterface::Vector& x,
774 handler.solve(lsi, &A, &b, &x, hints);
775 }
776}
#define CODI_DD(Type, Default)
Abbreviation for CODI_DECLARE_DEFAULT.
Definition macros.hpp:94
#define CODI_T(...)
Abbreviation for CODI_TEMPLATE.
Definition macros.hpp:111
size_t constexpr dim()
Number of dimensions this gradient value has.
Definition gradientTraits.hpp:96
TraitsImplementation< Gradient >::Real & at(Gradient &gradient, size_t dim)
Get the entry at the given index.
Definition gradientTraits.hpp:102
typename std::enable_if< IsForwardTape< Tape >::value >::type EnableIfForwardTape
Enable if wrapper for IsForwardTape.
Definition tapeTraits.hpp:93
CoDiPack - Code Differentiation Package.
Definition codi.hpp:91
void CODI_UNUSED(Args const &...)
Disable unused warnings for an arbitrary number of arguments.
Definition macros.hpp:46
void solveLinearSystem(LSInterface lsi, typename LSInterface::Matrix &A, typename LSInterface::Vector &b, typename LSInterface::Vector &x, LinearSystemSolverHints hints=LinearSystemSolverHints::ALL())
Definition linearSystemHandler.hpp:770
EnumBitset< LinearSystemSolverFlags > LinearSystemSolverHints
All hints for the LinearSystemSolverHelper.
Definition linearSystemFlags.hpp:58
Identifier & getIdentifier()
Definition activeTypeBase.hpp:156
Represents a concrete lvalue in the CoDiPack expression tree.
Definition activeType.hpp:52
static constexpr EnumBitset ALL()
Constructor for a bitset with all values flagged as true.
Definition enumBitset.hpp:181
bool test(Enum pos) const
Test if the bit for the enum is set.
Definition enumBitset.hpp:93
User-defined evaluation functions for the taping process.
Definition externalFunction.hpp:102
Real const & getValue() const
Get the primal value of this lvalue.
Definition lhsExpressionInterface.hpp:125
Definition linearSystemInterface.hpp:109
Definition linearSystemHandler.hpp:69
T_LinearSystem LinearSystem
See LinearSystemSolverHandler.
Definition linearSystemHandler.hpp:73
typename LinearSystem::MatrixIdentifier MatrixIdentifier
See LinearSystemInterfaceTypes.
Definition linearSystemHandler.hpp:79
void solve(LinearSystem lsi, Matrix *A, Vector *b, Vector *x, LinearSystemSolverHints hints)
Definition linearSystemHandler.hpp:499
typename LinearSystem::VectorIdentifier VectorIdentifier
See LinearSystemInterfaceTypes.
Definition linearSystemHandler.hpp:82
typename LinearSystem::Vector Vector
See LinearSystemInterfaceTypes.
Definition linearSystemHandler.hpp:80
typename LinearSystem::Type Type
See LinearSystemInterfaceTypes.
Definition linearSystemHandler.hpp:75
typename LinearSystem::MatrixReal MatrixReal
See LinearSystemInterfaceTypes.
Definition linearSystemHandler.hpp:78
typename LinearSystem::VectorReal VectorReal
See LinearSystemInterfaceTypes.
Definition linearSystemHandler.hpp:81
typename LinearSystem::Matrix Matrix
See LinearSystemInterfaceTypes.
Definition linearSystemHandler.hpp:77
Definition linearSystemSpecializationDetection.hpp:57
static bool SupportsForwardMode()
True if all functions for the forward mode support are specialized.
Definition linearSystemSpecializationDetection.hpp:122
static bool IsSubtractMultiplyImplemented()
Checks if subtractMultiply is specialized in LinearSystem.
Definition linearSystemSpecializationDetection.hpp:103
static bool IsDyadicImplemented()
Checks if iterateDyadic is specialized in LinearSystem.
Definition linearSystemSpecializationDetection.hpp:90
static bool IsSolvePrimalImplemented()
Checks if solveSystemPrimal is specialized in LinearSystem.
Definition linearSystemSpecializationDetection.hpp:108
static bool SupportsReverseMode()
True if all functions for the reverse mode support are specialized.
Definition linearSystemSpecializationDetection.hpp:117
static bool IsTransposeImplemented()
Checks if transposeMatrix is specialized in LinearSystem.
Definition linearSystemSpecializationDetection.hpp:98
virtual void resetAdjoint(Identifier const &index, size_t dim)=0
Set the adjoint component to zero.
virtual void setPrimal(Identifier const &index, Real const &primal)=0
Set the primal value.
virtual size_t getVectorSize() const =0
Vector size in the current tape evaluation.
virtual Real getPrimal(Identifier const &index)=0
Get the primal value.
virtual void updateAdjoint(Identifier const &index, size_t dim, Real const &adjoint)=0
Update the adjoint component.
virtual Real getAdjoint(Identifier const &index, size_t dim)=0
Get the adjoint component.