CoDiPack  3.0.0
A Code Differentiation Package
SciComp TU Kaiserslautern
Loading...
Searching...
No Matches
externalFunctionHelper.hpp
1/*
2 * CoDiPack, a Code Differentiation Package
3 *
4 * Copyright (C) 2015-2025 Chair for Scientific Computing (SciComp), University of Kaiserslautern-Landau
5 * Homepage: http://scicomp.rptu.de
6 * Contact: Prof. Nicolas R. Gauger (codi@scicomp.uni-kl.de)
7 *
8 * Lead developers: Max Sagebaum, Johannes Blühdorn (SciComp, University of Kaiserslautern-Landau)
9 *
10 * This file is part of CoDiPack (http://scicomp.rptu.de/software/codi).
11 *
12 * CoDiPack is free software: you can redistribute it and/or
13 * modify it under the terms of the GNU General Public License
14 * as published by the Free Software Foundation, either version 3 of the
15 * License, or (at your option) any later version.
16 *
17 * CoDiPack is distributed in the hope that it will be useful,
18 * but WITHOUT ANY WARRANTY; without even the implied warranty
19 * of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
20 *
21 * See the GNU General Public License for more details.
22 * You should have received a copy of the GNU
23 * General Public License along with CoDiPack.
24 * If not, see <http://www.gnu.org/licenses/>.
25 *
26 * For other licensing options please contact us.
27 *
28 * Authors:
29 * - SciComp, University of Kaiserslautern-Landau:
30 * - Max Sagebaum
31 * - Johannes Blühdorn
32 * - Former members:
33 * - Tim Albring
34 */
35#pragma once
36
37#include <vector>
38
39#include "../../config.h"
40#include "../../expressions/lhsExpressionInterface.hpp"
41#include "../../misc/macros.hpp"
42#include "../../tapes/interfaces/fullTapeInterface.hpp"
43#include "../../tapes/misc/vectorAccessInterface.hpp"
44#include "../../traits/tapeTraits.hpp"
45#include "../data/externalFunctionUserData.hpp"
46#include "../parallel/synchronizationInterface.hpp"
47#include "../parallel/threadInformationInterface.hpp"
48
50namespace codi {
51
100 template<typename T_Type, typename T_Synchronization = DefaultSynchronization,
101 typename T_ThreadInformation = DefaultThreadInformation>
103 public:
104
106 using Type = CODI_DD(T_Type, CODI_DEFAULT_LHS_EXPRESSION);
107
109 using Synchronization = CODI_DD(T_Synchronization, DefaultSynchronization);
110
113
114 using Real = typename Type::Real;
115 using Identifier = typename Type::Identifier;
116
117 using Tape = CODI_DD(typename Type::Tape, CODI_DEFAULT_TAPE);
118
120 using ReverseFunc = void (*)(Real const* x, Real* x_b, size_t m, Real const* y, Real const* y_b, size_t n,
122
124 using ForwardFunc = void (*)(Real const* x, Real const* x_d, size_t m, Real* y, Real* y_d, size_t n,
126
128 using PrimalFunc = void (*)(Real const* x, size_t m, Real* y, size_t n, ExternalFunctionUserData* d);
129
130 private:
131
132 static constexpr bool IsPrimalValueTape = TapeTraits::IsPrimalValueTape<Tape>::value;
133
134 struct EvalData {
135 public:
136
137 std::vector<Identifier> inputIndices;
138 std::vector<Identifier> outputIndices;
139
140 std::vector<Real> inputValues;
141 std::vector<Real> outputValues;
142 std::vector<Real> oldPrimals;
143
144 std::vector<Real> x_d;
145 std::vector<Real> y_d;
146 std::vector<Real> x_b;
147 std::vector<Real> y_b;
148
149 ReverseFunc reverseFunc;
150 ForwardFunc forwardFunc;
151 PrimalFunc primalFunc;
152
154
155 bool provideInputValues;
156 bool provideOutputValues;
157 bool getPrimalsFromPrimalValueVector;
159
160 EvalData(bool getPrimalsFromPrimalValueVector, bool reallocatePrimalVectors)
161 : inputIndices(0),
162 outputIndices(0),
163 inputValues(0),
164 outputValues(0),
165 oldPrimals(0),
166 x_d(0),
167 y_d(0),
168 x_b(0),
169 y_b(0),
170 reverseFunc(nullptr),
171 forwardFunc(nullptr),
172 primalFunc(nullptr),
173 provideInputValues(true),
174 provideOutputValues(true),
175 getPrimalsFromPrimalValueVector(getPrimalsFromPrimalValueVector),
177
178 static void delFunc(Tape* t, void* d) {
179 CODI_UNUSED(t);
180
181 EvalData* data = (EvalData*)d;
182
183 delete data;
184 }
185
186 static void evalForwFuncStatic(Tape* t, void* d, VectorAccessInterface<Real, Identifier>* ra) {
187 EvalData* data = (EvalData*)d;
188
189 if (nullptr != data->forwardFunc) {
190 data->evalForwFunc(t, ra);
191 } else {
192 CODI_EXCEPTION(
193 "Calling forward evaluation in external function helper without a forward function pointer.");
194 }
195 }
196
197 CODI_INLINE void evalForwFunc(Tape* t, VectorAccessInterface<Real, Identifier>* ra) {
198 CODI_UNUSED(t);
199
201 x_d.resize(inputIndices.size());
202 y_d.resize(outputIndices.size());
203
204 initRun(ra);
205 });
206
208
209 for (size_t dim = 0; dim < ra->getVectorSize(); ++dim) {
211 for (size_t i = 0; i < inputIndices.size(); ++i) {
212 x_d[i] = ra->getAdjoint(inputIndices[i], dim);
213 }
214 });
215
217
218 forwardFunc(inputValues.data(), x_d.data(), inputIndices.size(), outputValues.data(), y_d.data(),
219 outputIndices.size(), &userData);
220
222
224 for (size_t i = 0; i < outputIndices.size(); ++i) {
225 ra->resetAdjoint(outputIndices[i], dim);
226 ra->updateAdjoint(outputIndices[i], dim, y_d[i]);
227 }
228 });
229
231 }
232
234 finalizeRun(ra);
235
236 x_d.resize(0);
237 y_d.resize(0);
238 });
239
241 }
242
243 static void evalPrimFuncStatic(Tape* t, void* d, VectorAccessInterface<Real, Identifier>* ra) {
244 EvalData* data = (EvalData*)d;
245
246 if (nullptr != data->primalFunc) {
247 data->evalPrimFunc(t, ra);
248 } else {
249 CODI_EXCEPTION(
250 "Calling primal evaluation in external function helper without a primal function pointer.");
251 }
252 }
253
254 CODI_INLINE void evalPrimFunc(Tape* t, VectorAccessInterface<Real, Identifier>* ra) {
255 CODI_UNUSED(t);
256
258 initRun(ra);
259 });
260
262
263 primalFunc(inputValues.data(), inputIndices.size(), outputValues.data(), outputIndices.size(), &userData);
264
266
268 finalizeRun(ra);
269 });
270
272 }
273
274 static void evalRevFuncStatic(Tape* t, void* d, VectorAccessInterface<Real, Identifier>* ra) {
275 EvalData* data = (EvalData*)d;
276
277 if (nullptr != data->reverseFunc) {
278 data->evalRevFunc(t, ra);
279 } else {
280 CODI_EXCEPTION(
281 "Calling reverse evaluation in external function helper without a reverse function pointer.");
282 }
283 }
284
286 CODI_UNUSED(t);
287
289 x_b.resize(inputIndices.size());
290 y_b.resize(outputIndices.size());
291
292 initRun(ra, true);
293 });
294
296
297 for (size_t dim = 0; dim < ra->getVectorSize(); ++dim) {
299 for (size_t i = 0; i < outputIndices.size(); ++i) {
300 y_b[i] = ra->getAdjoint(outputIndices[i], dim);
301 ra->resetAdjoint(outputIndices[i], dim);
302 }
303 });
304
306
307 reverseFunc(inputValues.data(), x_b.data(), inputIndices.size(), outputValues.data(), y_b.data(),
308 outputIndices.size(), &userData);
309
311
313 for (size_t i = 0; i < inputIndices.size(); ++i) {
314 ra->updateAdjoint(inputIndices[i], dim, x_b[i]);
315 }
316 });
317
319 }
320
322 finalizeRun(ra, true);
323
324 x_b.resize(0);
325 y_b.resize(0);
326 });
327
329 }
330
331 private:
332
333 CODI_INLINE void initRun(VectorAccessInterface<Real, Identifier>* ra, bool isReverse = false) {
334 if (getPrimalsFromPrimalValueVector && provideOutputValues) {
336 outputValues.resize(outputIndices.size());
337 }
338
339 if (isReverse) { // Provide result values for reverse evaluations.
340 for (size_t i = 0; i < outputIndices.size(); ++i) {
341 outputValues[i] = ra->getPrimal(outputIndices[i]);
342 }
343 }
344 }
345
346 // Restore the old primals for reverse evaluations, before the inputs are read.
347 if (isReverse && Tape::RequiresPrimalRestore) {
348 for (size_t i = 0; i < outputIndices.size(); ++i) {
349 ra->setPrimal(outputIndices[i], oldPrimals[i]);
350 }
351 }
352
353 if (getPrimalsFromPrimalValueVector && provideInputValues) {
355 inputValues.resize(inputIndices.size());
356 }
357
358 for (size_t i = 0; i < inputIndices.size(); ++i) {
359 inputValues[i] = ra->getPrimal(inputIndices[i]);
360 }
361 }
362 }
363
364 CODI_INLINE void finalizeRun(VectorAccessInterface<Real, Identifier>* ra, bool isReverse = false) {
365 if (getPrimalsFromPrimalValueVector && !isReverse) {
366 for (size_t i = 0; i < outputIndices.size(); ++i) {
367 if (Tape::RequiresPrimalRestore) {
368 oldPrimals[i] = ra->getPrimal(outputIndices[i]);
369 }
370 ra->setPrimal(outputIndices[i], outputValues[i]);
371 }
372 }
373
375 if (getPrimalsFromPrimalValueVector && provideInputValues) {
376 inputValues.clear();
377 inputValues.shrink_to_fit();
378 }
379 if (getPrimalsFromPrimalValueVector && provideOutputValues) {
380 outputValues.clear();
381 outputValues.shrink_to_fit();
382 }
383 }
384 }
385 };
386
387 protected:
388
389 std::vector<Type*> outputValues;
390
398
399 EvalData* data;
400
401 std::vector<Real> y;
402
403 public:
404
406 ExternalFunctionHelper(bool primalFuncUsesADType = false)
407 : outputValues(),
408 storeInputPrimals(true),
409 storeOutputPrimals(true),
410 storeInputOutputForPrimalEval(!primalFuncUsesADType),
412 getPrimalValuesFromPrimalValueVector(IsPrimalValueTape),
413 data(nullptr),
414 y(0) {
416 }
417
420 delete data;
421 }
422
426 if (IsPrimalValueTape) {
427 storeInputPrimals = false;
428 storeOutputPrimals = false;
430 data->reallocatePrimalVectors = true;
431 }
432 }
433
436 if (IsPrimalValueTape) {
438 data->getPrimalsFromPrimalValueVector = false;
439 }
440 }
441
444 storeInputPrimals = false;
445 data->provideInputValues = false;
446 }
447
450 storeOutputPrimals = false;
451 data->provideOutputValues = false;
452 }
453
455 CODI_INLINE void addInput(Type const& input) {
456 if (Type::getTape().isActive()) {
457 Identifier identifier = input.getIdentifier();
458 if (!Type::getTape().isIdentifierActive(identifier)) {
459 // Register input values for primal value tapes when they are restored from the tape, otherwise the primal
460 // values can not be restored. For a lot of inactive inputs, this can inflate the number of identifiers
461 // quite a lot. This is especially true for reuse index tapes.
462 if (data->getPrimalsFromPrimalValueVector) {
463 Type temp = input;
464 Type::getTape().registerInput(temp);
465 identifier = temp.getIdentifier();
466 }
467 }
468
469 data->inputIndices.push_back(identifier);
470 }
471
472 // Ignore the setting at this place and the active check,
473 // we might need the values for the evaluation.
475 data->inputValues.push_back(input.getValue());
476 }
477 }
478
479 private:
480
481 CODI_INLINE void addOutputToData(Type& output) {
482 Real oldPrimal = Type::getTape().registerExternalFunctionOutput(output);
483
484 data->outputIndices.push_back(output.getIdentifier());
485 if (storeOutputPrimals) {
486 data->outputValues.push_back(output.getValue());
487 }
488 if (Tape::RequiresPrimalRestore) {
489 data->oldPrimals.push_back(oldPrimal);
490 }
491 }
492
493 public:
494
496 CODI_INLINE void addOutput(Type& output) {
497 if (Type::getTape().isActive() || storeInputOutputForPrimalEval) {
498 outputValues.push_back(&output);
499 }
500 }
501
503 template<typename Data>
504 CODI_INLINE void addUserData(Data const& data) {
505 this->data->userData.addData(data);
506 }
507
511 return this->data->userData;
512 }
513
516 template<typename FuncObj, typename... Args>
517 CODI_INLINE void callPrimalFuncWithADType(FuncObj& func, Args&&... args) {
518 bool isTapeActive = Type::getTape().isActive();
519
520 if (isTapeActive) {
521 Type::getTape().setPassive();
522 }
523
524 func(std::forward<Args>(args)...);
525
527
528 if (isTapeActive) {
529 Type::getTape().setActive();
530
532 for (size_t i = 0; i < outputValues.size(); ++i) {
533 addOutputToData(*outputValues[i]);
534 }
535 });
536 }
537
539 }
540
546 // Store the primal function in the external function data so that it can be used for primal evaluations of
547 // the tape.
548 data->primalFunc = func;
549
550 y.resize(outputValues.size());
551 });
552
554
555 func(data->inputValues.data(), data->inputValues.size(), y.data(), outputValues.size(), &data->userData);
556
558
560 // Set the primal values on the output values and add them to the data for the reverse evaluation.
561 for (size_t i = 0; i < outputValues.size(); ++i) {
562 outputValues[i]->setValue(y[i]);
563
564 if (Type::getTape().isActive()) {
565 addOutputToData(*outputValues[i]);
566 }
567 }
568
569 y.resize(0);
570 });
571
573 } else {
574 CODI_EXCEPTION(
575 "callPrimalFunc() not available if external function helper is initialized with passive function mode "
576 "enabled. Use callPrimalFuncWithADType() instead.");
577 }
578 }
579
581 CODI_INLINE void addToTape(ReverseFunc reverseFunc, ForwardFunc forwardFunc = nullptr,
582 PrimalFunc primalFunc = nullptr) {
583 if (Type::getTape().isActive()) {
584 // Collect shared data in a serial manner.
586 data->reverseFunc = reverseFunc;
587 data->forwardFunc = forwardFunc;
588
589 if (nullptr != primalFunc) {
590 // Only overwrite the primal function if the user provides one, otherwise it is set in the callPrimalFunc
591 // method.
592 data->primalFunc = primalFunc;
593 }
594
595 // Clear the primal values if they are not required.
596 if (!storeInputPrimals) {
597 data->inputValues.clear();
598 data->inputValues.shrink_to_fit();
599 }
600 });
601
602 // Only push once everything is prepared.
604
605 // Push the delete handle on at most one thread's tape.
606 typename ExternalFunction<Tape>::DeleteFunction delFunc = nullptr;
608 delFunc = EvalData::delFunc;
609 });
610
611 Type::getTape().pushExternalFunction(ExternalFunction<Tape>::create(
612 EvalData::evalRevFuncStatic, data, delFunc, EvalData::evalForwFuncStatic, EvalData::evalPrimFuncStatic));
613
614 // Only begin the cleanup once all pushes are finished.
616
617 // Clear the assembled data in a serial manner.
619 data = nullptr;
620 });
621 } else {
622 // Clear the assembled data in a serial manner.
624 delete data;
625 });
626 }
627
628 // Create a new data object for the next call in a serial manner.
631 outputValues.clear();
632 });
633
634 // Return only after the preparations for the next call are done.
636 }
637 };
638}
#define CODI_INLINE
See codi::Config::ForcedInlines.
Definition config.h:469
#define CODI_DD(Type, Default)
Abbreviation for CODI_DECLARE_DEFAULT.
Definition macros.hpp:96
CoDiPack - Code Differentiation Package.
Definition codi.hpp:94
inlinevoid CODI_UNUSED(Args const &...)
Disable unused warnings for an arbitrary number of arguments.
Definition macros.hpp:54
Default implementation of SynchronizationInterface for serial applications.
Definition synchronizationInterface.hpp:62
static inlinevoid serialize(FunctionObject const &func)
Ensures that only one among the calling threads calls the given function object.
Definition synchronizationInterface.hpp:67
static inlinevoid synchronize()
Does not return until called by all threads.
Definition synchronizationInterface.hpp:73
Default implementation of ThreadInformationInterface for serial applications.
Definition threadInformationInterface.hpp:63
bool storeInputPrimals
Definition externalFunctionHelper.hpp:391
void(*)(Real const *x, Real const *x_d, size_t m, Real *y, Real *y_d, size_t n, ExternalFunctionUserData *d) ForwardFunc
Definition externalFunctionHelper.hpp:124
DefaultSynchronization Synchronization
Definition externalFunctionHelper.hpp:109
inlinevoid addInput(Type const &input)
Add an input value.
Definition externalFunctionHelper.hpp:455
typename Type::Identifier Identifier
Definition externalFunctionHelper.hpp:115
void disableInputPrimalStore()
Do not store primal input values. In function calls, pointers to primal inputs will be null.
Definition externalFunctionHelper.hpp:443
void enableReallocationOfPrimalValueVectors()
Definition externalFunctionHelper.hpp:425
void disableRenewOfPrimalValues()
Do not update the inputs and outputs from the primal values of the tape. Has no effect on Jacobian ta...
Definition externalFunctionHelper.hpp:435
ExternalFunctionHelper(bool primalFuncUsesADType=false)
Constructor.
Definition externalFunctionHelper.hpp:406
~ExternalFunctionHelper()
Destructor.
Definition externalFunctionHelper.hpp:419
bool storeOutputPrimals
Definition externalFunctionHelper.hpp:392
inlinevoid addUserData(Data const &data)
Add user data. See ExternalFunctionUserData for details.
Definition externalFunctionHelper.hpp:504
typename Type::Real Real
Definition externalFunctionHelper.hpp:114
ExternalFunctionUserData & getExternalFunctionUserData()
Definition externalFunctionHelper.hpp:510
typename Type::Tape Tape
Definition externalFunctionHelper.hpp:117
void disableOutputPrimalStore()
Do not store primal output values. In function calls, pointers to primal outputs will be null.
Definition externalFunctionHelper.hpp:449
void(*)(Real const *x, Real *x_b, size_t m, Real const *y, Real const *y_b, size_t n, ExternalFunctionUserData *d) ReverseFunc
Definition externalFunctionHelper.hpp:120
bool storeInputOutputForPrimalEval
Definition externalFunctionHelper.hpp:393
EvalData * data
Definition externalFunctionHelper.hpp:399
void(*)(Real const *x, size_t m, Real *y, size_t n, ExternalFunctionUserData *d) PrimalFunc
Definition externalFunctionHelper.hpp:128
bool reallocatePrimalVectors
Definition externalFunctionHelper.hpp:394
inlinevoid callPrimalFunc(PrimalFunc func)
Definition externalFunctionHelper.hpp:543
DefaultThreadInformation ThreadInformation
Definition externalFunctionHelper.hpp:112
bool getPrimalValuesFromPrimalValueVector
Definition externalFunctionHelper.hpp:396
inlinevoid callPrimalFuncWithADType(FuncObj &func, Args &&... args)
Definition externalFunctionHelper.hpp:517
std::vector< Real > y
Definition externalFunctionHelper.hpp:401
Type Type
Definition externalFunctionHelper.hpp:106
inlinevoid addOutput(Type &output)
Add an output value.
Definition externalFunctionHelper.hpp:496
std::vector< Type * > outputValues
Definition externalFunctionHelper.hpp:389
inlinevoid addToTape(ReverseFunc reverseFunc, ForwardFunc forwardFunc=nullptr, PrimalFunc primalFunc=nullptr)
Add the external function to the tape.
Definition externalFunctionHelper.hpp:581
Ease of access structure for user-provided data on the tape for external functions....
Definition externalFunctionUserData.hpp:59
void(* DeleteFunction)(Tape *tape, void *data)
Delete function definition.
Definition externalFunction.hpp:113
static ExternalFunction create(CallFunction funcReverse, void *data, DeleteFunction funcDelete, CallFunction funcForward=nullptr, CallFunction funcPrimal=nullptr)
Helper function for the creation of an ExternalFunction object.
Definition externalFunction.hpp:124
If the tape inherits from PrimalValueBaseTape.
Definition tapeTraits.hpp:95
Unified access to the adjoint vector and primal vector in a tape evaluation.
Definition vectorAccessInterface.hpp:94
virtual void resetAdjoint(Identifier const &index, size_t dim)=0
Set the adjoint component to zero.
virtual void setPrimal(Identifier const &index, Real const &primal)=0
Set the primal value.
virtual size_t getVectorSize() const =0
Vector size in the current tape evaluation.
virtual Real getPrimal(Identifier const &index)=0
Get the primal value.
virtual void updateAdjoint(Identifier const &index, size_t dim, Real const &adjoint)=0
Update the adjoint component.
virtual Real getAdjoint(Identifier const &index, size_t dim)=0
Get the adjoint component.