smish.dev
graph_differentiation

Algebra

When we first learn about functions in algebra class, teachers often describe them as "black boxes" that transform inputs to outputs. This is usually accompanied by a picture of a box (representing the function) and some arrows (representing the inputs/outputs) like this,

function

to help students make sense of the new abstract notation y=f(x). After that, the following math classes tend to use the algebraic notation almost exclusively. However, in many cases the graph representation of calculations (where functions are nodes, and values are edges) is the more useful of the two.

In this document we'll cover how thinking about programs as graphs helps us understand and implement algorithmic differentiation.


Calculus

Several years after that algebra class we wind up taking calculus, where we study how small changes in the inputs, dx, effect small changes in the output of the function, dy. When these changes in the input are "small enough", then dy and dx are proportional to each other. The constant of proportionality, dydx, gets a special name ("the derivative") and it lets us write

dy=dydxdx=dfdxdx

If we take that output y and feed it into another function, g

simple_graph

then the output z is (indirectly) a function of x as well: z(x)=g(f(x)). The "Chain Rule" describes how to differentiate nested functions like this:

dzdx=dzdydydx=g(y)f(x)

That is, the derivative of the composite function is the product of the derivatives of the individual steps.

 

Forward Mode Interpretation

Let's take that chain rule expression above, and multiply both sides of the equation above by dx in order to write

dz=dzdydydxdx=g(y)f(x)dx

From here, we can group certain terms together to keep track of the intermediate differential quantities

dz=dzdy(dydxdxdy)=g(y)(f(x)dxdy)

With all of the intermediate differential quantities accounted for now, we can draw a graph representation for them just like we did for the original calculation:

simple_graph_fwd_differentials

This is the graph associated with forward mode differentiation. There are a few important things to notice about it:

If we want to compute the derivative of quantities w.r.t. x, we can evaluate this graph with the input dx=1 (or the appropriate identity element for more complicated types).

Reverse Mode Interpretation

Now let's go back to the chain rule equation and multiply both sides by ddz on the left, where is some placeholder representing any possible quantity

ddx=ddzdzdx=ddzdzdydydx=ddzg(y)f(x)

Since this equation is the same for any such quantity , let's just leave it out of the expressions to make things a little more succinct.

ddx=ddzdzdydydx=ddzg(y)f(x)

Just like with the forward mode derivation, we'll group terms together. Except this time, let's see what happens when grouping terms from the left, rather than the right:

ddx=(ddzdzdy)ddydydx=(ddzg(y))ddyf(x)

Again, this grouping of terms reveals information about the intermediate quantities. In this case, those intermediates are the derivatives w.r.t. edge data. I think this is easier to understand as a graph, so let's draw a picture

simple_graph_fwd_differentials

This is the graph associated with reverse mode differentiation (also referred to as "the adjoint method" and "backpropagation"). There are a few important things to notice about it:

If we want to compute the derivative of some quantity w.r.t. x, we can evaluate this graph with the "input" derivative w.r.t. z. If the quantity we want to differentiate is z itself, then we feed in dzdz=1 as the input to the reverse mode graph.


Multivariable Calculus

1D calculus is a good start, but in practice, calculations frequently involve functions with multiple inputs and outputs. So, let's briefly review how derivatives work for that case. For example, here's an arbitrary C++ function that has 2 inputs and 3 outputs:

In this case, there are a total of 6 independent partial derivatives: y1x1,y1x2,y2x1,y2x2,y3x1,y3x2.

Technically, partial derivatives use the symbol instead d, but for the rest of the document I'll just use d so all the notation on the graphs is consistent.

A natural way to arrange these partial derivatives is in a matrix (called the Jacobian) with as many rows as outputs and as many columns as inputs.

d{y1,y2,y3}d{x1,x2}=J(x1,x2):=[dy1dx1dy1dx2dy2dx1dy2dx2dy3dx1dy3dx2]

The reason for doing this is that it lets us use matrix multiplication notation to compactly represent both forward mode and reverse mode differentiation operations:

 

Forward mode: [dy1dy2dy3]=d{y1,y2,y3}=d{y1,y2,y3}d{x1,x2}d{x1,x2}=[dy1dx1dy1dx2dy2dx1dy2dx2dy3dx1dy3dx2]J[dx1dx2]

 

Reverse mode: [ddx1ddx2]=dd{x1,x2}=dd{y1,y2,y3}d{y1,y2,y3}d{x1,x2}=[ddy1ddy2ddy3][dy1dx1dy1dx2dy2dx1dy2dx2dy3dx1dy3dx2]J

 

Some AD libraries use the terms "JVP" and "VJP" as an alternate way of referring to forward and reverse mode, respectively. Those terms stand for "Jacobian-Vector Product" and "Vector-Jacobian Product", referring to whether the vector appears on the right or left of the Jacobian in the product.

 


 

More Complicated Example

Let's apply our new intuition for forward and reverse mode differentiation to a slightly more involved calculation given in C++ below:

