Goal: Learn how to generalize the data extraction for external functions.
Prequesties: Example 11 - External function user data
Function: Simple real valued function
template<typename Type>
Type func(const Type& x) {
return x * x;
}
Full code:
#include <codi.hpp>
#include <iostream>
template<typename Type>
Type func(const Type& x) {
return x * x;
}
template<typename Type>
using VectorWrapper = typename Factory::RType;
VectorWrapper* vaType = Factory::create(va);
using TypeIdentifier = typename VectorWrapper::Identifier;
using TypeReal = typename VectorWrapper::Real;
TypeReal x_v = data->
getData<TypeReal>();
TypeIdentifier x_i = data->
getData<TypeIdentifier>();
TypeIdentifier w_i = data->
getData<TypeIdentifier>();
TypeReal w_b = vaType->getAdjoint(w_i, 0);
TypeReal t_b = 2.0 * codi::ComputationTraits::transpose(x_v) * w_b;
vaType->updateAdjoint(x_i, 0, t_b);
vaType->resetAdjoint(w_i, 0);
Factory::destroy(vaType);
}
void extFunc_del(Tape* t, void* d) {
delete data;
std::cout << " Reset: data is deleted." << std::endl;
}
template<typename Type>
Type addExternalFunc(Type const& x) {
tape.setPassive();
Type w = func(x);
tape.setActive();
return w;
}
int main(int nargs, char** args) {
tape.setActive();
tape.registerInput(x);
Real t1 = addExternalFunc(x);
std::complex<Real> c(t1, -t1);
std::complex<Real> t2 = addExternalFunc(c);
tape.registerOutput(y);
tape.setPassive();
tape.evaluate();
std::cout << "x = " << x << std::endl;
std::cout << "y = " << y << std::endl;
std::cout <<
"dy/dx = " << x.
getGradient() << std::endl;
tape.reset();
return 0;
}
DataExtraction< Type >::Identifier registerExternalFunctionOutput(Type &v)
Register all active types of a aggregated type as external function outputs.
Definition realTraits.hpp:240
DataExtraction< Type >::Real getValue(Type const &v)
Extract the primal values from a type of aggregated active types.
Definition realTraits.hpp:210
Factory for the creation of AggregatedTypeVectorAccessWrapper instances.
Definition aggregatedTypeVectorAccessWrapper.hpp:217
The example shows how a function that can be called with double
and std::complex<double>
can be differentiated with external functions. The implementation for the differentiation is generalized for the template parameter of the function. In the recording process, the helper structure codi::RealTraits
is used for the generalization. For the reverse handling in the external function, the codi::AggregatedTypeVectorAccessWrapperFactory
is used to create a wrapped version of the codi::VectorAccessInterface
. In addition, codi::ComputationTraits
are used for a generalization of the transpose. The advantage of using these traits and the wrapper is that aggregated types can be used in a similar fashion to standard CoDiPack types. In this case, the same code covers codi::RealReverse
and std::complex<codi::RealReverse>
.