Hooray! My book Deep Learning with JAX (formerly JAX in Action) is now in print! I just received my physical copies 🙂
For those who haven't been following, JAX is a Python library for high-performance computing and large-scale machine learning, with excellent support for accelerators like TPUs.
At this point, JAX is a legitimate alternative to TensorFlow and PyTorch (torch.func, which started as functorch, is still trying to catch up and remains in beta), and many companies, including Google DeepMind, Cohere, xAI, and others, have already adopted it. Some well-known models built with JAX include AlphaFold, GraphCast, Gemini, Gemma, Grok, and countless other research projects.
JAX is more than just a library for ML; it’s a tool for various high-performance, parallel, and distributed computations. There’s a reason it’s called “NumPy on steroids.” Beyond ML/DL, for instance, JAX is widely used for physical simulations, and there are already a huge number of derivative libraries on GitHub.
Functional programming enthusiasts will be especially thrilled because JAX is the first large-scale framework to operate in this paradigm. It’s cool to use functions to transform other functions. Write a function to process a single element—then transform it into a function that processes a batch. Write a complex mathematical function—then transform it into a function that computes its derivative. Similarly, you can transform functions for compilation and parallelization. No hidden states or side-effects—everything is clean, beautiful, and easy to understand. And it’s also FAST! (see this post by François Chollet)
Now is a great time to carve out a bit of the future for yourself 🙂
The book is divided into three parts, spanning over 370 pages.
Part 1: First steps.
A high-level introduction to JAX for managers and everyone else, explaining where and why to use JAX. Plus, there’s a dedicated chapter for those who love to see code, which demonstrates a full cycle of implementing a simple neural network using most of JAX's features.
Part 2: Core JAX.
The main part of the book covers all JAX basics step-by-step. Topics include working with arrays (tensors), autodiff, compilation, vectorization, parallelization and sharding, as well as random number generation (the traditional approaches from NumPy aren’t effective in functional programming, but now everything is clear and reproducible!) and pytrees.
Part 3: Ecosystem.
A large chapter providing a practical introduction to high-level libraries for deep learning (Flax, Optax, Orbax, CLU, etc.) and examples of using Hugging Face Transformers/Diffusers, which have long supported JAX. There’s also a separate chapter with a high-level and broad overview of what's available within and around JAX beyond the neural network mainstream.
The supporting notebooks with code are available on GitHub.
Many smart and talented people have read and reviewed my book—thanks to many GDEs and others. Special thanks to François Chollet for the kind words 🙂
“A comprehensive guide to mastering JAX, whether you’re a seasoned deep learning practitioner or just venturing into the realm of differentiable programming and large-scale numerical simulations.”
— François Chollet, Software Engineer, Google
👏Some more quotes:
“A must-read! The emphasis on functional programming has transformed the way I approach building models.”
— Stephen Oates, Data Scientist, Allianz Australia“I thoroughly enjoyed this excellent book! I feel confident that I can now apply JAX in my own work.”
— Tony Holdroyd, Retired Senior Lecturer in Computer Science and Mathematics
“Great, modular code. Helpful explanations. This book is a treasure.”
— Ritobrata Ghosh, Independent Deep Learning Consultant
Overall, it was an incredible experience, and I’m happy with the result. I hope you’ll like it too.
Congrats Grigory!