75 using Type =
CODI_DD(
typename LinearSystem::Type, CODI_DEFAULT_LHS_EXPRESSION);
77 using Matrix =
typename LinearSystem::Matrix;
80 using Vector =
typename LinearSystem::Vector;
89 using Real =
typename Type::Real;
90 using Identifier =
typename Type::Identifier;
91 using Gradient =
typename Type::Gradient;
94 using Tape =
CODI_DD(
typename Type::Tape, CODI_DEFAULT_TAPE);
106 struct VectorAccessFunctor {
110 VectorAccess* adjointInterface;
113 VectorAccessFunctor(
size_t dim, VectorAccess* adjointInterface)
114 : dim(dim), adjointInterface(adjointInterface) {}
118 static void extract(
Type const& value, Real& value_v, Identifier& value_id) {
120 value_id = value.getIdentifier();
124 struct ExtractAdjoint :
public VectorAccessFunctor {
126 using VectorAccessFunctor::VectorAccessFunctor;
128 void operator()(Real& value_b, Identifier
const& value_id) {
129 value_b = this->adjointInterface->getAdjoint(value_id, this->dim);
130 this->adjointInterface->resetAdjoint(value_id, this->dim);
135 static void getOutput(
Type const& value, Real& value_v) {
140 struct GetAdjoint :
public VectorAccessFunctor {
142 using VectorAccessFunctor::VectorAccessFunctor;
144 void operator()(Real& value_b, Identifier
const& value_id) {
145 value_b = this->adjointInterface->getAdjoint(value_id, this->dim);
150 struct GetPrimal :
public VectorAccessFunctor {
152 using VectorAccessFunctor::VectorAccessFunctor;
154 void operator()(Real& value_v, Identifier
const& value_id) {
155 value_v = this->adjointInterface->getPrimal(value_id);
160 struct GetPrimalAndGetAdjoint :
public VectorAccessFunctor {
162 using VectorAccessFunctor::VectorAccessFunctor;
164 void operator()(Real& value_v, Real& value_b, Identifier
const& value_id) {
165 value_v = this->adjointInterface->getPrimal(value_id);
166 value_b = this->adjointInterface->getAdjoint(value_id, this->dim);
171 using GetPrimalAndGetTangent = GetPrimalAndGetAdjoint;
174 using GetTangent = GetAdjoint;
177 static Real registerOutput(
Type& value, Real& value_v, Identifier& value_id) {
179 Real oldTemp = Type::getTape().registerExternalFunctionOutput(value);
186 static void registerOutputWithPrimal(
Type& value, Real& value_v, Identifier& value_id, Real& oldValue) {
187 oldValue = registerOutput(value, value_v, value_id);
191 static void setOutput(
Type& value, Real
const& value_v) {
196 struct SetTangent :
public VectorAccessFunctor {
198 using VectorAccessFunctor::VectorAccessFunctor;
200 void operator()(Real& value_d, Identifier
const& value_id) {
201 this->adjointInterface->resetAdjoint(value_id, this->dim);
202 this->adjointInterface->updateAdjoint(value_id, this->dim, value_d);
207 struct SetPrimal :
public VectorAccessFunctor {
209 using VectorAccessFunctor::VectorAccessFunctor;
211 void operator()(Real& value_v, Identifier
const& value_id) {
212 this->adjointInterface->setPrimal(value_id, value_v);
217 struct SetPrimalAndSetTangent :
public VectorAccessFunctor {
219 using VectorAccessFunctor::VectorAccessFunctor;
221 void operator()(Real& value_v, Real& value_d, Identifier
const& value_id) {
222 this->adjointInterface->setPrimal(value_id, value_v);
223 this->adjointInterface->resetAdjoint(value_id, this->dim);
224 this->adjointInterface->updateAdjoint(value_id, this->dim, value_d);
229 struct SetPrimalAndSetTangentAndUpdateOldPrimal :
public VectorAccessFunctor {
231 using VectorAccessFunctor::VectorAccessFunctor;
233 void operator()(Real& value_v, Real& value_d, Identifier
const& value_id, Real& oldValue) {
234 oldValue = this->adjointInterface->getPrimal(value_id);
235 this->adjointInterface->setPrimal(value_id, value_v);
236 this->adjointInterface->resetAdjoint(value_id, this->dim);
237 this->adjointInterface->updateAdjoint(value_id, this->dim, value_d);
242 struct SetPrimalAndUpdateOldPrimals :
public VectorAccessFunctor {
244 using VectorAccessFunctor::VectorAccessFunctor;
246 void operator()(Real& value_v, Identifier
const& value_id, Real& oldValue) {
247 oldValue = this->adjointInterface->getPrimal(value_id);
248 this->adjointInterface->setPrimal(value_id, value_v);
253 struct UpdateAdjoint :
public VectorAccessFunctor {
255 using VectorAccessFunctor::VectorAccessFunctor;
257 void operator()(Real& value_b, Identifier
const& value_id) {
258 this->adjointInterface->updateAdjoint(value_id, this->dim, value_b);
263 struct UpdateAdjointDyadic :
public VectorAccessFunctor {
265 using VectorAccessFunctor::VectorAccessFunctor;
267 void operator()(Identifier& mat_id, Real
const& b_b, Real
const& x_v) {
268 Real adjoint = -b_b * x_v;
269 this->adjointInterface->updateAdjoint(mat_id, this->dim, adjoint);
277 static bool constexpr IsPrimalValueTape = TapeTraits::IsPrimalValueTape<Tape>::value;
278 static bool constexpr StoreOldPrimals = IsPrimalValueTape & !Tape::LinearIndexHandling;
312 lsi.deleteMatrixReal(A_v);
314 if (NULL != A_v_trans) {
315 lsi.deleteMatrixReal(A_v_trans);
318 lsi.deleteMatrixIdentifier(A_id);
321 lsi.deleteVectorIdentifier(b_id);
324 lsi.deleteVectorReal(x_v);
327 lsi.deleteVectorIdentifier(x_id);
329 if (NULL != oldPrimals) {
330 lsi.deleteVectorReal(oldPrimals);
342 static void solve_b(Tape* tape,
void* d, VectorAccess* adjointInterface) {
346 CODI_EXCEPTION(
"Missing functionality for linear system reverse mode. iterateDyadic(%d), transposeMatrix(%d)",
350 ExtFuncData* data = (ExtFuncData*)d;
352 if (!data->hints.test(LinearSystemSolverFlags::ReverseEvaluation)) {
354 "Linear system reverse mode called without hint 'LinearSystemSolverFlags::ReverseEvaluation'.");
357 VectorReal* x_b = data->lsi.createVectorReal(data->x_id);
358 VectorReal* s = data->lsi.createVectorReal(data->b_id);
360 if (NULL != data->oldPrimals) {
361 data->lsi.iterateVector(SetPrimal(0, adjointInterface), data->oldPrimals, data->x_id);
365 for (
size_t curDim = 0; curDim < maxDim; curDim += 1) {
366 data->lsi.iterateVector(ExtractAdjoint(curDim, adjointInterface), x_b, data->x_id);
368 data->lsi.solveSystem(data->A_v_trans, x_b, s);
370 data->lsi.iterateDyadic(UpdateAdjointDyadic(curDim, adjointInterface), data->A_id, s, data->x_v);
371 data->lsi.iterateVector(UpdateAdjoint(curDim, adjointInterface), s, data->b_id);
374 data->lsi.deleteVectorReal(x_b);
375 data->lsi.deleteVectorReal(s);
383 static void solve_d(Tape* tape,
void* d, VectorAccess* adjointInterface) {
386 ExtFuncData* data = (ExtFuncData*)d;
389 CODI_EXCEPTION(
"Missing functionality for linear system forward mode. subtractMultiply(%d)",
392 if (!data->hints.test(LinearSystemSolverFlags::ForwardEvaluation)) {
394 "Linear system forward mode called without hint 'LinearSystemSolverFlags::ForwardEvaluation'.");
397 bool const updatePrimals =
398 IsPrimalValueTape && data->hints.test(LinearSystemSolverFlags::RecomputePrimalInForwardEvaluation);
400 MatrixReal* A_d = data->lsi.createMatrixReal(data->A_id);
401 VectorReal* b_v = data->lsi.createVectorReal(data->b_id);
402 VectorReal* b_d = data->lsi.createVectorReal(data->b_id);
403 VectorReal* x_d = data->lsi.createVectorReal(data->x_id);
406 for (
size_t curDim = 0; curDim < maxDim; curDim += 1) {
407 if (0 == curDim && updatePrimals) {
408 data->lsi.iterateMatrix(GetPrimalAndGetTangent(curDim, adjointInterface), data->A_v, A_d, data->A_id);
409 data->lsi.iterateVector(GetPrimalAndGetTangent(curDim, adjointInterface), b_v, b_d, data->b_id);
411 data->lsi.iterateMatrix(GetTangent(curDim, adjointInterface), A_d, data->A_id);
412 data->lsi.iterateVector(GetTangent(curDim, adjointInterface), b_d, data->b_id);
415 if (0 == curDim && updatePrimals) {
417 if (NULL != data->A_v_trans) {
419 data->lsi.deleteMatrixReal(data->A_v_trans);
420 data->A_v_trans = data->lsi.transposeMatrix(data->A_v);
423 data->lsi.solveSystem(data->A_v, b_v, data->x_v);
426 data->lsi.subtractMultiply(b_v, b_d, A_d, data->x_v);
428 data->lsi.solveSystem(data->A_v, b_v , x_d);
431 if (NULL != data->oldPrimals) {
432 data->lsi.iterateVector(SetPrimalAndSetTangentAndUpdateOldPrimal(curDim, adjointInterface), data->x_v,
433 x_d, data->x_id, data->oldPrimals);
435 data->lsi.iterateVector(SetPrimalAndSetTangent(curDim, adjointInterface), data->x_v, x_d, data->x_id);
438 data->lsi.iterateVector(SetTangent(curDim, adjointInterface), x_d, data->x_id);
442 data->lsi.deleteMatrixReal(A_d);
443 data->lsi.deleteVectorReal(b_v);
444 data->lsi.deleteVectorReal(b_d);
445 data->lsi.deleteVectorReal(x_d);
452 static void solve_p(Tape* tape,
void* d, VectorAccess* adjointInterface) {
455 ExtFuncData* data = (ExtFuncData*)d;
457 if (!data->hints.test(LinearSystemSolverFlags::PrimalEvaluation)) {
458 CODI_EXCEPTION(
"Linear system primal mode called without hint 'LinearSystemSolverFlags::PrimalEvaluation'.");
461 VectorReal* b_v = data->lsi.createVectorReal(data->b_id);
463 data->lsi.iterateMatrix(GetPrimal(0, adjointInterface), data->A_v, data->A_id);
464 data->lsi.iterateVector(GetPrimal(0, adjointInterface), b_v, data->b_id);
466 data->lsi.solveSystem(data->A_v, b_v, data->x_v);
468 if (NULL != data->A_v_trans) {
471 data->lsi.deleteMatrixReal(data->A_v_trans);
472 data->A_v_trans = data->lsi.transposeMatrix(data->A_v);
475 if (NULL != data->oldPrimals) {
476 data->lsi.iterateVector(SetPrimalAndUpdateOldPrimals(0, adjointInterface), data->x_v, data->x_id,
479 data->lsi.iterateVector(SetPrimal(0, adjointInterface), data->x_v, data->x_id);
482 data->lsi.deleteVectorReal(b_v);
485 static void deleteData(Tape* tape,
void* d) {
488 ExtFuncData* data = (ExtFuncData*)d;
500 Tape& tape = Type::getTape();
509 lsi.iterateMatrix(extract, A, A_v, A_id);
510 lsi.iterateVector(extract, b, b_v, b_id);
512 if (hints.
test(LinearSystemSolverFlags::ProvidePrimalSolution)) {
513 lsi.iterateVector(getOutput, x, x_v);
517 lsi.solveSystemPrimal(A_v, b_v, x_v);
519 lsi.solveSystem(A_v, b_v, x_v);
522 if (tape.isActive()) {
524 if (hints.
test(LinearSystemSolverFlags::ReverseEvaluation)) {
525 A_v_trans = lsi.transposeMatrix(A_v);
529 if (StoreOldPrimals && hints.
test(LinearSystemSolverFlags::ReverseEvaluation)) {
531 lsi.iterateVector(registerOutputWithPrimal, x, x_v, x_id, oldPrimals);
534 lsi.iterateVector(registerOutput, x, x_v, x_id);
537 ExtFuncData* data =
new ExtFuncData(lsi, hints);
538 if (hints.
test(LinearSystemSolverFlags::ForwardEvaluation) ||
539 hints.
test(LinearSystemSolverFlags::PrimalEvaluation)) {
543 data->A_v_trans = A_v_trans;
548 data->oldPrimals = oldPrimals;
553 lsi.deleteVectorReal(b_v);
556 lsi.deleteMatrixReal(A_v);
559 lsi.iterateVector(setOutput, x, x_v);
561 lsi.deleteMatrixReal(A_v);
562 lsi.deleteMatrixIdentifier(A_id);
563 lsi.deleteVectorReal(b_v);
564 lsi.deleteVectorIdentifier(b_id);
565 lsi.deleteVectorReal(x_v);
566 lsi.deleteVectorIdentifier(x_id);