Dual Numbers for FWD Mode AD
Introduction
Automatic Differentiation(AD) is a core algorithm in modern machine learning, implementations like Torch's autograd and JAX's jax.grad dominate the space. autograd and jax.grad both primarily compute derivatives using back propagation (backprop), a case of backwards mode AD. Forward mode automatic differentiation is effectively the opposite of backwards mode, we calculate the derivatives
from beginning to end, as we perform the forward pass, instead of from the end to the beginning, traversing back up the graph.
Dual Numbers
Very simply, you can think of the dual numbers as being similar to the complex numbers in this sense: the imaginary numbers take the form of and the dual numbers are defined where .
Now, both and are both somewhat odd sounding statements at first, they have really interesting properties that make them nice to study. Complex numbers are an algebraically closed field, meaning all polynomials (excl. order 0) with complex coefficients have complex roots, a property the Reals do not have.
We will get the special property of the Duals soon, but first we can just look at performing basic operations with them, is a really nice property it turns out. Addition
is easy but multiplication (think FOIL) will get more interesting.
or division (first step is to multiply the numerator and denominator by the conjugate of the denominator).
Okay, the Dual numbers still provide a pretty simple system of arithmetic, but what does this get us? What is the special property?
It's the Derivative!
Okay lets look at at example, consider , if we evaluate this at we see , pretty simple, we can also find the derivative , so . Now lets extend that to the Duals, where instead of , we will use
We can see that the real component is just the result of , while the dual component is . When we apply a dual number to function that is ordinarily only real valued, it is easy to imagine simply lifting any real number to , this will ensure that we ended up with the derivative that we desire.
We can now show that this property holds for any differentiable function . Remember first the definition of the Taylor series of around some point
(reminder: is the nth derivative of ). Now we can think about evaluating , but before we do that we must choose our starting point , I propose that we let , the real part of our dual number, this gives us:
Where every term for just goes to 0 by .
Implementation
Implementing the Duals in your language of choice is quite easy, just create a new type with a real part and a dual part, define all your operations, and you're good to go! (I choose rust).
struct Dual {
real: f32,
dual: f32,
}
...
impl Mul for &Dual {
type Output = Dual;
fn mul(self, rhs: Self) -> Self::Output {
Dual {
real: self.real * rhs.real,
dual: rhs.real * self.dual + self.real * rhs.dual,
}
}
}
...
That's just an example of one of the arithmetic functions you will need to define of course, but it illustrates the simplicity. Here is an example evaluating the function
#[test]
fn test_mul_grad() {
let a = Node::with_grad(5.0);
let b = Node::with_no_grad(5.0);
let r = mul(a, b);
assert_eq!(r.eval().real, 25_f32);
assert_eq!(r.eval().dual, 5_f32);
println!("{} = {}", r.trace(), r.eval().to_string());
}
And its output
(5 + 1ε * 5 + 0ε) = 25 + 5ε
You might notice theres a lot of extra fluff here, Node, eval, trace. The reason for this that my actual goal is to create a system for lazily creating the computation graph, then evaluating and/or tracing it later, the dual numbers are just a fun rabbit hole I fell down while I fully develop this system.
Benefits and Limitations
There is a reason why we don't use forward mode AD for training neural networks, as you may have noticed, the derivatives flow from the inputs to the outputs, and also there are issues with preserving the differentiation properties when performing arithmetic between two dual numbers that both have non-zero dual parts, because of this, in a machine learning system you would need to run multiple forwards passes, one for each term in the gradient we saw at the beginning, unfortunately, this falls quite short relative to backwards AD. That's not to say there are no benefits or use cases for forward AD in machine learning, the primary benefit is memory usage, forward AD is much more memory efficient, at the cost of many times more computation under usual circumstances. However, one specific scenario I imagine could potentially benefit from forward AD is fine tuning a model with a LoRA, especially in a reinforcement learning setting. If you only had a single vector or matrix to optimize, you could use the memory saving from not storing intermediate results to more simultaneous rollouts, or rollouts that are able to last much longer due, etc. I would like to explore this more in the future.
Future Work
I have several goals for this codebase in the future, the first is increasing my fluency with rust and building more complex systems in it than I have before. My second goal is practice creating libraries with good developer experience specifically for machine learning tasks. My last goal is to implement some system for optimizing/lowering the graph into a more performant form, aka a ML compiler, I use torch.compile everyday, and I have a working understanding of how it and similar systems work, but I desire a much deeper level understanding, in order to both optimize my existing work that utilizes these systems, but also to be able to contribute to improving ML compiler projects and the libraries that they are a part of.