Accelerated JAX training on Mac
Metal plug-in
JAX uses the new Metal plug-in to provide Metal acceleration on Mac platforms. The Metal plug-in uses the OpenXLA compiler and PjRT runtime to accelerate JAX machine learning workloads on GPU. The OpenXLA compiler lowers the JAX Graphs to a Stable HLO format, which is converted to MPSGraph executables and Metal runtime APIs to dispatch to GPU.
Requirements
- Mac computers with Apple silicon or AMD GPUs
- macOS 13.4 or later
- Python 3.9 or later
- Xcode command-line tools:
xcode-select --install
The table below tracks jax-metal versions and compatible versions of macOS, jax, and jaxlib:
jax-metal | macOS | jaxlib | jax |
---|---|---|---|
0.0.4 | Sonoma 14.0 | v0.4.11 | v0.4.11 |
0.0.3 | Ventura 13.4.1+, Sonoma 14.0 beta | v0.4.10 | v0.4.11 |
Get started
1. Set up
python3 -m venv ~/jax-metal
source ~/jax-metal/bin/activate
python -m pip install -U pip
python -m pip install numpy wheel ml-dtypes==0.2.0
2. Installation
jax-metal 0.0.4 or later
A custom build of jaxlib isn’t required for these versions, as they rely on a pinned version of jax and jaxlib through package dependencies.
python -m pip install jax-metal
jax-metal 0.0.3 or earlier
First, build compatible JAX from the source. This version of the plug-in is compatible with jax v0.4.11 and the pinned version of jaxlib v0.4.10. To enable plug-in loading with that version, jaxlib needs to be built from the source.
For pre-required setups and scripts to build JAX from source in general, visit https://jax.readthedocs.io/en/latest/developer.html. Use the steps below to build a specific jaxlib compatible with the Metal plug-in:
# obtain JAX source code
git clone https://github.com/google/jax.git --branch jaxlib-v0.4.10 --single-branch
cd jax
# build jaxlib from source, with capability to load plugin
python build/build.py --bazel_options=--@xla//xla/python:enable_tpu=true
# install jaxlib
python -m pip install dist/*.whl
# install jax
python -m pip install jax
Then, install the jax-metal plug-in.
python -m pip install jax-metal==0.0.3
3. Verification
python -c 'import jax; print(jax.numpy.arange(10))'
Questions and feedback
To ask questions and share feedback about the jax-metal plug-in, visit the Apple Developer Forums.