Hi, just got an Apple M3 Pro to try it out on some Jax operations. I see the development is actively ongoing so maybe this error can help.
This is the environment:
Metal device set to: Apple M3 Pro
systemMemory: 18.00 GB
maxCacheSize: 6.00 GB
jax: 0.4.26
jaxlib: 0.4.23
numpy: 1.26.4
python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:49:36) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='MKFL96VR9YT', release='23.4.0', version='Darwin Kernel Version 23.4.0: Wed Feb 21 21:44:54 PST 2024; root:xnu-10063.101.15~2/RELEASE_ARM64_T6030', machine='arm64')
This is a minimal example which produces an error, I think due to the fft part:
from jax import numpy as np
array = np.ones((16, 16))
np.fft.fft2(array)
This is the full traceback:
Traceback (most recent call last):
File "/Users/user/Downloads/wow.py", line 5, in <module>
np.fft.fft2(array)
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/numpy/fft.py", line 216, in fft2
return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/numpy/fft.py", line 210, in _fft_core_2d
return _fft_core(func_name, fft_type, a, s, axes, norm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/numpy/fft.py", line 102, in _fft_core
transformed = lax.fft(arr, fft_type, tuple(s))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/pjit.py", line 298, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/pjit.py", line 176, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/core.py", line 2788, in bind
return self.bind_with_trace(top_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/core.py", line 425, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/core.py", line 913, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1494, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1471, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1406, in _pjit_call_impl_python
lowering_parameters=mlir.LoweringParameters()).compile()
^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2369, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2908, in from_hlo
xla_executable, compile_options = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2718, in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/compiler.py", line 266, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/compiler.py", line 238, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
<unknown>:0: note: see current operation:
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}], function_type = (tensor<16x16xf32>) -> tensor<16x16xcomplex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<16x16xf32>):
%0 = "mhlo.convert"(%arg0) : (tensor<16x16xf32>) -> tensor<16x16xcomplex<f32>>
%1 = "mhlo.fft"(%0) {fft_length = dense<16> : tensor<2xi64>, fft_type = #mhlo<fft_type FFT>} : (tensor<16x16xcomplex<f32>>) -> tensor<16x16xcomplex<f32>>
"func.return"(%1) : (tensor<16x16xcomplex<f32>>) -> ()
}) : () -> ()
<unknown>:0: error: failed to legalize operation 'func.func'
<unknown>:0: note: see current operation:
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}], function_type = (tensor<16x16xf32>) -> tensor<16x16xcomplex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<16x16xf32>):
%0 = "mhlo.convert"(%arg0) : (tensor<16x16xf32>) -> tensor<16x16xcomplex<f32>>
%1 = "mhlo.fft"(%0) {fft_length = dense<16> : tensor<2xi64>, fft_type = #mhlo<fft_type FFT>} : (tensor<16x16xcomplex<f32>>) -> tensor<16x16xcomplex<f32>>
"func.return"(%1) : (tensor<16x16xcomplex<f32>>) -> ()
}) : () -> ()
I'd be happy running more tests should you need them, I'm new to this, so not sure which just yet.
Many thanks!!