Easy, reproducible, maintainable builds of Jax using the NIX system is straightforward:

  1. The package is already defined in pkgs/development/python-modules/jaxlib and pkgs/development/python-modules/jax. See https://nixos.wiki/wiki/JAX
  2. Bringing up to most update version requires updating the versions and hashes in the default.nix files and also potentially some dependencies:

    a. The dependency I’ve needed to update is the Google snappy library (a compression library)

  3. The jaxlib can be easily modified by adding the patches option in the bazel-build element of the default.nix file
  4. Build using the usual nix-build -A python3Packages.jaxlib <mynixpkgs> option. Add -K to aid debugging.

Be prepared to wait though – the compilation stage takes a substantial amount of processing resources