CoDiPack  2.3.0
A Code Differentiation Package
SciComp TU Kaiserslautern
Loading...
Searching...
No Matches
codiOpDiLibTool.hpp
1/*
2 * CoDiPack, a Code Differentiation Package
3 *
4 * Copyright (C) 2015-2024 Chair for Scientific Computing (SciComp), University of Kaiserslautern-Landau
5 * Homepage: http://scicomp.rptu.de
6 * Contact: Prof. Nicolas R. Gauger (codi@scicomp.uni-kl.de)
7 *
8 * Lead developers: Max Sagebaum, Johannes Blühdorn (SciComp, University of Kaiserslautern-Landau)
9 *
10 * This file is part of CoDiPack (http://scicomp.rptu.de/software/codi).
11 *
12 * CoDiPack is free software: you can redistribute it and/or
13 * modify it under the terms of the GNU General Public License
14 * as published by the Free Software Foundation, either version 3 of the
15 * License, or (at your option) any later version.
16 *
17 * CoDiPack is distributed in the hope that it will be useful,
18 * but WITHOUT ANY WARRANTY; without even the implied warranty
19 * of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
20 *
21 * See the GNU General Public License for more details.
22 * You should have received a copy of the GNU
23 * General Public License along with CoDiPack.
24 * If not, see <http://www.gnu.org/licenses/>.
25 *
26 * For other licensing options please contact us.
27 *
28 * Authors:
29 * - SciComp, University of Kaiserslautern-Landau:
30 * - Max Sagebaum
31 * - Johannes Blühdorn
32 * - Former members:
33 * - Tim Albring
34 */
35#pragma once
36
37#include <opdi/tool/toolInterface.hpp>
38#include <sstream>
39
40#include "../../../expressions/lhsExpressionInterface.hpp"
41#include "../../../expressions/parallelActiveType.hpp"
43#include "../../../tapes/interfaces/editingTapeInterface.hpp"
44#include "../../../tapes/misc/vectorAccessInterface.hpp"
45#include "../../../traits/atomicTraits.hpp"
46#include "openMPReverseAtomic.hpp"
47
48#ifndef DOXYGEN_DISABLE
49
50template<typename T_CoDiType>
51struct CoDiOpDiLibTool : public opdi::ToolInterface {
52 public:
53 using CoDiType = CODI_DD(T_CoDiType, codi::CODI_DEFAULT_PARALLEL_ACTIVE_TYPE);
54 using Real = typename CoDiType::Real;
55 using Identifier = typename CoDiType::Identifier;
56 using Tape = typename CoDiType::Tape;
57 using Position = typename Tape::Position;
58
60
61 private:
62 static void callHandleReverse(Tape*, void* handlePtr, VAI*) {
63 opdi::Handle* handle = (opdi::Handle*)handlePtr;
64 handle->reverseFunc(handle->data);
65 }
66
67 static void callHandleDelete(Tape*, void* handlePtr) {
68 opdi::Handle* handle = (opdi::Handle*)handlePtr;
69 if (handle->deleteFunc != nullptr) {
70 handle->deleteFunc(handle->data);
71 }
72 delete handle;
73 }
74
75 public:
76 void init() {}
77
78 void finalize() {}
79
80 void* createTape() {
81 return (void*)new Tape;
82 }
83
84 void deleteTape(void* tapePtr) {
85 Tape* tape = (Tape*)tapePtr;
86 delete tape;
87 }
88
89 void* allocPosition() {
90 return new Position();
91 }
92
93 void freePosition(void* positionPtr) {
94 Position* position = (Position*)positionPtr;
95 delete position;
96 }
97
98 size_t getPositionSize() {
99 return sizeof(Position);
100 }
101
102 std::string positionToString(void* positionPtr) {
103 Position* position = (Position*)positionPtr;
104 std::stringstream conv;
105 conv << *position;
106 return conv.str();
107 }
108
109 void getTapePosition(void* tapePtr, void* positionPtr) {
110 Tape* tape = (Tape*)tapePtr;
111 Position* position = (Position*)positionPtr;
112
113 *position = tape->getPosition();
114 }
115
116 void getZeroPosition(void* tapePtr, void* positionPtr) {
117 Tape* tape = (Tape*)tapePtr;
118 Position* position = (Position*)positionPtr;
119
120 *position = tape->getZeroPosition();
121 }
122
123 void copyPosition(void* dstPtr, void* srcPtr) {
124 Position* dst = (Position*)dstPtr;
125 Position* src = (Position*)srcPtr;
126
127 *dst = *src;
128 }
129
130 int comparePosition(void* lhsPtr, void* rhsPtr) {
131 Position* lhs = (Position*)lhsPtr;
132 Position* rhs = (Position*)rhsPtr;
133
134 if (*lhs <= *rhs) {
135 if (*rhs <= *lhs) {
136 return 0;
137 } else {
138 return -1;
139 }
140 } else {
141 return 1;
142 }
143 }
144
145 bool isActive(void* tapePtr) {
146 Tape* tape = (Tape*)tapePtr;
147 return tape->isActive();
148 }
149
150 void setActive(void* tapePtr, bool active) {
151 Tape* tape = (Tape*)tapePtr;
152 if (active) {
153 tape->setActive();
154 } else {
155 tape->setPassive();
156 }
157 }
158
159 void evaluate(void* tapePtr, void* startPtr, void* endPtr, bool useAtomics = true) {
160 Tape* tape = (Tape*)tapePtr;
161 Position* start = (Position*)startPtr;
162 Position* end = (Position*)endPtr;
163
164 if (tape->isActive()) {
165 std::cerr << "Warning: OpDiLib evaluation of an active tape." << std::endl;
166 }
167
170
171 if (useAtomics) {
172 tape->evaluate(*start, *end, reinterpret_cast<AtomicGradient*>(&tape->gradient(0)));
173 } else {
174 tape->evaluate(*start, *end, reinterpret_cast<NonAtomicGradient*>(&tape->gradient(0)));
175 }
176 }
177
178 void reset(void* tapePtr, bool clearAdjoints) {
179 Tape* tape = (Tape*)tapePtr;
180 tape->reset(clearAdjoints);
181 }
182
183 void reset(void* tapePtr, void* positionPtr, bool clearAdjoints) {
184 Tape* tape = (Tape*)tapePtr;
185 Position* position = (Position*)positionPtr;
186 tape->resetTo(*position, clearAdjoints);
187 }
188
189 void* getThreadLocalTape() {
190 return (void*)CoDiType::getTapePtr();
191 }
192
193 void setThreadLocalTape(void* tapePtr) {
194 Tape* tape = (Tape*)tapePtr;
195 CoDiType::setTapePtr(tape);
196 }
197
198 void pushExternalFunction(void* tapePtr, opdi::Handle const* handle) {
199 Tape* tape = (Tape*)tapePtr;
200 tape->pushExternalFunction(codi::ExternalFunction<Tape>::create(CoDiOpDiLibTool::callHandleReverse, (void*)handle,
201 CoDiOpDiLibTool::callHandleDelete));
202 }
203
204 void erase(void* tapePtr, void* startPtr, void* endPtr) {
205 Tape* tape = (Tape*)tapePtr;
206 Position* start = (Position*)startPtr;
207 Position* end = (Position*)endPtr;
208
209 tape->erase(*start, *end);
210 }
211
212 void append(void* dstTapePtr, void* srcTapePtr, void* startPtr, void* endPtr) {
213 Tape* dstTape = (Tape*)dstTapePtr;
214 Tape* srcTape = (Tape*)srcTapePtr;
215 Position* start = (Position*)startPtr;
216 Position* end = (Position*)endPtr;
217
218 dstTape->append(*srcTape, *start, *end);
219 }
220};
221
222#endif
#define CODI_DD(Type, Default)
Abbreviation for CODI_DECLARE_DEFAULT.
Definition macros.hpp:94
typename RemoveAtomicImpl< Type >::Type RemoveAtomic
Wrapper for removing atomic from a type.
Definition atomicTraits.hpp:83
Represents a concrete lvalue in the CoDiPack expression tree.
Definition activeType.hpp:52
User-defined evaluation functions for the taping process.
Definition externalFunction.hpp:102
Reverse atomic implementation for OpenMP.
Definition openMPReverseAtomic.hpp:61
Unified access to the adjoint vector and primal vector in a tape evaluation.
Definition vectorAccessInterface.hpp:91