jax-metal failing due to incompatibility with jax 0.5.1 or later.

Hello,

I am interested in using jax-metal to train ML models using Apple Silicon. I understand this is experimental.

After installing jax-metal according to https://developer.apple.com/metal/jax/, my python code fails with the following error

JaxRuntimeError: UNKNOWN: -:0:0: error: unknown attribute code: 22
-:0:0: note: in bytecode version 6 produced by: StableHLO_v1.12.1

My issue is identical to the one reported here https://github.com/jax-ml/jax/issues/26968#issuecomment-2733120325, and is fixed by pinning to jax-metal 0.1.1., jax 0.5.0 and jaxlib 0.5.0.

Thank you!

jax-metal failing due to incompatibility with jax 0.5.1 or later.
 
 
Q