In the previous post we detailed how we can derive derivative formulae for any function constructed from elementary functions and operations, and how derivatives of such functions are themselves constructed from elementary functions/operations. These facts have far-reaching consequences for the practical computing of derivatives - allowing us to construct a very effective derivative calculator called Automatic Differentiation which we detail here.
More specifically we describe how one can quickly code up the so-called forward mode of Automatic Differentiation, a natural and direct implementation of the method for calculating derivatives 'by hand' using a computation graph as discussed in the previous post. Leveraging closed form derivative rules - instead of e.g., numerical approximation - the Automatic Differentiator it is typically more accurate than the Numerical Differentiation calculator discussed in the second post of this series, and has no parameters that require tuning.
# imports from custom library
import sys
sys.path.append('../../')
import numpy as np
import matplotlib.pyplot as plt
from mlrefined_libraries import calculus_library as calclib
from mlrefined_libraries import basics_library as baslib
%load_ext autoreload
%autoreload 2
In this Section we outline the basic design elements of an Automatic Differentiation (AD) calculator, which largely involves discussing how to organize and implement the derivative rules described in the previous post. Interestingly this approach does not require us to parse an input equation to construct its computation graph - this is done explicitly when we simply define a function for derivative computation.
In addition we also describe a simple way to significantly improve the usability of the AD calculator via operator overloading.
The computation graph is a powerful tool for recursively describing how a mathematical function is constructed, how it is evaluated, and how its derivative is formed. In particular we have seen how recursively sweeping forward through a computation graph breaks down large derivative calculations into a long sequence of much smaller ones.
When performing these calculations by hand our main concerns were in computing each step accurately, and keeping the many computations organized. Thankfully because recursive algorithms are so naturally dealt with using for loops / while loops, the computation graph provides a bridge for thinking about how to perform / organize computations involving mathematical functions - like e.g., derivative calculations - on a computer instead of by hand.
Because AD calculations will be made using the computation graph, our first task is to decide how the graph will be constructed by our AD algorithm. Essentially we have two choices: we can either parse the input function ahead of time and construct its computation graph explicitly, or build our AD algorithm such that the graph is constructed implicitly and is never stored as a whole.
Each approach has strengths and weaknesses 1, but we go with the latter approach and derive an AD algorithm which constructs the computation graph implicitly because it is simple and lightweight, and requires fewer moving parts (like e.g., a parser).
In the previous post we described an algorithmic process - based on the elementary function/operation derivative rules - to derive an algebraic function for the derivative of an input function. While the AD calculator employs these rules in precisely the same manner it produces a program-based function of the derivative, i.e., a subroutine that provides mathematically correct evaluations of the mathematical function and its derivative for any input desired. In other words, the fundamental object on which the AD calculator operates on is a tuple of values consisting of the evaluation of the function and its derivative at a point [2].
The AD calculator employs derivative rules in precisely the same manner as we applied by hand but produces a program-based function of the derivative, i.e., a subroutine that provides correct evaluation of the mathematical function and its derivative for any input desired. In other words, the fundamental object on which the AD calculator operates on is a tuple of values consisting of the evaluation of the function and its derivative at a point.
As we further develop the AD calculator we will see that a very convenient way to store, transport, and update these two values is by incorporating them into a simple Python class-object.
We define this simple class below - again instances of this class are simply a convenient way of holding onto the current evaluation of a function (called val for short) as well as its derivative value (called der). For those wanting a good introduction to Python classes in the context of mathematical functions see e.g., this excellent book.
class MyTuple:
'''
The basic object manipulated by an AD calculator
Instances of this class is a tuple containining one
function/derivative evaluation.
'''
def __init__(self,**kwargs):
# variables for the value (val) and derivative (der) of our input function
self.val = 0
self.der = 1
# re-assign these default values
if 'val' in kwargs:
self.val = kwargs['val']
if 'der' in kwargs:
self.der = kwargs['der']
With our base object taken care of we can quickly code up the derivative rules for elementary functions contained in Table 1 of the previous post, which we repeat here for convenience.
elementary function | equation | derivative |
---|---|---|
constant | $c$ | $0$ |
monomial (degree $p\neq 0$) | $w^p$ | $pw^{p-1}$ |
sine | $\text{sin}(w)$ | $\text{cos}(w)$ |
cosine | $\text{cos}(w)$ | $-\text{sin}(w)$ |
exponential | $e^w$ | $e^w$ |
logarithm | $\text{log}(w)$ | $\frac{1}{w}$ |
hyperbloic tangent | $\text{tanh}(w)$ | $1 - \text{tanh}^2(w)$ |
rectified linear unit (ReLU) | $\text{max}\left(0,w\right)$ | $\begin{cases}0 & w\leq0\\1 & w>0\end{cases}$ |
Since we are tracing both the function and derivative evaluation in order to create one of these rules as a Python function we need to define how the elementary function transforms the input function value, as well as how its corresponding derivative rule transforms the input derivative value.
We walk through several examples below.
For example, we have a Python function for the sinusoid update rule in the next cell.
# our implementation of the sinusoid rule from Table 1
def sin(a):
# Create output evaluation and derivative object
b = MyTuple()
# Produce new function value
b.val = np.sin(a.val)
# Produce new derivative value - we need to use the chain rule here!
b.der = np.cos(a.val)*a.der
# Return updated object
return b
Here we input $a$, an MyTuple
object with a current function and derivative value, and create a new instance to contain their updates called $b$. To get the new function update
b.val = np.sin(a.val)
we simply pass the current value through a sinusoid. The corresponding derivative value update
b.der = np.cos(a.val)*a.der
involves two parts. The sinusoid derivative rule alone would have us update the derivative value by simply passing a.val
through cosine. But remember - as discussed in the previous post - that every time we apply an elementary derivative formula we must apply the chain rule as well. This is why we multiply np.cos(a.val)
by a.der
in the update.
We can now test our sinusoid function over a few input points, as is done in the next cell.
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: sin(w)
# initialize our AutoDiff object at each point
a1 = MyTuple(val = 0); a2 = MyTuple(val = 0.5)
# evaluate
result1 = g(a1); result2 = g(a2)
# print results
print ('function value at ' + str(0) + ' = ' + str(result1.val))
print ('derivaive value at ' + str(0) + ' = ' + str(result1.der))
print ('function value at ' + str(0.5) + ' = ' + str(result2.val))
print ('derivaive value at ' + str(0.5) + ' = ' + str(result2.der))
This looks fine - so lets evaluate over a large range of inputs and visualize both the function and derivative values. In the next Python cell we use a short custom plotting function that performs the above evaluations over a large range of input values, and then plots the resulting function/derivative values.
# define a function using our AD components
g = lambda w: sin(w)
# use custom plotter to evaluate function/derivative over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g)
And indeed this is correct: the function evaluation is sine and the derivative evaluation is cosine over the entire input range.
We can just as easily define a cosine function as well, and do so in the next Python cell.
# our implementation of the cosine rule from Table 1
def cos(a):
# Create output evaluation and derivative object
b = MyTuple()
# Produce new function value
b.val = np.cos(a.val)
# Produce new derivative value - we need to use the chain rule here!
b.der = -np.sin(a.val)*a.der
# Return updated object
return b
We can define a power rule precisely as done with the sinusoid function in the previous example: updating the current function evaluation using the elementary function, and the current derivative value using the corresponding derivative rule from Table 1. Once again we must include an instance of the chain rule with our derivative update.
# our implementation of the power rule from Table 1
def power(a,n):
# Create output evaluation and derivative object
b = MyTuple()
# Produce new function value
b.val = a.val**n
# Produce new derivative value - we need to use the chain rule here!
b.der = n*(a.val**(n-1))*a.der
# Return updated object
return b
And we can test out this function by evaluating/plotting over a large range of inputs, say the third degree monomial $w^2$.
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: power(w,2)
# use the custom plotting function above to plot function/derivative over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g)
Again everything looks good.
As with the previous examples we can define the $\text{tanh}$ function, making sure to include the chain rule with our update of the derivative value.
# our implementation of the power rule from Table 1
def tanh(a):
# Create output evaluation and derivative object
b = MyTuple()
# Produce new function value
b.val = np.tanh(a.val)
# Produce new derivative value
b.der = (1 - np.tanh(a.val)**2)*a.der
# Return updated object
return b
And we can test our new tanh function over a range of values.
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: tanh(w)
# use the custom plotting function above to plot function/derivative over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g)
This is correct! We can of course verify that this is correct by plotting the function and derivative equations given in the table. We do this in the next cell using a custom plotting function.
# define function and its derivative equations using numpy
g = lambda w: np.tanh(w)
dgdw = lambda w: (1 - np.tanh(w)**2)
# plot both
calclib.plotter.derval_eq_plot(g,dgdw)
Mirroring the previous examples we can code up the derivative rule for log as follows.
# our implementation of the sinusoid rule from Table 1
def log(a):
# Create output evaluation and derivative object
b = MyTuple()
# Produce new function value
b.val = np.log(a.val)
# Produce new derivative value
b.der = (1/a.val)*a.der
# Return updated object
return b
And quickly testing it out, we see that it indeed works.
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: log(w)
w = np.linspace(0.01,2.1000)
# use the custom plotting function above to plot over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g,w=w)
We can very easily continue, defining a function for each elementary derivative rule precisely as we have done here with these three examples. The only wrinkle to remember with each is that we must include an instance of the chain rule with each derivative update from Table 1.
In the previous post we described derivative rules for each elementary operation, which we repeat below for convenience.
operation | equation | derivative rule |
---|---|---|
addition of a constant $c$ | $g(w) + c$ | $\frac{\mathrm{d}}{\mathrm{d}w}\left(g(w) + c\right)= \frac{\mathrm{d}}{\mathrm{d}w}g(w)$ |
multiplication by a constant $c$ | $cg(w)$ | $\frac{\mathrm{d}}{\mathrm{d}w}\left(cg(w)\right)= c\frac{\mathrm{d}}{\mathrm{d}w}g(w)$ |
addition of functions (often called the summation rule) | $f(w) + g(w)$ | $\frac{\mathrm{d}}{\mathrm{d}w}(f(w) + g(w))= \frac{\mathrm{d}}{\mathrm{d}w}f(w) + \frac{\mathrm{d}}{\mathrm{d}w}g(w)$ |
multiplication of functions (often called the product rule) | $f(w)g(w)$ | $\frac{\mathrm{d}}{\mathrm{d}w}(f(w)\cdot g(w))= \left(\frac{\mathrm{d}}{\mathrm{d}w}f(w)\right)\cdot g(w) + f(w)\cdot \left(\frac{\mathrm{d}}{\mathrm{d}w}g(w)\right)$ |
composition of functions (often called the chain rule) | $f(g(w))$ | $\frac{\mathrm{d}}{\mathrm{d}w}(f(g(w)))= \frac{\mathrm{d}}{\mathrm{d}g}f(g) \cdot \frac{\mathrm{d}}{\mathrm{d}w}g(w)$ |
maximum of two functions | $\text{max}(f(w),\,g(w))$ | $\frac{\mathrm{d}}{\mathrm{d}w}(\text{max}(f(w),\,g(w))) = \begin{cases}\frac{\mathrm{d}}{\mathrm{d}w}f\left(w\right) & \text{if}\,\,\,f\left(w\right)\geq g\left(w\right)\\\frac{\mathrm{d}}{\mathrm{d}w}g\left(w\right) & \text{otherwise}\end{cases}$ |
As with the derivative formulae for elementary functions, implementing these rules means providing - in each case - an update for the function and derivative value. We walk through several examples below.
In addition to making sure we update function and derivative values, when coding up the addition rules we may as well knock out both of the addition rules in Table 2: defining the derivative of a function and constant, and two functions. A simple switch or if/else statement is used below to sort between these two cases.
# our implementation of the addition rules from Table 2
def add(a,b):
# Create output evaluation and derivative object
c = MyTuple()
# switch to determine if a or b is a constant
if type(a) != MyTuple:
c.val = a + b.val
c.der = b.der
elif type(b) != MyTuple:
c.val = a.val + b
c.der = a.der
else: # both inputs are MyTuple objects, i.e., functions
c.val = a.val + b.val
c.der = a.der + b.der
# Return updated object
return c
With the addition rules taken care of above we can try it out using two of the elementary function rules coded in the previous subsection. In particular we test with the sum
$$ \text{sin}(w) + w $$since both the sinusoid and the power function update rules were coded previously.
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: add(sin(w),w)
# use the custom plotting function above to plot over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g)
Notice here that the computation graph for our input function - which includes derivative rules for both elementary functions and operations - is implicitly constructed and computed on when we pass a MyTuple object through the function add(sin(w),w)
. In other words, while we do not have the computation graph itself we are traversing it to construct the function/derivative values just as we did by hand in the previous post.
Since we also have several multiplication rules - what to do with the product of a constant and a function, and two functions - we can also wrap both rules up into one Python function, using a switch or if/else to distinguish between them.
# our implementation of the addition rules from Table 2
def multiply(a,b):
# Create output evaluation and derivative object
c = MyTuple()
# switch to determine if a or b is a constant
if type(a) != MyTuple:
c.val = a*b.val
c.der = a*b.der
elif type(b) != MyTuple:
c.val = a.val*b
c.der = a.der*b
else: # both inputs are MyTuple objects i.e., functions
c.val = a.val*b.val
c.der = a.der*b.val + a.val*b.der # product rule
# Return updated object
return c
With the multiplication rules taken care of above we can try it out using two of the elementary function rules coded in the previous subsection. In particular we test with the sum
$$ \text{sin}(w)\times w^2 $$since both the sinusoid and the power function update rules were coded previously.
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: multiply(sin(w),power(w,2))
# use the custom plotting function above to plot over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g)
Notice here that the computation graph for our input function - which includes derivative rules for both elementary functions and operations - is implicitly constructed and computed on when we pass a MyTuple object through the function multiply(sin(w),power(w,2))
.
Its easy to check that this is correct by plotting the derivative equation itself, which using the elementary function/operation rules can be written as
$$ \frac{\mathrm{d}}{\mathrm{d}w}g(w) = 2\text{sin}(w)w + \text{cos}(w)w^2 $$We plot this equation directly, along with the original function, in the next Python cell.
# define function and its derivative equations
g = lambda w: np.sin(w)*w**2
dgdw = lambda w: 2*np.sin(w)*w + np.cos(w)*w**2
# plot both
calclib.plotter.derval_eq_plot(g,dgdw)
With our current AD calculator setup we have just seen that in order to compute the derivative of
$$ g(w) = \text{sin}(w) + w $$we use the somewhat clunky Python notation
add(sin(w),w)
to invoke the summation derivative rule, and compute the derivative of the function. In this subsection we briefly discuss how one can overload operators in Python to make the call above look more similar to the algebraic form of the function.
Python reserves symbols like +
and *
to denote a Python functions that perform standard addition and multiplication on floating point objects. This means that if we try to use one of these symbols on our MyTuple object directly we will receive an error like
unsupported operand type(s) for +: 'MyTuple' and 'MyTuple'
because our objects differ from the intended input to Python's default addition function, the one assigned to the +
operator. We try this in the next Python cell, and indeed we provoke an error.
# create two MyTuple objects and try to use Python's built in function assigned to the + operator on them
a = MyTuple(); b = MyTuple();
a + b
We have already defined an addition function for MyTuple objects, one that updates both function and derivative values, and can force Python to use this function whenever we write use symbol +
in the context of MyTuple objects. This is called operator overloading, in short we re-define the function Python uses when it sees a particular operator like +
.
Because we want the operator +
to behave differently for MyTuple objects we must overload it in the class definition of MyTuple. So, we can go back and add our add
function in the MyTuple class definition. In order to tell Python to use this function to overload the +
operator we use the special function name __add__
, as shown in the abbreviated version of the add
function below.
# our implementation of the addition rules from Table 2
def __add__(self,a,b):
# Create output evaluation and derivative object
c = MyTuple()
.
.
.
# Return updated object
return c
We can also do this on the fly, appending our add
function to the class definition, and we do this in the next Python cell.
# this next line overloads the addition operator for our MyTuple objects, or in other words adds the 'add' function to our MyTuple class definition on the fly
MyTuple.__add__ = add
# overload the reverse direction so that a + b = b + a
MyTuple.__radd__ = add
Note we also overload the operator __radd__
above as well, because in Python different functions can be assigned to +
depending on the ordering of the two elements being operated on. i.e., we can assign different functions to +
for each scenario
__add__: assigns operation to a + b
__radd__: assigns operation to b + a
if we so desired. Of course we want Python to interpret +
with our objects so that these two are equal
a + b = b + a
Since we want to treat both the same way, and so overload __radd__
with our add
function as well.
Now we can use the +
symbol with our MyTuple objects, and Python will employ our own add
function to combine the two objects. We demonstrate this in the next Python cell.
# create two MyTuple objects and try to use Python's built in function assigned to the + operator on them
a = MyTuple(); b = MyTuple();
a + b
b + a
We can do precisely the same thing with other natural Python operators reserved for multiplication *
, subtraction -
, raising to a power **
, etc., You can see a full list of operators that can be overwritten here. The more of these we overload appropriately the more user-friendly our AD calculator becomes.
Since we have already made functions for multiplying and raising MyType objects to a power, we overload these two operators on the fly in the next Python cell.
# create two MyTuple objects and try to use Python's built in function assigned to the * operator on them
MyTuple.__mul__ = multiply
# overload the 'reverse multiplication' so that a*b = b*a
MyTuple.__rmul__ = multiply
# create two MyTuple objects and try to use Python's built in function assigned to the ** operator on them
MyTuple.__pow__ = power
With these operators overloaded we can write out several of the previous examples more naturally, which we do in the next few Python cells.
For example we calculated the derivative above as
$$ g(w) = \text{sin}(w)w^2 $$which we first had to write as
multiply(sin(w),power(w,2))
Now we can write this derivative calculation much more naturally as
sin(w)*w**2
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: sin(w)*w**2
# use the custom plotting function above to plot over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g)
In the previous post we also computed (by hand) and plotted the derivatives of the following by hand, and plotted the derivative function explicitly.
\begin{array} \ g(w) = \text{sin}(w^3) \\ g(w) = \text{tanh}(w)\text{cos}(w) + \text{log}(w) \\ g(w) = \frac{\text{cos}(20w)}{w^2 + 1} \\ \end{array}In the next three Python cells we use our AD calculator to compute these derivatives, and plot the results.
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: sin(w**3)
w = np.linspace(-3,3,1000)
# use the custom plotting function above to plot over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g,w=w)
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: tanh(w)*cos(w) + log(w)
w = np.linspace(0.01,3,1000)
# use the custom plotting function above to plot over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g,w=w)
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: cos(20*w)*(w**2 + 1)**(-1)
w = np.linspace(-3,3,1000)
# use the custom plotting function above to plot over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g,w=w)
Correct again!
The content of this notebook is supplementary material for the textbook Machine Learning Refined (Cambridge University Press, 2016). Visit http://mlrefined.com for free chapter downloads and tutorials, and our Amazon site for details regarding a hard copy of the text.
[1] In the former case explicitly deriving the graph allows one to make other computations (i.e., higher order derivatives) more conveniently, while requiring the construction of a significantly sophisticated parser as well as storage of the graph (which could be significant for complex functions). On the other hand while constructing the graph implicitly means a lighter-weight AD calculator with fewer moving parts, it does mean that the core AD algorithm must be adjusted to produce other higher order derivatives.
[2] That the AD does not provide an algebraic description of the derivative function is a mute point for applications in machine learning / deep learning - as the program-based form is entirely equivalent to the algebraic equation and what we want is not to produce equations, but a way to automate precise derivative calculations
One can build an algorithm that employs the basic derivative rules to provides an algebraic derivative, but this requires the implementation of a computer algebra system. Such a derivative calculator - that deals with derivatives using symbolic computation (i.e., algebra on the computer) - is called a Symbolic Differentiator. However there are a few reasons why - at least for machine learning / deep learning applications - why Automatic Differentiation is a better choice.
First off for our applications we only need a calculator that can provide a program-based description for the derivative - i.e., one that provides precise derivative value at selected input points - which the AD calculator provides.
Secondly the AD calculator requires fewer tools to build, as it requires only basic coding methods and no computer algebra systems.
Thirdly expressing derivative equations algebraically can be quite unwieldy. For example, the rather complicated looking function
\begin{equation} g(w) = \text{sin}\left(e^{\,5\text{tanh}^2(w) + w^5}\right)\text{log}\left(\frac{1}{w^2 + 1} \right)\frac{w^2 + 1}{\text{cos}(\text{sin}(w))} \end{equation}has an expansive algebraic derivative. Below are just the first few terms
$$ \frac{\mathrm{d}}{\mathrm{d}w}g(w) = -2w\,\text{sin}\left(e^{\,5\text{tanh}^2(w) + w^5}\right) \, \frac{1}{\text{cos}(\text{sin}(w))} + 2w\,\text{log}\left(\frac{1}{w^2 + 1}\right)\text{sin}\left(e^{\,5\text{tanh}^2(w) + w^5}\right)\frac{1}{\text{cos}(\text{sin}(w))} + \cdots $$And this sort of problem is exponentially worse - to the point of being a considerable computational burden - when dealing with multivariable functions. Such an example illustrates the real need to automatic simplifying of algebraic expressions as well, ideally during the differentiation process to make sure things do not get too far out of hand. AD - while not algebraic method - essentially does this simplifying automatically while computing derivatives.