Jax is a Python package for efficient computation and automatic differentiation. The question here is : will it always use all CPU cores efficiently?

Intra-op multi-threading

Some Jax operations have internal multi-threading parallelisation and will use multiple threads. This is true of many operations backed by for example the Eigen library such as matrix multiplications etc. However not all operations have internal multi threading, for example the FFT operations *do not have** multi-threading enabled.

PMap-based multi-threading

In situations where internal operation multithreading does not use the CPU resources well enough it is possible to parallise at JAX level through the use of the pmap function. However it is first necessary to split the CPU into multiple apparent devices using the --xla_force_host_platform_device_count flag (see pmap-cpu multithreading issue ). For example:

XLA_FLAGS="--xla_force_host_platform_device_count=8" python myscript.py

will split the CPU into 8 independent devices over which pmap will parallelism. See the pmap-cookbook for examples of how to use pmap.