40#include "../../expressions/lhsExpressionInterface.hpp"
42#include "../../tapes/interfaces/fullTapeInterface.hpp"
43#include "../../tapes/misc/vectorAccessInterface.hpp"
44#include "../../traits/tapeTraits.hpp"
45#include "../data/externalFunctionUserData.hpp"
46#include "../parallel/synchronizationInterface.hpp"
47#include "../parallel/threadInformationInterface.hpp"
114 using Real =
typename Type::Real;
137 std::vector<Identifier> inputIndices;
138 std::vector<Identifier> outputIndices;
140 std::vector<Real> inputValues;
142 std::vector<Real> oldPrimals;
144 std::vector<Real> x_d;
145 std::vector<Real> y_d;
146 std::vector<Real> x_b;
147 std::vector<Real> y_b;
155 bool provideInputValues;
156 bool provideOutputValues;
157 bool getPrimalsFromPrimalValueVector;
170 reverseFunc(
nullptr),
171 forwardFunc(
nullptr),
173 provideInputValues(
true),
174 provideOutputValues(
true),
175 getPrimalsFromPrimalValueVector(getPrimalsFromPrimalValueVector),
178 static void delFunc(
Tape* t,
void* d) {
181 EvalData*
data = (EvalData*)d;
187 EvalData*
data = (EvalData*)d;
189 if (
nullptr !=
data->forwardFunc) {
190 data->evalForwFunc(t, ra);
193 "Calling forward evaluation in external function helper without a forward function pointer.");
201 x_d.resize(inputIndices.size());
202 y_d.resize(outputIndices.size());
211 for (
size_t i = 0; i < inputIndices.size(); ++i) {
212 x_d[i] = ra->
getAdjoint(inputIndices[i], dim);
218 forwardFunc(inputValues.data(), x_d.data(), inputIndices.size(),
outputValues.data(), y_d.data(),
219 outputIndices.size(), &userData);
224 for (
size_t i = 0; i < outputIndices.size(); ++i) {
244 EvalData*
data = (EvalData*)d;
246 if (
nullptr !=
data->primalFunc) {
247 data->evalPrimFunc(t, ra);
250 "Calling primal evaluation in external function helper without a primal function pointer.");
263 primalFunc(inputValues.data(), inputIndices.size(),
outputValues.data(), outputIndices.size(), &userData);
275 EvalData*
data = (EvalData*)d;
277 if (
nullptr !=
data->reverseFunc) {
278 data->evalRevFunc(t, ra);
281 "Calling reverse evaluation in external function helper without a reverse function pointer.");
289 x_b.resize(inputIndices.size());
290 y_b.resize(outputIndices.size());
299 for (
size_t i = 0; i < outputIndices.size(); ++i) {
300 y_b[i] = ra->
getAdjoint(outputIndices[i], dim);
307 reverseFunc(inputValues.data(), x_b.data(), inputIndices.size(),
outputValues.data(), y_b.data(),
308 outputIndices.size(), &userData);
313 for (
size_t i = 0; i < inputIndices.size(); ++i) {
322 finalizeRun(ra,
true);
334 if (getPrimalsFromPrimalValueVector && provideOutputValues) {
340 for (
size_t i = 0; i < outputIndices.size(); ++i) {
347 if (isReverse && Tape::RequiresPrimalRestore) {
348 for (
size_t i = 0; i < outputIndices.size(); ++i) {
349 ra->
setPrimal(outputIndices[i], oldPrimals[i]);
353 if (getPrimalsFromPrimalValueVector && provideInputValues) {
355 inputValues.resize(inputIndices.size());
358 for (
size_t i = 0; i < inputIndices.size(); ++i) {
359 inputValues[i] = ra->
getPrimal(inputIndices[i]);
365 if (getPrimalsFromPrimalValueVector && !isReverse) {
366 for (
size_t i = 0; i < outputIndices.size(); ++i) {
367 if (Tape::RequiresPrimalRestore) {
368 oldPrimals[i] = ra->
getPrimal(outputIndices[i]);
375 if (getPrimalsFromPrimalValueVector && provideInputValues) {
377 inputValues.shrink_to_fit();
379 if (getPrimalsFromPrimalValueVector && provideOutputValues) {
426 if (IsPrimalValueTape) {
430 data->reallocatePrimalVectors =
true;
436 if (IsPrimalValueTape) {
438 data->getPrimalsFromPrimalValueVector =
false;
445 data->provideInputValues =
false;
451 data->provideOutputValues =
false;
456 if (Type::getTape().isActive()) {
457 Identifier identifier = input.getIdentifier();
458 if (!Type::getTape().isIdentifierActive(identifier)) {
462 if (
data->getPrimalsFromPrimalValueVector) {
464 Type::getTape().registerInput(temp);
465 identifier = temp.getIdentifier();
469 data->inputIndices.push_back(identifier);
475 data->inputValues.push_back(input.getValue());
482 Real oldPrimal = Type::getTape().registerExternalFunctionOutput(output);
484 data->outputIndices.push_back(output.getIdentifier());
486 data->outputValues.push_back(output.getValue());
488 if (Tape::RequiresPrimalRestore) {
489 data->oldPrimals.push_back(oldPrimal);
503 template<
typename Data>
505 this->data->userData.addData(
data);
511 return this->data->userData;
516 template<
typename FuncObj,
typename... Args>
518 bool isTapeActive = Type::getTape().isActive();
521 Type::getTape().setPassive();
524 func(std::forward<Args>(args)...);
529 Type::getTape().setActive();
548 data->primalFunc = func;
564 if (Type::getTape().isActive()) {
575 "callPrimalFunc() not available if external function helper is initialized with passive function mode "
576 "enabled. Use callPrimalFuncWithADType() instead.");
583 if (Type::getTape().isActive()) {
586 data->reverseFunc = reverseFunc;
587 data->forwardFunc = forwardFunc;
589 if (
nullptr != primalFunc) {
592 data->primalFunc = primalFunc;
597 data->inputValues.clear();
598 data->inputValues.shrink_to_fit();
608 delFunc = EvalData::delFunc;
612 EvalData::evalRevFuncStatic,
data, delFunc, EvalData::evalForwFuncStatic, EvalData::evalPrimFuncStatic));
#define CODI_INLINE
See codi::Config::ForcedInlines.
Definition config.h:469
#define CODI_DD(Type, Default)
Abbreviation for CODI_DECLARE_DEFAULT.
Definition macros.hpp:96
CoDiPack - Code Differentiation Package.
Definition codi.hpp:94
inlinevoid CODI_UNUSED(Args const &...)
Disable unused warnings for an arbitrary number of arguments.
Definition macros.hpp:54
Default implementation of SynchronizationInterface for serial applications.
Definition synchronizationInterface.hpp:62
static inlinevoid serialize(FunctionObject const &func)
Ensures that only one among the calling threads calls the given function object.
Definition synchronizationInterface.hpp:67
static inlinevoid synchronize()
Does not return until called by all threads.
Definition synchronizationInterface.hpp:73
bool storeInputPrimals
Definition externalFunctionHelper.hpp:391
void(*)(Real const *x, Real const *x_d, size_t m, Real *y, Real *y_d, size_t n, ExternalFunctionUserData *d) ForwardFunc
Definition externalFunctionHelper.hpp:124
DefaultSynchronization Synchronization
Definition externalFunctionHelper.hpp:109
inlinevoid addInput(Type const &input)
Add an input value.
Definition externalFunctionHelper.hpp:455
typename Type::Identifier Identifier
Definition externalFunctionHelper.hpp:115
void disableInputPrimalStore()
Do not store primal input values. In function calls, pointers to primal inputs will be null.
Definition externalFunctionHelper.hpp:443
void enableReallocationOfPrimalValueVectors()
Definition externalFunctionHelper.hpp:425
void disableRenewOfPrimalValues()
Do not update the inputs and outputs from the primal values of the tape. Has no effect on Jacobian ta...
Definition externalFunctionHelper.hpp:435
ExternalFunctionHelper(bool primalFuncUsesADType=false)
Constructor.
Definition externalFunctionHelper.hpp:406
~ExternalFunctionHelper()
Destructor.
Definition externalFunctionHelper.hpp:419
bool storeOutputPrimals
Definition externalFunctionHelper.hpp:392
inlinevoid addUserData(Data const &data)
Add user data. See ExternalFunctionUserData for details.
Definition externalFunctionHelper.hpp:504
typename Type::Real Real
Definition externalFunctionHelper.hpp:114
ExternalFunctionUserData & getExternalFunctionUserData()
Definition externalFunctionHelper.hpp:510
typename Type::Tape Tape
Definition externalFunctionHelper.hpp:117
void disableOutputPrimalStore()
Do not store primal output values. In function calls, pointers to primal outputs will be null.
Definition externalFunctionHelper.hpp:449
void(*)(Real const *x, Real *x_b, size_t m, Real const *y, Real const *y_b, size_t n, ExternalFunctionUserData *d) ReverseFunc
Definition externalFunctionHelper.hpp:120
bool storeInputOutputForPrimalEval
Definition externalFunctionHelper.hpp:393
EvalData * data
Definition externalFunctionHelper.hpp:399
void(*)(Real const *x, size_t m, Real *y, size_t n, ExternalFunctionUserData *d) PrimalFunc
Definition externalFunctionHelper.hpp:128
bool reallocatePrimalVectors
Definition externalFunctionHelper.hpp:394
inlinevoid callPrimalFunc(PrimalFunc func)
Definition externalFunctionHelper.hpp:543
DefaultThreadInformation ThreadInformation
Definition externalFunctionHelper.hpp:112
bool getPrimalValuesFromPrimalValueVector
Definition externalFunctionHelper.hpp:396
inlinevoid callPrimalFuncWithADType(FuncObj &func, Args &&... args)
Definition externalFunctionHelper.hpp:517
std::vector< Real > y
Definition externalFunctionHelper.hpp:401
Type Type
Definition externalFunctionHelper.hpp:106
inlinevoid addOutput(Type &output)
Add an output value.
Definition externalFunctionHelper.hpp:496
std::vector< Type * > outputValues
Definition externalFunctionHelper.hpp:389
inlinevoid addToTape(ReverseFunc reverseFunc, ForwardFunc forwardFunc=nullptr, PrimalFunc primalFunc=nullptr)
Add the external function to the tape.
Definition externalFunctionHelper.hpp:581
Ease of access structure for user-provided data on the tape for external functions....
Definition externalFunctionUserData.hpp:59
void(* DeleteFunction)(Tape *tape, void *data)
Delete function definition.
Definition externalFunction.hpp:113
static ExternalFunction create(CallFunction funcReverse, void *data, DeleteFunction funcDelete, CallFunction funcForward=nullptr, CallFunction funcPrimal=nullptr)
Helper function for the creation of an ExternalFunction object.
Definition externalFunction.hpp:124
If the tape inherits from PrimalValueBaseTape.
Definition tapeTraits.hpp:95
Unified access to the adjoint vector and primal vector in a tape evaluation.
Definition vectorAccessInterface.hpp:94
virtual void resetAdjoint(Identifier const &index, size_t dim)=0
Set the adjoint component to zero.
virtual void setPrimal(Identifier const &index, Real const &primal)=0
Set the primal value.
virtual size_t getVectorSize() const =0
Vector size in the current tape evaluation.
virtual Real getPrimal(Identifier const &index)=0
Get the primal value.
virtual void updateAdjoint(Identifier const &index, size_t dim, Real const &adjoint)=0
Update the adjoint component.
virtual Real getAdjoint(Identifier const &index, size_t dim)=0
Get the adjoint component.