CoDiPack  2.3.0
A Code Differentiation Package
SciComp TU Kaiserslautern
Loading...
Searching...
No Matches
aggregatedTypeVectorAccessWrapper.hpp
1/*
2 * CoDiPack, a Code Differentiation Package
3 *
4 * Copyright (C) 2015-2024 Chair for Scientific Computing (SciComp), University of Kaiserslautern-Landau
5 * Homepage: http://scicomp.rptu.de
6 * Contact: Prof. Nicolas R. Gauger (codi@scicomp.uni-kl.de)
7 *
8 * Lead developers: Max Sagebaum, Johannes Blühdorn (SciComp, University of Kaiserslautern-Landau)
9 *
10 * This file is part of CoDiPack (http://scicomp.rptu.de/software/codi).
11 *
12 * CoDiPack is free software: you can redistribute it and/or
13 * modify it under the terms of the GNU General Public License
14 * as published by the Free Software Foundation, either version 3 of the
15 * License, or (at your option) any later version.
16 *
17 * CoDiPack is distributed in the hope that it will be useful,
18 * but WITHOUT ANY WARRANTY; without even the implied warranty
19 * of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
20 *
21 * See the GNU General Public License for more details.
22 * You should have received a copy of the GNU
23 * General Public License along with CoDiPack.
24 * If not, see <http://www.gnu.org/licenses/>.
25 *
26 * For other licensing options please contact us.
27 *
28 * Authors:
29 * - SciComp, University of Kaiserslautern-Landau:
30 * - Max Sagebaum
31 * - Johannes Blühdorn
32 * - Former members:
33 * - Tim Albring
34 */
35#pragma once
36
37#include <complex>
38#include <vector>
39
40#include "../../config.h"
41#include "../../expressions/lhsExpressionInterface.hpp"
42#include "../../misc/macros.hpp"
43#include "../../tapes/misc/vectorAccessInterface.hpp"
44#include "../../traits/computationTraits.hpp"
45#include "../../traits/expressionTraits.hpp"
46#include "../../traits/realTraits.hpp"
47
49namespace codi {
50
76 template<typename T_Type, typename = void>
77 struct AggregatedTypeVectorAccessWrapper : public VectorAccessInterface<CODI_ANY, CODI_ANY> {
78 CODI_STATIC_ASSERT(false && std::is_void<T_Type>::value,
79 "Instantiation of unspecialized AggregatedTypeVectorAccessWrapper.");
80
81 using Type = CODI_DD(T_Type, CODI_ANY);
82 };
83
92 template<typename T_Real, typename T_Identifier, typename T_InnerInterface>
93 struct AggregatedTypeVectorAccessWrapperBase : public VectorAccessInterface<T_Real, T_Identifier> {
94 public:
95
96 using Real = CODI_DD(T_Real, double);
97 using Identifier = CODI_DD(T_Identifier, int);
98
100 CODI_DD(T_InnerInterface,
102
103 protected:
104
106
107 std::vector<Real> lhs;
108 std::vector<Real> buffer;
109
110 public:
111
117
118 /*******************************************************************************/
120
122 size_t getVectorSize() const {
123 return innerInterface.getVectorSize();
124 }
125
127 bool isLhsZero() {
128 bool isZero = true;
129
130 for (size_t curDim = 0; isZero && curDim < lhs.size(); curDim += 1) {
131 isZero &= RealTraits::isTotalZero(lhs[curDim]);
132 }
133
134 return isZero;
135 }
136
137 /*******************************************************************************/
139
141 void setLhsAdjoint(Identifier const& index) {
142 getAdjointVec(index, lhs.data());
143 this->resetAdjointVec(index);
144 }
145
147 void updateAdjointWithLhs(Identifier const& index, Real const& jacobian) {
148 for (size_t curDim = 0; curDim < lhs.size(); curDim += 1) {
149 Real update = jacobian * lhs[curDim];
150 this->updateAdjoint(index, curDim, update);
151 }
152 }
153
154 /*******************************************************************************/
156
158 void setLhsTangent(Identifier const& index) {
159 updateAdjointVec(index, lhs.data());
160
161 for (size_t curDim = 0; curDim < lhs.size(); curDim += 1) {
162 lhs[curDim] = Real();
163 }
164 }
165
167 void updateTangentWithLhs(Identifier const& index, Real const& jacobian) {
168 for (size_t curDim = 0; curDim < lhs.size(); curDim += 1) {
169 lhs[curDim] += jacobian * this->getAdjoint(index, curDim);
170 }
171 }
172
173 /*******************************************************************************/
175
177 void getAdjointVec(Identifier const& index, Real* const vec) {
178 for (size_t curDim = 0; curDim < lhs.size(); curDim += 1) {
179 vec[curDim] = this->getAdjoint(index, curDim);
180 }
181 }
182
184 Real const* getAdjointVec(Identifier const& index) {
185 getAdjointVec(index, buffer.data());
186 return buffer.data();
187 }
188
190 void updateAdjointVec(Identifier const& index, Real const* const vec) {
191 for (size_t curDim = 0; curDim < lhs.size(); curDim += 1) {
192 this->updateAdjoint(index, curDim, vec[curDim]);
193 }
194 }
195
196 /*******************************************************************************/
198
200 bool hasPrimals() {
201 return innerInterface.hasPrimals();
202 }
203 };
204
216 template<typename T_Type, typename = void>
218 public:
219 using Type = CODI_DD(T_Type, CODI_ANY);
220
222
226 template<typename Real, typename Identifier>
228 return new RType(access);
229 }
230
232 static void destroy(RType* access) {
233 delete access;
234 }
235 };
236
237#ifndef DOXYGEN_DISABLE
241 template<typename T_InnerType>
242 struct AggregatedTypeVectorAccessWrapper<std::complex<T_InnerType>>
243 : public AggregatedTypeVectorAccessWrapperBase<
244 std::complex<typename T_InnerType::Real>, std::complex<typename T_InnerType::Identifier>,
245 VectorAccessInterface<typename T_InnerType::Real, typename T_InnerType::Identifier>> {
246 public:
247
248 using InnerType = CODI_DD(T_InnerType, CODI_DEFAULT_LHS_EXPRESSION);
249 using Type = std::complex<InnerType>;
250
251 using InnerInterface = VectorAccessInterface<
252 typename InnerType::Real,
253 typename InnerType::Identifier>;
254
255 using Real = std::complex<typename InnerType::Real>;
256 using Identifier = std::complex<typename InnerType::Identifier>;
257
258 using Base =
259 AggregatedTypeVectorAccessWrapperBase<Real, Identifier, InnerInterface>;
260
262 AggregatedTypeVectorAccessWrapper(InnerInterface* innerInterface) : Base(innerInterface) {}
263
264 /*******************************************************************************/
266
268 void resetAdjoint(Identifier const& index, size_t dim) {
269 Base::innerInterface.resetAdjoint(std::real(index), dim);
270 Base::innerInterface.resetAdjoint(std::imag(index), dim);
271 }
272
274 void resetAdjointVec(Identifier const& index) {
275 Base::innerInterface.resetAdjointVec(std::real(index));
276 Base::innerInterface.resetAdjointVec(std::imag(index));
277 }
278
280 Real getAdjoint(Identifier const& index, size_t dim) {
281 return Real(Base::innerInterface.getAdjoint(std::real(index), dim),
282 Base::innerInterface.getAdjoint(std::imag(index), dim));
283 }
284
286 void updateAdjoint(Identifier const& index, size_t dim, Real const& adjoint) {
287 Base::innerInterface.updateAdjoint(std::real(index), dim, std::real(adjoint));
288 Base::innerInterface.updateAdjoint(std::imag(index), dim, std::imag(adjoint));
289 }
290
291 /*******************************************************************************/
293
295 void setPrimal(Identifier const& index, Real const& primal) {
296 Base::innerInterface.setPrimal(std::real(index), std::real(primal));
297 Base::innerInterface.setPrimal(std::imag(index), std::imag(primal));
298 }
299
301 Real getPrimal(Identifier const& index) {
302 return Real(Base::innerInterface.getPrimal(std::real(index)), Base::innerInterface.getPrimal(std::imag(index)));
303 }
304
306 VectorAccessInterface<Real, Identifier>* clone() const {
307 return new AggregatedTypeVectorAccessWrapper(Base::innerInterface.clone());
308 }
309 };
310
314 template<typename T_Type>
315 struct AggregatedTypeVectorAccessWrapperFactory<T_Type, ExpressionTraits::EnableIfLhsExpression<T_Type>> {
316 public:
317 using Type = CODI_DD(T_Type, CODI_DEFAULT_LHS_EXPRESSION);
318
319 using RType = VectorAccessInterface<typename Type::Real, typename Type::Identifier>;
320
322 static RType* create(RType* access) {
323 return access;
324 }
325
327 static void destroy(RType* access) {
328 CODI_UNUSED(access);
329
330 // Do nothing
331 }
332 };
333#endif
334}
#define CODI_DD(Type, Default)
Abbreviation for CODI_DECLARE_DEFAULT.
Definition macros.hpp:94
#define CODI_ANY
Used in default declarations of expression templates.
Definition macros.hpp:98
#define CODI_STATIC_ASSERT(cond, message)
Static assert in CoDiPack.
Definition macros.hpp:123
#define CODI_T(...)
Abbreviation for CODI_TEMPLATE.
Definition macros.hpp:111
typename std::enable_if< IsLhsExpression< Expr >::value, T >::type EnableIfLhsExpression
Enable if wrapper for IsLhsExpression.
Definition expressionTraits.hpp:137
bool isTotalZero(Type const &v)
Function for checking if the value of the type is completely zero.
Definition realTraits.hpp:139
CoDiPack - Code Differentiation Package.
Definition codi.hpp:91
void CODI_UNUSED(Args const &...)
Disable unused warnings for an arbitrary number of arguments.
Definition macros.hpp:46
Represents a concrete lvalue in the CoDiPack expression tree.
Definition activeType.hpp:52
Implements all methods from AggregatedTypeVectorAccessWrapper, that can be implemented with combinati...
Definition aggregatedTypeVectorAccessWrapper.hpp:93
void updateAdjointWithLhs(Identifier const &index, Real const &jacobian)
Definition aggregatedTypeVectorAccessWrapper.hpp:147
bool isLhsZero()
True if the adjoint set with setLhsAdjoint is zero.
Definition aggregatedTypeVectorAccessWrapper.hpp:127
void setLhsAdjoint(Identifier const &index)
Definition aggregatedTypeVectorAccessWrapper.hpp:141
T_Real Real
See RealTraits::DataExtraction::Real.
Definition aggregatedTypeVectorAccessWrapper.hpp:96
T_InnerInterface InnerInterface
See AggregatedTypeVectorAccessWrapperBase.
Definition aggregatedTypeVectorAccessWrapper.hpp:99
T_Identifier Identifier
See RealTraits::DataExtraction::Identifier.
Definition aggregatedTypeVectorAccessWrapper.hpp:97
size_t getVectorSize() const
Vector size in the current tape evaluation.
Definition aggregatedTypeVectorAccessWrapper.hpp:122
void getAdjointVec(Identifier const &index, Real *const vec)
Get the adjoint entry.
Definition aggregatedTypeVectorAccessWrapper.hpp:177
void updateTangentWithLhs(Identifier const &index, Real const &jacobian)
Definition aggregatedTypeVectorAccessWrapper.hpp:167
InnerInterface & innerInterface
Reference to the accessor of the underlying tape.
Definition aggregatedTypeVectorAccessWrapper.hpp:105
void updateAdjointVec(Identifier const &index, Real const *const vec)
Update the adjoint entry.
Definition aggregatedTypeVectorAccessWrapper.hpp:190
bool hasPrimals()
True if the tape/vector interface has primal values.
Definition aggregatedTypeVectorAccessWrapper.hpp:200
void setLhsTangent(Identifier const &index)
Definition aggregatedTypeVectorAccessWrapper.hpp:158
AggregatedTypeVectorAccessWrapperBase(InnerInterface *innerInterface)
Constructor.
Definition aggregatedTypeVectorAccessWrapper.hpp:113
std::vector< Real > lhs
Temporary storage for indirect adjoint or tangent updates.
Definition aggregatedTypeVectorAccessWrapper.hpp:107
Real const * getAdjointVec(Identifier const &index)
Get the adjoint entry.
Definition aggregatedTypeVectorAccessWrapper.hpp:184
std::vector< Real > buffer
Temporary storage for getAdjointVec access.
Definition aggregatedTypeVectorAccessWrapper.hpp:108
Factory for the creation of AggregatedTypeVectorAccessWrapper instances.
Definition aggregatedTypeVectorAccessWrapper.hpp:217
AggregatedTypeVectorAccessWrapper< Type > RType
Which instances this factory creates.
Definition aggregatedTypeVectorAccessWrapper.hpp:221
static void destroy(RType *access)
Delete the AggregatedTypeVectorAccessWrapper instance created by the crate method.
Definition aggregatedTypeVectorAccessWrapper.hpp:232
T_Type Type
See AggregatedTypeVectorAccessWrapperBase.
Definition aggregatedTypeVectorAccessWrapper.hpp:219
static RType * create(VectorAccessInterface< Real, Identifier > *access)
Definition aggregatedTypeVectorAccessWrapper.hpp:227
Generalized wrapper of the VectorAccessInterface for aggregated data types, e.g. std::complex<codi::R...
Definition aggregatedTypeVectorAccessWrapper.hpp:77
T_Type Type
See AggregatedTypeVectorAccessWrapperBase.
Definition aggregatedTypeVectorAccessWrapper.hpp:81
Unified access to the adjoint vector and primal vector in a tape evaluation.
Definition vectorAccessInterface.hpp:91
virtual void resetAdjoint(Identifier const &index, size_t dim)=0
Set the adjoint component to zero.
virtual void setPrimal(Identifier const &index, Real const &primal)=0
Set the primal value.
int Real
See VectorAccessInterface.
Definition vectorAccessInterface.hpp:94
virtual void resetAdjointVec(Identifier const &index)=0
Set the adjoint entry to zero.
virtual Real getPrimal(Identifier const &index)=0
Get the primal value.
virtual VectorAccessInterface * clone() const=0
virtual void updateAdjoint(Identifier const &index, size_t dim, Real const &adjoint)=0
Update the adjoint component.
virtual Real getAdjoint(Identifier const &index, size_t dim)=0
Get the adjoint component.