Thanks for your post. The install went okay and the command
print(xla_bridge.get_backend().platform)
returned METAL
Unfortunately jax seems corrupt
import jax
import jax.numpy as jnp
gave
AttributeError: module 'jax' has no attribute 'custom_jvp'
I'm using miniforge. The package versions look okay.