Goal: Learn how side effects can effect the tangent computation.
Prerequisite: Tutorial 1 - Forward mode AD
Function:
Real func(
const Real& x,
bool updateGlobal) {
if(updateGlobal) {
global = x * x;
}
return x * global;
}
Full code:
#include <codi.hpp>
#include <iostream>
Real func(
const Real& x,
bool updateGlobal) {
if(updateGlobal) {
global = x * x;
}
return x * global;
}
int main(int nargs, char** args) {
std::cout << "Update global:" << std::endl;
std::cout << "f(4.0, true) = " << y << std::endl;
std::cout <<
"df/dx(4.0, true) = " << y.
getGradient() << std::endl << std::endl;
y = func(x, false);
std::cout << "No update global:" << std::endl;
std::cout << "f(4.0, false) = " << y << std::endl;
std::cout <<
"df/dx(4.0, false) = " << y.
getGradient() << std::endl << std::endl;
y = func(x, false);
std::cout << "No update global with reset:" << std::endl;
std::cout << "f(4.0, false) = " << y << std::endl;
std::cout <<
"df/dx(4.0, false) = " << y.
getGradient() << std::endl << std::endl;
return 0;
}
RealForwardGen< double, double > RealForward
Definition codi.hpp:104
void setGradient(Gradient const &g)
Set the gradient of this lvalue in the tape.
Definition lhsExpressionInterface.hpp:120
The computational path of the function is changed via the parameter updateGlobal
. If this parameter is true
then x
also enters the computation of the global variable. If the parameter is false
then the global value is seen as a constant with respect to AD.
The three different calls demonstrate the error in the tangent computation. The first and third call are correct, the second one is wrong. Here, we fix the issue by directly resetting the tangent of the global
variable. An other option would be to call func(x, true)
again, with a tangent value of zero for x
.
Notes
This problem is not specific to CoDiPack. Nearly all tapeless forward AD tools suffer from this problem. One possible final fix is to implement a Gradient type, that tags tangents with an epoch. The user would then need to mange the global valid epoch and the Gradient type can ignore tangents from older epochs.