Error while using JAX

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! 2024-03-23 22:04:38.947506: W pjrt_plugin/src/] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB maxCacheSize: 5.33 GB

loc("-":0:0): error: current mps dialect version is 1.0.0, can't parse version 1.1.0 /AppleInternal/Library/BuildRoots/495c257e-668e-11ee-93ce-926038f30c31/Library/Caches/ failed assertion `Error importing MLIR bytecode. ' zsh: abort python -c 'import jax; print(jax.numpy.arange(10))'

Was able to resolve this by going from OS 14.0 -> 14.4. Currently using jax/jaxlib 0.4.23 and jax metal 0.0.6.

Thanks it works for me. Hope that it won't interfere with torch.backends

Error while using JAX