Jax-metal breaks jax.numpy.take()

0.4.11 print(jax.version)

[ 0. 1. 2. 3. 4. nan nan] <----print(jnp.take(jnp.arange(5).astype(float), jnp.arange(7)))

[0. 1. 2. 3. 4. 4. 4.] <----print(jnp.take(jnp.arange(5).astype(float), jnp.arange(7), mode='clip'))

[0. 1. 2. 3. 4. 4. 4.] <----print(np.take(np.arange(5).astype(float), np.arange(7), mode='clip'))

Jax 0.4.11, jaxlib 0.4.10, without jax-metal

0.4.11 print(jax.version)

[0. 1. 2. 3. 4. 0. 0.] <----print(jnp.take(jnp.arange(5).astype(float), jnp.arange(7)))

[0. 1. 2. 3. 4. 0. 0.] <----print(jnp.take(jnp.arange(5).astype(float), jnp.arange(7), mode='clip'))

[0. 1. 2. 3. 4. 4. 4.] <----print(np.take(np.arange(5).astype(float), np.arange(7), mode='clip'))

Jax 0.4.11, jaxlib 0.4.10, jax-metal 0.0.3

Jax-metal breaks jax.numpy.take()
 
 
Q