CoDiPack  3.1.0
A Code Differentiation Package
SciComp TU Kaiserslautern
Loading...
Searching...
No Matches
codiReverseMeDiPackTool.hpp
1/*
2 * CoDiPack, a Code Differentiation Package
3 *
4 * Copyright (C) 2015-2026 Chair for Scientific Computing (SciComp), University of Kaiserslautern-Landau
5 * Homepage: http://scicomp.rptu.de
6 * Contact: Prof. Nicolas R. Gauger (codi@scicomp.uni-kl.de)
7 *
8 * Lead developers: Max Sagebaum, Johannes Blühdorn (SciComp, University of Kaiserslautern-Landau)
9 *
10 * This file is part of CoDiPack (http://scicomp.rptu.de/software/codi).
11 *
12 * CoDiPack is free software: you can redistribute it and/or
13 * modify it under the terms of the GNU General Public License
14 * as published by the Free Software Foundation, either version 3 of the
15 * License, or (at your option) any later version.
16 *
17 * CoDiPack is distributed in the hope that it will be useful,
18 * but WITHOUT ANY WARRANTY; without even the implied warranty
19 * of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
20 *
21 * See the GNU General Public License for more details.
22 * You should have received a copy of the GNU
23 * General Public License along with CoDiPack.
24 * If not, see <http://www.gnu.org/licenses/>.
25 *
26 * For other licensing options please contact us.
27 *
28 * Authors:
29 * - SciComp, University of Kaiserslautern-Landau:
30 * - Max Sagebaum
31 * - Johannes Blühdorn
32 * - Former members:
33 * - Tim Albring
34 */
35#pragma once
36
37#include <medi/adToolInterface.h>
38#include <medi/ampi/ampiMisc.h>
39
40#include <medi/adToolImplCommon.hpp>
41#include <medi/adjointInterface.hpp>
42#include <medi/ampi/op.hpp>
43#include <medi/ampi/typeDefault.hpp>
44#include <medi/ampi/types/indexTypeHelper.hpp>
45
46#include "../../config.h"
47#include "../../expressions/lhsExpressionInterface.hpp"
48#include "../../misc/macros.hpp"
49#include "../../tapes/interfaces/fullTapeInterface.hpp"
50#include "../../tapes/misc/adjointVectorAccess.hpp"
51
53namespace codi {
54
55#ifndef DOXYGEN_DISABLE
56
57 template<typename T_Type>
58 struct CoDiMeDiAdjointInterfaceWrapper : public medi::AdjointInterface {
59 public:
60
61 using Type = CODI_DD(T_Type, CODI_DEFAULT_LHS_EXPRESSION);
62
63 using Real = typename Type::Real;
64 using Identifier = typename Type::Identifier;
65
66 VectorAccessInterface<Real, Identifier>* codiInterface;
67
68 int vecSize;
69
70 CoDiMeDiAdjointInterfaceWrapper(VectorAccessInterface<Real, Identifier>* interface)
71 : codiInterface(interface), vecSize((int)interface->getVectorSize()) {}
72
73 CODI_INLINE_NO_FA int computeElements(int elements) const {
74 return elements * vecSize;
75 }
76
77 CODI_INLINE_NO_FA int getVectorSize() const {
78 return vecSize;
79 }
80
81 CODI_INLINE_NO_FA void getAdjoints(void const* i, void* a, int elements) const {
82 Real* adjoints = (Real*)a;
83 Identifier* indices = (Identifier*)i;
84
85 for (int pos = 0; pos < elements; ++pos) {
86 codiInterface->getAdjointVec(indices[pos], &adjoints[pos * vecSize]);
87 codiInterface->resetAdjointVec(indices[pos]);
88 }
89 }
90
91 CODI_INLINE_NO_FA void updateAdjoints(void const* i, void const* a, int elements) const {
92 Real* adjoints = (Real*)a;
93 Identifier* indices = (Identifier*)i;
94
95 for (int pos = 0; pos < elements; ++pos) {
96 codiInterface->updateAdjointVec(indices[pos], &adjoints[pos * vecSize]);
97 }
98 }
99
100 CODI_INLINE_NO_FA void getPrimals(void const* i, void const* p, int elements) const {
101 Real* primals = (Real*)p;
102 Identifier* indices = (Identifier*)i;
103
104 for (int pos = 0; pos < elements; ++pos) {
105 primals[pos] = codiInterface->getPrimal(indices[pos]);
106 }
107 }
108
109 CODI_INLINE_NO_FA void setPrimals(void const* i, void const* p, int elements) const {
110 Real* primals = (Real*)p;
111 Identifier* indices = (Identifier*)i;
112
113 for (int pos = 0; pos < elements; ++pos) {
114 codiInterface->setPrimal(indices[pos], primals[pos]);
115 }
116 }
117
118 CODI_INLINE_NO_FA void combineAdjoints(void* b, int const elements, int const ranks) const {
119 Real* buf = (Real*)b;
120
121 for (int curRank = 1; curRank < ranks; ++curRank) {
122 for (int curPos = 0; curPos < elements; ++curPos) {
123 for (int dim = 0; dim < vecSize; ++dim) {
124 buf[curPos * vecSize + dim] += buf[(elements * curRank + curPos) * vecSize + dim];
125 }
126 }
127 }
128 }
129
130 CODI_INLINE_NO_FA void createPrimalTypeBuffer(void*& buf, size_t size) const {
131 buf = (void*)(new Real[size * vecSize]);
132 }
133
134 CODI_INLINE_NO_FA void deletePrimalTypeBuffer(void*& b) const {
135 if (nullptr != b) {
136 Real* buf = (Real*)b;
137 delete[] buf;
138 b = nullptr;
139 }
140 }
141
142 CODI_INLINE_NO_FA void createAdjointTypeBuffer(void*& buf, size_t size) const {
143 buf = (void*)(new Real[size * vecSize]);
144 }
145
146 CODI_INLINE_NO_FA void deleteAdjointTypeBuffer(void*& b) const {
147 if (nullptr != b) {
148 Real* buf = (Real*)b;
149 delete[] buf;
150 b = nullptr;
151 }
152 }
153 };
154
155 template<typename T_Type>
156 struct CoDiPackReverseTool
157 : public medi::ADToolImplCommon<CoDiPackReverseTool<T_Type>, T_Type::Tape::RequiresPrimalRestore, false, T_Type,
158 typename T_Type::Gradient, typename T_Type::Real, typename T_Type::Identifier> {
159 public:
160
161 // All type definitions for the interface.
162 using Type = CODI_DD(T_Type, CODI_DEFAULT_LHS_EXPRESSION);
163 using PrimalType = typename Type::Real;
164 using AdjointType = void;
165 using ModifiedType = Type;
166 using IndexType = typename Type::Identifier;
167
168 // Helper definition for CoDiPack.
169 using Tape = CODI_DD(typename Type::Tape, CODI_DEFAULT_TAPE);
170 using IterCallback = typename ExternalFunction<Tape>::IterCallback;
171
172 using OpHelper =
173 medi::OperatorHelper<medi::FunctionHelper<Type, Type, typename Type::PassiveReal, typename Type::Gradient,
174 typename Type::Identifier, CoDiPackReverseTool> >;
175
176 using Base = medi::ADToolImplCommon<CoDiPackReverseTool, Tape::RequiresPrimalRestore, false, Type,
177 typename Type::Gradient, PrimalType, IndexType>;
178
179 private:
180 // Private structures for the implementation.
181
182 OpHelper opHelper;
183
184 public:
185 CoDiPackReverseTool(MPI_Datatype primalMpiType, MPI_Datatype adjointMpiType)
186 : Base(primalMpiType, adjointMpiType), opHelper() {
187 opHelper.init();
188 }
189
190 ~CoDiPackReverseTool() {
191 opHelper.finalize();
192 }
193
194 // Implementation of the interface.
195
196 CODI_INLINE_NO_FA bool isHandleRequired() const {
197 // Handle creation is based on the CoDiPack tape activity. Only if the tape is recording the adjoint
198 // communication needs to be evaluated.
199 return getTape().isActive();
200 }
201
202 CODI_INLINE_NO_FA void startAssembly(medi::HandleBase* h) const {
203 CODI_UNUSED(h);
204
205 // No preparation required for CoDiPack.
206 }
207
208 CODI_INLINE_NO_FA void addToolAction(medi::HandleBase* h) const {
209 if (nullptr != h) {
210 getTape().pushExternalFunction(
211 ExternalFunction<Tape>::create(callHandleReverse, h, deleteHandle, callHandleForward, callHandlePrimal,
212 callHandleIterateInputs, callHandleIterateOutputs));
213 }
214 }
215
216 medi::AMPI_Op convertOperator(medi::AMPI_Op op) const {
217 return opHelper.convertOperator(op);
218 }
219
220 CODI_INLINE_NO_FA void stopAssembly(medi::HandleBase* h) const {
221 CODI_UNUSED(h);
222
223 // No preparation required for CoDiPack.
224 }
225
226 static CODI_INLINE_NO_FA IndexType getIndex(Type const& value) {
227 return value.getIdentifier();
228 }
229
230 static CODI_INLINE_NO_FA void registerValue(Type& value, PrimalType& oldPrimal, IndexType& index) {
231 bool wasActive = getTape().isIdentifierActive(value.getIdentifier());
232 value.getIdentifier() = IndexType();
233
234 // Make the value active again if it has been active before on the other processor.
235 if (wasActive) {
236 if (Tape::LinearIndexHandling) {
237 // Value has been registered in createIndices.
238 value.getIdentifier() = index;
239
240 // In createIndices the primal value has been set to zero. So set now the correct value.
241 if (Tape::HasPrimalValues) {
242 getTape().setPrimal(index, value.getValue());
243 }
244 if (Tape::RequiresPrimalRestore) {
245 oldPrimal = PrimalType(0.0);
246 }
247 } else {
248 PrimalType primal = getTape().registerExternalFunctionOutput(value);
249 if (Tape::RequiresPrimalRestore) {
250 oldPrimal = primal;
251 }
252 index = value.getIdentifier();
253 }
254 } else {
255 if (Tape::RequiresPrimalRestore) {
256 oldPrimal = PrimalType(0.0);
257 }
258 if (!Tape::LinearIndexHandling) {
259 index = getTape().getPassiveIndex();
260 }
261 }
262 }
263
264 static CODI_INLINE_NO_FA void clearIndex(Type& value) {
265 IndexType oldIndex = value.getIdentifier();
266 value.~Type();
267 value.getIdentifier() = oldIndex; // Restore the index here so that the other side can decide of the
268 // communication was active or not.
269 }
270
271 static CODI_INLINE_NO_FA void createIndex(Type& value, IndexType& index) {
272 if (Tape::LinearIndexHandling) {
273 IndexType oldIndex = value.getIdentifier();
274 getTape().registerInput(value);
275 index = value.getIdentifier();
276 value.getIdentifier() = oldIndex; // Restore the index here so that the other side can decide of the
277 // communication was active or not.
278 }
279 }
280
281 static CODI_INLINE_NO_FA PrimalType getValue(Type const& value) {
282 return value.getValue();
283 }
284
285 static CODI_INLINE_NO_FA void setIntoModifyBuffer(ModifiedType& modValue, Type const& value) {
286 CODI_UNUSED(modValue, value);
287
288 // CoDiPack values are send in place. No modified buffer is crated.
289 }
290
291 static CODI_INLINE_NO_FA void getFromModifyBuffer(ModifiedType const& modValue, Type& value) {
292 CODI_UNUSED(modValue, value);
293
294 // CoDiPack values are send in place. No modified buffer is crated.
295 }
296
297 static PrimalType getPrimalFromMod(ModifiedType const& modValue) {
298 return modValue.value();
299 }
300
301 static void setPrimalToMod(ModifiedType& modValue, PrimalType const& value) {
302 modValue.value() = value;
303 }
304
305 static void modifyDependency(ModifiedType& inval, ModifiedType& inoutval) {
306 bool active = getTape().isIdentifierActive(inoutval.getIdentifier()) ||
307 getTape().isIdentifierActive(inval.getIdentifier());
308 if (active) {
309 inoutval.getIdentifier() = getTape().getInvalidIndex();
310 } else {
311 inoutval.getIdentifier() = getTape().getPassiveIndex();
312 }
313 }
314
315 private:
316
317 static void callHandleReverse(Tape* tape, void* h, VectorAccessInterface<PrimalType, IndexType>* ah) {
318 CODI_UNUSED(tape);
319
320 medi::HandleBase* handle = static_cast<medi::HandleBase*>(h);
321 CoDiMeDiAdjointInterfaceWrapper<Type> ahWrapper(ah);
322 handle->funcReverse(handle, &ahWrapper);
323 }
324
325 static void callHandleForward(Tape* tape, void* h, VectorAccessInterface<PrimalType, IndexType>* ah) {
326 CODI_UNUSED(tape);
327
328 medi::HandleBase* handle = static_cast<medi::HandleBase*>(h);
329 CoDiMeDiAdjointInterfaceWrapper<Type> ahWrapper(ah);
330 handle->funcForward(handle, &ahWrapper);
331 }
332
333 static void callHandlePrimal(Tape* tape, void* h, VectorAccessInterface<PrimalType, IndexType>* ah) {
334 CODI_UNUSED(tape);
335
336 medi::HandleBase* handle = static_cast<medi::HandleBase*>(h);
337 CoDiMeDiAdjointInterfaceWrapper<Type> ahWrapper(ah);
338 handle->funcPrimal(handle, &ahWrapper);
339 }
340
341 static void deleteHandle(Tape* tape, void* h) {
342 CODI_UNUSED(tape);
343
344 medi::HandleBase* handle = static_cast<medi::HandleBase*>(h);
345 delete handle;
346 }
347
348 static void callHandleIterateInputs(Tape* tape, void* h, IterCallback func, void* userData) {
349 CODI_UNUSED(tape);
350
351 medi::HandleBase* handle = static_cast<medi::HandleBase*>(h);
352 handle->funcIterateInputIds(handle, (::medi::CallbackFunc)func, userData);
353 }
354
355 static void callHandleIterateOutputs(Tape* tape, void* h, IterCallback func, void* userData) {
356 CODI_UNUSED(tape);
357
358 medi::HandleBase* handle = static_cast<medi::HandleBase*>(h);
359 handle->funcIterateOutputIds(handle, (::medi::CallbackFunc)func, userData);
360 }
361
362 static Tape& getTape() {
363 return Type::getTape();
364 }
365 };
366#endif
367}
#define CODI_INLINE_NO_FA
See codi::Config::ForcedInlines.
Definition config.h:471
#define CODI_DD(Type, Default)
Abbreviation for CODI_DECLARE_DEFAULT.
Definition macros.hpp:97
typename TraitsImplementation< Gradient >::Real Real
The base value used in the gradient entries.
Definition gradientTraits.hpp:92
inlinesize_t constexpr dim()
Number of dimensions this gradient value has.
Definition gradientTraits.hpp:96
inlinetypename DataExtraction< Type >::Real getValue(Type const &v)
Extract an aggregate of primal values from an aggregate of active types.
Definition realTraits.hpp:381
CoDiPack - Code Differentiation Package.
Definition codi.hpp:97
inlinevoid CODI_UNUSED(Args const &...)
Disable unused warnings for an arbitrary number of arguments.
Definition macros.hpp:55