Hi, I just ran into this issue on an M1 14 MBP. I got it to install and run correctly. Instructions here. The key is for now it needs jaxlib 0.4.10 but jax 0.4.11. Jax seems to allow using a jaxlib that is one point version less, so this configuration works. The key instructions are below: # obtain JAX source code git clone https://github.com/google/jax.git --branch jaxlib-v0.4.10 --single-branch cd jax # build jaxlib from source, with capability to load plugin python build/build.py --bazel_options=--@xla//xla/python:enable_tpu=true # install jaxlib python -m pip install dist/*.whl You also need Bazel 5.1.1 to build jaxlib (it’ll give you instructions if it can’t find it) and Python 3.10 or it won’t install the jaxlib wheel. If you’re using Anaconda you’ll have to create an environment using 3.10 and not any other version. At this point it tells you to install Jax via pip, but don't do that or it will default to 0.4.10 which is the wrong version. Instead, download the zip for the source code for the
Topic:
Machine Learning & AI
SubTopic:
General
Tags: