recent JAX versions fail on Metal

Hi,

I'm not sure whether this is the appropriate forum for this topic. I just followed a link from the JAX Metal plugin page https://developer.apple.com/metal/jax/

I'm writing a Python app with JAX, and recent JAX versions fail on Metal. E.g. v0.8.2

I have to downgrade JAX pretty hard to make it work:

pip install jax==0.4.35 jaxlib==0.4.35 jax-metal==0.1.1

Can we get an updated release of jax-metal that would fix this issue?

Here is the error I get with JAX v0.8.2:

WARNING:2025-12-26 09:55:28,117:jax._src.xla_bridge:881: Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1766771728.118004  207582 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Max

systemMemory: 36.00 GB
maxCacheSize: 13.50 GB

I0000 00:00:1766771728.129886  207582 service.cc:145] XLA service 0x600001fad300 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1766771728.129893  207582 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1766771728.130856  207582 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1766771728.130864  207582 mps_client.cc:384] XLA backend will use up to 28990554112 bytes on device 0 for SimpleAllocator.
Traceback (most recent call last):
  File "<string>", line 1, in <module>
    import jax; print(jax.numpy.arange(10))
                      ~~~~~~~~~~~~~~~~^^^^
  File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py", line 5951, in arange
    return _arange(start, stop=stop, step=step, dtype=dtype,
                   out_sharding=sharding)
  File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py", line 6012, in _arange
    return lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding)
           ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 3415, in broadcasted_iota
    return iota_p.bind(dtype=dtype, shape=shape,
           ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
                       dimension=dimension, sharding=out_sharding)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 633, in bind
    return self._true_bind(*args, **params)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 649, in _true_bind
    return self.bind_with_trace(prev_trace, args, params)
           ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 661, in bind_with_trace
    return trace.process_primitive(self, args, params)
           ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 1210, in process_primitive
    return primitive.impl(*args, **params)
           ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/dispatch.py", line 91, in apply_primitive
    outs = fun(*args)
jax.errors.JaxRuntimeError: UNKNOWN: -:0:0: error: unknown attribute code: 22
-:0:0: note: in bytecode version 6 produced by: StableHLO_v1.13.0

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
I0000 00:00:1766771728.149951  207582 mps_client.h:209] MetalClient destroyed.
recent JAX versions fail on Metal
 
 
Q