MeDiPack  1.4.0
A Message Differentiation Package
SciComp TU Kaiserslautern
Loading...
Searching...
No Matches
adToolInterface.h
Go to the documentation of this file.
1/*
2 * MeDiPack, a Message Differentiation Package
3 *
4 * Copyright (C) 2015-2026 Chair for Scientific Computing (SciComp), University of Kaiserslautern-Landau
5 * Homepage: http://scicomp.rptu.de
6 * Contact: Prof. Nicolas R. Gauger (codi@scicomp.uni-kl.de)
7 *
8 * Lead developers: Max Sagebaum (SciComp, University of Kaiserslautern-Landau)
9 *
10 * This file is part of MeDiPack (http://scicomp.rptu.de/software/medi).
11 *
12 * MeDiPack is free software: you can redistribute it and/or
13 * modify it under the terms of the GNU Lesser General Public
14 * License as published by the Free Software Foundation, either
15 * version 3 of the License, or (at your option) any later version.
16 *
17 * MeDiPack is distributed in the hope that it will be useful,
18 * but WITHOUT ANY WARRANTY; without even the implied warranty of
19 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
20 *
21 * See the GNU Lesser General Public License for more details.
22 * You should have received a copy of the GNU
23 * Lesser General Public License along with MeDiPack.
24 * If not, see <http://www.gnu.org/licenses/>.
25 *
26 * Authors: Max Sagebaum, Tim Albring (SciComp, University of Kaiserslautern-Landau)
27 */
28
29#pragma once
30
31#include "ampi/op.hpp"
32#include "typeDefinitions.h"
33
37namespace medi {
38
43
44 MPI_Datatype primalMpiType;
45 MPI_Datatype adjointMpiType;
46
47 public:
48
52 typedef void Type;
53
57 typedef void ModifiedType;
58
62 typedef void AdjointType;
63
67 typedef void PrimalType;
68
72 typedef void IndexType;
73
78 ADToolInterface(MPI_Datatype primalMpiType, MPI_Datatype adjointMpiType) :
79 primalMpiType(primalMpiType),
80 adjointMpiType(adjointMpiType) {}
81
82 virtual ~ADToolInterface() {}
83
84
89 MPI_Datatype getPrimalMpiType() const {
90 return primalMpiType;
91 }
92
97 MPI_Datatype getAdjointMpiType() const {
98 return adjointMpiType;
99 }
100
105 virtual bool isActiveType() const = 0;
106
112 virtual bool isHandleRequired() const = 0;
113
118 virtual bool isModifiedBufferRequired() const = 0;
119
124 virtual bool isOldPrimalsRequired() const = 0;
125
130 virtual void startAssembly(HandleBase* h) const = 0;
131
136 virtual void stopAssembly(HandleBase* h) const = 0;
137
155 virtual void addToolAction(HandleBase* h) const = 0;
156
163 virtual AMPI_Op convertOperator(AMPI_Op op) const = 0;
164
171 virtual void createPrimalTypeBuffer(void* &buf, size_t size) const = 0;
172
179 virtual void createIndexTypeBuffer(void* &buf, size_t size) const = 0;
180
186 virtual void deletePrimalTypeBuffer(void* &buf) const = 0;
187
193 virtual void deleteIndexTypeBuffer(void* &buf) const = 0;
194
206 virtual void iterateIdentifiers(void* indices, int elements, CallbackFunc func, void* userData) const = 0;
207
208 };
209
216
220 typedef double Type;
221
225 typedef double ModifiedType;
226
230 typedef double PrimalType;
231
235 typedef double AdjointType;
236
240 typedef int IndexType;
241
248 static void setIntoModifyBuffer(ModifiedType& modValue, const Type& value);
249
256 static void getFromModifyBuffer(const ModifiedType& modValue, Type& value);
257
263 static IndexType getIndex(const Type& value);
264
273 static void registerValue(Type& value, PrimalType& oldPrimal, IndexType& index);
274
279 static void clearIndex(Type& value);
280
286 static void createIndex(Type& value, IndexType& index);
287
293 static PrimalType getValue(const Type& value);
294 };
295
296
307 template <typename Impl, typename AdjointTypeB, typename PrimalTypeB, typename IndexTypeB>
309 public:
310
311 using CallbackFuncTyped = void (*)(IndexTypeB* id, void* userData);
312
318 ADToolBase(MPI_Datatype primalMpiType, MPI_Datatype adjointMpiType) :
319 ADToolInterface(primalMpiType, adjointMpiType) {}
320
321 void createPrimalTypeBuffer(void* &buf, size_t size) const {
322 cast().createPrimalTypeBuffer(castBuffer<PrimalTypeB>(buf), size);
323 }
324
325 void createIndexTypeBuffer(void* &buf, size_t size) const {
326 cast().createIndexTypeBuffer(castBuffer<IndexTypeB>(buf), size);
327 }
328
329 void deletePrimalTypeBuffer(void* &buf) const {
330 cast().deletePrimalTypeBuffer(castBuffer<PrimalTypeB>(buf));
331 }
332
333 void deleteIndexTypeBuffer(void* &buf) const {
334 cast().deleteIndexTypeBuffer(castBuffer<IndexTypeB>(buf));
335 }
336
337 void iterateIdentifiers(void* indices, int elements, CallbackFunc func, void* userData) const {
338 cast().iterateIdentifiers(castBuffer<IndexTypeB>(indices), elements, (CallbackFuncTyped)func, userData);
339 }
340
341
342 private:
343
344 inline Impl& cast() {
345 return *reinterpret_cast<Impl*>(this);
346 }
347
348 inline const Impl& cast() const {
349 return *reinterpret_cast<const Impl*>(this);
350 }
351
352 template <typename T>
353 inline T*& castBuffer(void*& buf) const {
354 return reinterpret_cast<T*&>(buf);
355 }
356
357 template <typename T>
358 inline const T*& castBuffer(const void* &buf) const {
359 return reinterpret_cast<const T*&>(buf);
360 }
361 };
362
363 inline ADToolInterface const* selectADTool(ADToolInterface const& tool) {
364 return &tool;
365 }
366
367 inline ADToolInterface const* selectADTool(ADToolInterface const& toolA, ADToolInterface const& toolB) {
368 return toolA.isActiveType() ? &toolA : &toolB;
369 }
370}
ADToolBase(MPI_Datatype primalMpiType, MPI_Datatype adjointMpiType)
Construct the type safe wrapper.
Definition adToolInterface.h:318
void iterateIdentifiers(void *indices, int elements, CallbackFunc func, void *userData) const
Iterate of the identifiers of the AD tool. That is the AD tool should perform the operation:
Definition adToolInterface.h:337
void(*)(IndexType *id, void *userData) CallbackFuncTyped
Definition adToolInterface.h:311
void deletePrimalTypeBuffer(void *&buf) const
Delete the array of the passive variables.
Definition adToolInterface.h:329
void deleteIndexTypeBuffer(void *&buf) const
Delete the array of the index variables.
Definition adToolInterface.h:333
void createPrimalTypeBuffer(void *&buf, size_t size) const
Create an array for the passive variables.
Definition adToolInterface.h:321
void createIndexTypeBuffer(void *&buf, size_t size) const
Create an array for the index variables.
Definition adToolInterface.h:325
The interface for the AD tool that is accessed by MeDiPack.
Definition adToolInterface.h:42
ADToolInterface(MPI_Datatype primalMpiType, MPI_Datatype adjointMpiType)
Create an interface for the AD type.
Definition adToolInterface.h:78
virtual AMPI_Op convertOperator(AMPI_Op op) const =0
Convert the mpi intrinsic operators like MPI_SUM to the specific one for the AD tool.
virtual void createIndexTypeBuffer(void *&buf, size_t size) const =0
Create an array for the index variables.
virtual void startAssembly(HandleBase *h) const =0
Indicates to the AD tool that an adjoint action is in the progress of beeing recorded.
virtual void stopAssembly(HandleBase *h) const =0
Indicates to the AD tool that an adjoint action is beeing finished.
MPI_Datatype getAdjointMpiType() const
The mpi data type for the adjoint type.
Definition adToolInterface.h:97
virtual void addToolAction(HandleBase *h) const =0
Register the handle so that the AD tool can evaluate it in the reverse sweep.
virtual ~ADToolInterface()
Definition adToolInterface.h:82
virtual bool isHandleRequired() const =0
The handle needs to be created if an adjoint action is required by the AD tool.
virtual bool isActiveType() const =0
If this AD interface represents an AD type.
virtual void deletePrimalTypeBuffer(void *&buf) const =0
Delete the array of the passive variables.
virtual void deleteIndexTypeBuffer(void *&buf) const =0
Delete the array of the index variables.
virtual void createPrimalTypeBuffer(void *&buf, size_t size) const =0
Create an array for the passive variables.
MPI_Datatype getPrimalMpiType() const
The mpi data type for the primal type.
Definition adToolInterface.h:89
void PrimalType
The data type used for the floating point data.
Definition adToolInterface.h:67
virtual bool isOldPrimalsRequired() const =0
Indicates if MeDiPack needs store the overwritten primal values for the AD tool.
virtual bool isModifiedBufferRequired() const =0
Indicates if the AD tool needs to modify the buffer in order to send the correct data.
virtual void iterateIdentifiers(void *indices, int elements, CallbackFunc func, void *userData) const =0
Iterate of the identifiers of the AD tool. That is the AD tool should perform the operation:
void Type
The actual type that the AD implementation uses.
Definition adToolInterface.h:52
void IndexType
The data type from the AD tool for the identification of AD variables.
Definition adToolInterface.h:72
void AdjointType
The data type that is used for the adjoint variables.
Definition adToolInterface.h:62
void ModifiedType
The type that is send through the modified buffers.
Definition adToolInterface.h:57
Global namespace for MeDiPack - Message Differentiation Package.
Definition adjointInterface.hpp:37
ADToolInterface const * selectADTool(ADToolInterface const &tool)
Definition adToolInterface.h:363
void(*)(void *id, void *userData) CallbackFunc
Definition typeDefinitions.h:45
Structure for the special handling of the MPI_Op structure.
Definition op.hpp:50
Definition typeDefinitions.h:57
The static methods for the AD tool interface.
Definition adToolInterface.h:215
static void getFromModifyBuffer(const ModifiedType &modValue, Type &value)
Copies the nescessary data from the received MeDiPack buffer into the user buffer.
double AdjointType
The data type that is used for the adjoint variables.
Definition adToolInterface.h:235
double ModifiedType
The type that is send through the modified buffers.
Definition adToolInterface.h:225
double PrimalType
The data type used for the floating point data.
Definition adToolInterface.h:230
static IndexType getIndex(const Type &value)
Get the AD identifier for this value.
static void setIntoModifyBuffer(ModifiedType &modValue, const Type &value)
Copies the nescessary data from the user buffer into the buffer crated by MeDiPack.
static void clearIndex(Type &value)
Delete the index in a buffer such that the buffer can be overwritten.
double Type
The actual type that the AD implementation uses.
Definition adToolInterface.h:220
static PrimalType getValue(const Type &value)
Get the primal floating point value of the AD value.
int IndexType
The data type from the AD tool for the identification of AD variables.
Definition adToolInterface.h:240
static void createIndex(Type &value, IndexType &index)
Create an index for the given item in the buffer.
static void registerValue(Type &value, PrimalType &oldPrimal, IndexType &index)
Register an AD value on the receiving side of the communication.