Hi everyone, I'm trying to test some functionality of jax-metal and got this error. Any help please?
import jax
import jax.numpy as jnp
import numpy as np
def f(x):
y1=x+x*x+3
y2=x*x+x*x.T
return y1*y2
x = np.random.randn(3000,3000).astype('float32')
jax_x_gpu = jax.device_put(jnp.array(x), jax.devices('METAL')[0])
jax_x_cpu = jax.device_put(jnp.array(x), jax.devices('cpu')[0])
jax_f_gpu = jax.jit(f, backend='METAL')
jax_f_gpu(jax_x_gpu)
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
Cell In[1], line 17
13 jax_x_cpu = jax.device_put(jnp.array(x), jax.devices('cpu')[0])
15 jax_f_gpu = jax.jit(f, backend='METAL')
---> 17 jax_f_gpu(jax_x_gpu)
[... skipping hidden 5 frame]
File ~/.virtualenvs/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py:817, in _create_sharding_with_device_backend(device, backend)
814 elif backend is not None:
815 assert device is None
816 out = SingleDeviceSharding(
--> 817 xb.get_backend(backend).get_default_device_assignment(1)[0])
818 return out
XlaRuntimeError: UNIMPLEMENTED: DefaultDeviceAssignment not supported for Metal Client.