Fun with Jax

Jax is a research project from Google aimed at high-performance machine learning research. It's been around for more than a year and is getting increasingly popular for research and prototyping. Quoting the developers, “Jax is a language for expressing and composing transformations of numerical programs.” Jax comes with transformations like jit, grad and vmap.

Let's explore ways to use these transformations in the context of optimization. Though grad was the first to catch my eyes, I have come to appreciate many other features of Jax apart from its automatic differentiation routines. In fact, we wouldn't use grad explicitly in this post and instead use its more flexible siblings like jacfwd, jacrev, jvp and vjp. Before getting to optimization, let's start small with iterations.

# using the following imports for the rest of the post
import numpy as onp  # original numpy
import jax.numpy as np  # a (mostly) drop-in replacement
from jax import jit, vmap, jacfwd, jacrev, jvp, vjp, put_device, random


$$ 10 x_{n+1} = x^2_{n}+21 $$

Iterations offer a simple numerical method to converge at a solution of interest. The shown iteration converges either at 3 or 7, the roots of $10 x_{c} = x^2_{c}+21$. The final value is sensitive to the starting point of the iteration. The fixed points can be classified as attracting or repelling and the basin of attraction can be found using calculus. But, let's find that using an experiment and see how the function transformation jit could help. Iterating using vanilla numpy,

def iterfun(start):
    for _ in range(500):
        start = (start**2+21)/10
    return start

start = onp.arange(-9, 10, 0.5)
%timeit -n 1000 iterfun(start)
1.74 ms per loop

Though we use numpy, the actual iterations are carried out using a for loop in python, which is a known performance bottleneck. Jax offers just-in-time compilation using the function transformation jit, which (in this case) compiles the for loop by statically unrolling it and making it faster. Further, the function is compiled to use an accelerator (GPU/TPU) if available. So, in some ways, Jax is numpy on GPU!

fast_iterfun = jit(iterfun)  
# compilation happens when fast_iterfun is called for the first time
start = np.arange(-9, 10, 0.5)
# ensure that variable is in accelerator (GPU/TPU) using device_put()
start = device_put(start)  # normally not needed
# block until ready gives the real time taken (waits despite asynchronous dispatch)
%timeit -n 1000 fast_iterfun(start).block_until_ready()
201 µs per loop

As expected Jax is much faster. However, note that compile time for jit is not included in the execution time, and it's prudent to be mindful of the trade-off between time spent compiling a function and time saved by repeatedly calling it! Plotting the values after $500$ iterations against the starting values,


It can be seen that $3$ is attracting with the basin of attraction $(-7, 7)$ and $7$ is repelling. Every other starting point diverges. However, this is not the only possible behaviour. There could also be iterations with no convergence and divergence. Instead, they have cycles, cantor sets and chaos like in the following family of iterations, which Gilbert Strang claims to be the most famous quadratic iteration in the world (when $a=4$).

$$ x_{n+1} = ax_{n} - ax^2_{n} $$

To analyze the effects of $a$, let's plot all values from $x_{1001}$ to $x_{2000}$ for different values of $a$ starting with $x_0=0.5$. It's easy to build the required output for a given value of $a$ using python lists and loops as follows

def iterfun(start, a):
    result = []
    for _ in range(1000):
        start = a*start - a*start**2
    for i in range(1000):
        start = a*start - a*start**2
    return np.array(result)

%timeit iterfun(0.5, 3.4)
354 ms per loop

This is slow. Once again jit to the rescue!

fast_iterfun = jit(iterfun)
%timeit fast_iterfun(0.5, 3.4).block_until_ready()
681 µs per loop

This is fast, but our compiled function can only handle one value of $a$. Fortunately, Jax provides a way to vectorize the function without rewriting any of the original function manually. The answer is vmap. Instead of using a naive outer loop over the existing function, vmap pushes the loop down to the primitive operations even after transforming the function with jit. In fact, Jax supports arbitrary compositions of its transformations.

vec_iterfun = vmap(lambda a : fast_iterfun(0.5 ,a))
a_vector = np.arange(3.4, 4.0, 0.001)
a_vector = device_put(a_vector)
%timeit vec_iterfun(a_vector)
# reshaping for plotting
plot_y = vec_iterfun(a_vector).reshape((-1))
plot_x = a_vector.repeat(1000)
626 µs per loop

The increase in time doesn't even seem to be statistically significant! The vectorized version gives all the data we need to plot (except reshaping) in a single call, with no appreciable increase in waiting time.


This plot shows the effects of $a$ on convergence (or the absence of it). It looks cool! But, let's stop here and pivot to the applications of iteration in optimization to see how Jax can further help us.


Calculus helps us express optimization as a solution to the equation $f^{\prime}(x)=0$ and iterations like the following gives a numerical method to find the solution.

