Jax-Metal error: failed to legalize operation of mhlo.fft

Hi, just got an Apple M3 Pro to try it out on some Jax operations. I see the development is actively ongoing so maybe this error can help.

This is the environment:

Metal device set to: Apple M3 Pro

systemMemory: 18.00 GB
maxCacheSize: 6.00 GB

jax:    0.4.26
jaxlib: 0.4.23
numpy:  1.26.4
python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:49:36) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='MKFL96VR9YT', release='23.4.0', version='Darwin Kernel Version 23.4.0: Wed Feb 21 21:44:54 PST 2024; root:xnu-10063.101.15~2/RELEASE_ARM64_T6030', machine='arm64')

This is a minimal example which produces an error, I think due to the fft part:

from jax import numpy as np
array = np.ones((16, 16))
np.fft.fft2(array)

This is the full traceback:

Traceback (most recent call last):
  File "/Users/user/Downloads/wow.py", line 5, in <module>
    np.fft.fft2(array)
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/numpy/fft.py", line 216, in fft2
    return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/numpy/fft.py", line 210, in _fft_core_2d
    return _fft_core(func_name, fft_type, a, s, axes, norm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/numpy/fft.py", line 102, in _fft_core
    transformed = lax.fft(arr, fft_type, tuple(s))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/pjit.py", line 298, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/pjit.py", line 176, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/core.py", line 2788, in bind
    return self.bind_with_trace(top_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/core.py", line 425, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/core.py", line 913, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1494, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,  # type: ignore
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1471, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
                         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1406, in _pjit_call_impl_python
    lowering_parameters=mlir.LoweringParameters()).compile()
                                                   ^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2369, in compile
    executable = UnloadedMeshExecutable.from_hlo(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2908, in from_hlo
    xla_executable, compile_options = _cached_compilation(
                                      ^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2718, in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/compiler.py", line 266, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/compiler.py", line 238, in backend_compile
    return backend.compile(built_c, compile_options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}], function_type = (tensor<16x16xf32>) -> tensor<16x16xcomplex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<16x16xf32>):
  %0 = "mhlo.convert"(%arg0) : (tensor<16x16xf32>) -> tensor<16x16xcomplex<f32>>
  %1 = "mhlo.fft"(%0) {fft_length = dense<16> : tensor<2xi64>, fft_type = #mhlo<fft_type FFT>} : (tensor<16x16xcomplex<f32>>) -> tensor<16x16xcomplex<f32>>
  "func.return"(%1) : (tensor<16x16xcomplex<f32>>) -> ()
}) : () -> ()
<unknown>:0: error: failed to legalize operation 'func.func'
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}], function_type = (tensor<16x16xf32>) -> tensor<16x16xcomplex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<16x16xf32>):
  %0 = "mhlo.convert"(%arg0) : (tensor<16x16xf32>) -> tensor<16x16xcomplex<f32>>
  %1 = "mhlo.fft"(%0) {fft_length = dense<16> : tensor<2xi64>, fft_type = #mhlo<fft_type FFT>} : (tensor<16x16xcomplex<f32>>) -> tensor<16x16xcomplex<f32>>
  "func.return"(%1) : (tensor<16x16xcomplex<f32>>) -> ()
}) : () -> ()

I'd be happy running more tests should you need them, I'm new to this, so not sure which just yet.

Many thanks!!