CoDiPack  2.2.0
A Code Differentiation Package
SciComp TU Kaiserslautern
Loading...
Searching...
No Matches
externalFunctionHelper.hpp
1/*
2 * CoDiPack, a Code Differentiation Package
3 *
4 * Copyright (C) 2015-2024 Chair for Scientific Computing (SciComp), University of Kaiserslautern-Landau
5 * Homepage: http://scicomp.rptu.de
6 * Contact: Prof. Nicolas R. Gauger (codi@scicomp.uni-kl.de)
7 *
8 * Lead developers: Max Sagebaum, Johannes Blühdorn (SciComp, University of Kaiserslautern-Landau)
9 *
10 * This file is part of CoDiPack (http://scicomp.rptu.de/software/codi).
11 *
12 * CoDiPack is free software: you can redistribute it and/or
13 * modify it under the terms of the GNU General Public License
14 * as published by the Free Software Foundation, either version 3 of the
15 * License, or (at your option) any later version.
16 *
17 * CoDiPack is distributed in the hope that it will be useful,
18 * but WITHOUT ANY WARRANTY; without even the implied warranty
19 * of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
20 *
21 * See the GNU General Public License for more details.
22 * You should have received a copy of the GNU
23 * General Public License along with CoDiPack.
24 * If not, see <http://www.gnu.org/licenses/>.
25 *
26 * For other licensing options please contact us.
27 *
28 * Authors:
29 * - SciComp, University of Kaiserslautern-Landau:
30 * - Max Sagebaum
31 * - Johannes Blühdorn
32 * - Former members:
33 * - Tim Albring
34 */
35#pragma once
36
37#include <vector>
38
39#include "../../config.h"
40#include "../../expressions/lhsExpressionInterface.hpp"
41#include "../../misc/macros.hpp"
42#include "../../tapes/interfaces/fullTapeInterface.hpp"
43#include "../../tapes/misc/vectorAccessInterface.hpp"
44#include "../../traits/tapeTraits.hpp"
45#include "../data/externalFunctionUserData.hpp"
46#include "../parallel/synchronizationInterface.hpp"
47#include "../parallel/threadInformationInterface.hpp"
48
50namespace codi {
51
100 template<typename T_Type, typename T_Synchronization = DefaultSynchronization,
101 typename T_ThreadInformation = DefaultThreadInformation>
103 public:
104
106 using Type = CODI_DD(T_Type, CODI_DEFAULT_LHS_EXPRESSION);
107
109 using Synchronization = CODI_DD(T_Synchronization, DefaultSynchronization);
110
113
114 using Real = typename Type::Real;
115 using Identifier = typename Type::Identifier;
116
117 using Tape = CODI_DD(typename Type::Tape, CODI_DEFAULT_TAPE);
118
120 using ReverseFunc = void (*)(Real const* x, Real* x_b, size_t m, Real const* y, Real const* y_b, size_t n,
122
124 using ForwardFunc = void (*)(Real const* x, Real const* x_d, size_t m, Real* y, Real* y_d, size_t n,
126
128 using PrimalFunc = void (*)(Real const* x, size_t m, Real* y, size_t n, ExternalFunctionUserData* d);
129
130 private:
131
132 static constexpr bool IsPrimalValueTape = TapeTraits::IsPrimalValueTape<Tape>::value;
133
134 struct EvalData {
135 public:
136
137 std::vector<Identifier> inputIndices;
138 std::vector<Identifier> outputIndices;
139
140 std::vector<Real> inputValues;
141 std::vector<Real> outputValues;
142 std::vector<Real> oldPrimals;
143
144 std::vector<Real> x_d;
145 std::vector<Real> y_d;
146 std::vector<Real> x_b;
147 std::vector<Real> y_b;
148
149 ReverseFunc reverseFunc;
150 ForwardFunc forwardFunc;
151 PrimalFunc primalFunc;
152
154
155 bool provideInputValues;
156 bool provideOutputValues;
157 bool getPrimalsFromPrimalValueVector;
158 bool reallocatePrimalVectors;
159
160 EvalData(bool getPrimalsFromPrimalValueVector, bool reallocatePrimalVectors)
161 : inputIndices(0),
162 outputIndices(0),
163 inputValues(0),
164 outputValues(0),
165 oldPrimals(0),
166 x_d(0),
167 y_d(0),
168 x_b(0),
169 y_b(0),
170 reverseFunc(nullptr),
171 forwardFunc(nullptr),
172 primalFunc(nullptr),
173 provideInputValues(true),
174 provideOutputValues(true),
175 getPrimalsFromPrimalValueVector(getPrimalsFromPrimalValueVector),
176 reallocatePrimalVectors(reallocatePrimalVectors) {}
177
178 static void delFunc(Tape* t, void* d) {
179 CODI_UNUSED(t);
180
181 EvalData* data = (EvalData*)d;
182
183 delete data;
184 }
185
186 static void evalForwFuncStatic(Tape* t, void* d, VectorAccessInterface<Real, Identifier>* ra) {
187 EvalData* data = (EvalData*)d;
188
189 if (nullptr != data->forwardFunc) {
190 data->evalForwFunc(t, ra);
191 } else {
192 CODI_EXCEPTION(
193 "Calling forward evaluation in external function helper without a forward function pointer.");
194 }
195 }
196
197 CODI_INLINE void evalForwFunc(Tape* t, VectorAccessInterface<Real, Identifier>* ra) {
198 CODI_UNUSED(t);
199
200 Synchronization::serialize([&]() {
201 x_d.resize(inputIndices.size());
202 y_d.resize(outputIndices.size());
203
204 initRun(ra);
205 });
206
207 Synchronization::synchronize();
208
209 for (size_t dim = 0; dim < ra->getVectorSize(); ++dim) {
210 Synchronization::serialize([&]() {
211 for (size_t i = 0; i < inputIndices.size(); ++i) {
212 x_d[i] = ra->getAdjoint(inputIndices[i], dim);
213 }
214 });
215
216 Synchronization::synchronize();
217
218 forwardFunc(inputValues.data(), x_d.data(), inputIndices.size(), outputValues.data(), y_d.data(),
219 outputIndices.size(), &userData);
220
221 Synchronization::synchronize();
222
223 Synchronization::serialize([&]() {
224 for (size_t i = 0; i < outputIndices.size(); ++i) {
225 ra->resetAdjoint(outputIndices[i], dim);
226 ra->updateAdjoint(outputIndices[i], dim, y_d[i]);
227 }
228 });
229
230 Synchronization::synchronize();
231 }
232
233 Synchronization::serialize([&]() {
234 finalizeRun(ra);
235
236 x_d.resize(0);
237 y_d.resize(0);
238 });
239
240 Synchronization::synchronize();
241 }
242
243 static void evalPrimFuncStatic(Tape* t, void* d, VectorAccessInterface<Real, Identifier>* ra) {
244 EvalData* data = (EvalData*)d;
245
246 if (nullptr != data->primalFunc) {
247 data->evalPrimFunc(t, ra);
248 } else {
249 CODI_EXCEPTION(
250 "Calling primal evaluation in external function helper without a primal function pointer.");
251 }
252 }
253
254 CODI_INLINE void evalPrimFunc(Tape* t, VectorAccessInterface<Real, Identifier>* ra) {
255 CODI_UNUSED(t);
256
257 Synchronization::serialize([&]() { initRun(ra); });
258
259 Synchronization::synchronize();
260
261 primalFunc(inputValues.data(), inputIndices.size(), outputValues.data(), outputIndices.size(), &userData);
262
263 Synchronization::synchronize();
264
265 Synchronization::serialize([&]() { finalizeRun(ra); });
266
267 Synchronization::synchronize();
268 }
269
270 static void evalRevFuncStatic(Tape* t, void* d, VectorAccessInterface<Real, Identifier>* ra) {
271 EvalData* data = (EvalData*)d;
272
273 if (nullptr != data->reverseFunc) {
274 data->evalRevFunc(t, ra);
275 } else {
276 CODI_EXCEPTION(
277 "Calling reverse evaluation in external function helper without a reverse function pointer.");
278 }
279 }
280
282 CODI_UNUSED(t);
283
284 Synchronization::serialize([&]() {
285 x_b.resize(inputIndices.size());
286 y_b.resize(outputIndices.size());
287
288 initRun(ra, true);
289 });
290
291 Synchronization::synchronize();
292
293 for (size_t dim = 0; dim < ra->getVectorSize(); ++dim) {
294 Synchronization::serialize([&]() {
295 for (size_t i = 0; i < outputIndices.size(); ++i) {
296 y_b[i] = ra->getAdjoint(outputIndices[i], dim);
297 ra->resetAdjoint(outputIndices[i], dim);
298 }
299 });
300
301 Synchronization::synchronize();
302
303 reverseFunc(inputValues.data(), x_b.data(), inputIndices.size(), outputValues.data(), y_b.data(),
304 outputIndices.size(), &userData);
305
306 Synchronization::synchronize();
307
308 Synchronization::serialize([&]() {
309 for (size_t i = 0; i < inputIndices.size(); ++i) {
310 ra->updateAdjoint(inputIndices[i], dim, x_b[i]);
311 }
312 });
313
314 Synchronization::synchronize();
315 }
316
317 Synchronization::serialize([&]() {
318 finalizeRun(ra, true);
319
320 x_b.resize(0);
321 y_b.resize(0);
322 });
323
324 Synchronization::synchronize();
325 }
326
327 private:
328
329 CODI_INLINE void initRun(VectorAccessInterface<Real, Identifier>* ra, bool isReverse = false) {
330 if (getPrimalsFromPrimalValueVector && provideOutputValues) {
331 if (reallocatePrimalVectors) {
332 outputValues.resize(outputIndices.size());
333 }
334
335 if (isReverse) { // Provide result values for reverse evaluations.
336 for (size_t i = 0; i < outputIndices.size(); ++i) {
337 outputValues[i] = ra->getPrimal(outputIndices[i]);
338 }
339 }
340 }
341
342 // Restore the old primals for reverse evaluations, before the inputs are read.
343 if (isReverse && Tape::RequiresPrimalRestore) {
344 for (size_t i = 0; i < outputIndices.size(); ++i) {
345 ra->setPrimal(outputIndices[i], oldPrimals[i]);
346 }
347 }
348
349 if (getPrimalsFromPrimalValueVector && provideInputValues) {
350 if (reallocatePrimalVectors) {
351 inputValues.resize(inputIndices.size());
352 }
353
354 for (size_t i = 0; i < inputIndices.size(); ++i) {
355 inputValues[i] = ra->getPrimal(inputIndices[i]);
356 }
357 }
358 }
359
360 CODI_INLINE void finalizeRun(VectorAccessInterface<Real, Identifier>* ra, bool isReverse = false) {
361 if (getPrimalsFromPrimalValueVector && !isReverse) {
362 for (size_t i = 0; i < outputIndices.size(); ++i) {
363 if (Tape::RequiresPrimalRestore) {
364 oldPrimals[i] = ra->getPrimal(outputIndices[i]);
365 }
366 ra->setPrimal(outputIndices[i], outputValues[i]);
367 }
368 }
369
370 if (reallocatePrimalVectors) {
371 if (getPrimalsFromPrimalValueVector && provideInputValues) {
372 inputValues.clear();
373 inputValues.shrink_to_fit();
374 }
375 if (getPrimalsFromPrimalValueVector && provideOutputValues) {
376 outputValues.clear();
377 outputValues.shrink_to_fit();
378 }
379 }
380 }
381 };
382
383 protected:
384
385 std::vector<Type*> outputValues;
386
394
395 EvalData* data;
396
397 std::vector<Real> y;
398
399 public:
400
402 ExternalFunctionHelper(bool primalFuncUsesADType = false)
403 : outputValues(),
404 storeInputPrimals(true),
405 storeOutputPrimals(true),
406 storeInputOutputForPrimalEval(!primalFuncUsesADType),
408 getPrimalValuesFromPrimalValueVector(IsPrimalValueTape),
409 data(nullptr),
410 y(0) {
412 }
413
416 delete data;
417 }
418
422 if (IsPrimalValueTape) {
423 storeInputPrimals = false;
424 storeOutputPrimals = false;
426 data->reallocatePrimalVectors = true;
427 }
428 }
429
432 if (IsPrimalValueTape) {
434 data->getPrimalsFromPrimalValueVector = false;
435 }
436 }
437
440 storeInputPrimals = false;
441 data->provideInputValues = false;
442 }
443
446 storeOutputPrimals = false;
447 data->provideOutputValues = false;
448 }
449
451 CODI_INLINE void addInput(Type const& input) {
452 if (Type::getTape().isActive()) {
453 Identifier identifier = input.getIdentifier();
454 if (!Type::getTape().isIdentifierActive(identifier)) {
455 // Register input values for primal value tapes when they are restored from the tape, otherwise the primal
456 // values can not be restored. For a lot of inactive inputs, this can inflate the number of identifiers
457 // quite a lot. This is especially true for reuse index tapes.
458 if (data->getPrimalsFromPrimalValueVector) {
459 Type temp = input;
460 Type::getTape().registerInput(temp);
461 identifier = temp.getIdentifier();
462 }
463 }
464
465 data->inputIndices.push_back(identifier);
466 }
467
468 // Ignore the setting at this place and the active check,
469 // we might need the values for the evaluation.
471 data->inputValues.push_back(input.getValue());
472 }
473 }
474
475 private:
476
477 CODI_INLINE void addOutputToData(Type& output) {
478 Real oldPrimal = Type::getTape().registerExternalFunctionOutput(output);
479
480 data->outputIndices.push_back(output.getIdentifier());
481 if (storeOutputPrimals) {
482 data->outputValues.push_back(output.getValue());
483 }
484 if (Tape::RequiresPrimalRestore) {
485 data->oldPrimals.push_back(oldPrimal);
486 }
487 }
488
489 public:
490
492 CODI_INLINE void addOutput(Type& output) {
493 if (Type::getTape().isActive() || storeInputOutputForPrimalEval) {
494 outputValues.push_back(&output);
495 }
496 }
497
499 template<typename Data>
500 CODI_INLINE void addUserData(Data const& data) {
501 this->data->userData.addData(data);
502 }
503
507 return this->data->userData;
508 }
509
512 template<typename FuncObj, typename... Args>
513 CODI_INLINE void callPrimalFuncWithADType(FuncObj& func, Args&&... args) {
514 bool isTapeActive = Type::getTape().isActive();
515
516 if (isTapeActive) {
517 Type::getTape().setPassive();
518 }
519
520 func(std::forward<Args>(args)...);
521
522 Synchronization::synchronize();
523
524 if (isTapeActive) {
525 Type::getTape().setActive();
526
527 Synchronization::serialize([&]() {
528 for (size_t i = 0; i < outputValues.size(); ++i) {
529 addOutputToData(*outputValues[i]);
530 }
531 });
532 }
533
534 Synchronization::synchronize();
535 }
536
541 Synchronization::serialize([&]() {
542 // Store the primal function in the external function data so that it can be used for primal evaluations of
543 // the tape.
544 data->primalFunc = func;
545
546 y.resize(outputValues.size());
547 });
548
549 Synchronization::synchronize();
550
551 func(data->inputValues.data(), data->inputValues.size(), y.data(), outputValues.size(), &data->userData);
552
553 Synchronization::synchronize();
554
555 Synchronization::serialize([&]() {
556 // Set the primal values on the output values and add them to the data for the reverse evaluation.
557 for (size_t i = 0; i < outputValues.size(); ++i) {
558 outputValues[i]->setValue(y[i]);
559
560 if (Type::getTape().isActive()) {
561 addOutputToData(*outputValues[i]);
562 }
563 }
564
565 y.resize(0);
566 });
567
568 Synchronization::synchronize();
569 } else {
570 CODI_EXCEPTION(
571 "callPrimalFunc() not available if external function helper is initialized with passive function mode "
572 "enabled. Use callPrimalFuncWithADType() instead.");
573 }
574 }
575
577 CODI_INLINE void addToTape(ReverseFunc reverseFunc, ForwardFunc forwardFunc = nullptr,
578 PrimalFunc primalFunc = nullptr) {
579 if (Type::getTape().isActive()) {
580 // Collect shared data in a serial manner.
581 Synchronization::serialize([&]() {
582 data->reverseFunc = reverseFunc;
583 data->forwardFunc = forwardFunc;
584
585 if (nullptr != primalFunc) {
586 // Only overwrite the primal function if the user provides one, otherwise it is set in the callPrimalFunc
587 // method.
588 data->primalFunc = primalFunc;
589 }
590
591 // Clear the primal values if they are not required.
592 if (!storeInputPrimals) {
593 data->inputValues.clear();
594 data->inputValues.shrink_to_fit();
595 }
596 });
597
598 // Only push once everything is prepared.
599 Synchronization::synchronize();
600
601 // Push the delete handle on at most one thread's tape.
603 0 == ThreadInformation::getThreadId() ? EvalData::delFunc : nullptr;
604 Type::getTape().pushExternalFunction(ExternalFunction<Tape>::create(
605 EvalData::evalRevFuncStatic, data, delFunc, EvalData::evalForwFuncStatic, EvalData::evalPrimFuncStatic));
606
607 // Only begin the cleanup once all pushes are finished.
608 Synchronization::synchronize();
609
610 // Clear the assembled data in a serial manner.
611 Synchronization::serialize([&]() { data = nullptr; });
612 } else {
613 // Clear the assembled data in a serial manner.
614 Synchronization::serialize([&]() { delete data; });
615 }
616
617 // Create a new data object for the next call in a serial manner.
618 Synchronization::serialize([&]() {
620 outputValues.clear();
621 });
622
623 // Return only after the preparations for the next call are done.
624 Synchronization::synchronize();
625 }
626 };
627}
#define CODI_INLINE
See codi::Config::ForcedInlines.
Definition config.h:457
#define CODI_DD(Type, Default)
Abbreviation for CODI_DECLARE_DEFAULT.
Definition macros.hpp:94
CoDiPack - Code Differentiation Package.
Definition codi.hpp:90
void CODI_UNUSED(Args const &...)
Disable unused warnings for an arbitrary number of arguments.
Definition macros.hpp:46
Represents a concrete lvalue in the CoDiPack expression tree.
Definition activeType.hpp:52
Default implementation of SynchronizationInterface for serial applications.
Definition synchronizationInterface.hpp:62
Default implementation of ThreadInformationInterface for serial applications.
Definition threadInformationInterface.hpp:63
Helper class for the implementation of an external function in CoDiPack.
Definition externalFunctionHelper.hpp:102
bool storeInputPrimals
If input primals are stored. Can be disabled by the user.
Definition externalFunctionHelper.hpp:387
void(*)(Real const *x, Real const *x_d, size_t m, Real *y, Real *y_d, size_t n, ExternalFunctionUserData *d) ForwardFunc
Function interface for the forward AD call of an external function.
Definition externalFunctionHelper.hpp:124
T_Synchronization Synchronization
See ExternalFunctionHelper.
Definition externalFunctionHelper.hpp:109
typename Type::Identifier Identifier
See LhsExpressionInterface.
Definition externalFunctionHelper.hpp:115
void disableInputPrimalStore()
Do not store primal input values. In function calls, pointers to primal inputs will be null.
Definition externalFunctionHelper.hpp:439
void enableReallocationOfPrimalValueVectors()
Definition externalFunctionHelper.hpp:421
void disableRenewOfPrimalValues()
Do not update the inputs and outputs from the primal values of the tape. Has no effect on Jacobian ta...
Definition externalFunctionHelper.hpp:431
ExternalFunctionHelper(bool primalFuncUsesADType=false)
Constructor.
Definition externalFunctionHelper.hpp:402
~ExternalFunctionHelper()
Destructor.
Definition externalFunctionHelper.hpp:415
void addInput(Type const &input)
Add an input value.
Definition externalFunctionHelper.hpp:451
bool storeOutputPrimals
If output primals are stored. Can be disabled by the user.
Definition externalFunctionHelper.hpp:388
typename Type::Real Real
See LhsExpressionInterface.
Definition externalFunctionHelper.hpp:114
ExternalFunctionUserData & getExternalFunctionUserData()
Definition externalFunctionHelper.hpp:506
typename Type::Tape Tape
See LhsExpressionInterface.
Definition externalFunctionHelper.hpp:117
void disableOutputPrimalStore()
Do not store primal output values. In function calls, pointers to primal outputs will be null.
Definition externalFunctionHelper.hpp:445
void callPrimalFuncWithADType(FuncObj &func, Args &&... args)
Definition externalFunctionHelper.hpp:513
void(*)(Real const *x, Real *x_b, size_t m, Real const *y, Real const *y_b, size_t n, ExternalFunctionUserData *d) ReverseFunc
Function interface for the reverse AD call of an external function.
Definition externalFunctionHelper.hpp:120
void addUserData(Data const &data)
Add user data. See ExternalFunctionUserData for details.
Definition externalFunctionHelper.hpp:500
void addOutput(Type &output)
Add an output value.
Definition externalFunctionHelper.hpp:492
bool storeInputOutputForPrimalEval
If a primal call with a self-implemented function will be done.
Definition externalFunctionHelper.hpp:389
EvalData * data
External function data.
Definition externalFunctionHelper.hpp:395
void(*)(Real const *x, size_t m, Real *y, size_t n, ExternalFunctionUserData *d) PrimalFunc
Function interface for the primal call of an external function.
Definition externalFunctionHelper.hpp:128
bool reallocatePrimalVectors
Definition externalFunctionHelper.hpp:390
T_ThreadInformation ThreadInformation
See ExternalFunctionHelper.
Definition externalFunctionHelper.hpp:112
void callPrimalFunc(PrimalFunc func)
Definition externalFunctionHelper.hpp:539
bool getPrimalValuesFromPrimalValueVector
Definition externalFunctionHelper.hpp:392
std::vector< Real > y
Shared vector of output variables.
Definition externalFunctionHelper.hpp:397
T_Type Type
See ExternalFunctionHelper.
Definition externalFunctionHelper.hpp:106
void addToTape(ReverseFunc reverseFunc, ForwardFunc forwardFunc=nullptr, PrimalFunc primalFunc=nullptr)
Add the external function to the tape.
Definition externalFunctionHelper.hpp:577
std::vector< Type * > outputValues
References to output values.
Definition externalFunctionHelper.hpp:385
Ease of access structure for user-provided data on the tape for external functions....
Definition externalFunctionUserData.hpp:59
User-defined evaluation functions for the taping process.
Definition externalFunction.hpp:102
If the tape inherits from PrimalValueBaseTape.
Definition tapeTraits.hpp:92
Unified access to the adjoint vector and primal vector in a tape evaluation.
Definition vectorAccessInterface.hpp:91
virtual void resetAdjoint(Identifier const &index, size_t dim)=0
Set the adjoint component to zero.
virtual void setPrimal(Identifier const &index, Real const &primal)=0
Set the primal value.
virtual size_t getVectorSize() const =0
Vector size in the current tape evaluation.
virtual Real getPrimal(Identifier const &index)=0
Get the primal value.
virtual void updateAdjoint(Identifier const &index, size_t dim, Real const &adjoint)=0
Update the adjoint component.
virtual Real getAdjoint(Identifier const &index, size_t dim)=0
Get the adjoint component.