Hello,
I am interested in using jax-metal to train ML models using Apple Silicon. I understand this is experimental.
After installing jax-metal according to https://developer.apple.com/metal/jax/, my python code fails with the following error
JaxRuntimeError: UNKNOWN: -:0:0: error: unknown attribute code: 22
-:0:0: note: in bytecode version 6 produced by: StableHLO_v1.12.1
My issue is identical to the one reported here https://github.com/jax-ml/jax/issues/26968#issuecomment-2733120325, and is fixed by pinning to jax-metal 0.1.1., jax 0.5.0 and jaxlib 0.5.0.
Thank you!