1. NNX
https://flax.readthedocs.io/en/latest/why.html
The new Flax high-level neural network NNX (or Neural Network library for JAX) API is going to replace the current Linen API. Remember, Linen is a second-generation neural network API; the older API was flax.nn
. So, we can say NNX is the third-generation API!
NNX represents objects as PyGraphs (instead of PyTrees), enabling reference sharing and mutability. This design allows your models to resemble familiar Python object-oriented code, particularly appealing to users of frameworks like PyTorch. Besides the more straightforward syntax similar to PyTorch, NNX supports all the JAX transformations. It can also use Penzai for visualization, check how great it looks in the documentation.
The main Flax documentation now uses NNX by default. The Linen API was moved to a separate domain: https://flax-linen.readthedocs.io.
2. shard_map
https://jax.readthedocs.io/en/latest/notebooks/shard_map.htmlshard_map
is a single-program multiple-data (SPMD) multi-device parallelism API that maps a function over shards of data. It is an evolution of the existing and classic pmap
(see the doc) and the deprecated experimental xmap
(which is already removed from everywhere, but you may find a thorough description with examples in Appendix D of my book Deep Learning with JAX). It solves the issues pmap
or xmap
cannot.
shard_map
is still experimental but already widely used. It is the likely place to go if you want explicit control over how exactly the per-device code and communication collectives should look (contrary to another school of thought with automatic parallelization by tensor sharding). I spent a good deal of my JAX book describing the parallelization of the “before shard_map” era; please read it if you want a deeper understanding of the topic.
3. Pallas: a JAX kernel language
https://jax.readthedocs.io/en/latest/pallas/index.html
It is an extension to JAX that enables writing custom kernels for GPU and TPU. In some sense, it's similar to OpenAI’s Triton (in Greek mythology, Pallas is Triron’s daughter). On GPUs, Pallas lowers to Triton, and on TPUs, Pallas lowers to Mosaic.
Specifically, Pallas requires users to think about memory access and how to divide up computations across multiple compute units in a hardware accelerator. Part of writing Pallas kernels is thinking about how to take big arrays that live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations that operate on “blocks” of those arrays that can fit in SRAM.
If you need a specialized kernel (e.g. for a special attention implementation like FlashAttention, a new recurrent architecture like RecurrentGemma, or maybe even cryptography?), you now have an option to do it on TPU! This is also the way to build TPU kernels in PyTorch/XLA. So, Pallas is much more than just a part of JAX.
Pallas is also experimental but widely used.
4. JAX AI Stack
https://github.com/jax-ml/jax-ai-stack
There is a movement towards having a complete stack for developing neural network models (remember, JAX is not necessarily for deep learning, it’s also a high-performance numerical computing library very popular in physics and other fields).
Just emerged, the JAX AI stack is going to be a single point of entry for this suite of libraries, so you can install and begin using many of the same open-source packages that Google developers are using in their everyday work.
JAX AI stack pins particular versions of component projects that are known to work correctly together via the integration tests in this repository.
Right now it includes the following:
JAX: the core JAX package, which includes array operations and program transformations like
jit
,vmap
,grad
, etc.flax: build neural networks with JAX
ml_dtypes: NumPy dtype extensions for machine learning.
optax: gradient processing and optimization in JAX.
orbax: checkpointing and persistence utilities for JAX.
There are also some optional packages for data loading, namely grain and TensorFlow Datasets.
It also includes a collection of tutorials in the docs directory. As of now, there are examples for Vision Transformer (ViT), text classification, image segmentation, LLM pertaining, machine translation, a simple diffusion model and a variational autoencoder (VAE), a basic introduction into JAX for AI, and a quick overview of JAX and the JAX AI stack written for those who are familiar with PyTorch.
This is a work in progress and will likely be updated soon.
5. Grain
https://github.com/google/grain
Originally, there was a message that there is no sense in developing every possible battery to include in JAX, and you can take whatever data loader you want and use it with JAX. However, finally, here is a dataloader for JAX.
PyGrain serves as the pure Python backend for Grain, making it especially valuable for JAX users. This powerful and flexible framework allows users to implement arbitrary Python transformations while maintaining a modular design that's easy to customize. What sets PyGrain apart is its commitment to deterministic processing - ensuring consistent outputs across multiple runs - combined with efficient checkpoint management that enables seamless recovery from preemptions. The framework delivers robust performance across various data types, including text, audio, images, and videos, while maintaining a lightweight footprint with minimal dependencies.
The PyGrain backend differs from traditional tf.data
pipelines. Instead of starting from filenames that need to be shuffled and interleaved to shuffle the data, the PyGrain pipeline starts by sampling indices.
Indices are globally unique, monotonically increasing values used to track the pipeline's progress (for checkpointing). These indices are then mapped into record keys in the range [0, len(dataset)]. Doing so enables global transformations to be performed (e.g., global shuffling, mixing, repeating for multiple epochs, sharding across multiple machines) before reading any records. Local transformations that map/filter (aka preprocessing) a single example or combine multiple consecutive records happen after reading.
There is a basic tutorial here, documentation is here.
MaxText already uses Grain.
6. MaxText
https://github.com/AI-Hypercomputer/maxtext/tree/main
Everyone loves LLMs now, and many train their own LLMs. Can you do it with JAX? Surely.
MaxText is a simple, high-performance, and highly scalable JAX LLM library targeting Google Cloud TPUs and GPUs for training and inference. MaxText achieves high MFUs and scales from single host to very large clusters.
It was used to demonstrate high-performance, well-converging training in int8 and scale training to ~51K chips.
MaxText supports training and inference of various open models, including Llama2, Mixtral, and Gemma.
MaxText is heavily inspired by MinGPT/NanoGPT, but is more similar to Nvidia/Megatron-LM, a very well tuned LLM implementation targeting Nvidia GPUs. The two implementations achieve comparable MFUs. MaxText is pure Python, relying heavily on the XLA compiler to achieve high performance. By contrast, Megatron-LM is a mix of Python and CUDA, relying on well-optimized CUDA kernels to achieve high performance.
A set of instructions is located here.
7. Job market
Last but not least, why bother to learn JAX?
Besides being beautiful (especially if you like functional programming and clear code), highly performant, and widely applicable (at least everywhere where you already use NumPy and want to make everything faster, not only execution but also development), JAX has good perspectives on the job market.
Many companies already use it extensively. To name a few, it’s obviously Google and Google DeepMind, Anthropic, xAI, Cohere, Apple, and many others. For the Senior Software Engineer, JAX position, NVIDIA proposes a base salary range of USD 180,000 - 339,250. Not bad.
For an easy start, I’d still recommend my book on JAX. I gathered a lot of useful knowledge in a single place, prepared tons of easy-to-digest examples to help build your intuition, published notebooks with code, and basically did everything to help you quickly start. That’s the resource I wish I had when I started learning JAX. I hope you’ll do it much faster with this help.