41#include "../../../expressions/lhsExpressionInterface.hpp"
42#include "../../../misc/exceptions.hpp"
43#include "../../../traits/tapeTraits.hpp"
44#include "../../data/direction.hpp"
45#include "linearSystemFlags.hpp"
46#include "linearSystemInterface.hpp"
47#include "linearSystemSpecializationDetection.hpp"
68 template<
typename T_LinearSystem,
typename =
void>
75 using Type =
CODI_DD(
typename LinearSystem::Type, CODI_DEFAULT_LHS_EXPRESSION);
77 using Matrix =
typename LinearSystem::Matrix;
80 using Vector =
typename LinearSystem::Vector;
89 using Real =
typename Type::Real;
90 using Identifier =
typename Type::Identifier;
91 using Gradient =
typename Type::Gradient;
94 using Tape =
CODI_DD(
typename Type::Tape, CODI_DEFAULT_TAPE);
106 struct VectorAccessFunctor {
113 VectorAccessFunctor(
size_t dim,
VectorAccess* adjointInterface)
114 : dim(dim), adjointInterface(adjointInterface) {}
118 static void extract(
Type const& value,
Real& value_v, Identifier& value_id) {
120 value_id = value.getIdentifier();
124 struct ExtractAdjoint :
public VectorAccessFunctor {
126 using VectorAccessFunctor::VectorAccessFunctor;
128 void operator()(
Real& value_b, Identifier
const& value_id) {
129 value_b = this->adjointInterface->
getAdjoint(value_id, this->dim);
130 this->adjointInterface->
resetAdjoint(value_id, this->dim);
135 static void getOutput(
Type const& value,
Real& value_v) {
140 struct GetAdjoint :
public VectorAccessFunctor {
142 using VectorAccessFunctor::VectorAccessFunctor;
144 void operator()(
Real& value_b, Identifier
const& value_id) {
145 value_b = this->adjointInterface->
getAdjoint(value_id, this->dim);
150 struct GetPrimal :
public VectorAccessFunctor {
152 using VectorAccessFunctor::VectorAccessFunctor;
154 void operator()(
Real& value_v, Identifier
const& value_id) {
155 value_v = this->adjointInterface->
getPrimal(value_id);
160 struct GetPrimalAndGetAdjoint :
public VectorAccessFunctor {
162 using VectorAccessFunctor::VectorAccessFunctor;
164 void operator()(
Real& value_v,
Real& value_b, Identifier
const& value_id) {
165 value_v = this->adjointInterface->
getPrimal(value_id);
166 value_b = this->adjointInterface->
getAdjoint(value_id, this->dim);
171 using GetPrimalAndGetTangent = GetPrimalAndGetAdjoint;
174 using GetTangent = GetAdjoint;
177 static Real registerOutput(
Type& value,
Real& value_v, Identifier& value_id) {
179 Real oldTemp = Type::getTape().registerExternalFunctionOutput(value);
186 static void registerOutputWithPrimal(
Type& value,
Real& value_v, Identifier& value_id,
Real& oldValue) {
187 oldValue = registerOutput(value, value_v, value_id);
191 static void setOutput(
Type& value,
Real const& value_v) {
196 struct SetTangent :
public VectorAccessFunctor {
198 using VectorAccessFunctor::VectorAccessFunctor;
200 void operator()(
Real& value_d, Identifier
const& value_id) {
201 this->adjointInterface->
resetAdjoint(value_id, this->dim);
202 this->adjointInterface->
updateAdjoint(value_id, this->dim, value_d);
207 struct SetPrimal :
public VectorAccessFunctor {
209 using VectorAccessFunctor::VectorAccessFunctor;
211 void operator()(
Real& value_v, Identifier
const& value_id) {
212 this->adjointInterface->
setPrimal(value_id, value_v);
217 struct SetPrimalAndSetTangent :
public VectorAccessFunctor {
219 using VectorAccessFunctor::VectorAccessFunctor;
221 void operator()(
Real& value_v,
Real& value_d, Identifier
const& value_id) {
222 this->adjointInterface->
setPrimal(value_id, value_v);
223 this->adjointInterface->
resetAdjoint(value_id, this->dim);
224 this->adjointInterface->
updateAdjoint(value_id, this->dim, value_d);
229 struct SetPrimalAndSetTangentAndUpdateOldPrimal :
public VectorAccessFunctor {
231 using VectorAccessFunctor::VectorAccessFunctor;
233 void operator()(
Real& value_v,
Real& value_d, Identifier
const& value_id,
Real& oldValue) {
234 oldValue = this->adjointInterface->
getPrimal(value_id);
235 this->adjointInterface->
setPrimal(value_id, value_v);
236 this->adjointInterface->
resetAdjoint(value_id, this->dim);
237 this->adjointInterface->
updateAdjoint(value_id, this->dim, value_d);
242 struct SetPrimalAndUpdateOldPrimals :
public VectorAccessFunctor {
244 using VectorAccessFunctor::VectorAccessFunctor;
246 void operator()(
Real& value_v, Identifier
const& value_id,
Real& oldValue) {
247 oldValue = this->adjointInterface->
getPrimal(value_id);
248 this->adjointInterface->
setPrimal(value_id, value_v);
253 struct UpdateAdjoint :
public VectorAccessFunctor {
255 using VectorAccessFunctor::VectorAccessFunctor;
257 void operator()(
Real& value_b, Identifier
const& value_id) {
258 this->adjointInterface->
updateAdjoint(value_id, this->dim, value_b);
263 struct UpdateAdjointDyadic :
public VectorAccessFunctor {
265 using VectorAccessFunctor::VectorAccessFunctor;
267 void operator()(Identifier& mat_id,
Real const& x_v,
Real const& b_b) {
268 Real adjoint = -x_v * b_b;
269 this->adjointInterface->
updateAdjoint(mat_id, this->dim, adjoint);
277 static bool constexpr IsPrimalValueTape = TapeTraits::IsPrimalValueTape<Tape>::value;
278 static bool constexpr StoreOldPrimals = IsPrimalValueTape & !Tape::LinearIndexHandling;
312 lsi.deleteMatrixReal(A_v);
314 if (NULL != A_v_trans) {
315 lsi.deleteMatrixReal(A_v_trans);
318 lsi.deleteMatrixIdentifier(A_id);
321 lsi.deleteVectorIdentifier(b_id);
324 lsi.deleteVectorReal(x_v);
327 lsi.deleteVectorIdentifier(x_id);
329 if (NULL != oldPrimals) {
330 lsi.deleteVectorReal(oldPrimals);
342 static void solve_b(Tape* tape,
void* d,
VectorAccess* adjointInterface) {
346 CODI_EXCEPTION(
"Missing functionality for linear system reverse mode. iterateDyadic(%d), transposeMatrix(%d)",
350 ExtFuncData* data = (ExtFuncData*)d;
352 if (!data->hints.test(LinearSystemSolverFlags::ReverseEvaluation)) {
354 "Linear system reverse mode called without hint 'LinearSystemSolverFlags::ReverseEvaluation'.");
357 VectorReal* x_b = data->lsi.createVectorReal(data->x_id);
358 VectorReal* s = data->lsi.createVectorReal(data->b_id);
360 if (NULL != data->oldPrimals) {
361 data->lsi.iterateVector(SetPrimal(0, adjointInterface), data->oldPrimals, data->x_id);
365 for (
size_t curDim = 0; curDim < maxDim; curDim += 1) {
366 data->lsi.iterateVector(ExtractAdjoint(curDim, adjointInterface), x_b, data->x_id);
368 data->lsi.solveSystem(data->A_v_trans, x_b, s);
370 data->lsi.iterateDyadic(UpdateAdjointDyadic(curDim, adjointInterface), data->A_id, data->x_v, s);
371 data->lsi.iterateVector(UpdateAdjoint(curDim, adjointInterface), s, data->b_id);
374 data->lsi.deleteVectorReal(x_b);
375 data->lsi.deleteVectorReal(s);
383 static void solve_d(Tape* tape,
void* d,
VectorAccess* adjointInterface) {
386 ExtFuncData* data = (ExtFuncData*)d;
389 CODI_EXCEPTION(
"Missing functionality for linear system forward mode. subtractMultiply(%d)",
392 if (!data->hints.test(LinearSystemSolverFlags::ForwardEvaluation)) {
394 "Linear system forward mode called without hint 'LinearSystemSolverFlags::ForwardEvaluation'.");
397 bool const updatePrimals =
398 IsPrimalValueTape && data->hints.test(LinearSystemSolverFlags::RecomputePrimalInForwardEvaluation);
400 MatrixReal* A_d = data->lsi.createMatrixReal(data->A_id);
401 VectorReal* b_v = data->lsi.createVectorReal(data->b_id);
402 VectorReal* b_d = data->lsi.createVectorReal(data->b_id);
403 VectorReal* x_d = data->lsi.createVectorReal(data->x_id);
406 for (
size_t curDim = 0; curDim < maxDim; curDim += 1) {
407 if (0 == curDim && updatePrimals) {
408 data->lsi.iterateMatrix(GetPrimalAndGetTangent(curDim, adjointInterface), data->A_v, A_d, data->A_id);
409 data->lsi.iterateVector(GetPrimalAndGetTangent(curDim, adjointInterface), b_v, b_d, data->b_id);
411 data->lsi.iterateMatrix(GetTangent(curDim, adjointInterface), A_d, data->A_id);
412 data->lsi.iterateVector(GetTangent(curDim, adjointInterface), b_d, data->b_id);
415 if (0 == curDim && updatePrimals) {
417 if (NULL != data->A_v_trans) {
419 data->lsi.deleteMatrixReal(data->A_v_trans);
420 data->A_v_trans = data->lsi.transposeMatrix(data->A_v);
423 data->lsi.solveSystem(data->A_v, b_v, data->x_v);
426 data->lsi.subtractMultiply(b_v, b_d, A_d, data->x_v);
428 data->lsi.solveSystem(data->A_v, b_v , x_d);
431 if (NULL != data->oldPrimals) {
432 data->lsi.iterateVector(SetPrimalAndSetTangentAndUpdateOldPrimal(curDim, adjointInterface), data->x_v,
433 x_d, data->x_id, data->oldPrimals);
435 data->lsi.iterateVector(SetPrimalAndSetTangent(curDim, adjointInterface), data->x_v, x_d, data->x_id);
438 data->lsi.iterateVector(SetTangent(curDim, adjointInterface), x_d, data->x_id);
442 data->lsi.deleteMatrixReal(A_d);
443 data->lsi.deleteVectorReal(b_v);
444 data->lsi.deleteVectorReal(b_d);
445 data->lsi.deleteVectorReal(x_d);
452 static void solve_p(Tape* tape,
void* d,
VectorAccess* adjointInterface) {
455 ExtFuncData* data = (ExtFuncData*)d;
457 if (!data->hints.test(LinearSystemSolverFlags::PrimalEvaluation)) {
458 CODI_EXCEPTION(
"Linear system primal mode called without hint 'LinearSystemSolverFlags::PrimalEvaluation'.");
461 VectorReal* b_v = data->lsi.createVectorReal(data->b_id);
463 data->lsi.iterateMatrix(GetPrimal(0, adjointInterface), data->A_v, data->A_id);
464 data->lsi.iterateVector(GetPrimal(0, adjointInterface), b_v, data->b_id);
466 data->lsi.solveSystem(data->A_v, b_v, data->x_v);
468 if (NULL != data->A_v_trans) {
471 data->lsi.deleteMatrixReal(data->A_v_trans);
472 data->A_v_trans = data->lsi.transposeMatrix(data->A_v);
475 if (NULL != data->oldPrimals) {
476 data->lsi.iterateVector(SetPrimalAndUpdateOldPrimals(0, adjointInterface), data->x_v, data->x_id,
479 data->lsi.iterateVector(SetPrimal(0, adjointInterface), data->x_v, data->x_id);
482 data->lsi.deleteVectorReal(b_v);
485 static void deleteData(Tape* tape,
void* d) {
488 ExtFuncData* data = (ExtFuncData*)d;
500 Tape& tape = Type::getTape();
509 lsi.iterateMatrix(extract, A, A_v, A_id);
510 lsi.iterateVector(extract, b, b_v, b_id);
512 if (hints.
test(LinearSystemSolverFlags::ProvidePrimalSolution)) {
513 lsi.iterateVector(getOutput, x, x_v);
517 lsi.solveSystemPrimal(A_v, b_v, x_v);
519 lsi.solveSystem(A_v, b_v, x_v);
522 if (tape.isActive()) {
524 if (hints.
test(LinearSystemSolverFlags::ReverseEvaluation)) {
525 A_v_trans = lsi.transposeMatrix(A_v);
529 if (StoreOldPrimals && hints.
test(LinearSystemSolverFlags::ReverseEvaluation)) {
531 lsi.iterateVector(registerOutputWithPrimal, x, x_v, x_id, oldPrimals);
534 lsi.iterateVector(registerOutput, x, x_v, x_id);
537 ExtFuncData* data =
new ExtFuncData(lsi, hints);
538 if (hints.
test(LinearSystemSolverFlags::ForwardEvaluation) ||
539 hints.
test(LinearSystemSolverFlags::PrimalEvaluation)) {
543 data->A_v_trans = A_v_trans;
548 data->oldPrimals = oldPrimals;
553 lsi.deleteVectorReal(b_v);
556 lsi.deleteMatrixReal(A_v);
559 lsi.iterateVector(setOutput, x, x_v);
561 lsi.deleteMatrixReal(A_v);
562 lsi.deleteMatrixIdentifier(A_id);
563 lsi.deleteVectorReal(b_v);
564 lsi.deleteVectorIdentifier(b_id);
565 lsi.deleteVectorReal(x_v);
566 lsi.deleteVectorIdentifier(x_id);
571#ifndef DOXYGEN_DISABLE
574 template<
typename T_LinearSystem>
575 struct LinearSystemSolverHandler<T_LinearSystem, RealTraits::EnableIfPassiveReal<typename T_LinearSystem::Type>> {
580 using Matrix =
typename LinearSystem::Matrix;
581 using Vector =
typename LinearSystem::Vector;
585 using Overloads = LinearSystemSpecializationDetection<LinearSystem>;
597 lsi.solveSystemPrimal(A, b, x);
599 lsi.solveSystem(A, b, x);
606 template<
typename T_LinearSystem>
607 struct LinearSystemSolverHandler<T_LinearSystem,
613 CODI_T(LinearSystemInterface<LinearSystemInterfaceTypes>));
616 CODI_DEFAULT_LHS_EXPRESSION);
618 using Matrix =
typename LinearSystem::Matrix;
619 using MatrixReal =
typename LinearSystem::MatrixReal;
621 using Vector =
typename LinearSystem::Vector;
622 using VectorReal =
typename LinearSystem::VectorReal;
630 using Real =
typename Type::Real;
631 using Identifier =
typename Type::Identifier;
632 using Gradient =
typename Type::Gradient;
635 using Overloads = LinearSystemSpecializationDetection<LinearSystem>;
644 DimFunctor(
size_t dim) :
dim(
dim) {}
648 static void getOutput(
Type& value,
Real& value_v) {
653 struct GetPrimalAndGetTangent :
public DimFunctor {
655 using DimFunctor::DimFunctor;
657 void operator()(
Type const& value,
Real& value_v,
Real& value_d) {
664 struct GetTangent :
public DimFunctor {
666 using DimFunctor::DimFunctor;
668 void operator()(
Type const& value,
Real& value_d) {
674 struct SetPrimalAndSetTangent :
public DimFunctor {
676 using DimFunctor::DimFunctor;
678 void operator()(
Type& value,
Real const& value_v,
Real const& value_d) {
679 value.value() = value_v;
685 struct SetTangent :
public DimFunctor {
687 using DimFunctor::DimFunctor;
689 void operator()(
Type& value,
Real const& value_d) {
711 size_t maxDim = GradientTraits::dim<Gradient>();
713 if (hints.test(LinearSystemSolverFlags::ProvidePrimalSolution)) {
714 lsi.iterateVector(getOutput, x, x_v);
717 for (
size_t curDim = 0; curDim < maxDim; curDim += 1) {
719 lsi.iterateMatrix(GetPrimalAndGetTangent(curDim), A, A_v, A_d);
720 lsi.iterateVector(GetPrimalAndGetTangent(curDim), b, b_v, b_d);
722 lsi.iterateMatrix(GetTangent(curDim), A, A_d);
723 lsi.iterateVector(GetTangent(curDim), b, b_d);
729 lsi.solveSystemPrimal(A_v, b_v, x_v);
731 lsi.solveSystem(A_v, b_v, x_v);
736 lsi.subtractMultiply(x_d, b_d, A_d, x_v);
741 lsi.solveSystem(A_v, b_d, x_d);
744 lsi.iterateVector(SetPrimalAndSetTangent(curDim), x, x_v, x_d);
746 lsi.iterateVector(SetTangent(curDim), x, x_d);
750 lsi.deleteMatrixReal(A_v);
751 lsi.deleteMatrixReal(A_d);
752 lsi.deleteVectorReal(b_v);
753 lsi.deleteVectorReal(b_d);
754 lsi.deleteVectorReal(x_v);
755 lsi.deleteVectorReal(x_d);
769 template<
typename LSInterface>
770 void solveLinearSystem(LSInterface lsi,
typename LSInterface::Matrix& A,
typename LSInterface::Vector& b,
771 typename LSInterface::Vector& x,
774 handler.solve(lsi, &A, &b, &x, hints);
#define CODI_DD(Type, Default)
Abbreviation for CODI_DECLARE_DEFAULT.
Definition macros.hpp:94
#define CODI_T(...)
Abbreviation for CODI_TEMPLATE.
Definition macros.hpp:111
size_t constexpr dim()
Number of dimensions this gradient value has.
Definition gradientTraits.hpp:96
TraitsImplementation< Gradient >::Real & at(Gradient &gradient, size_t dim)
Get the entry at the given index.
Definition gradientTraits.hpp:102
typename std::enable_if< IsForwardTape< Tape >::value >::type EnableIfForwardTape
Enable if wrapper for IsForwardTape.
Definition tapeTraits.hpp:93
CoDiPack - Code Differentiation Package.
Definition codi.hpp:91
void CODI_UNUSED(Args const &...)
Disable unused warnings for an arbitrary number of arguments.
Definition macros.hpp:46
void solveLinearSystem(LSInterface lsi, typename LSInterface::Matrix &A, typename LSInterface::Vector &b, typename LSInterface::Vector &x, LinearSystemSolverHints hints=LinearSystemSolverHints::ALL())
Definition linearSystemHandler.hpp:770
EnumBitset< LinearSystemSolverFlags > LinearSystemSolverHints
All hints for the LinearSystemSolverHelper.
Definition linearSystemFlags.hpp:58
Identifier & getIdentifier()
Definition activeTypeBase.hpp:156
Represents a concrete lvalue in the CoDiPack expression tree.
Definition activeType.hpp:52
static constexpr EnumBitset ALL()
Constructor for a bitset with all values flagged as true.
Definition enumBitset.hpp:181
bool test(Enum pos) const
Test if the bit for the enum is set.
Definition enumBitset.hpp:93
User-defined evaluation functions for the taping process.
Definition externalFunction.hpp:102
Real const & getValue() const
Get the primal value of this lvalue.
Definition lhsExpressionInterface.hpp:125
Definition linearSystemInterface.hpp:109
Definition linearSystemHandler.hpp:69
T_LinearSystem LinearSystem
See LinearSystemSolverHandler.
Definition linearSystemHandler.hpp:73
typename LinearSystem::MatrixIdentifier MatrixIdentifier
See LinearSystemInterfaceTypes.
Definition linearSystemHandler.hpp:79
void solve(LinearSystem lsi, Matrix *A, Vector *b, Vector *x, LinearSystemSolverHints hints)
Definition linearSystemHandler.hpp:499
typename LinearSystem::VectorIdentifier VectorIdentifier
See LinearSystemInterfaceTypes.
Definition linearSystemHandler.hpp:82
typename LinearSystem::Vector Vector
See LinearSystemInterfaceTypes.
Definition linearSystemHandler.hpp:80
typename LinearSystem::Type Type
See LinearSystemInterfaceTypes.
Definition linearSystemHandler.hpp:75
typename LinearSystem::MatrixReal MatrixReal
See LinearSystemInterfaceTypes.
Definition linearSystemHandler.hpp:78
typename LinearSystem::VectorReal VectorReal
See LinearSystemInterfaceTypes.
Definition linearSystemHandler.hpp:81
typename LinearSystem::Matrix Matrix
See LinearSystemInterfaceTypes.
Definition linearSystemHandler.hpp:77
Definition linearSystemSpecializationDetection.hpp:57
static bool SupportsForwardMode()
True if all functions for the forward mode support are specialized.
Definition linearSystemSpecializationDetection.hpp:122
static bool IsSubtractMultiplyImplemented()
Checks if subtractMultiply is specialized in LinearSystem.
Definition linearSystemSpecializationDetection.hpp:103
static bool IsDyadicImplemented()
Checks if iterateDyadic is specialized in LinearSystem.
Definition linearSystemSpecializationDetection.hpp:90
static bool IsSolvePrimalImplemented()
Checks if solveSystemPrimal is specialized in LinearSystem.
Definition linearSystemSpecializationDetection.hpp:108
static bool SupportsReverseMode()
True if all functions for the reverse mode support are specialized.
Definition linearSystemSpecializationDetection.hpp:117
static bool IsTransposeImplemented()
Checks if transposeMatrix is specialized in LinearSystem.
Definition linearSystemSpecializationDetection.hpp:98
virtual void resetAdjoint(Identifier const &index, size_t dim)=0
Set the adjoint component to zero.
virtual void setPrimal(Identifier const &index, Real const &primal)=0
Set the primal value.
virtual size_t getVectorSize() const =0
Vector size in the current tape evaluation.
virtual Real getPrimal(Identifier const &index)=0
Get the primal value.
virtual void updateAdjoint(Identifier const &index, size_t dim, Real const &adjoint)=0
Update the adjoint component.
virtual Real getAdjoint(Identifier const &index, size_t dim)=0
Get the adjoint component.