40#include "../../expressions/lhsExpressionInterface.hpp"
42#include "../../tapes/interfaces/fullTapeInterface.hpp"
43#include "../../tapes/misc/vectorAccessInterface.hpp"
44#include "../../traits/tapeTraits.hpp"
45#include "../data/externalFunctionUserData.hpp"
46#include "../parallel/synchronizationInterface.hpp"
47#include "../parallel/threadInformationInterface.hpp"
100 template<
typename T_Type,
typename T_Synchronization = DefaultSynchronization,
101 typename T_ThreadInformation = DefaultThreadInformation>
114 using Real =
typename Type::Real;
137 std::vector<Identifier> inputIndices;
138 std::vector<Identifier> outputIndices;
140 std::vector<Real> inputValues;
141 std::vector<Real> outputValues;
142 std::vector<Real> oldPrimals;
144 std::vector<Real> x_d;
145 std::vector<Real> y_d;
146 std::vector<Real> x_b;
147 std::vector<Real> y_b;
155 bool provideInputValues;
156 bool provideOutputValues;
157 bool getPrimalsFromPrimalValueVector;
158 bool reallocatePrimalVectors;
160 EvalData(
bool getPrimalsFromPrimalValueVector,
bool reallocatePrimalVectors)
170 reverseFunc(
nullptr),
171 forwardFunc(
nullptr),
173 provideInputValues(
true),
174 provideOutputValues(
true),
175 getPrimalsFromPrimalValueVector(getPrimalsFromPrimalValueVector),
176 reallocatePrimalVectors(reallocatePrimalVectors) {}
178 static void delFunc(
Tape* t,
void* d) {
181 EvalData*
data = (EvalData*)d;
187 EvalData*
data = (EvalData*)d;
189 if (
nullptr !=
data->forwardFunc) {
190 data->evalForwFunc(t, ra);
193 "Calling forward evaluation in external function helper without a forward function pointer.");
200 Synchronization::serialize([&]() {
201 x_d.resize(inputIndices.size());
202 y_d.resize(outputIndices.size());
207 Synchronization::synchronize();
210 Synchronization::serialize([&]() {
211 for (
size_t i = 0; i < inputIndices.size(); ++i) {
212 x_d[i] = ra->
getAdjoint(inputIndices[i], dim);
216 Synchronization::synchronize();
218 forwardFunc(inputValues.data(), x_d.data(), inputIndices.size(), outputValues.data(), y_d.data(),
219 outputIndices.size(), &userData);
221 Synchronization::synchronize();
223 Synchronization::serialize([&]() {
224 for (
size_t i = 0; i < outputIndices.size(); ++i) {
230 Synchronization::synchronize();
233 Synchronization::serialize([&]() {
240 Synchronization::synchronize();
244 EvalData*
data = (EvalData*)d;
246 if (
nullptr !=
data->primalFunc) {
247 data->evalPrimFunc(t, ra);
250 "Calling primal evaluation in external function helper without a primal function pointer.");
257 Synchronization::serialize([&]() { initRun(ra); });
259 Synchronization::synchronize();
261 primalFunc(inputValues.data(), inputIndices.size(), outputValues.data(), outputIndices.size(), &userData);
263 Synchronization::synchronize();
265 Synchronization::serialize([&]() { finalizeRun(ra); });
267 Synchronization::synchronize();
271 EvalData*
data = (EvalData*)d;
273 if (
nullptr !=
data->reverseFunc) {
274 data->evalRevFunc(t, ra);
277 "Calling reverse evaluation in external function helper without a reverse function pointer.");
284 Synchronization::serialize([&]() {
285 x_b.resize(inputIndices.size());
286 y_b.resize(outputIndices.size());
291 Synchronization::synchronize();
294 Synchronization::serialize([&]() {
295 for (
size_t i = 0; i < outputIndices.size(); ++i) {
296 y_b[i] = ra->
getAdjoint(outputIndices[i], dim);
301 Synchronization::synchronize();
303 reverseFunc(inputValues.data(), x_b.data(), inputIndices.size(), outputValues.data(), y_b.data(),
304 outputIndices.size(), &userData);
306 Synchronization::synchronize();
308 Synchronization::serialize([&]() {
309 for (
size_t i = 0; i < inputIndices.size(); ++i) {
314 Synchronization::synchronize();
317 Synchronization::serialize([&]() {
318 finalizeRun(ra,
true);
324 Synchronization::synchronize();
330 if (getPrimalsFromPrimalValueVector && provideOutputValues) {
331 if (reallocatePrimalVectors) {
332 outputValues.resize(outputIndices.size());
336 for (
size_t i = 0; i < outputIndices.size(); ++i) {
337 outputValues[i] = ra->
getPrimal(outputIndices[i]);
343 if (isReverse && Tape::RequiresPrimalRestore) {
344 for (
size_t i = 0; i < outputIndices.size(); ++i) {
345 ra->
setPrimal(outputIndices[i], oldPrimals[i]);
349 if (getPrimalsFromPrimalValueVector && provideInputValues) {
350 if (reallocatePrimalVectors) {
351 inputValues.resize(inputIndices.size());
354 for (
size_t i = 0; i < inputIndices.size(); ++i) {
355 inputValues[i] = ra->
getPrimal(inputIndices[i]);
361 if (getPrimalsFromPrimalValueVector && !isReverse) {
362 for (
size_t i = 0; i < outputIndices.size(); ++i) {
363 if (Tape::RequiresPrimalRestore) {
364 oldPrimals[i] = ra->
getPrimal(outputIndices[i]);
366 ra->
setPrimal(outputIndices[i], outputValues[i]);
370 if (reallocatePrimalVectors) {
371 if (getPrimalsFromPrimalValueVector && provideInputValues) {
373 inputValues.shrink_to_fit();
375 if (getPrimalsFromPrimalValueVector && provideOutputValues) {
376 outputValues.clear();
377 outputValues.shrink_to_fit();
422 if (IsPrimalValueTape) {
426 data->reallocatePrimalVectors =
true;
432 if (IsPrimalValueTape) {
434 data->getPrimalsFromPrimalValueVector =
false;
441 data->provideInputValues =
false;
447 data->provideOutputValues =
false;
452 if (Type::getTape().isActive()) {
453 Identifier identifier = input.getIdentifier();
454 if (!Type::getTape().isIdentifierActive(identifier)) {
458 if (
data->getPrimalsFromPrimalValueVector) {
460 Type::getTape().registerInput(temp);
461 identifier = temp.getIdentifier();
465 data->inputIndices.push_back(identifier);
471 data->inputValues.push_back(input.getValue());
478 Real oldPrimal = Type::getTape().registerExternalFunctionOutput(output);
480 data->outputIndices.push_back(output.getIdentifier());
482 data->outputValues.push_back(output.getValue());
484 if (Tape::RequiresPrimalRestore) {
485 data->oldPrimals.push_back(oldPrimal);
499 template<
typename Data>
501 this->data->userData.addData(
data);
507 return this->data->userData;
512 template<
typename FuncObj,
typename... Args>
514 bool isTapeActive = Type::getTape().isActive();
517 Type::getTape().setPassive();
520 func(std::forward<Args>(args)...);
522 Synchronization::synchronize();
525 Type::getTape().setActive();
527 Synchronization::serialize([&]() {
534 Synchronization::synchronize();
541 Synchronization::serialize([&]() {
544 data->primalFunc = func;
549 Synchronization::synchronize();
553 Synchronization::synchronize();
555 Synchronization::serialize([&]() {
560 if (Type::getTape().isActive()) {
568 Synchronization::synchronize();
571 "callPrimalFunc() not available if external function helper is initialized with passive function mode "
572 "enabled. Use callPrimalFuncWithADType() instead.");
579 if (Type::getTape().isActive()) {
581 Synchronization::serialize([&]() {
582 data->reverseFunc = reverseFunc;
583 data->forwardFunc = forwardFunc;
585 if (
nullptr != primalFunc) {
588 data->primalFunc = primalFunc;
593 data->inputValues.clear();
594 data->inputValues.shrink_to_fit();
599 Synchronization::synchronize();
603 0 == ThreadInformation::getThreadId() ? EvalData::delFunc :
nullptr;
605 EvalData::evalRevFuncStatic,
data, delFunc, EvalData::evalForwFuncStatic, EvalData::evalPrimFuncStatic));
608 Synchronization::synchronize();
611 Synchronization::serialize([&]() {
data =
nullptr; });
614 Synchronization::serialize([&]() {
delete data; });
618 Synchronization::serialize([&]() {
624 Synchronization::synchronize();
#define CODI_INLINE
See codi::Config::ForcedInlines.
Definition config.h:457
#define CODI_DD(Type, Default)
Abbreviation for CODI_DECLARE_DEFAULT.
Definition macros.hpp:94
CoDiPack - Code Differentiation Package.
Definition codi.hpp:90
void CODI_UNUSED(Args const &...)
Disable unused warnings for an arbitrary number of arguments.
Definition macros.hpp:46
Represents a concrete lvalue in the CoDiPack expression tree.
Definition activeType.hpp:52
Default implementation of SynchronizationInterface for serial applications.
Definition synchronizationInterface.hpp:62
Helper class for the implementation of an external function in CoDiPack.
Definition externalFunctionHelper.hpp:102
bool storeInputPrimals
If input primals are stored. Can be disabled by the user.
Definition externalFunctionHelper.hpp:387
void(*)(Real const *x, Real const *x_d, size_t m, Real *y, Real *y_d, size_t n, ExternalFunctionUserData *d) ForwardFunc
Function interface for the forward AD call of an external function.
Definition externalFunctionHelper.hpp:124
T_Synchronization Synchronization
See ExternalFunctionHelper.
Definition externalFunctionHelper.hpp:109
typename Type::Identifier Identifier
See LhsExpressionInterface.
Definition externalFunctionHelper.hpp:115
void disableInputPrimalStore()
Do not store primal input values. In function calls, pointers to primal inputs will be null.
Definition externalFunctionHelper.hpp:439
void enableReallocationOfPrimalValueVectors()
Definition externalFunctionHelper.hpp:421
void disableRenewOfPrimalValues()
Do not update the inputs and outputs from the primal values of the tape. Has no effect on Jacobian ta...
Definition externalFunctionHelper.hpp:431
ExternalFunctionHelper(bool primalFuncUsesADType=false)
Constructor.
Definition externalFunctionHelper.hpp:402
~ExternalFunctionHelper()
Destructor.
Definition externalFunctionHelper.hpp:415
void addInput(Type const &input)
Add an input value.
Definition externalFunctionHelper.hpp:451
bool storeOutputPrimals
If output primals are stored. Can be disabled by the user.
Definition externalFunctionHelper.hpp:388
typename Type::Real Real
See LhsExpressionInterface.
Definition externalFunctionHelper.hpp:114
ExternalFunctionUserData & getExternalFunctionUserData()
Definition externalFunctionHelper.hpp:506
typename Type::Tape Tape
See LhsExpressionInterface.
Definition externalFunctionHelper.hpp:117
void disableOutputPrimalStore()
Do not store primal output values. In function calls, pointers to primal outputs will be null.
Definition externalFunctionHelper.hpp:445
void callPrimalFuncWithADType(FuncObj &func, Args &&... args)
Definition externalFunctionHelper.hpp:513
void(*)(Real const *x, Real *x_b, size_t m, Real const *y, Real const *y_b, size_t n, ExternalFunctionUserData *d) ReverseFunc
Function interface for the reverse AD call of an external function.
Definition externalFunctionHelper.hpp:120
void addUserData(Data const &data)
Add user data. See ExternalFunctionUserData for details.
Definition externalFunctionHelper.hpp:500
void addOutput(Type &output)
Add an output value.
Definition externalFunctionHelper.hpp:492
bool storeInputOutputForPrimalEval
If a primal call with a self-implemented function will be done.
Definition externalFunctionHelper.hpp:389
EvalData * data
External function data.
Definition externalFunctionHelper.hpp:395
void(*)(Real const *x, size_t m, Real *y, size_t n, ExternalFunctionUserData *d) PrimalFunc
Function interface for the primal call of an external function.
Definition externalFunctionHelper.hpp:128
bool reallocatePrimalVectors
Definition externalFunctionHelper.hpp:390
T_ThreadInformation ThreadInformation
See ExternalFunctionHelper.
Definition externalFunctionHelper.hpp:112
void callPrimalFunc(PrimalFunc func)
Definition externalFunctionHelper.hpp:539
bool getPrimalValuesFromPrimalValueVector
Definition externalFunctionHelper.hpp:392
std::vector< Real > y
Shared vector of output variables.
Definition externalFunctionHelper.hpp:397
T_Type Type
See ExternalFunctionHelper.
Definition externalFunctionHelper.hpp:106
void addToTape(ReverseFunc reverseFunc, ForwardFunc forwardFunc=nullptr, PrimalFunc primalFunc=nullptr)
Add the external function to the tape.
Definition externalFunctionHelper.hpp:577
std::vector< Type * > outputValues
References to output values.
Definition externalFunctionHelper.hpp:385
Ease of access structure for user-provided data on the tape for external functions....
Definition externalFunctionUserData.hpp:59
User-defined evaluation functions for the taping process.
Definition externalFunction.hpp:102
If the tape inherits from PrimalValueBaseTape.
Definition tapeTraits.hpp:92
Unified access to the adjoint vector and primal vector in a tape evaluation.
Definition vectorAccessInterface.hpp:91
virtual void resetAdjoint(Identifier const &index, size_t dim)=0
Set the adjoint component to zero.
virtual void setPrimal(Identifier const &index, Real const &primal)=0
Set the primal value.
virtual size_t getVectorSize() const =0
Vector size in the current tape evaluation.
virtual Real getPrimal(Identifier const &index)=0
Get the primal value.
virtual void updateAdjoint(Identifier const &index, size_t dim, Real const &adjoint)=0
Update the adjoint component.
virtual Real getAdjoint(Identifier const &index, size_t dim)=0
Get the adjoint component.