m2 air jax-metal

Thanks for your post. The install went okay and the command

print(xla_bridge.get_backend().platform)

returned METAL

Unfortunately jax seems corrupt

import jax
import jax.numpy as jnp

gave

AttributeError: module 'jax' has no attribute 'custom_jvp'

I'm using miniforge. The package versions look okay.

m2 air jax-metal
 
 
Q