jax-metal returns wrong jacobian

I found strange behavior when using jax-metal on the gpu (Intel Mac). The jacobian of the identity function is the identity matrix, which is not what the jax-metal backend seems to return

import jax
import jax.numpy as jnp
jax.jacfwd(lambda x : x)(jnp.array([0.1,0.1]))

yields

Array([[1., 1.],
       [1., 1.]], dtype=float32)

instead of the correct answer when using the CPU backend:

Array([[1., 0.],
       [0., 1.]], dtype=float32)
jax-metal returns wrong jacobian
 
 
Q