Hi,
I'm not sure whether this is the appropriate forum for this topic. I just followed a link from the JAX Metal plugin page https://developer.apple.com/metal/jax/
I'm writing a Python app with JAX, and recent JAX versions fail on Metal. E.g. v0.8.2
I have to downgrade JAX pretty hard to make it work:
pip install jax==0.4.35 jaxlib==0.4.35 jax-metal==0.1.1
Can we get an updated release of jax-metal that would fix this issue?
Here is the error I get with JAX v0.8.2:
WARNING:2025-12-26 09:55:28,117:jax._src.xla_bridge:881: Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1766771728.118004 207582 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Max
systemMemory: 36.00 GB
maxCacheSize: 13.50 GB
I0000 00:00:1766771728.129886 207582 service.cc:145] XLA service 0x600001fad300 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1766771728.129893 207582 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1766771728.130856 207582 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1766771728.130864 207582 mps_client.cc:384] XLA backend will use up to 28990554112 bytes on device 0 for SimpleAllocator.
Traceback (most recent call last):
File "<string>", line 1, in <module>
import jax; print(jax.numpy.arange(10))
~~~~~~~~~~~~~~~~^^^^
File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py", line 5951, in arange
return _arange(start, stop=stop, step=step, dtype=dtype,
out_sharding=sharding)
File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py", line 6012, in _arange
return lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding)
~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 3415, in broadcasted_iota
return iota_p.bind(dtype=dtype, shape=shape,
~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
dimension=dimension, sharding=out_sharding)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 633, in bind
return self._true_bind(*args, **params)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 649, in _true_bind
return self.bind_with_trace(prev_trace, args, params)
~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 661, in bind_with_trace
return trace.process_primitive(self, args, params)
~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 1210, in process_primitive
return primitive.impl(*args, **params)
~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/dispatch.py", line 91, in apply_primitive
outs = fun(*args)
jax.errors.JaxRuntimeError: UNKNOWN: -:0:0: error: unknown attribute code: 22
-:0:0: note: in bytecode version 6 produced by: StableHLO_v1.13.0
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
I0000 00:00:1766771728.149951 207582 mps_client.h:209] MetalClient destroyed.