I found strange behavior when using jax-metal on the gpu (Intel Mac). The jacobian of the identity function is the identity matrix, which is not what the jax-metal backend seems to return
import jax
import jax.numpy as jnp
jax.jacfwd(lambda x : x)(jnp.array([0.1,0.1]))
yields
Array([[1., 1.],
[1., 1.]], dtype=float32)
instead of the correct answer when using the CPU backend:
Array([[1., 0.],
[0., 1.]], dtype=float32)