Accelerated JAX on Mac

Metal plug-in

JAX uses the new Metal plug-in to provide Metal acceleration on Mac platforms. The Metal plugin uses the OpenXLA compiler and PjRT runtime to accelerate JAX machine learning workloads on GPU. The OpenXLA compiler lowers the JAX primitives to a Stable HLO format, which is converted to MPSGraph executables and Metal runtime APIs to dispatch to GPU.

Requirements

  • Mac computers with Apple silicon or AMD GPUs
  • Python 3.9 or later
  • Xcode command-line tools: xcode-select --install

Get started

Install

The below table tracks jax-metal versions and compatible versions of macOS and jaxlib. The compatibility between jax and jaxlib is maintained through JAX.

jax-metal macOS jaxlib
0.1.0 Sonoma 14.4+ >=v0.4.26
0.0.7 Sonoma 14.4+ >=v0.4.26
0.0.6 Sonoma 14.4 Beta >=v0.4.22, >v0.4.24
0.0.5 Sonoma 14.2+ >=v0.4.20, >v0.4.22
0.0.4 Sonoma 14.0 v0.4.11
0.0.3 Ventura 13.4.1+, Sonoma 14.0 Beta v0.4.10

We recommend installing the binary package with venv or Miniconda.

python3 -m venv ~/jax-metal
source ~/jax-metal/bin/activate
python -m pip install -U pip
python -m pip install numpy wheel
python -m pip install jax-metal

Verify

python -c 'import jax; print(jax.numpy.arange(10))'

Compatibility with jaxlib

jjax-metal is compatible with the minimal jaxlib version tracked in the above table. It can be compatibly run with jaxlibs beyond the minimum version by setting the environment variable to ENABLE_PJRT_COMPATIBILITY=1.

pip install -U jaxlib jax
ENABLE_PJRT_COMPATIBILITY=1 python -c 'import jax; print(jax.numpy.arange(10))'

Run inference on AXLearn Fuji Model

Install Miniconda

wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh; \
	bash Miniconda3-latest-Linux-x86_64.sh; \
	bash

Setup venv

conda create -n axlearn python=3.9
conda activate axlearn
git clone https://github.com/apple/axlearn.git
cd axlearn
pip install -e .
pip install jax-metal

Run demo

Environment variable JAX_DISABLE_JIT can be removed with next MacOS seed (when the underlying issue is fixed).

JAX_DISABLE_JIT=1 ENABLE_PJRT_COMPATIBILITY=1 CHECKPOINT_PATH=${path_to_checkpoints} python demo.py
# demo.py
import os

import jax
import jax.numpy as jnp
import numpy as np

from axlearn.common.inference import InferenceRunner
from axlearn.experiments import get_named_trainer_config

def get_runner(name, checkpoint_path):
    """Make an inference runner initialized with pre-trained state according to name."""
    trainer_cfg = get_named_trainer_config(
            name,
            config_module=f"axlearn.experiments.text.gpt.c4_trainer",
        )()
    inference_runner_cfg = InferenceRunner.default_config().set(
        name=f"{name}_inference_runner",
        mesh_axis_names=("data", "expert", "fsdp", "seq", "model"),
        mesh_shape=(1, 1, len(jax.devices()), 1, 1),
        model=trainer_cfg.model.set(dtype=jnp.bfloat16),
        inference_dtype=jnp.bfloat16,
    )
    inference_runner_cfg.init_state_builder.dir = checkpoint_path
    inference_runner = inference_runner_cfg.instantiate(parent=None)
    return inference_runner

def predict(inference_runner, inputs_ids):
    """
    Helper method to perform one forward pass for the model.
    """

    input_batches = [{"input_ids": jnp.array(inputs_ids)}]
    for result in inference_runner.run(
        input_batches,
        method="predict",
        prng_key=jax.random.PRNGKey(11),
    ):
        return result["outputs"]

def gen_tokens(inference_runner, inputs_ids, max_new_tokens):
    """
    Helper method to generate multiple tokens for the model.
    """
    batch_size, prompt_len = inputs_ids.shape

    result_len = prompt_len + max_new_tokens
    result = np.zeros(
        (batch_size, result_len), dtype=inputs_ids.dtype)
    result[:, :prompt_len] = inputs_ids

    input_batches = [{"input_ids": jnp.array(inputs_ids),
    "prefix": jnp.array(result)}]

    for result in inference_runner.run(
        input_batches,
        method="sample_decode",
        prng_key=jax.random.PRNGKey(11),
    ):
        return result["outputs"]

def get_data(seq_len, vocab_size, batch_size=1):
    """ Generate random input in shape of [batch, seq] """
    rng = np.random.RandomState(11)
    input_ids = rng.randint(0, vocab_size, (batch_size, seq_len)).astype(np.int32)
    return input_ids

if __name__ == "__main__":
    model_name = 'fuji-7B-v1-single-host'
    checkpoint_path = os.getenv("CHECKPOINT_PATH")
    vocab_size = 32768
    seq_len = 10
    
    model = get_runner(
            model_name,
            checkpoint_path=checkpoint_path,
        )

    input_data = get_data(seq_len, vocab_size)
    print(f'Creating random input ids: {input_data}')

    print("Extracting logits after 1 step.....")
    predict_output = predict(model, input_data)
    logits = predict_output['logits']
    res = np.array(logits).astype(np.float32)
    last_logit = logits[:,-1,:]
    print(f'Last logits are {last_logit}')
    token_id = np.argmax(res[:,-1,:], axis=-1)
    print(f'Predicated token is {token_id}')
    
    max_new_tokens = 5
    print(f'Generating {max_new_tokens} extended tokens.....', )
    decoding_output = gen_tokens(model, input_data, max_new_tokens)
    new_tokens = decoding_output.sequences[:,:, 10:]
    print(f'Extend tokens are {new_tokens}',)

Testing

Check out jax-metal testing status through Metal workflow action in Github JAX.

Currently not supported

The Metal plug-in is experimental and not all JAX functionality may be supported. Issues that are reported and tracked can be found in the list: https://github.com/google/jax/issues?q=is%3Aissue+is%3Aopen+metal

Questions and feedback

To ask questions and share feedback about the Metal plug-in, visit the Apple Developer Forums. You can also view GitHub JAX Issues with the label “Apple GPU (Metal) plugin”.