CoDiPack  2.2.0
A Code Differentiation Package
SciComp TU Kaiserslautern
Loading...
Searching...
No Matches
codiReverseMeDiPackTool.hpp
1/*
2 * CoDiPack, a Code Differentiation Package
3 *
4 * Copyright (C) 2015-2024 Chair for Scientific Computing (SciComp), University of Kaiserslautern-Landau
5 * Homepage: http://scicomp.rptu.de
6 * Contact: Prof. Nicolas R. Gauger (codi@scicomp.uni-kl.de)
7 *
8 * Lead developers: Max Sagebaum, Johannes Blühdorn (SciComp, University of Kaiserslautern-Landau)
9 *
10 * This file is part of CoDiPack (http://scicomp.rptu.de/software/codi).
11 *
12 * CoDiPack is free software: you can redistribute it and/or
13 * modify it under the terms of the GNU General Public License
14 * as published by the Free Software Foundation, either version 3 of the
15 * License, or (at your option) any later version.
16 *
17 * CoDiPack is distributed in the hope that it will be useful,
18 * but WITHOUT ANY WARRANTY; without even the implied warranty
19 * of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
20 *
21 * See the GNU General Public License for more details.
22 * You should have received a copy of the GNU
23 * General Public License along with CoDiPack.
24 * If not, see <http://www.gnu.org/licenses/>.
25 *
26 * For other licensing options please contact us.
27 *
28 * Authors:
29 * - SciComp, University of Kaiserslautern-Landau:
30 * - Max Sagebaum
31 * - Johannes Blühdorn
32 * - Former members:
33 * - Tim Albring
34 */
35#pragma once
36
37#include <medi/adToolInterface.h>
38#include <medi/ampi/ampiMisc.h>
39
40#include <medi/adToolImplCommon.hpp>
41#include <medi/adjointInterface.hpp>
42#include <medi/ampi/op.hpp>
43#include <medi/ampi/typeDefault.hpp>
44#include <medi/ampi/types/indexTypeHelper.hpp>
45
46#include "../../config.h"
47#include "../../expressions/lhsExpressionInterface.hpp"
48#include "../../misc/macros.hpp"
49#include "../../tapes/interfaces/fullTapeInterface.hpp"
50#include "../../tapes/misc/adjointVectorAccess.hpp"
51
53namespace codi {
54
55#ifndef DOXYGEN_DISABLE
56
57 template<typename T_Type>
58 struct CoDiMeDiAdjointInterfaceWrapper : public medi::AdjointInterface {
59 public:
60
61 using Type = CODI_DD(T_Type, CODI_DEFAULT_LHS_EXPRESSION);
62
63 using Real = typename Type::Real;
64 using Identifier = typename Type::Identifier;
65
66 VectorAccessInterface<Real, Identifier>* codiInterface;
67
68 int vecSize;
69
70 CoDiMeDiAdjointInterfaceWrapper(VectorAccessInterface<Real, Identifier>* interface)
71 : codiInterface(interface), vecSize((int)interface->getVectorSize()) {}
72
73 CODI_INLINE_NO_FA int computeElements(int elements) const {
74 return elements * vecSize;
75 }
76
77 CODI_INLINE_NO_FA int getVectorSize() const {
78 return vecSize;
79 }
80
81 CODI_INLINE_NO_FA void getAdjoints(void const* i, void* a, int elements) const {
82 Real* adjoints = (Real*)a;
83 Identifier* indices = (Identifier*)i;
84
85 for (int pos = 0; pos < elements; ++pos) {
86 codiInterface->getAdjointVec(indices[pos], &adjoints[pos * vecSize]);
87 codiInterface->resetAdjointVec(indices[pos]);
88 }
89 }
90
91 CODI_INLINE_NO_FA void updateAdjoints(void const* i, void const* a, int elements) const {
92 Real* adjoints = (Real*)a;
93 Identifier* indices = (Identifier*)i;
94
95 for (int pos = 0; pos < elements; ++pos) {
96 codiInterface->updateAdjointVec(indices[pos], &adjoints[pos * vecSize]);
97 }
98 }
99
100 CODI_INLINE_NO_FA void getPrimals(void const* i, void const* p, int elements) const {
101 Real* primals = (Real*)p;
102 Identifier* indices = (Identifier*)i;
103
104 for (int pos = 0; pos < elements; ++pos) {
105 primals[pos] = codiInterface->getPrimal(indices[pos]);
106 }
107 }
108
109 CODI_INLINE_NO_FA void setPrimals(void const* i, void const* p, int elements) const {
110 Real* primals = (Real*)p;
111 Identifier* indices = (Identifier*)i;
112
113 for (int pos = 0; pos < elements; ++pos) {
114 codiInterface->setPrimal(indices[pos], primals[pos]);
115 }
116 }
117
118 CODI_INLINE_NO_FA void combineAdjoints(void* b, int const elements, int const ranks) const {
119 Real* buf = (Real*)b;
120
121 for (int curRank = 1; curRank < ranks; ++curRank) {
122 for (int curPos = 0; curPos < elements; ++curPos) {
123 for (int dim = 0; dim < vecSize; ++dim) {
124 buf[curPos * vecSize + dim] += buf[(elements * curRank + curPos) * vecSize + dim];
125 }
126 }
127 }
128 }
129
130 CODI_INLINE_NO_FA void createPrimalTypeBuffer(void*& buf, size_t size) const {
131 buf = (void*)(new Real[size * vecSize]);
132 }
133
134 CODI_INLINE_NO_FA void deletePrimalTypeBuffer(void*& b) const {
135 if (nullptr != b) {
136 Real* buf = (Real*)b;
137 delete[] buf;
138 b = nullptr;
139 }
140 }
141
142 CODI_INLINE_NO_FA void createAdjointTypeBuffer(void*& buf, size_t size) const {
143 buf = (void*)(new Real[size * vecSize]);
144 }
145
146 CODI_INLINE_NO_FA void deleteAdjointTypeBuffer(void*& b) const {
147 if (nullptr != b) {
148 Real* buf = (Real*)b;
149 delete[] buf;
150 b = nullptr;
151 }
152 }
153 };
154
155 template<typename T_Type>
156 struct CoDiPackReverseTool
157 : public medi::ADToolImplCommon<CoDiPackReverseTool<T_Type>, T_Type::Tape::RequiresPrimalRestore, false, T_Type,
158 typename T_Type::Gradient, typename T_Type::Real, typename T_Type::Identifier> {
159 public:
160
161 // All type definitions for the interface.
162 using Type = CODI_DD(T_Type, CODI_DEFAULT_LHS_EXPRESSION);
163 using PrimalType = typename Type::Real;
164 using AdjointType = void;
165 using ModifiedType = Type;
166 using IndexType = typename Type::Identifier;
167
168 // Helper definition for CoDiPack.
169 using Tape = CODI_DD(typename Type::Tape, CODI_DEFAULT_TAPE);
170
171 using OpHelper =
172 medi::OperatorHelper<medi::FunctionHelper<Type, Type, typename Type::PassiveReal, typename Type::Gradient,
173 typename Type::Identifier, CoDiPackReverseTool> >;
174
175 using Base = medi::ADToolImplCommon<CoDiPackReverseTool, Tape::RequiresPrimalRestore, false, Type,
176 typename Type::Gradient, PrimalType, IndexType>;
177
178 private:
179 // Private structures for the implementation.
180
181 OpHelper opHelper;
182
183 public:
184 CoDiPackReverseTool(MPI_Datatype primalMpiType, MPI_Datatype adjointMpiType)
185 : Base(primalMpiType, adjointMpiType), opHelper() {
186 opHelper.init();
187 }
188
189 ~CoDiPackReverseTool() {
190 opHelper.finalize();
191 }
192
193 // Implementation of the interface.
194
195 CODI_INLINE_NO_FA bool isHandleRequired() const {
196 // Handle creation is based on the CoDiPack tape activity. Only if the tape is recording the adjoint
197 // communication needs to be evaluated.
198 return getTape().isActive();
199 }
200
201 CODI_INLINE_NO_FA void startAssembly(medi::HandleBase* h) const {
202 CODI_UNUSED(h);
203
204 // No preparation required for CoDiPack.
205 }
206
207 CODI_INLINE_NO_FA void addToolAction(medi::HandleBase* h) const {
208 if (nullptr != h) {
209 getTape().pushExternalFunction(
210 ExternalFunction<Tape>::create(callHandleReverse, h, deleteHandle, callHandleForward, callHandlePrimal));
211 }
212 }
213
214 medi::AMPI_Op convertOperator(medi::AMPI_Op op) const {
215 return opHelper.convertOperator(op);
216 }
217
218 CODI_INLINE_NO_FA void stopAssembly(medi::HandleBase* h) const {
219 CODI_UNUSED(h);
220
221 // No preparation required for CoDiPack.
222 }
223
224 static CODI_INLINE_NO_FA IndexType getIndex(Type const& value) {
225 return value.getIdentifier();
226 }
227
228 static CODI_INLINE_NO_FA void registerValue(Type& value, PrimalType& oldPrimal, IndexType& index) {
229 bool wasActive = getTape().isIdentifierActive(value.getIdentifier());
230 value.getIdentifier() = IndexType();
231
232 // Make the value active again if it has been active before on the other processor.
233 if (wasActive) {
234 if (Tape::LinearIndexHandling) {
235 // Value has been registered in createIndices.
236 value.getIdentifier() = index;
237
238 // In createIndices the primal value has been set to zero. So set now the correct value.
239 if (Tape::HasPrimalValues) {
240 getTape().setPrimal(index, value.getValue());
241 }
242 if (Tape::RequiresPrimalRestore) {
243 oldPrimal = PrimalType(0.0);
244 }
245 } else {
246 PrimalType primal = getTape().registerExternalFunctionOutput(value);
247 if (Tape::RequiresPrimalRestore) {
248 oldPrimal = primal;
249 }
250 index = value.getIdentifier();
251 }
252 } else {
253 if (Tape::RequiresPrimalRestore) {
254 oldPrimal = PrimalType(0.0);
255 }
256 if (!Tape::LinearIndexHandling) {
257 index = getTape().getPassiveIndex();
258 }
259 }
260 }
261
262 static CODI_INLINE_NO_FA void clearIndex(Type& value) {
263 IndexType oldIndex = value.getIdentifier();
264 value.~Type();
265 value.getIdentifier() = oldIndex; // Restore the index here so that the other side can decide of the
266 // communication was active or not.
267 }
268
269 static CODI_INLINE_NO_FA void createIndex(Type& value, IndexType& index) {
270 if (Tape::LinearIndexHandling) {
271 IndexType oldIndex = value.getIdentifier();
272 getTape().registerInput(value);
273 index = value.getIdentifier();
274 value.getIdentifier() = oldIndex; // Restore the index here so that the other side can decide of the
275 // communication was active or not.
276 }
277 }
278
279 static CODI_INLINE_NO_FA PrimalType getValue(Type const& value) {
280 return value.getValue();
281 }
282
283 static CODI_INLINE_NO_FA void setIntoModifyBuffer(ModifiedType& modValue, Type const& value) {
284 CODI_UNUSED(modValue, value);
285
286 // CoDiPack values are send in place. No modified buffer is crated.
287 }
288
289 static CODI_INLINE_NO_FA void getFromModifyBuffer(ModifiedType const& modValue, Type& value) {
290 CODI_UNUSED(modValue, value);
291
292 // CoDiPack values are send in place. No modified buffer is crated.
293 }
294
295 static PrimalType getPrimalFromMod(ModifiedType const& modValue) {
296 return modValue.value();
297 }
298
299 static void setPrimalToMod(ModifiedType& modValue, PrimalType const& value) {
300 modValue.value() = value;
301 }
302
303 static void modifyDependency(ModifiedType& inval, ModifiedType& inoutval) {
304 bool active = getTape().isIdentifierActive(inoutval.getIdentifier()) ||
305 getTape().isIdentifierActive(inval.getIdentifier());
306 if (active) {
307 inoutval.getIdentifier() = getTape().getInvalidIndex();
308 } else {
309 inoutval.getIdentifier() = getTape().getPassiveIndex();
310 }
311 }
312
313 private:
314
315 static void callHandleReverse(Tape* tape, void* h, VectorAccessInterface<PrimalType, IndexType>* ah) {
316 CODI_UNUSED(tape);
317
318 medi::HandleBase* handle = static_cast<medi::HandleBase*>(h);
319 CoDiMeDiAdjointInterfaceWrapper<Type> ahWrapper(ah);
320 handle->funcReverse(handle, &ahWrapper);
321 }
322
323 static void callHandleForward(Tape* tape, void* h, VectorAccessInterface<PrimalType, IndexType>* ah) {
324 CODI_UNUSED(tape);
325
326 medi::HandleBase* handle = static_cast<medi::HandleBase*>(h);
327 CoDiMeDiAdjointInterfaceWrapper<Type> ahWrapper(ah);
328 handle->funcForward(handle, &ahWrapper);
329 }
330
331 static void callHandlePrimal(Tape* tape, void* h, VectorAccessInterface<PrimalType, IndexType>* ah) {
332 CODI_UNUSED(tape);
333
334 medi::HandleBase* handle = static_cast<medi::HandleBase*>(h);
335 CoDiMeDiAdjointInterfaceWrapper<Type> ahWrapper(ah);
336 handle->funcPrimal(handle, &ahWrapper);
337 }
338
339 static void deleteHandle(Tape* tape, void* h) {
340 CODI_UNUSED(tape);
341
342 medi::HandleBase* handle = static_cast<medi::HandleBase*>(h);
343 delete handle;
344 }
345
346 static Tape& getTape() {
347 return Type::getTape();
348 }
349 };
350#endif
351}
#define CODI_INLINE_NO_FA
See codi::Config::ForcedInlines.
Definition config.h:459
#define CODI_DD(Type, Default)
Abbreviation for CODI_DECLARE_DEFAULT.
Definition macros.hpp:94
size_t constexpr dim()
Number of dimensions this gradient value has.
Definition gradientTraits.hpp:96
DataExtraction< Type >::Real getValue(Type const &v)
Extract the primal values from a type of aggregated active types.
Definition realTraits.hpp:210
CoDiPack - Code Differentiation Package.
Definition codi.hpp:90
void CODI_UNUSED(Args const &...)
Disable unused warnings for an arbitrary number of arguments.
Definition macros.hpp:46
Represents a concrete lvalue in the CoDiPack expression tree.
Definition activeType.hpp:52