Start off by drawing the graph representation of the calculation, where edges are the variables and nodes are the calculations

simple_graph_fwd_differentials

and then we can apply the rules we learned about forward and reverse mode from the earlier example.


Forward Mode

Recall that in order to make the graph for forward-mode, we replace each edge by a new edge carrying the differential of the original quantity and replace the nodes by their derivatives (evaluated at their respective function's inputs) to get:

simple_graph_fwd_differentials

From here, we can compute different kinds of derivatives by feeding in different values for dx1,dx2:

Reverse Mode

In reverse mode, we replace each edge by a new edge (in the opposite direction) that carries the derivative w.r.t. the original quantity and replace the nodes by their derivatives (evaluated at their respective function's inputs) to get:

simple_graph_fwd_differentials

In reverse mode, if we provide a single value of ddw as the "input" (from the right), then when we carry out the calculations on this graph, we will have computed both ddx1,ddx2. This is different than forward-mode, which only gives a derivative w.r.t. one quantity at a time.

If we're after the quantities dwdx1,dwdx2, then we can just set =w, and pass in the number 1 as our input from the right.


Aside: Variable Reuse

If you were looking closely, you may have noticed that there was a small problem with the way I drew the graph for this "more complicated" example: the edge associated with input value x1 splits and goes off to two different functions. Technically, the edges in a graph should have a single starting point and a single endpoint. But in practice, that requirement feels awkward because programming languages have no problem letting you use a variable more than once.

The way I think about it is that drawings like the one below

simple_graph_fwd_differentials

are really just a convenient short-hand notation for a more "proper" graph definition, where the splitting point is itself a node in the graph as well:

simple_graph_fwd_differentials

The actual calculation happening at that node is trivial: it's just copying the input value to the outputs. We could define this copy function as

copy(x):=[xxx],dcopydx=[111]

This way, our graph is still well-formed. This also has some implications for reverse-mode, since it means that, when visiting a copy node on the reverse pass, we need to sum the relevant derivative terms:

x=[x(1)x(2)x(n)][111]dcopydx=x(1)+x(2)++x(n)

Or, in graph form

simple_graph_fwd_differentials

 


When Should I Use Forward vs. Reverse Mode?

So far, we've only discussed how forward and reverse mode work, but haven't really described their strengths and weaknesses. Let's look at a couple specific examples to better understand the performance implications of forward vs. reverse mode.

Few Inputs, Many Outputs

Here is a graph featuring a scalar-valued input, scalar-valued intermediates and a vector-valued output.

function_composition_many_outputs

For this graph, the chain rule tells us that d{w1,,wn}dx=d{w1,,wn}dzdzdydydx.

 

Evaluating this expression with forward-mode means evaluating the products from right to left:

d{w1,,wn}dx=(d{w1,,wn}dz(dzdydydx))

which involves multiplying two scalars dzdydydx first (1 op), and then multiplying a vector by a scalar (n ops) for a total of n+1 operations. In contrast, reverse-mode evaluates the products from left to right:

d{w1,,wn}dx=((d{w1,,wn}dzdzdy)dydx)

That means first multplying a vector by a scalar (n ops), and then multiplying a vector by a scalar (n ops) for a total of 2n operations.

 

So, in this case where the size of the inputs/intermediates is smaller than the outputs, forward mode differentiation is preferred.


Many Inputs, Few Outputs

Here is another graph, except this one feature a vector-valued input, vector-valued intermediates and a scalar-valued output.

function_composition_many_inputs

For this graph, the chain rule tells us that dwd{x1,,xn}=dwd{z1,,zn}d{z1,,zn}d{y1,,yn}d{y1,,yn}d{x1,,xn}.

 

Evaluating this expression with forward-mode performs the matrix-matrix multiplication on the right first (~2n3 ops) and then a vector-matrix multiplication (~2n2 ops) for a total of about 2n3 operations.

dwd{x1,,xn}=(dwd{z1,,zn}(d{z1,,zn}d{y1,,yn}d{y1,,yn}d{x1,,xn}))

With reverse mode, grouping from the left side leads to two vector-matrix products (~2n2 ops, each) for a total of ~4n2 operations. That is: for this kind of graph where there is a single output and many intermediates reverse mode is asymptotically faster by a factor of n.

This kind of graph structure is extremely important, as it describes practically every optimization problem (since they typically involve many inputs variables and scalar-valued objective functions as the outputs). This also includes machine learning training workflows, which are essentially just big optimization problems.

This is the main reason that reverse-mode differentiation is such an important part of machine learning: forward mode would be incredibly slow for training.


Similar Number of Inputs, Intermediates, Outputs

In this case, the operation count of forward and reverse mode are roughly comparable. So, the tie is broken by the fact that forward-mode is generally simpler to implement and has a smaller memory footprint (since intermediate quantities can be discarded after use).


Summary

We covered a lot in this document, so let's recap some of the important ideas: