CoDiPack  2.2.0
A Code Differentiation Package
SciComp TU Kaiserslautern
Loading...
Searching...
No Matches
enzymeExternalFunctionHelper.hpp
1/*
2 * CoDiPack, a Code Differentiation Package
3 *
4 * Copyright (C) 2015-2024 Chair for Scientific Computing (SciComp), University of Kaiserslautern-Landau
5 * Homepage: http://scicomp.rptu.de
6 * Contact: Prof. Nicolas R. Gauger (codi@scicomp.uni-kl.de)
7 *
8 * Lead developers: Max Sagebaum, Johannes Blühdorn (SciComp, University of Kaiserslautern-Landau)
9 *
10 * This file is part of CoDiPack (http://scicomp.rptu.de/software/codi).
11 *
12 * CoDiPack is free software: you can redistribute it and/or
13 * modify it under the terms of the GNU General Public License
14 * as published by the Free Software Foundation, either version 3 of the
15 * License, or (at your option) any later version.
16 *
17 * CoDiPack is distributed in the hope that it will be useful,
18 * but WITHOUT ANY WARRANTY; without even the implied warranty
19 * of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
20 *
21 * See the GNU General Public License for more details.
22 * You should have received a copy of the GNU
23 * General Public License along with CoDiPack.
24 * If not, see <http://www.gnu.org/licenses/>.
25 *
26 * For other licensing options please contact us.
27 *
28 * Authors:
29 * - SciComp, University of Kaiserslautern-Landau:
30 * - Max Sagebaum
31 * - Johannes Blühdorn
32 * - Former members:
33 * - Tim Albring
34 */
35#pragma once
36
37#include <vector>
38
39#include "../../config.h"
40#include "../../expressions/lhsExpressionInterface.hpp"
41#include "../../misc/macros.hpp"
42#include "../../tapes/interfaces/fullTapeInterface.hpp"
43#include "../../tapes/misc/vectorAccessInterface.hpp"
44#include "../../traits/tapeTraits.hpp"
45#include "../data/externalFunctionUserData.hpp"
46#include "externalFunctionHelper.hpp"
47
48extern int enzyme_dup;
49extern int enzyme_out;
50extern int enzyme_const;
51
52template<typename... Args>
53void __enzyme_autodiff(void*, Args...);
54
55template<typename... Args>
56void __enzyme_fwddiff(void*, Args...);
57
59namespace codi {
60
83 template<typename T_Type>
85 public:
86
89
91
92 using Real = typename Type::Real;
93 using Identifier = typename Type::Identifier;
94
97
98 using PrimalFunc = typename Base::PrimalFunc;
99
100 public:
101
104
105 private:
106 // Hide base class functionality.
107 using Base::addToTape;
110
111 public:
112 // Implement enzyme specific functionality.
113
117 template<PrimalFunc func>
120 Base::addToTape(enzymeDiff_b<func>, enzymeDiff_d<func>, func);
121 }
122
124 template<PrimalFunc func>
125 void callAndAddToTape(Type const* x, size_t m, Type* y, size_t n) {
126 for (size_t i = 0; i < m; i += 1) {
127 Base::addInput(x[i]);
128 }
129
130 for (size_t i = 0; i < n; i += 1) {
131 Base::addOutput(y[i]);
132 }
133
134 callAndAddToTape<func>();
135 }
136
137 private:
138 template<PrimalFunc func>
139 static void enzymeDiff_b(Real const* x, Real* x_b, size_t m, Real const* y, Real const* y_b, size_t n,
141 // clang-format off
142 __enzyme_autodiff(
143 (void*) func,
144 enzyme_dup, x, x_b,
145 enzyme_const, m,
146 enzyme_dup, y, y_b,
147 enzyme_const, n,
148 enzyme_const, d);
149 // clang-format on
150 }
151
152 template<PrimalFunc func>
153 static void enzymeDiff_d(Real const* x, Real const* x_d, size_t m, Real* y, Real* y_d, size_t n,
154 ExternalFunctionUserData* d) {
155 // clang-format off
156 __enzyme_fwddiff(
157 (void*) func,
158 enzyme_dup, x, x_d,
159 enzyme_const, m,
160 enzyme_dup, y, y_d,
161 enzyme_const, n,
162 enzyme_const, d);
163 // clang-format on
164 }
165 };
166}
#define CODI_DD(Type, Default)
Abbreviation for CODI_DECLARE_DEFAULT.
Definition macros.hpp:94
#define CODI_T(...)
Abbreviation for CODI_TEMPLATE.
Definition macros.hpp:111
CoDiPack - Code Differentiation Package.
Definition codi.hpp:90
Represents a concrete lvalue in the CoDiPack expression tree.
Definition activeType.hpp:52
Helper class for the implementation of an external function with Enzyme in CoDiPack.
Definition enzymeExternalFunctionHelper.hpp:84
typename Type::Real Real
See LhsExpressionInterface.
Definition enzymeExternalFunctionHelper.hpp:92
typename Type::Tape Tape
See LhsExpressionInterface.
Definition enzymeExternalFunctionHelper.hpp:96
T_Type Type
See ExternalFunctionHelper.
Definition enzymeExternalFunctionHelper.hpp:88
typename Base::PrimalFunc PrimalFunc
See ExternalFunctionHelper.
Definition enzymeExternalFunctionHelper.hpp:98
typename Type::Identifier Identifier
See LhsExpressionInterface.
Definition enzymeExternalFunctionHelper.hpp:93
void callAndAddToTape()
Definition enzymeExternalFunctionHelper.hpp:118
EnzymeExternalFunctionHelper()
Constructor.
Definition enzymeExternalFunctionHelper.hpp:103
void callAndAddToTape(Type const *x, size_t m, Type *y, size_t n)
Adds all inputs in x and outputs in y to the external function and then calls callAndAddToTape().
Definition enzymeExternalFunctionHelper.hpp:125
Helper class for the implementation of an external function in CoDiPack.
Definition externalFunctionHelper.hpp:102
void addInput(Type const &input)
Add an input value.
Definition externalFunctionHelper.hpp:451
void callPrimalFuncWithADType(FuncObj &func, Args &&... args)
Definition externalFunctionHelper.hpp:513
void addOutput(Type &output)
Add an output value.
Definition externalFunctionHelper.hpp:492
void(*)(Real const *x, size_t m, Real *y, size_t n, ExternalFunctionUserData *d) PrimalFunc
Function interface for the primal call of an external function.
Definition externalFunctionHelper.hpp:128
void callPrimalFunc(PrimalFunc func)
Definition externalFunctionHelper.hpp:539
std::vector< Real > y
Shared vector of output variables.
Definition externalFunctionHelper.hpp:397
void addToTape(ReverseFunc reverseFunc, ForwardFunc forwardFunc=nullptr, PrimalFunc primalFunc=nullptr)
Add the external function to the tape.
Definition externalFunctionHelper.hpp:577
Ease of access structure for user-provided data on the tape for external functions....
Definition externalFunctionUserData.hpp:59
Full tape interface that supports all features of CoDiPack.
Definition fullTapeInterface.hpp:82
Base class for all CoDiPack lvalue expression.
Definition lhsExpressionInterface.hpp:63