MeDiPack  1.3.1
A Message Differentiation Package
SciComp TU Kaiserslautern
Loading...
Searching...
No Matches
indexTypeHelper.hpp
Go to the documentation of this file.
1/*
2 * MeDiPack, a Message Differentiation Package
3 *
4 * Copyright (C) 2015-2025 Chair for Scientific Computing (SciComp), University of Kaiserslautern-Landau
5 * Homepage: http://scicomp.rptu.de
6 * Contact: Prof. Nicolas R. Gauger (codi@scicomp.uni-kl.de)
7 *
8 * Lead developers: Max Sagebaum (SciComp, University of Kaiserslautern-Landau)
9 *
10 * This file is part of MeDiPack (http://scicomp.rptu.de/software/medi).
11 *
12 * MeDiPack is free software: you can redistribute it and/or
13 * modify it under the terms of the GNU Lesser General Public
14 * License as published by the Free Software Foundation, either
15 * version 3 of the License, or (at your option) any later version.
16 *
17 * MeDiPack is distributed in the hope that it will be useful,
18 * but WITHOUT ANY WARRANTY; without even the implied warranty of
19 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
20 *
21 * See the GNU Lesser General Public License for more details.
22 * You should have received a copy of the GNU
23 * Lesser General Public License along with MeDiPack.
24 * If not, see <http://www.gnu.org/licenses/>.
25 *
26 * Authors: Max Sagebaum, Tim Albring (SciComp, University of Kaiserslautern-Landau)
27 */
28
29#pragma once
30
31#include <algorithm>
32
33#include "../ampiMisc.h"
34#include "../../macros.h"
35#include "../typeInterface.hpp"
36#include "../op.hpp"
37
41namespace medi {
42
43 template <typename Type, typename ModifiedType, typename PrimalType, typename IndexType>
45 static PrimalType getPrimalFromMod(const ModifiedType& mod);
46 static void setPrimalToMod(ModifiedType& mod, const PrimalType& value);
47
48 static void modifyDependency(const ModifiedType& in, ModifiedType& inout);
49 };
50
60 template <typename Type, typename ModifiedType, typename PrimalType, typename IndexType, typename AdjointType, INTERFACE_ARG(Tool)>
62 INTERFACE_DEF(ToolInterface, Tool, Type, ModifiedType, PrimalType, IndexType)
63
64 struct TypeInt {
65 Type value;
66 int index;
67 };
68
69 struct ModTypeInt {
70 ModifiedType value;
71 int index;
72 };
73
74 static void unmodifiedAdd(Type* invec, Type* inoutvec, int* len, MPI_Datatype* datatype) {
75 MEDI_UNUSED(datatype);
76
77 for(int i = 0; i < *len; ++i) {
78 inoutvec[i] += invec[i];
79 }
80 }
81
82 static void unmodifiedMul(Type* invec, Type* inoutvec, int* len, MPI_Datatype* datatype) {
83 MEDI_UNUSED(datatype);
84
85 for(int i = 0; i < *len; ++i) {
86 inoutvec[i] *= invec[i];
87 }
88 }
89
90 static void unmodifiedMax(Type* invec, Type* inoutvec, int* len, MPI_Datatype* datatype) {
91 MEDI_UNUSED(datatype);
92
93 using std::max;
94 for(int i = 0; i < *len; ++i) {
95 inoutvec[i] = max(inoutvec[i], invec[i]);
96 }
97 }
98
99 static void unmodifiedMin(Type* invec, Type* inoutvec, int* len, MPI_Datatype* datatype) {
100 MEDI_UNUSED(datatype);
101
102 using std::min;
103 for(int i = 0; i < *len; ++i) {
104 inoutvec[i] = min(inoutvec[i], invec[i]);
105 }
106 }
107
108 static void unmodifiedMaxLoc(TypeInt* invec, TypeInt* inoutvec, int* len, MPI_Datatype* datatype) {
109 MEDI_UNUSED(datatype);
110
111 using std::max;
112 for(int i = 0; i < *len; ++i) {
113 // first determine the index
114 if(invec[i].value > inoutvec[i].value) {
115 inoutvec[i].index = invec[i].index;
116 } else if(invec[i].value < inoutvec[i].value){
117 // empty operation: inoutvec[i].index = inoutvec[i].index;
118 } else {
119 inoutvec[i].index = std::min(invec[i].index, inoutvec[i].index);
120 }
121
122 // second determine the value
123 inoutvec[i].value = max(inoutvec[i].value, invec[i].value);
124 }
125 }
126
127 static void unmodifiedMinLoc(TypeInt* invec, TypeInt* inoutvec, int* len, MPI_Datatype* datatype) {
128 MEDI_UNUSED(datatype);
129
130 using std::min;
131 for(int i = 0; i < *len; ++i) {
132 // first determine the index
133 if(invec[i].value < inoutvec[i].value) {
134 inoutvec[i].index = invec[i].index;
135 } else if(invec[i].value > inoutvec[i].value){
136 // empty operation: inoutvec[i].index = inoutvec[i].index;
137 } else {
138 inoutvec[i].index = std::min(invec[i].index, inoutvec[i].index);
139 }
140
141 // second determine the value
142 inoutvec[i].value = min(inoutvec[i].value, invec[i].value);
143 }
144 }
145
146 static void modifiedAdd(ModifiedType* invec, ModifiedType* inoutvec, int* len, MPI_Datatype* datatype) {
147 MEDI_UNUSED(datatype);
148
149 for(int i = 0; i < *len; ++i) {
150 Tool::modifyDependency(invec[i], inoutvec[i]);
151 Tool::setPrimalToMod(inoutvec[i], Tool::getPrimalFromMod(invec[i]) + Tool::getPrimalFromMod(inoutvec[i]));
152 }
153 }
154
155 static void modifiedMul(ModifiedType* invec, ModifiedType* inoutvec, int* len, MPI_Datatype* datatype) {
156 MEDI_UNUSED(datatype);
157
158 for(int i = 0; i < *len; ++i) {
159 Tool::modifyDependency(invec[i], inoutvec[i]);
160 Tool::setPrimalToMod(inoutvec[i], Tool::getPrimalFromMod(invec[i]) * Tool::getPrimalFromMod(inoutvec[i]));
161 }
162 }
163
164 static void modifiedMax(ModifiedType* invec, ModifiedType* inoutvec, int* len, MPI_Datatype* datatype) {
165 MEDI_UNUSED(datatype);
166
167 using std::max;
168 for(int i = 0; i < *len; ++i) {
169 Tool::modifyDependency(invec[i], inoutvec[i]);
170 Tool::setPrimalToMod(inoutvec[i], max(Tool::getPrimalFromMod(invec[i]), Tool::getPrimalFromMod(inoutvec[i])));
171 }
172 }
173
174 static void modifiedMin(ModifiedType* invec, ModifiedType* inoutvec, int* len, MPI_Datatype* datatype) {
175 MEDI_UNUSED(datatype);
176
177 using std::min;
178 for(int i = 0; i < *len; ++i) {
179 Tool::modifyDependency(invec[i], inoutvec[i]);
180 Tool::setPrimalToMod(inoutvec[i], min(Tool::getPrimalFromMod(invec[i]), Tool::getPrimalFromMod(inoutvec[i])));
181 }
182 }
183
184 static void modifiedMaxLoc(ModTypeInt* invec, ModTypeInt* inoutvec, int* len, MPI_Datatype* datatype) {
185 MEDI_UNUSED(datatype);
186
187 using std::max;
188 for(int i = 0; i < *len; ++i) {
189
190 PrimalType inPrimal = Tool::getPrimalFromMod(invec[i].value);
191 PrimalType inoutPrimal = Tool::getPrimalFromMod(inoutvec[i].value);
192
193 // first determine the index
194 if(inPrimal > inoutPrimal) {
195 inoutvec[i].index = invec[i].index;
196 } else if(inPrimal < inoutPrimal){
197 // empty operation: inoutvec[i].index = inoutvec[i].index;
198 } else {
199 inoutvec[i].index = std::min(invec[i].index, inoutvec[i].index);
200 }
201
202 Tool::modifyDependency(invec[i].value, inoutvec[i].value);
203 Tool::setPrimalToMod(inoutvec[i].value, max(inPrimal, inoutPrimal));
204 }
205 }
206
207 static void modifiedMinLoc(ModTypeInt* invec, ModTypeInt* inoutvec, int* len, MPI_Datatype* datatype) {
208 MEDI_UNUSED(datatype);
209
210 using std::min;
211 for(int i = 0; i < *len; ++i) {
212 PrimalType inPrimal = Tool::getPrimalFromMod(invec[i].value);
213 PrimalType inoutPrimal = Tool::getPrimalFromMod(inoutvec[i].value);
214
215 // first determine the index
216 if(inPrimal < inoutPrimal) {
217 inoutvec[i].index = invec[i].index;
218 } else if(inPrimal > inoutPrimal){
219 // empty operation: inoutvec[i].index = inoutvec[i].index;
220 } else {
221 inoutvec[i].index = std::min(invec[i].index, inoutvec[i].index);
222 }
223
224 Tool::modifyDependency(invec[i].value, inoutvec[i].value);
225 Tool::setPrimalToMod(inoutvec[i].value, min(inPrimal, inoutPrimal));
226 }
227 }
228
229// TODO: These are currently not used since we can not handle zero terms. Need to implement
230// a tracking of how many zeros the multiplication contained.
231//
232// void preAdjMul(AdjointType* adjoints, PrimalType* primals, int count) {
233// for(int i = 0; i < count; ++i) {
234// adjoints[i] *= primals[i];
235// }
236// }
237//
238// void postAdjMul(AdjointType* adjoints, PrimalType* primals, PrimalType* rootPrimals, int count) {
239// CODI_UNUSED(rootPrimals);
240//
241// for(int i = 0; i < count; ++i) {
242// if(0.0 != primals[i]) {
243// adjoints[i] /= primals[i];
244// }
245// }
246// }
247
248 static void postAdjMinMax(AdjointType* adjoints, PrimalType* primals, PrimalType* rootPrimals, int count, int vecSize) {
249 for(int i = 0; i < count; ++i) {
250 if(rootPrimals[i] != primals[i]) {
251 for(int dim = 0; dim < vecSize; ++dim) {
252 adjoints[i * vecSize + dim] = AdjointType(); // the primal of this process was not the minimum or maximum so do not perfrom the adjoint update
253 }
254 }
255 }
256 }
257 };
258
263 template<INTERFACE_ARG(FuncHelp)>
265 public:
266
267 INTERFACE_DEF(FunctionHelper, FuncHelp, void, void, void, void)
268
275
277 AMPI_Op_create(false, false,
278 (MPI_User_function*)FuncHelp::unmodifiedAdd, 1,
279 (MPI_User_function*)FuncHelp::modifiedAdd, 1,
280 medi::noPreAdjointOperation,
281 medi::noPostAdjointOperation,
282 &OP_SUM);
283 AMPI_Op_create((MPI_User_function*)FuncHelp::unmodifiedMul, 1, &OP_PROD);
284 AMPI_Op_create((MPI_User_function*)FuncHelp::unmodifiedMin, 1, &OP_MIN);
285 AMPI_Op_create((MPI_User_function*)FuncHelp::unmodifiedMax, 1, &OP_MAX);
286 AMPI_Op_create((MPI_User_function*)FuncHelp::unmodifiedMinLoc, 1, &OP_MINLOC);
287 AMPI_Op_create((MPI_User_function*)FuncHelp::unmodifiedMaxLoc, 1, &OP_MAXLOC);
288 }
289
291 if(MPI_SUM == op.primalFunction) {
292 return OP_SUM;
293 } else if(MPI_PROD == op.primalFunction) {
294 return OP_PROD;
295 } else if(MPI_MIN == op.primalFunction) {
296 return OP_MIN;
297 } else if(MPI_MAX == op.primalFunction) {
298 return OP_MAX;
299 } else if(MPI_MINLOC == op.primalFunction) {
300 return OP_MINLOC;
301 } else if(MPI_MAXLOC == op.primalFunction) {
302 return OP_MAXLOC;
303 } else {
304 // do not change the type if it is not one of the above
305 return op;
306 }
307 }
308
309 void init() {
311 }
312
313 void finalize() {
314 OP_SUM.free();
315 OP_PROD.free();
316 OP_MIN.free();
317 OP_MAX.free();
318 OP_MINLOC.free();
319 OP_MAXLOC.free();
320 }
321
323
324 AMPI_Datatype intType;
325 AMPI_Aint offsets[3] = {
326 offsetof(typename FuncHelp::TypeInt, value),
327 offsetof(typename FuncHelp::TypeInt, index),
328 offsetof(typename FuncHelp::TypeInt, index) + sizeof(int)
329 };
330 int blockLength[3] = {1, 1, (int)(sizeof(typename FuncHelp::TypeInt) - offsets[2])};
331 const AMPI_Datatype types[3] = {type, AMPI_INT, AMPI_BYTE};
332
333 AMPI_Type_create_struct(3, blockLength, offsets, types, &intType);
334 AMPI_Type_commit(&intType);
335
336 return intType;
337 }
338
339 static void freeIntType(AMPI_Datatype& type) {
340 AMPI_Type_free(&type);
341 }
342 };
343}
#define AMPI_Aint
Definition ampiDefinitions.h:760
Wrapper interface for MPI types in communications.
Definition typeInterface.hpp:63
#define INTERFACE_DEF(interface, name,...)
Definition macros.h:116
#define MEDI_UNUSED(name)
Definition macros.h:108
Global namespace for MeDiPack - Message Differentiation Package.
Definition adjointInterface.hpp:37
int AMPI_Type_free(MpiTypeInterface **d)
Definition constructedDatatypes.hpp:812
int AMPI_Type_commit(MpiTypeInterface **d)
Definition constructedDatatypes.hpp:800
int AMPI_Type_create_struct(int count, const int *array_of_blocklengths, const MPI_Aint *array_of_displacements, MpiTypeInterface *const *array_of_types, MpiTypeInterface **newtype)
Definition constructedDatatypes.hpp:787
AMPI_BYTE_Type * AMPI_BYTE
Definition ampiDefinitions.cpp:169
int AMPI_Op_create(MPI_User_function *user_fn, int commute, AMPI_Op *op)
Default forward of the operator creation.
Definition operatorFunctions.hpp:53
AMPI_INT_Type * AMPI_INT
Definition ampiDefinitions.cpp:91
Structure for the special handling of the MPI_Op structure.
Definition op.hpp:50
int free()
Definition op.hpp:186
MPI_Op primalFunction
The mpi operator for the unmodified AD types. The AD tool needs to record all operations that are eva...
Definition op.hpp:69
Definition indexTypeHelper.hpp:69
int index
Definition indexTypeHelper.hpp:71
ModifiedType value
Definition indexTypeHelper.hpp:70
Definition indexTypeHelper.hpp:64
int index
Definition indexTypeHelper.hpp:66
Type value
Definition indexTypeHelper.hpp:65
The provides all methods required for the creation of operators for AD types.
Definition indexTypeHelper.hpp:61
static void unmodifiedAdd(Type *invec, Type *inoutvec, int *len, MPI_Datatype *datatype)
Definition indexTypeHelper.hpp:74
static void modifiedAdd(ModifiedType *invec, ModifiedType *inoutvec, int *len, MPI_Datatype *datatype)
Definition indexTypeHelper.hpp:146
static void unmodifiedMul(Type *invec, Type *inoutvec, int *len, MPI_Datatype *datatype)
Definition indexTypeHelper.hpp:82
static void unmodifiedMaxLoc(TypeInt *invec, TypeInt *inoutvec, int *len, MPI_Datatype *datatype)
Definition indexTypeHelper.hpp:108
static void postAdjMinMax(AdjointType *adjoints, PrimalType *primals, PrimalType *rootPrimals, int count, int vecSize)
Definition indexTypeHelper.hpp:248
static void unmodifiedMinLoc(TypeInt *invec, TypeInt *inoutvec, int *len, MPI_Datatype *datatype)
Definition indexTypeHelper.hpp:127
static void modifiedMin(ModifiedType *invec, ModifiedType *inoutvec, int *len, MPI_Datatype *datatype)
Definition indexTypeHelper.hpp:174
static void unmodifiedMax(Type *invec, Type *inoutvec, int *len, MPI_Datatype *datatype)
Definition indexTypeHelper.hpp:90
static void modifiedMaxLoc(ModTypeInt *invec, ModTypeInt *inoutvec, int *len, MPI_Datatype *datatype)
Definition indexTypeHelper.hpp:184
static void modifiedMinLoc(ModTypeInt *invec, ModTypeInt *inoutvec, int *len, MPI_Datatype *datatype)
Definition indexTypeHelper.hpp:207
static void modifiedMax(ModifiedType *invec, ModifiedType *inoutvec, int *len, MPI_Datatype *datatype)
Definition indexTypeHelper.hpp:164
static void unmodifiedMin(Type *invec, Type *inoutvec, int *len, MPI_Datatype *datatype)
Definition indexTypeHelper.hpp:99
static void modifiedMul(ModifiedType *invec, ModifiedType *inoutvec, int *len, MPI_Datatype *datatype)
Definition indexTypeHelper.hpp:155
Definition indexTypeHelper.hpp:264
AMPI_Op OP_SUM
Definition indexTypeHelper.hpp:269
AMPI_Op OP_MIN
Definition indexTypeHelper.hpp:271
void init()
Definition indexTypeHelper.hpp:309
static void freeIntType(AMPI_Datatype &type)
Definition indexTypeHelper.hpp:339
AMPI_Op OP_MAX
Definition indexTypeHelper.hpp:272
AMPI_Op OP_MAXLOC
Definition indexTypeHelper.hpp:274
AMPI_Op OP_MINLOC
Definition indexTypeHelper.hpp:273
AMPI_Op OP_PROD
Definition indexTypeHelper.hpp:270
void createOperators()
Definition indexTypeHelper.hpp:276
static AMPI_Datatype createIntType(const AMPI_Datatype type)
Definition indexTypeHelper.hpp:322
void finalize()
Definition indexTypeHelper.hpp:313
AMPI_Op convertOperator(AMPI_Op op) const
Definition indexTypeHelper.hpp:290
Definition indexTypeHelper.hpp:44
static void setPrimalToMod(ModifiedType &mod, const PrimalType &value)
static PrimalType getPrimalFromMod(const ModifiedType &mod)
static void modifyDependency(const ModifiedType &in, ModifiedType &inout)