jax-metal (0.0.6) segmentation fault in `jax.lax.scan`

Hi,

I have encountered to a segfault error when I called something via jax.lax.scan. A minimum failing example is pasted below:

 $ ipython
Python 3.9.6 (default, Feb  3 2024, 15:58:27)
Type 'copyright', 'credits' or 'license' for more information
IPython 8.18.1 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax

In [2]: jax.__version__
Out[2]: '0.4.22'

In [3]: import jaxlib

In [4]: jaxlib.__version__
Out[4]: '0.4.22'

In [6]: import jax.numpy as jnp

In [7]: def f(carry, x):
   ...:     return carry + x * x, x * x
   ...:
   ...: jax.lax.scan(f, jnp.zeros((), dtype=jnp.float32), jnp.arange(3, dtype=jnp.float32))
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-04-16 01:03:52.483015: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Max

systemMemory: 36.00 GB
maxCacheSize: 13.50 GB

zsh: segmentation fault  ipython

This might be related to the thread below: https://developer.apple.com/forums/thread/749080

Strangely, when we call it

jax.lax.scan is a very important building block, so I would greatly appreciate if this can be resolved soon.

It's still not working with the following versions:

  • jax-metal: 0.0.7
  • jax, jaxlib: 0.4.26
  • maxOS: Sonoma 14.5
jax-metal (0.0.6) segmentation fault in `jax.lax.scan`
 
 
Q