Jax-metal on M2 Pro does not recognize GPU

Build and installed Jax and Jax-metal following instructions on a M2Pro Mac-mini from here - https://developer.apple.com/metal/jax/

However, the following check seems to suggest XLA using CPU and not GPU.

>>> from jax.lib import xla_bridge
>>> print(xla_bridge.get_backend().platform)
cpu

Has anyone got it working to dump GPU?

Thanks in advance!

Post not yet marked as solved Up vote post of shibd Down vote post of shibd
2.9k views

Replies

Hi, I just ran into this issue on an M1 14" MBP. I got it to install and run correctly. Instructions here.

The key is for now it needs jaxlib 0.4.10 but jax 0.4.11. Jax seems to allow using a jaxlib that is one point version less, so this configuration works.

The key instructions are below:

# obtain JAX source code
git clone https://github.com/google/jax.git --branch jaxlib-v0.4.10 --single-branch
cd jax
# build jaxlib from source, with capability to load plugin
python build/build.py --bazel_options=--@xla//xla/python:enable_tpu=true
# install jaxlib
python -m pip install dist/*.whl

You also need Bazel 5.1.1 to build jaxlib (it’ll give you instructions if it can’t find it) and Python 3.10 or it won’t install the jaxlib wheel. If you’re using Anaconda you’ll have to create an environment using 3.10 and not any other version.

At this point it tells you to install Jax via pip, but don't do that or it will default to 0.4.10 which is the wrong version. Instead, download the zip for the source code for the 0.4.11 release of Jax: https://github.com/google/jax/releases/tag/jax-v0.4.11

# make sure you're in the jax-v0.4.11 folder
pip install -e .

This should install correctly if it found the correct version of jaxlib it wants and from there you should be able to load Jax and see it is using the GPU by running this command:

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

Good luck.

I'm having the exact same issue. Using jax 0.4.11, jaxlib 0.4.10 and jax-metal 0.0.2. The gpu is not recognised on my 2020 MBA M1.

Any other things I could possibly try?

  • I banged my head on the same problem for a while. Rebooting the machine after installing everything fixed it for me.

Add a Comment

same problem here, but when I changed the instructions to

git clone https://github.com/google/jax.git --branch jaxlib-v0.4.11 --single-branch

and

python -m pip install jax==v0.4.11

it now seems to recognize the GPU:

>>> from jax.lib import xla_bridge
>>> print(xla_bridge.get_backend().platform)
Metal device set to: Apple M2 Max

systemMemory: 96.00 GB
maxCacheSize: 36.00 GB

METAL
>>> import jax
>>> jax.devices()
[MetalDevice(id=0, process_index=0)]
>>> jax.devices()[0].platform
'METAL'
>>> jax.devices()[0].device_kind
'Metal'
>>> jax.devices()[0].client.platform
'METAL'
>>> jax.devices()[0].client.runtime_type
'tfrt'

But now, x = jnp.ones((10000, 10000)) generates errors:

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: -:0:0: error: bytecode version 5 is newer than the current version 1
  • This is the solution that worked for me on an M2 MBA. I cloned/built jaxlib-v0.4.11 locally as shown above, pip installed it, then pip install jax jax-metal, resulting in getting jax 0.4.11 and jax-metal 0.0.3. Now I get this:

    $ python -c 'import jax; print(jax.lib.xla_bridge.get_backend().platform)' WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! Metal device set to: Apple M2 systemMemory: 24.00 GB maxCacheSize: 8.00 GB METAL
Add a Comment

I was facing the same issue but I appear to have solved it (although my device is an M1 Max). Jaxlib 0.4.10, Jax 0.4.11, and jax-metal 0.0.2.

Install Jaxlib 0.4.10:

# obtain JAX source code
git clone https://github.com/google/jax.git --branch jaxlib-v0.4.10 --single-branch
cd jax
# build jaxlib from source, with capability to load plugin
python build/build.py --bazel_options=--@xla//xla/python:enable_tpu=true
# install jaxlib
python -m pip install dist/*.whl

Now install Jax 0.4.11:

  1. Download the source code zip of release 0.4.11 of Jax: https://github.com/google/jax/releases/tag/jax-v0.4.11
  2. Install Jax: python -m pip install -e .

Finally, install jax-metal 0.0.2: python -m pip install jax-metal

Python 3.10.9 (main, Mar  1 2023, 12:20:14) [Clang 14.0.6 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()
Metal device set to: Apple M1 Max

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

[MetalDevice(id=0, process_index=0)]

Hope this helps!

Thanks! That helped and Jax now recognizes the GPU.

Unfortunately, when I tried to run a simple example of the Newton algorithm from https://jax.quantecon.org/newtons_method.html it fails:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 11, in newton
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:4:0: error: failed to legalize operation 'mhlo.scatter'
<stdin>:4:0: note: called from
<stdin>:4:0: note: see current operation: 
%2177 = "mhlo.scatter"(%2052, %2176, %2167) ({
^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>):
  "mhlo.return"(%arg7) : (tensor<f32>) -> ()
}) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true} : (tensor<5000x128xf32>, tensor<1xsi32>, tensor<5000xf32>) -> tensor<5000x128xf32>

Running the same code on the CPU works fine.

In my MBP (M1 Max), jax-metal works in the following setup. (I followed the instructions in https://developer.apple.com/metal/jax/ with modifications shown below in parentheses. I used Anaconda environments.)

jaxlib 0.4.10  ( git clone https://github.com/google/jax.git --branch jaxlib-v0.4.10 --single-branch )
jax 0.4.11       ( pip install jax==0.4.11 )
jax-metal 0.0.2

Some catch:

  • This does not work well in Jupyter, but it works in the command line. So, I use it in a Spyder editor (with IPython).
  • Some basic jax.numpy operations may have bugs. Ex: The following code does not work.
import jax.numpy as jnp
from jax import random

key = random.PRNGKey(1)
x = random.normal(key, (200, 100))
y = random.normal(key, (100,))

a = jnp.dot(y, x.T) # OK
print(a.shape)
b = jnp.dot(x, y[:, jnp.newaxis]) # OK, but it becomes 2D
print(b.shape)
c = jnp.dot(x, y) # Error in jax-metal 0.0.2 in GPU. Works fine in CPU (with jaxlib 0.4.12, which does not work in GPU, but there is no point in using it in CPU).
  • Jax-metal 0.0.3 appears to have fixed the jnp.dot(x,y) bug. It still needs jaxlib0.4.10 and jax0.4.11 to run in IPython. Performance is not good (slightly better than free Google Colab GPU). It was much slower than the CPU for a small model. It became faster than the CPU with a bigger model.

Add a Comment