37#include <medi/adToolInterface.h>
38#include <medi/ampi/ampiMisc.h>
40#include <medi/adToolImplCommon.hpp>
41#include <medi/adjointInterface.hpp>
42#include <medi/ampi/op.hpp>
43#include <medi/ampi/typeDefault.hpp>
44#include <medi/ampi/types/indexTypeHelper.hpp>
47#include "../../expressions/lhsExpressionInterface.hpp"
49#include "../../tapes/interfaces/fullTapeInterface.hpp"
50#include "../../tapes/misc/adjointVectorAccess.hpp"
55#ifndef DOXYGEN_DISABLE
57 template<
typename T_Type>
58 struct CoDiMeDiAdjointInterfaceWrapper :
public medi::AdjointInterface {
61 using Type =
CODI_DD(T_Type, CODI_DEFAULT_LHS_EXPRESSION);
63 using Real =
typename Type::Real;
64 using Identifier =
typename Type::Identifier;
66 VectorAccessInterface<Real, Identifier>* codiInterface;
70 CoDiMeDiAdjointInterfaceWrapper(VectorAccessInterface<Real, Identifier>* interface)
71 : codiInterface(interface), vecSize((int)interface->getVectorSize()) {}
74 return elements * vecSize;
83 Identifier* indices = (Identifier*)i;
85 for (
int pos = 0; pos < elements; ++pos) {
86 codiInterface->getAdjointVec(indices[pos], &adjoints[pos * vecSize]);
87 codiInterface->resetAdjointVec(indices[pos]);
91 CODI_INLINE_NO_FA void updateAdjoints(
void const* i,
void const* a,
int elements)
const {
93 Identifier* indices = (Identifier*)i;
95 for (
int pos = 0; pos < elements; ++pos) {
96 codiInterface->updateAdjointVec(indices[pos], &adjoints[pos * vecSize]);
100 CODI_INLINE_NO_FA void getPrimals(
void const* i,
void const* p,
int elements)
const {
102 Identifier* indices = (Identifier*)i;
104 for (
int pos = 0; pos < elements; ++pos) {
105 primals[pos] = codiInterface->getPrimal(indices[pos]);
109 CODI_INLINE_NO_FA void setPrimals(
void const* i,
void const* p,
int elements)
const {
111 Identifier* indices = (Identifier*)i;
113 for (
int pos = 0; pos < elements; ++pos) {
114 codiInterface->setPrimal(indices[pos], primals[pos]);
118 CODI_INLINE_NO_FA void combineAdjoints(
void* b,
int const elements,
int const ranks)
const {
121 for (
int curRank = 1; curRank < ranks; ++curRank) {
122 for (
int curPos = 0; curPos < elements; ++curPos) {
123 for (
int dim = 0;
dim < vecSize; ++
dim) {
124 buf[curPos * vecSize +
dim] += buf[(elements * curRank + curPos) * vecSize + dim];
131 buf = (
void*)(
new Real[size * vecSize]);
143 buf = (
void*)(
new Real[size * vecSize]);
155 template<
typename T_Type>
156 struct CoDiPackReverseTool
157 :
public medi::ADToolImplCommon<CoDiPackReverseTool<T_Type>, T_Type::Tape::RequiresPrimalRestore, false, T_Type,
158 typename T_Type::Gradient, typename T_Type::Real, typename T_Type::Identifier> {
162 using Type =
CODI_DD(T_Type, CODI_DEFAULT_LHS_EXPRESSION);
163 using PrimalType =
typename Type::Real;
164 using AdjointType = void;
165 using ModifiedType = Type;
166 using IndexType =
typename Type::Identifier;
169 using Tape =
CODI_DD(
typename Type::Tape, CODI_DEFAULT_TAPE);
170 using IterCallback =
typename ExternalFunction<Tape>::IterCallback;
173 medi::OperatorHelper<medi::FunctionHelper<Type, Type,
typename Type::PassiveReal,
typename Type::Gradient,
174 typename Type::Identifier, CoDiPackReverseTool> >;
176 using Base = medi::ADToolImplCommon<CoDiPackReverseTool, Tape::RequiresPrimalRestore,
false, Type,
177 typename Type::Gradient, PrimalType, IndexType>;
185 CoDiPackReverseTool(MPI_Datatype primalMpiType, MPI_Datatype adjointMpiType)
186 : Base(primalMpiType, adjointMpiType), opHelper() {
190 ~CoDiPackReverseTool() {
199 return getTape().isActive();
210 getTape().pushExternalFunction(
211 ExternalFunction<Tape>::create(callHandleReverse, h, deleteHandle, callHandleForward, callHandlePrimal,
212 callHandleIterateInputs, callHandleIterateOutputs));
216 medi::AMPI_Op convertOperator(medi::AMPI_Op op)
const {
217 return opHelper.convertOperator(op);
227 return value.getIdentifier();
230 static CODI_INLINE_NO_FA void registerValue(Type& value, PrimalType& oldPrimal, IndexType& index) {
231 bool wasActive = getTape().isIdentifierActive(value.getIdentifier());
232 value.getIdentifier() = IndexType();
236 if (Tape::LinearIndexHandling) {
238 value.getIdentifier() = index;
241 if (Tape::HasPrimalValues) {
242 getTape().setPrimal(index, value.getValue());
244 if (Tape::RequiresPrimalRestore) {
245 oldPrimal = PrimalType(0.0);
248 PrimalType primal = getTape().registerExternalFunctionOutput(value);
249 if (Tape::RequiresPrimalRestore) {
252 index = value.getIdentifier();
255 if (Tape::RequiresPrimalRestore) {
256 oldPrimal = PrimalType(0.0);
258 if (!Tape::LinearIndexHandling) {
259 index = getTape().getPassiveIndex();
265 IndexType oldIndex = value.getIdentifier();
267 value.getIdentifier() = oldIndex;
272 if (Tape::LinearIndexHandling) {
273 IndexType oldIndex = value.getIdentifier();
274 getTape().registerInput(value);
275 index = value.getIdentifier();
276 value.getIdentifier() = oldIndex;
282 return value.getValue();
285 static CODI_INLINE_NO_FA void setIntoModifyBuffer(ModifiedType& modValue, Type
const& value) {
291 static CODI_INLINE_NO_FA void getFromModifyBuffer(ModifiedType
const& modValue, Type& value) {
297 static PrimalType getPrimalFromMod(ModifiedType
const& modValue) {
298 return modValue.value();
301 static void setPrimalToMod(ModifiedType& modValue, PrimalType
const& value) {
302 modValue.value() = value;
305 static void modifyDependency(ModifiedType& inval, ModifiedType& inoutval) {
306 bool active = getTape().isIdentifierActive(inoutval.getIdentifier()) ||
307 getTape().isIdentifierActive(inval.getIdentifier());
309 inoutval.getIdentifier() = getTape().getInvalidIndex();
311 inoutval.getIdentifier() = getTape().getPassiveIndex();
317 static void callHandleReverse(Tape* tape,
void* h, VectorAccessInterface<PrimalType, IndexType>* ah) {
320 medi::HandleBase* handle =
static_cast<medi::HandleBase*
>(h);
321 CoDiMeDiAdjointInterfaceWrapper<Type> ahWrapper(ah);
322 handle->funcReverse(handle, &ahWrapper);
325 static void callHandleForward(Tape* tape,
void* h, VectorAccessInterface<PrimalType, IndexType>* ah) {
328 medi::HandleBase* handle =
static_cast<medi::HandleBase*
>(h);
329 CoDiMeDiAdjointInterfaceWrapper<Type> ahWrapper(ah);
330 handle->funcForward(handle, &ahWrapper);
333 static void callHandlePrimal(Tape* tape,
void* h, VectorAccessInterface<PrimalType, IndexType>* ah) {
336 medi::HandleBase* handle =
static_cast<medi::HandleBase*
>(h);
337 CoDiMeDiAdjointInterfaceWrapper<Type> ahWrapper(ah);
338 handle->funcPrimal(handle, &ahWrapper);
341 static void deleteHandle(Tape* tape,
void* h) {
344 medi::HandleBase* handle =
static_cast<medi::HandleBase*
>(h);
348 static void callHandleIterateInputs(Tape* tape,
void* h, IterCallback func,
void* userData) {
351 medi::HandleBase* handle =
static_cast<medi::HandleBase*
>(h);
352 handle->funcIterateInputIds(handle, (::medi::CallbackFunc)func, userData);
355 static void callHandleIterateOutputs(Tape* tape,
void* h, IterCallback func,
void* userData) {
358 medi::HandleBase* handle =
static_cast<medi::HandleBase*
>(h);
359 handle->funcIterateOutputIds(handle, (::medi::CallbackFunc)func, userData);
362 static Tape& getTape() {
363 return Type::getTape();
#define CODI_INLINE_NO_FA
See codi::Config::ForcedInlines.
Definition config.h:471
#define CODI_DD(Type, Default)
Abbreviation for CODI_DECLARE_DEFAULT.
Definition macros.hpp:97
typename TraitsImplementation< Gradient >::Real Real
The base value used in the gradient entries.
Definition gradientTraits.hpp:92
inlinesize_t constexpr dim()
Number of dimensions this gradient value has.
Definition gradientTraits.hpp:96
inlinetypename DataExtraction< Type >::Real getValue(Type const &v)
Extract an aggregate of primal values from an aggregate of active types.
Definition realTraits.hpp:381
CoDiPack - Code Differentiation Package.
Definition codi.hpp:97
inlinevoid CODI_UNUSED(Args const &...)
Disable unused warnings for an arbitrary number of arguments.
Definition macros.hpp:55