0.4.11 print(jax.version)
[ 0. 1. 2. 3. 4. nan nan] <----print(jnp.take(jnp.arange(5).astype(float), jnp.arange(7)))
[0. 1. 2. 3. 4. 4. 4.] <----print(jnp.take(jnp.arange(5).astype(float), jnp.arange(7), mode='clip'))
[0. 1. 2. 3. 4. 4. 4.] <----print(np.take(np.arange(5).astype(float), np.arange(7), mode='clip'))
Jax 0.4.11, jaxlib 0.4.10, without jax-metal
0.4.11 print(jax.version)
[0. 1. 2. 3. 4. 0. 0.] <----print(jnp.take(jnp.arange(5).astype(float), jnp.arange(7)))
[0. 1. 2. 3. 4. 0. 0.] <----print(jnp.take(jnp.arange(5).astype(float), jnp.arange(7), mode='clip'))
[0. 1. 2. 3. 4. 4. 4.] <----print(np.take(np.arange(5).astype(float), np.arange(7), mode='clip'))
Jax 0.4.11, jaxlib 0.4.10, jax-metal 0.0.3