Hello,
I'm interested in trying the new JAX Metal plug-in and followed the steps in https://developer.apple.com/metal/jax/. Upon installation, I don't see any difference between the backend device detected by JAX and a pure CPU setup:
>>> import jax
>>> jax.devices()
[CpuDevice(id=0)]
>>> jax.devices()[0].platform
'cpu'
>>> jax.devices()[0].device_kind
'cpu'
>>> jax.devices()[0].client.platform
'cpu'
>>> jax.devices()[0].client.runtime_type
'tfrt'
Is this really using a Metal backend? How can I determine for sure?
Thank you!