jax-metal error jax.numpy.linalg.inv

Hi, I have a an issue with jax.numpy.linalg.inv(a).

import jax.numpy.linalg as jnpl
B = jnp.identity(2)
jnpl.inv(B) 

Throws the following error:

XlaRuntimeError: UNKNOWN: /var/folders/pw/wk5rfkjj6qggqp8r8zb2bw8w0000gn/T/ipykernel_34334/2572982404.py:9:0: error: failed to legalize operation 'mhlo.triangular_solve'

/var/folders/pw/wk5rfkjj6qggqp8r8zb2bw8w0000gn/T/ipykernel_34334/2572982404.py:9:0: note: called from

/var/folders/pw/wk5rfkjj6qggqp8r8zb2bw8w0000gn/T/ipykernel_34334/2572982404.py:9:0: note: see current operation: %120 = \"mhlo.triangular_solve\"(%42#4, %119) {left_side = true, lower = true, transpose_a = #mhlo<transpose NO_TRANSPOSE>, unit_diagonal = true} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>

Any ideas what could be the issue or how to solve it?

I got the same issues, here is the operation information of my environment: OS: Sonoma: 14.2.1 (23C71) Chip: Apple M1 Max

I wonder whether the other types of chip(M2, M3) get the same issues

I find the same thing. The Cholesky decomposition (cho_factor) has the same problem: failed to legalize operation. Looks to me like this needs to be implemented.


XlaRuntimeError Traceback (most recent call last) Cell In[33], line 1 ----> 1 jsl.cho_factor(Sigma_0_inv)

File ~/miniconda3/envs/jaxmetal/lib/python3.10/site-packages/jax/_src/scipy/linalg.py:61, in cho_factor(failed resolving arguments) 56 @_wraps(scipy.linalg.cho_factor, 57 lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) 58 def cho_factor(a: ArrayLike, lower: bool = False, overwrite_a: bool = False, 59 check_finite: bool = True) -> tuple[Array, bool]: 60 del overwrite_a, check_finite # Unused ---> 61 return (cholesky(a, lower=lower), lower)

File ~/miniconda3/envs/jaxmetal/lib/python3.10/site-packages/jax/_src/scipy/linalg.py:54, in cholesky(failed resolving arguments) 49 @_wraps(scipy.linalg.cholesky, 50 lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) 51 def cholesky(a: ArrayLike, lower: bool = False, overwrite_a: bool = False, 52 check_finite: bool = True) -> Array: 53 del overwrite_a, check_finite # Unused ---> 54 return _cholesky(a, lower)

File ~/miniconda3/envs/jaxmetal/lib/python3.10/site-packages/jax/_src/compiler.py:255, in backend_compile(backend, module, options, host_callbacks) 250 return backend.compile(built_c, compile_options=options, 251 host_callbacks=host_callbacks) 252 # Some backends don't have host_callbacks option yet 253 # TODO(sharadmv): remove this fallback when all backends allow compile 254 # to take in host_callbacks --> 255 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: UNKNOWN: /var/folders/9d/0035yr7j3bx84h3ghpp_86pc0000gn/T/ipykernel_40684/3046650730.py:1:0: error: failed to legalize operation 'mhlo.cholesky' /var/folders/9d/0035yr7j3bx84h3ghpp_86pc0000gn/T/ipykernel_40684/3046650730.py:1:0: note: see current operation: %5 = "mhlo.cholesky"(%4) {lower = true} : (tensor<200x200xf32>) -> tensor<200x200xf32>

jax-metal error jax.numpy.linalg.inv
 
 
Q