Jax-Metal - error: failed to legalize operation 'mhlo.cholesky'

After building jaxlib as per the instructions and installing jax-metal, upon testing upon an existing model which works fine using CPU (and GPU on linux), I get the following error.

jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: error: failed to legalize operation 'mhlo.cholesky'
/Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: called from
/Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: see current operation: %406 = "mhlo.cholesky"(%405) {lower = true} : (tensor<50x50xf32>) -> tensor<50x50xf32>

A have tried to reproduce this with the following minimal example, but this works fine.

from jax import jit
import jax.numpy as jnp
import jax.random as jnr
import jax.scipy as jsp

key = jnr.PRNGKey(0)
A = jnr.normal(key, (100,100))

def calc_cholesky_decomp(test_matrix):
    psd_test_matrix = test_matrix @ test_matrix.T
    col_decomp = jsp.linalg.cholesky(psd_test_matrix, lower=True)
    return col_decomp

calc_cholesky_decomp(A)

jitted_calc_cholesky_decomp = jit(calc_cholesky_decomp)
jitted_calc_cholesky_decomp(A)

I am unable to attach the full error message has it exceeds all the restricts placed on uploads attached to a post.

I am more than happy to try a more complex model if you have any suggestions.

Hi, I am getting the same error. I am running google's neural tangent tutorial (https://colab.research.google.com/github/google/neural-tangents/blob/master/notebooks/neural_tangents_cookbook.ipynb#scrollTo=zDIYbtgA_atG) by downloading it on to my iMac with M1 processor. Even if I try to run the minimal example given above I face a problem i.e.

XlaRuntimeError                           Traceback (most recent call last)
Cell In[24], line 14
     11     col_decomp = jsp.linalg.cholesky(psd_test_matrix, lower=True)
     12     return col_decomp
---> 14 calc_cholesky_decomp(A)
     16 jitted_calc_cholesky_decomp = jit(calc_cholesky_decomp)
     17 jitted_calc_cholesky_decomp(A)

Cell In[24], line 11, in calc_cholesky_decomp(test_matrix)
      9 def calc_cholesky_decomp(test_matrix):
     10     psd_test_matrix = test_matrix @ test_matrix.T
---> 11     col_decomp = jsp.linalg.cholesky(psd_test_matrix, lower=True)
     12     return col_decomp

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/linalg.py:54, in cholesky(***failed resolving arguments***)
     49 @_wraps(scipy.linalg.cholesky,
     50         lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
     51 def cholesky(a: ArrayLike, lower: bool = False, overwrite_a: bool = False,
     52              check_finite: bool = True) -> Array:
     53   del overwrite_a, check_finite  # Unused
---> 54   return _cholesky(a, lower)

    [... skipping hidden 14 frame]

File ~/jax-metal/lib/python3.10/site-packages/jax/_src/dispatch.py:465, in backend_compile(backend, module, options, host_callbacks)
    460   return backend.compile(built_c, compile_options=options,
    461                          host_callbacks=host_callbacks)
    462 # Some backends don't have `host_callbacks` option yet
    463 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    464 # to take in `host_callbacks`
--> 465 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: UNKNOWN: /var/folders/jh/ms0xbzxj5cq7vswmxdsdvh2w0000gn/T/ipykernel_77014/3178916729.py:11:0: error: failed to legalize operation 'mhlo.cholesky'
/var/folders/jh/ms0xbzxj5cq7vswmxdsdvh2w0000gn/T/ipykernel_77014/3178916729.py:11:0: note: see current operation: %2 = "mhlo.cholesky"(%arg0) {lower = true} : (tensor<100x100xf32>) -> tensor<100x100xf32>

It would be nice if someone could fix this soon

I am getting the same error by calling jax.random.multivariate_normal . Please fix this, it is very basic functionality.

Any update on this?

no reply from Apple? Does someone monitor these forums?

Can confirm that this is till broken at the latest release (0.0.5)

Still broken in 0.0.6. Needed to update osx of course. After that, same error as above

Still not working in 0.0.7.

Jax-Metal - error: failed to legalize operation 'mhlo.cholesky'
 
 
Q