\begin{align} s(x_{n+1}-x_{n}) &= f^{\prime}(x_{n}) \\
\text{at convergence}\qquad s(x_{c}-x_{c}) &= f^{\prime}(x_{c}) = 0 \end{align}

Earlier iterations didn't have $s$, which can be thought as a step size for the iteration. The sign of $s$ helps us steer the iteration towards maximum (or minimum) and this iteration is just good old gradient ascent (or descent). The goal here is to find the best $s$. Let's start by subtracting the two equations.

\begin{align} s(x_{n+1}-x_{c}) &= s(x_{n}-x_{c}) + f^{\prime}(x_{n}) - f^{\prime}(x_{c}) \\
\text{using linear approximation at }x_{c} \qquad s(x_{n+1}-x_{c}) &= (s + f^{\prime\prime}(x_{c})) (x_{n}-x_{c}) \end{align}

Thus, when $s=-f^{\prime\prime}(x_{c})$, then $x_{n+1}=x_{c}$ i.e, the iteration converges in a single step. But, we don't know $x_{c}$ and the whole exercise is to find it. Instead of $f^{\prime\prime}(x_{c})$, Newton approximated it with $f^{\prime\prime}(x_{n})$ and ran the iterations as follows and this became Newton's method.

$$ -f^{\prime\prime}(x_{n})(x_{n+1}-x_{n}) = f^{\prime}(x_{n}) $$

When $f$ is a scalar valued function that takes a vector $\mathbf{x}$, we have the following, where $J$ is the Jacobian vector and $H$ is the Hessian matrix

$$ H\cdot\Delta\mathbf{x} = J $$

The intuition is simple, $\Delta\mathbf{x}$ is selected such that the directional derivative is equal to the value of the Jacobian so that every dimension of the Jacobian can be made 0 in just one step! Please be mindful of the earlier linear approximation - this one step convergence happens only when the underlying function is quadratic. Let's check this update with the following quadratic function with the maximum at $(3, 7)$.

def quad_fun(x):
    x = x - np.array([3, 7])
    return  -1 * (x.T @ np.array([[3, 2], [2, 7]]) @ x - 2)

Let's compute the single converging update from $(0, 0)$. As $f : \mathbb{R}^n \rightarrow \mathbb{R}$ is a scalar valued function, its Jacobian can be efficiently found using reverse-mode automatic differentiation. Now, the Jacobian $J : \mathbb{R}^n \rightarrow \mathbb{R}^n$ is a vector valued function and its derivative can be (slightly more efficiently) found using forward mode automatic differentiation. Just like the other transformation in Jax, these two modes can be arbitrarily composed with one another to find the Hessian matrix.

# naive function to solve linear equation Hx = J
def naive_solve(start):
    J = jacrev(quad_fun)
    H = jacfwd(J)
    return np.linalg.inv(H(start)) @ J(start)
# compile to make the function fast
naive_solve = jit(naive_solve)
start = np.array([0, 0], dtype='float32')
# negative update as we maximize
update = -1 * naive_solve(start)
% timeit naive_solve(start)
[3.        7.0000005]
123 µs per loop

This worked as expected. But, we explicitly calculated the Hessian matrix and inverted it. Hessian of the parameters in bigger problems (like neural networks) would be huge. Explicitly calculating it is memory intensive and inverting it would be way too costly. Instead, let's solve the system of equations $H\cdot\Delta\mathbf{x} = J$ using conjugate gradients, which avoid explicitly calculating the Hessian and use low level Jax functions like jvp to find the directional derivative. jvp is a highly optimized function to calculate the product of a Jacobian and a tangents vector, evaluated at a point without explicitly finding the Jacobian.

# better function to solve linear equation Hx = J
def cg_solve(point, ndim):
    # point = point at which to evaluate J and H
    J = jacrev(quad_fun)
    H = lambda x: jvp(J, (point,), (x,))[1]
    x = np.zeros(ndim)
    b = J(point)
    r = b - H(x)
    p = r
    # search in ndim conjugate directions
    for _ in range(ndim):
        Ap = H(p)
        rr = r.T@r
        alpha = rr/(p.T@Ap)
        x = x + alpha*p
        r = r - alpha*Ap
        beta = r.T@r/(rr)
        p = r + beta*p
    return x
# compile with a fixed dimension size
fast_solve = jit(lambda x: cg_solve(x, 2))
start = np.array([0, 0], dtype='float32')
# negative update as we maximize
update = -1 * fast_solve(start)
% timeit fast_solve(start)
[3.        7.0000005]
146 µs per loop

This gives the same result without explicitly calculating the hessian and inverting it. It appears to be slightly slower, but the advantages would become apparent in bigger problems.

To ensure reproducibility

The code was run in colab using a GPU runtime. The notebook can be found here. The package versions used are listed below

  • jax 0.1.52
  • jaxlib 0.1.36
  • numpy 1.17.5
comments powered by Disqus