I’ve written while ago about visualising the computational graph of a JAX program here. Jax has evolved since then, so here is an update for the current (as of time of writing) version of Jax (0.3.1).

Export the HLO Intermediate Representation (IR)

The main change is that the graph is produced from an jax.xla_computation object. The other change is that instead of specifying the input shape I am supplying an example array (numpy.ones(100)).

Here is the updated program:


import jax
from jax import numpy, grad

def tanh(x):  
  y = numpy.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

def lfn(x):
    return numpy.log(tanh(x).sum())

def dlfn(x):
    return  grad(lfn)(x)

z=jax.xla_computation(dlfn)(numpy.ones(100))

with open("t.txt", "w") as f:
    f.write(z.as_hlo_text())

This will create a text file t.txt with a text-based HLO.

Visualise

The simplest way of visualising is to dump as a dot graph and run dot:

with open("t.dot", "w") as f:
    f.write(z.as_hlo_dot_graph())
dot t.dot  -Tpng > t.png

This produces the following image:

result

Visualise the optimised graph

(added on 2023-08-28 )

It is possible to also visualise the optimised graph produced by the XLA compiler.

The mechanism to get the HLO representation of a compiled graph is shown in test_compiler_ir test in file api_test.py of the JaX repository. Basically it consists of calling the jax.jit function, lower()ing using to a specific example data structure as argument, compile()ing and then using the as_text() method. So for example:

def ff(x):
    x = x*3
    x = x+2
    return x

jax.jit(ff).lower(numpy.ones(100)).compile().as_text()

To visualise it is necessary to re-parse the HLO into a XLA computation and then use the XLA functionality to generate a Dot graph. This can be easily achieved using the raw XLA Python bindings with following function:

def todotgraph(x):
   return xla_client._xla.hlo_module_to_dot_graph(xla_client._xla.hlo_module_from_text(x))

As an example for the above function the compiled graph shows XLA loop fusion. The un-optimised graph needs two loops over arrays:

unopt

But the optimised graph using the jit compiled function only needs one:

optgraph

Visualise a neural network

The same basic approach can be used to visualise a flax neural network, e.g.,:

import functools
import flax.linen as nn

class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

model = CNN()
batch = numpy.ones((32, 64, 64, 10))  # (N, H, W, C) format
variables = model.init(jax.random.PRNGKey(0), batch)

f=functools.partial(model.apply, variables)
z=jax.xla_computation(f)(batch)

with open("t2.dot", "w") as f:
    f.write(z.as_hlo_dot_graph())

This produces the following image:

result