TLDR I talk about a technique called automatic differentiation, going through a mathematical derivation before examining two different implementations: one in Rust and one in Python.
About a year ago, I read a blog post on automatic differentiation, a cool technique which automatically computes derivatives (and generalizations like gradients and Jacobians). That might not seem that interesting – after all, we could just use finite differences to calculate derivatives:
\[\frac{df}{dx} \approx \frac{f(x + h) - f(x)}{h} \]
choosing a small \(h\). Unfortunately, this kind of numerical differentiation usually doesn’t work too well in practice. If you make \(h\) too small, then your accuracy gets killed by floating point roundoff, and if \(h\) is too big, then approximation errors start ballooning.1
Automatic differentiation avoids these problems entirely: it calculates exact derivatives, so your accuracy is only limited by floating point error.
The applications of automatic differentiation should be pretty obvious. But just in case it isn’t, I’ll just point out that Google’s new machine learning framework Tensorflow, along with its competitor (and inspiration?) Theano, both use automatic differentiation under-the-hood.
What is Automatic Differentiation?
Automatic differentiation is really just a jumped-up chain rule. When you implement a function on a computer, you only have a small number of primitive operations available (e.g. addition, multiplication, logarithm). Any complicated function, like \(\frac{\log 2x}{x ^ x}\) is just a combination of these simple functions.
In other words, any complicated function \(f\) can be rewritten as the composition of a sequence of primitive functions \(f_k\):
\[ f = f_0 \circ f_1 \circ f_2 \circ \ldots \circ f_n \]
Because each primitive function \(f_k\) has a simple derivative, we can use the chain rule to find \(\frac{df}{dx}\) pretty easily.2
Although I’ve used a single-variable function \(f: \mathbb{R} \rightarrow \mathbb{R}\) as my example here, it’s straightforward to extend this idea to multivariate functions \(f: \mathbb{R}^n \rightarrow \mathbb{R}^m\).
Forward Mode
There are actually two different modes of automatic differentiation, based on how you apply the chain rule. We’ll start with forward mode automatic differentiation, which I find a little more intuitive.
Basics: Partial Derivatives
Given a function \(f\), we can construct a computational graph (a directed acyclic one) representing our function. For example, given the function \(f(x, y) = \cos x \sin y + \frac{x}{y}\), we can construct the graph:
Each node of the graph represents a primitive function, while the edges represent the flow of information. In this example, the top-most node \(w_7\) represents the value of \(f(x, y)\) while the bottom-most nodes \(w_1\) and \(w_2\) represent our input variables.
Forward differentiation works by recursively defining derivatives of nodes in terms of their parents. For example, suppose we want to calculate the partial \(\frac{\partial f}{\partial x}\). For reasons that’ll become clear in two paragraphs, let’s denote \(\frac{\partial f}{\partial x}\) using a derivative operator (i.e. \(\frac{\partial f}{\partial x} = D f\)). Then \(D f = D w_7\), and:
\[\begin{aligned} D w_7 &= D (w_5 + w_6) = D w_5 + D w_6 \\ D w_6 &= D \frac{w_1}{w_2} = \frac{w_1 D w_2 - w_2 D w_1}{w_2 ^ 2} \\ D w_5 &= D w_3 w_4 = w_3 D w_4 + w_4 D w_3 \\ D w_4 &= D \sin w_2 = \cos w_2 \cdot D w_2 \\ D w_3 &= D \cos w_1 = -\sin w_1 \cdot D w_1 \\ D w_2 &= D y \\ D w_1 &= D x \end{aligned} \]
The final value of \(D f\) depends only on \(x\), \(y\), \(D w_1\) and \(D w_2\). In this case, we’ve let \(D = \frac{\partial}{\partial x}\), so \(D x = 1\) and \(D y = 0\). But if we let \(D = \frac{\partial}{\partial y}\), all we have to change is \(D x\) (from 1 to 0) and \(D y\) (from 0 to 1). Everything else stays the same.
This neatly extends to calculating arbitrary directional derivatives by defining \(\langle D x, D y \rangle\) to be the unit vector in the direction of our derivative. We can also set \(D = \nabla\) and use vector addition/subtraction instead of scalar addition to calculate gradients.
Why is this called forward mode? Well, our actual values start at the bottom of our graph and flow to the top, just like they do when we evaluate the expression. Because information flows in the same direction as when we evaluate the expression (bottom-up), we call this “forward mode”. Predictably, the information flows top-down in “reverse mode”.
Runtime Complexity
Calculating derivatives of certain primitive functions requires both the values and the derivatives of their component parts. For example:
\[D w_i w_j = w_i D w_j + w_j D w_i \]
requires both the values of \(w_i\) and \(w_j\) along with the derivatives \(D w_i\) and \(D w_j\). If \(CD(w)\) is the cost of computing the derivative of node \(w\) and \(CV(w)\) is the cost of computing the values of node \(w\), then:
\[CD(w_i w_j) = CD(w_i) + CD(w_j) + CV(w_i) + CV(w_j) + 3 \]
where the 3 is for the 3 extra arithmetic operations (two multiplications and one addition). More generally:
\[CD(w) \le \sum_{w_k \in \text{children}(w)} CD(w_k) + \sum_{w_k \in \text{children}(w)} CV(w_k) + c \]
for some constant \(c\).
Consider an expression like \(x x x x x x x x x x x x x x x x x\). Under a naive differentiation scheme, we might have to recompute the value for node \(xxx\) 8 or 9 times, leading to a quadratic blow-up. Under a smarter scheme, we could calculate the value and the derivative of a given node at the same time to see:
\[CD(w) + CV(w) \le \sum_{w_k \in \text{children}(w)} (CD(w_k) + CV(w_k) ) + c + 1 \]
If our function is \(f: \mathbb{R}^n \rightarrow \mathbb{R}\), composed of \(P(f)\) primitive operations, then \(CD(f) + CV(f)\) is clearly linear in \(P(f)\). Calculating directional derivatives, then, is linear in the number of primitive operations.
For gradients, all our scalar operations become vector operations, so \(c\) becomes \(cn\) (vector operations are linear in the vector’s dimensionality). The cost of computing gradients is thus linear in \(n P(f)\).
Reverse Mode
Reverse mode automatic differentiation lets you calculate gradients much more efficiently than forward mode automatic differentiation.
Consider our old function \(f(x, y) = \cos x \sin y + \frac{x}{y}\) and its computation graph:
Suppose we want to calculate the gradient \(\nabla f = \langle \frac{\partial f}{\partial x}, \frac{\partial f}{\partial y} \rangle = \langle \frac{\partial w_7}{\partial w_1}, \frac{\partial w_7}{\partial w_2} \rangle\). Then:
\[\newcommand{\pder}[2]{\frac{\partial#1}{\partial#2}} \begin{aligned} \pder{w_7}{w_1} &= \pder{w_3}{w_1} \pder{w_7}{w_3} + \pder {w_6}{w_1} \pder{w_7}{w_6} = - \sin{w_1} \pder{w_7}{w_3} + \frac{1}{w_2} \pder{w_7}{w_6} \\ \pder{w_7}{w_2} &= \pder{w_4}{w_2} \pder{w_7}{w_4} + \pder {w_6}{w_2} \pder{w_7}{w_6} = \cos{w_2} \pder{w_7}{w_4} - \frac{w_1}{w_2 ^ 2} \pder{w_7}{w_6} \\ \pder{w_7}{w_3} &= \pder{w_5}{w_3} \pder{w_7}{w_5} = w_4 \pder{w_7}{w_5} \\ \pder{w_7}{w_4} &= \pder{w_5}{w_4} \pder{w_7}{w_5} = w_3 \pder{w_7}{w_4} \\ \pder{w_7}{w_5} &= 1 \\ \pder{w_7}{w_6} &= 1 \end{aligned} \]
As the name “reverse mode” suggests, things are reversed here. Information flows top-down here: instead of using the chain rule to find derivatives of parents in terms of derivatives of their children, we find derivatives of children nodes in terms of their parents. Because information flows top-down, the reverse of normal evaluation, we call this “reverse mode” automatic differentiation.
Runtime Complexity
If we reuse our notation of \(CD\) for the cost of calculating a gradient and \(CV\) for the cost of calculating a value, we see:
\[CD(w) + CV(w) \le \sum_{w_k \in \text{children}(w)} (CD(w_k) + CV(w_k) ) + c \]
Notice that all of our operations in reverse mode differentiation are scalar operations, even though we’re calculating the (vector) gradient. Thus, the last term is a \(+ c\) and not a \(+ c n\). Computing gradients via reverse mode is thus linear in \(P(f)\), and not in \(nP(f)\), which can be a big speed-up if \(f\) takes lots of input variables.
The only way I can see to calculate directional derivatives via reverse mode differentiation is to take a dot product of the gradient. In that case, the runtime cost of finding a directional derivative is linear in \(P(f) + n\): \(P(f)\) for the gradient and \(n\) for the dot product.
Thoughts on Implementation
I went ahead and implemented some basic automatic differentiation. It only supports arithmetic operators (addition, subtraction, multiplication, division, and exponentiation), although it should be pretty trivial to add support for unary functions like \(\sin\) or \(\log\).
I went with the easy approach of creating an Expr
class and forcing
the user to manually build their computation graph (via operator
overloading). While that’s certainly feasible for something where people
are already building these computation graphs (e.g. TensorFlow or
Theano), this isn’t ideal. After all, why should someone have to
completely replace their compute engine just with yours to do
automatic differentiation?
A better solution would be something that parses source code of a function and uses the abstract syntax tree to build the computation graphs manually. Even better would be to use dual numbers. Both of these are significantly more challenging to implement, so I decided to stick with a static manual computation graph construction.
Rust
I originally started implementing things in Rust. Algebraic data types seemed like the perfect way to represent computation graphs, which narrowed my initial choices to Rust or Haskell. I’m a little too rusty with Haskell to be really productive, and I wanted to play around with Rust some more anyways, so I chose Rust.
You can see the code for this on Github.
Ideally, I’d want Expr
to look something like:
|
|
Unfortunately for us, Rust is a systems language that doesn’t support
this kind of recursive structures. After all, how many bytes should Rust
allocate to Expr
in memory? There’s no good answer, because the memory
usage of Expr
must be enough to allocate an Expr
along with other
stuff. That’s why we have pointers.
It took a fair amount of struggling, before I realized that Rust’s
reference counted pointers std::rc::Rc
were perfect for the job:
shared immutable ownership of data (i.e. a node in a graph can have
multiple parents). Luckily, computation graphs are acyclic, so I didn’t
have to mess around with weak references or any other cycle-breaking
mechanism.
To make sure that the user doesn’t need to worry about any lifetime
stuff, I actually made a private enum InnerExpr
to store the
computation graph and made the public-facing Expr
struct a thin
wrapper around a std::rc::Rc<InnerExpr>
.
I also decided to split off the Add
, Sub
, Mul
, Div
, and Pow
variants into their own Arithmetic
sub-enum. I thought that this would
make the code a little more modular by separating the computation out
from the plumbing, so-to-speak. In retrospect, the sub-enum was way
more trouble than it was worth.
I represented points as hashmaps mapping the variable names (strings) to their values. It’s a little more verbose than I’d like, but I couldn’t think of any better alternatives.
I only ended implementing forward mode directional differentiation in Rust before moving on to Python (Rust is great, but developing in it is definitely slower than developing in Python, and I don’t really care about performance or memory efficiency here). To prevent the quadratic blow-up I mentioned earlier, I calculated both the value and the derivative of a node and used a lightweight struct to bubble it up. Otherwise, this was a pretty straightforward recursive implementation.
Reflections
I should start by saying that for a systems language, Rust was surprisingly nice to develop in. Much nicer than C (shudder). The type system was a giant plus, and the memory management was surprisingly easy once I got the hang of it. Still, it’s definitely more verbose than Python is, and there were a bunch of papercuts that really irritated me.
Operator Overloading
My biggest irritation with Rust was their operator overloading. In Rust, operators take their arguments by value, and thus claim ownership of their arguments. For example:
|
|
The only solution I could think of was to implement operator overloading for references (they’re kind of like const pointers), so things like:
|
|
work. It’s a little annoying to have to write &
everywhere, but that’s
not a huge deal. Infinitely more annoying is the fact because
operators returned an Expr
and not an &Expr
, proper operator
chaining forces me to implement everything four times:
|
|
That’s really annoying. I eventually wrote a macro to implement overloading (so I didn’t have to literally copy-paste code 4 times), but I suck at reading macros, and it takes me a couple extra seconds to figure out what the hell my macro is doing everytime I read it.
Strings
This complaint’s probably more because of my lack of knowledge than any
shortcoming with Rust, but I’ll make it anyways. The interaction between
String
and &str
and &String
can get very annoying very quickly.
I don’t remember the exact details, but when I made InnerExpr
use
&str
to store variable names, I ran into all kinds of irritating
lifetime issues, so I head to make InnerExpr
use String
. But when I
made my hashmaps use keys of type &str
, I ran into all kinds of
ownership problems trying to lookup my String
names. In the end, I
just used String
everywhere, which meant that I needed to call
to_owned
all over the place. Not the end of the world, but definitely
another papercut.
Testing
This is more of a complaint about the Rust’s immaturity than the language itself. Testing in Rust is much more annoying than what I’m used to from Python.
Part of it is that I’m spoiled from py.test --pdb
, which opens up an
interactive debugger when your Python tests fail. Not being able to just
open up an interpreter made the information Rust gives you with
assert!
seem painfully limited.
I also couldn’t find anything providing set-up tear-down style testing,
let alone fixtures from py.test
. It didn’t make a huge difference for
this project, but it’s still kind of irritation.
There also doesn’t seem to be a built-in float comparer. I understand
that technically nan != nan
, but for testing purposes, it’d be nice if
there was an equality function that handles edge-cases like nan
and
inf
and +0.0
vs -0.0
correctly, along with having some tolerance
for floating-point roundoff errors.
I also looked into a port of QuickCheck for Rust. It’s nice, but the test data
it generated for floats didn’t include any of the wonky edge cases like
+0.0
versus -0.0
or nan
. I’m not sure why, especially because
those are the things which usually make everything go to hell.
Python
Eventually, I got frustrated with Rust’s verbosity, so I switched over to my go-to Python. My Python implementation is pretty different from my Rust implementation: I used inheritance to control method dispatch instead of pattern matching and algebraic data types, and I definitely took advantage of Python’s dynamic capabilities.
You can look at my implementation here.
The high-level design of the Python implementation was pretty similar to
my Rust implementation. We have an Expr
class that represents a given
node and expose the user-facing interface (the operator overloading and
eval
, forward_diff
, and reverse_diff
). The actual implementation
is handled by the subclasses (e.g. Add
). I used dictionaries of names
(strings) to values (floats) to store points.
The subclasses all implemented a private _eval
method which
recursively populated a dictionary mapping node ids (ints) to their
evaluated value (floats). eval
just used _eval
to fill out this
cache and does a lookup for the current node’s value. Normally, I
would’ve just used the built-in functools.lru_cache
for caching, but
the input (points/dictionaries) aren’t hashable, so lru_cache
woudln’t
have worked.
forward_diff
first calls _eval
to populate the cache, which it
passes to _forward_diff
to do the actual computation. _forward_diff
is a pretty standard recursive implementation.
For reverse_diff
, we again use _eval
to populate the value cache
before dispatching to _reverse_diff
for the actual work. _reverse_diff
takes advantage of the fact that Python dictionaries are mutable; we
pass in we pass in the dictionary that reverse_diff
will return to
_reverse_diff
, and _reverse_diff
modifies it in-place. Otherwise,
it’s a pretty standard recursive algorithm, where each node merely
calculates its adjoint and passes it to its children. Only variable
nodes actually modify the output dictionary.
There is one subtlety regarding the chain rule. Imagine a situation where a node \(w_i\) has two parents, \(w_j\) and \(w_k\). Then:
\[\newcommand{\pder}[2]{\frac{\partial#1}{\partial#2}} \pder{f}{w_i} = \pder{f}{w_j} \pder{w_j}{w_i} + \pder{f}{w_k} \pder{w_k}{w_i} \]
Our implementation calculates \(\frac{\partial f}{w_j} \frac{\partial w_j}{\partial w_i}\) and \(\frac{\partial f}{\partial w_k} \frac{\partial w_k}{\partial w_i}\) separately. Luckily, addition is associative and commutative, so it doesn’t matter and we compute the right answer anyways.
-
Admittedly, I’m no expert on numerical differentiation, so it’s entirely possible that these problems have been solved through more complicated formulas. On the other hand, numerical differentiation packages never really worked for me, which makes me suspect that this is a problem inherent with numerical differentiation. ↩︎
-
This idea is basically the same as backpropagation, a method to efficiently train neural networks. I’d bet money that there’s some kind of historical connection between them, but I don’t know enough to be certain. ↩︎