Jax is a Python package for efficient computation and automatic differentiation. I’ve written about the automatic differentiation capabilities in the post on the Hessian and differentiating the Black-Scholes model. In essence Jax traces the Python program and then JIT-compiles it using the XLA system.

Sometimes it is useful to be able to visualise the computational graph that is constructed by the tracing. This is how this can be done in Jax.

Export the HLO Intermediate Representation (IR)

The JAX intermediate representation of a computation (i.e., not just individual functions but the composed computation) can be saved to a text file using the jax.tools.jax_to_hlo module.

As an example here is how to dump the IR of the gradient of a simple composed function:


import jax.tools.jax_to_hlo
from jax import numpy, grad
from jax.lib import xla_client

def tanh(x):  
  y = numpy.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

def lfn(x):
    return numpy.log(tanh(x).sum())

def dlfn(x):
    return  grad(lfn)(x)

with open("t.txt", "w") as f:
    f.write(jax.tools.jax_to_hlo.jax_to_hlo(dlfn,
	                                        [ ("x" , xla_client.Shape("f32[100]")) ])[1])

This will create a text file t.txt with a text-based HLO.

Visualise

A tool for visualising the HLO IR is developed together with the rest of the Tensorflow/XLA toolchain but I did not find it distributed as a binary, which means it needs to be built from source. Follow the instructions for building from source and then issue a following bazel build command:

bazel build  //tensorflow/compiler/xla/tools:interactive_graphviz

The build is likely to take rather longer then it takes to make a coffee!

After you had a longish nap you an start interactive_graphviz, with --text_hlo="t.txt". Command

list computations

will show all computations, and the graph of any computation can be displayed by entering its name. This is the graph I got from this example: