37#include <medi/adToolInterface.h>
38#include <medi/ampi/ampiMisc.h>
40#include <medi/adToolImplCommon.hpp>
41#include <medi/adjointInterface.hpp>
42#include <medi/ampi/op.hpp>
43#include <medi/ampi/typeDefault.hpp>
44#include <medi/ampi/types/indexTypeHelper.hpp>
47#include "../../expressions/lhsExpressionInterface.hpp"
49#include "../../tapes/interfaces/fullTapeInterface.hpp"
50#include "../../tapes/misc/adjointVectorAccess.hpp"
55#ifndef DOXYGEN_DISABLE
57 template<
typename T_Type>
58 struct CoDiMeDiAdjointInterfaceWrapper :
public medi::AdjointInterface {
61 using Type =
CODI_DD(T_Type, CODI_DEFAULT_LHS_EXPRESSION);
63 using Real =
typename Type::Real;
64 using Identifier =
typename Type::Identifier;
66 VectorAccessInterface<Real, Identifier>* codiInterface;
70 CoDiMeDiAdjointInterfaceWrapper(VectorAccessInterface<Real, Identifier>* interface)
71 : codiInterface(interface), vecSize((int)interface->getVectorSize()) {}
74 return elements * vecSize;
83 Identifier* indices = (Identifier*)i;
85 for (
int pos = 0; pos < elements; ++pos) {
86 codiInterface->getAdjointVec(indices[pos], &adjoints[pos * vecSize]);
87 codiInterface->resetAdjointVec(indices[pos]);
91 CODI_INLINE_NO_FA void updateAdjoints(
void const* i,
void const* a,
int elements)
const {
93 Identifier* indices = (Identifier*)i;
95 for (
int pos = 0; pos < elements; ++pos) {
96 codiInterface->updateAdjointVec(indices[pos], &adjoints[pos * vecSize]);
100 CODI_INLINE_NO_FA void getPrimals(
void const* i,
void const* p,
int elements)
const {
102 Identifier* indices = (Identifier*)i;
104 for (
int pos = 0; pos < elements; ++pos) {
105 primals[pos] = codiInterface->getPrimal(indices[pos]);
109 CODI_INLINE_NO_FA void setPrimals(
void const* i,
void const* p,
int elements)
const {
111 Identifier* indices = (Identifier*)i;
113 for (
int pos = 0; pos < elements; ++pos) {
114 codiInterface->setPrimal(indices[pos], primals[pos]);
118 CODI_INLINE_NO_FA void combineAdjoints(
void* b,
int const elements,
int const ranks)
const {
121 for (
int curRank = 1; curRank < ranks; ++curRank) {
122 for (
int curPos = 0; curPos < elements; ++curPos) {
123 for (
int dim = 0;
dim < vecSize; ++
dim) {
124 buf[curPos * vecSize +
dim] += buf[(elements * curRank + curPos) * vecSize + dim];
131 buf = (
void*)(
new Real[size * vecSize]);
143 buf = (
void*)(
new Real[size * vecSize]);
155 template<
typename T_Type>
156 struct CoDiPackReverseTool
157 :
public medi::ADToolImplCommon<CoDiPackReverseTool<T_Type>, T_Type::Tape::RequiresPrimalRestore, false, T_Type,
158 typename T_Type::Gradient, typename T_Type::Real, typename T_Type::Identifier> {
162 using Type =
CODI_DD(T_Type, CODI_DEFAULT_LHS_EXPRESSION);
163 using PrimalType =
typename Type::Real;
164 using AdjointType = void;
165 using ModifiedType = Type;
166 using IndexType =
typename Type::Identifier;
169 using Tape =
CODI_DD(
typename Type::Tape, CODI_DEFAULT_TAPE);
172 medi::OperatorHelper<medi::FunctionHelper<Type, Type,
typename Type::PassiveReal,
typename Type::Gradient,
173 typename Type::Identifier, CoDiPackReverseTool> >;
175 using Base = medi::ADToolImplCommon<CoDiPackReverseTool, Tape::RequiresPrimalRestore,
false, Type,
176 typename Type::Gradient, PrimalType, IndexType>;
184 CoDiPackReverseTool(MPI_Datatype primalMpiType, MPI_Datatype adjointMpiType)
185 : Base(primalMpiType, adjointMpiType), opHelper() {
189 ~CoDiPackReverseTool() {
198 return getTape().isActive();
209 getTape().pushExternalFunction(
210 ExternalFunction<Tape>::create(callHandleReverse, h, deleteHandle, callHandleForward, callHandlePrimal));
214 medi::AMPI_Op convertOperator(medi::AMPI_Op op)
const {
215 return opHelper.convertOperator(op);
225 return value.getIdentifier();
228 static CODI_INLINE_NO_FA void registerValue(Type& value, PrimalType& oldPrimal, IndexType& index) {
229 bool wasActive = getTape().isIdentifierActive(value.getIdentifier());
230 value.getIdentifier() = IndexType();
234 if (Tape::LinearIndexHandling) {
236 value.getIdentifier() = index;
239 if (Tape::HasPrimalValues) {
240 getTape().setPrimal(index, value.getValue());
242 if (Tape::RequiresPrimalRestore) {
243 oldPrimal = PrimalType(0.0);
246 PrimalType primal = getTape().registerExternalFunctionOutput(value);
247 if (Tape::RequiresPrimalRestore) {
250 index = value.getIdentifier();
253 if (Tape::RequiresPrimalRestore) {
254 oldPrimal = PrimalType(0.0);
256 if (!Tape::LinearIndexHandling) {
257 index = getTape().getPassiveIndex();
263 IndexType oldIndex = value.getIdentifier();
265 value.getIdentifier() = oldIndex;
270 if (Tape::LinearIndexHandling) {
271 IndexType oldIndex = value.getIdentifier();
272 getTape().registerInput(value);
273 index = value.getIdentifier();
274 value.getIdentifier() = oldIndex;
280 return value.getValue();
283 static CODI_INLINE_NO_FA void setIntoModifyBuffer(ModifiedType& modValue, Type
const& value) {
289 static CODI_INLINE_NO_FA void getFromModifyBuffer(ModifiedType
const& modValue, Type& value) {
295 static PrimalType getPrimalFromMod(ModifiedType
const& modValue) {
296 return modValue.value();
299 static void setPrimalToMod(ModifiedType& modValue, PrimalType
const& value) {
300 modValue.value() = value;
303 static void modifyDependency(ModifiedType& inval, ModifiedType& inoutval) {
304 bool active = getTape().isIdentifierActive(inoutval.getIdentifier()) ||
305 getTape().isIdentifierActive(inval.getIdentifier());
307 inoutval.getIdentifier() = getTape().getInvalidIndex();
309 inoutval.getIdentifier() = getTape().getPassiveIndex();
315 static void callHandleReverse(Tape* tape,
void* h, VectorAccessInterface<PrimalType, IndexType>* ah) {
318 medi::HandleBase* handle =
static_cast<medi::HandleBase*
>(h);
319 CoDiMeDiAdjointInterfaceWrapper<Type> ahWrapper(ah);
320 handle->funcReverse(handle, &ahWrapper);
323 static void callHandleForward(Tape* tape,
void* h, VectorAccessInterface<PrimalType, IndexType>* ah) {
326 medi::HandleBase* handle =
static_cast<medi::HandleBase*
>(h);
327 CoDiMeDiAdjointInterfaceWrapper<Type> ahWrapper(ah);
328 handle->funcForward(handle, &ahWrapper);
331 static void callHandlePrimal(Tape* tape,
void* h, VectorAccessInterface<PrimalType, IndexType>* ah) {
334 medi::HandleBase* handle =
static_cast<medi::HandleBase*
>(h);
335 CoDiMeDiAdjointInterfaceWrapper<Type> ahWrapper(ah);
336 handle->funcPrimal(handle, &ahWrapper);
339 static void deleteHandle(Tape* tape,
void* h) {
342 medi::HandleBase* handle =
static_cast<medi::HandleBase*
>(h);
346 static Tape& getTape() {
347 return Type::getTape();
#define CODI_INLINE_NO_FA
See codi::Config::ForcedInlines.
Definition config.h:459
#define CODI_DD(Type, Default)
Abbreviation for CODI_DECLARE_DEFAULT.
Definition macros.hpp:94
size_t constexpr dim()
Number of dimensions this gradient value has.
Definition gradientTraits.hpp:96
DataExtraction< Type >::Real getValue(Type const &v)
Extract the primal values from a type of aggregated active types.
Definition realTraits.hpp:210
CoDiPack - Code Differentiation Package.
Definition codi.hpp:90
void CODI_UNUSED(Args const &...)
Disable unused warnings for an arbitrary number of arguments.
Definition macros.hpp:46
Represents a concrete lvalue in the CoDiPack expression tree.
Definition activeType.hpp:52