Accelerated JAX on Mac
Metal plug-in
JAX uses the new Metal plug-in to provide Metal acceleration on Mac platforms. The Metal plugin uses the OpenXLA compiler and PjRT runtime to accelerate JAX machine learning workloads on GPU. The OpenXLA compiler lowers the JAX primitives to a Stable HLO format, which is converted to MPSGraph executables and Metal runtime APIs to dispatch to GPU.
Requirements
- Mac computers with Apple silicon or AMD GPUs
- Python 3.9 or later
- Xcode command-line tools:
xcode-select --install
Get started
Install
The below table tracks jax-metal versions and compatible versions of macOS and jaxlib. The compatibility between jax and jaxlib is maintained through JAX.
jax-metal | macOS | jaxlib |
---|---|---|
0.1.0 | Sonoma 14.4+ | >=v0.4.26 |
0.0.7 | Sonoma 14.4+ | >=v0.4.26 |
0.0.6 | Sonoma 14.4 Beta | >=v0.4.22, >v0.4.24 |
0.0.5 | Sonoma 14.2+ | >=v0.4.20, >v0.4.22 |
0.0.4 | Sonoma 14.0 | v0.4.11 |
0.0.3 | Ventura 13.4.1+, Sonoma 14.0 Beta | v0.4.10 |
We recommend installing the binary package with venv or Miniconda.
python3 -m venv ~/jax-metal
source ~/jax-metal/bin/activate
python -m pip install -U pip
python -m pip install numpy wheel
python -m pip install jax-metal
Verify
python -c 'import jax; print(jax.numpy.arange(10))'
Compatibility with jaxlib
jjax-metal is compatible with the minimal jaxlib version tracked in the above table. It can be compatibly run with jaxlibs beyond the minimum version by setting the environment variable to ENABLE_PJRT_COMPATIBILITY=1
.
pip install -U jaxlib jax
ENABLE_PJRT_COMPATIBILITY=1 python -c 'import jax; print(jax.numpy.arange(10))'
Run inference on AXLearn Fuji Model
Install Miniconda
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh; \
bash Miniconda3-latest-Linux-x86_64.sh; \
bash
Setup venv
conda create -n axlearn python=3.9
conda activate axlearn
git clone https://github.com/apple/axlearn.git
cd axlearn
pip install -e .
pip install jax-metal
Run demo
Environment variable JAX_DISABLE_JIT
can be removed with next MacOS seed (when the underlying issue is fixed).
JAX_DISABLE_JIT=1 ENABLE_PJRT_COMPATIBILITY=1 CHECKPOINT_PATH=${path_to_checkpoints} python demo.py
# demo.py
import os
import jax
import jax.numpy as jnp
import numpy as np
from axlearn.common.inference import InferenceRunner
from axlearn.experiments import get_named_trainer_config
def get_runner(name, checkpoint_path):
"""Make an inference runner initialized with pre-trained state according to name."""
trainer_cfg = get_named_trainer_config(
name,
config_module=f"axlearn.experiments.text.gpt.c4_trainer",
)()
inference_runner_cfg = InferenceRunner.default_config().set(
name=f"{name}_inference_runner",
mesh_axis_names=("data", "expert", "fsdp", "seq", "model"),
mesh_shape=(1, 1, len(jax.devices()), 1, 1),
model=trainer_cfg.model.set(dtype=jnp.bfloat16),
inference_dtype=jnp.bfloat16,
)
inference_runner_cfg.init_state_builder.dir = checkpoint_path
inference_runner = inference_runner_cfg.instantiate(parent=None)
return inference_runner
def predict(inference_runner, inputs_ids):
"""
Helper method to perform one forward pass for the model.
"""
input_batches = [{"input_ids": jnp.array(inputs_ids)}]
for result in inference_runner.run(
input_batches,
method="predict",
prng_key=jax.random.PRNGKey(11),
):
return result["outputs"]
def gen_tokens(inference_runner, inputs_ids, max_new_tokens):
"""
Helper method to generate multiple tokens for the model.
"""
batch_size, prompt_len = inputs_ids.shape
result_len = prompt_len + max_new_tokens
result = np.zeros(
(batch_size, result_len), dtype=inputs_ids.dtype)
result[:, :prompt_len] = inputs_ids
input_batches = [{"input_ids": jnp.array(inputs_ids),
"prefix": jnp.array(result)}]
for result in inference_runner.run(
input_batches,
method="sample_decode",
prng_key=jax.random.PRNGKey(11),
):
return result["outputs"]
def get_data(seq_len, vocab_size, batch_size=1):
""" Generate random input in shape of [batch, seq] """
rng = np.random.RandomState(11)
input_ids = rng.randint(0, vocab_size, (batch_size, seq_len)).astype(np.int32)
return input_ids
if __name__ == "__main__":
model_name = 'fuji-7B-v1-single-host'
checkpoint_path = os.getenv("CHECKPOINT_PATH")
vocab_size = 32768
seq_len = 10
model = get_runner(
model_name,
checkpoint_path=checkpoint_path,
)
input_data = get_data(seq_len, vocab_size)
print(f'Creating random input ids: {input_data}')
print("Extracting logits after 1 step.....")
predict_output = predict(model, input_data)
logits = predict_output['logits']
res = np.array(logits).astype(np.float32)
last_logit = logits[:,-1,:]
print(f'Last logits are {last_logit}')
token_id = np.argmax(res[:,-1,:], axis=-1)
print(f'Predicated token is {token_id}')
max_new_tokens = 5
print(f'Generating {max_new_tokens} extended tokens.....', )
decoding_output = gen_tokens(model, input_data, max_new_tokens)
new_tokens = decoding_output.sequences[:,:, 10:]
print(f'Extend tokens are {new_tokens}',)
Testing
Check out jax-metal testing status through Metal workflow action in Github JAX.
Currently not supported
The Metal plug-in is experimental and not all JAX functionality may be supported. Issues that are reported and tracked can be found in the list: https://github.com/google/jax/issues?q=is%3Aissue+is%3Aopen+metal
- Unsupported data types: np.float64, np.complex64, np.complex128
- The Metal plug-in doesn’t pass all tests under https://github.com/google/jax/tree/main/tests.
Questions and feedback
To ask questions and share feedback about the Metal plug-in, visit the Apple Developer Forums. You can also view GitHub JAX Issues with the label “Apple GPU (Metal) plugin”.