Error with GPU JIT function with GPU tensor UNIMPLEMENTED: DefaultDeviceAssignment not supported for Metal Client.

Hi everyone, I'm trying to test some functionality of jax-metal and got this error. Any help please?

import jax
import jax.numpy as jnp
import numpy as np

def f(x):
    y1=x+x*x+3
    y2=x*x+x*x.T
    return y1*y2

x = np.random.randn(3000,3000).astype('float32')

jax_x_gpu = jax.device_put(jnp.array(x), jax.devices('METAL')[0])
jax_x_cpu = jax.device_put(jnp.array(x), jax.devices('cpu')[0])

jax_f_gpu = jax.jit(f, backend='METAL')

jax_f_gpu(jax_x_gpu)
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[1], line 17
     13 jax_x_cpu = jax.device_put(jnp.array(x), jax.devices('cpu')[0])
     15 jax_f_gpu = jax.jit(f, backend='METAL')
---> 17 jax_f_gpu(jax_x_gpu)

    [... skipping hidden 5 frame]

File ~/.virtualenvs/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py:817, in _create_sharding_with_device_backend(device, backend)
    814 elif backend is not None:
    815   assert device is None
    816   out = SingleDeviceSharding(
--> 817       xb.get_backend(backend).get_default_device_assignment(1)[0])
    818 return out

XlaRuntimeError: UNIMPLEMENTED: DefaultDeviceAssignment not supported for Metal Client.