 
							- 
							
							Train your machine learning and AI models on Apple GPUsLearn how to train your models on Apple Silicon with Metal for PyTorch, JAX and TensorFlow. Take advantage of new attention operations and quantization support for improved transformer model performance on your devices. Chapters- 0:00 - Introduction
- 1:36 - Training frameworks on Apple silicon
- 4:16 - PyTorch improvements
- 11:26 - ExecuTorch
- 13:19 - JAX features
 ResourcesRelated VideosWWDC24WWDC23WWDC22WWDC21
- 
							Search this video…Hello, my name is Yona Havocainen and I'm a software engineer from the GPU, graphics and display software team. Today I will present how to train your machine learning and AI models with Apple Silicon GPUs and what new features have been added this year. Apple Silicon offers lots of amazing features for machine learning on your devices. The powerful GPU excels at the type of computations required to optimize modern neural networks. Combined with the unified memory architecture, the GPU can directly access significant amounts of memory. This large memory allows you to train and run large models locally on your device. Or use larger batch sizes during training, often leading to faster convergence. And without the need to distribute the model weights over multiple machines, the path from training to deployment becomes simpler. Training is the initial step on deployin models on Apple's platforms. Once the model has been trained, it has to be prepared for deployment on device. After being prepared, the model is ready to be integrated to your application. If you haven't yet seen the talk on the overall flow for deploying machine learning models, please refer to our video on machine learning workflow on Apple devices. In this talk, I will focus on training and present some of the frameworks that are able to utilize the unique compute capabilities on Apple Silicon. To access the powerful GPU, you can use Metal backend in one of the popular frameworks for machine learning. TensorFlow, PyTorch, Jax, and MLX. TensorFlow is the trusted framework for many industry applications. The Metal backend supports features like distributed training for really large projects, and mixed precision to boost training performance. Enabling the Metal backend for TensorFlow is easier than ever. Just install TensorFlow using a package manager like Pip and import it to your project. You can learn more on the TensorFlow Metal backend from our video in WWDC21. Another widely used framework is PyTorch. The Metal backend supports features such as custom operations and profiling, making it easy to benchmark and improve your network's performance. Getting started with Metal backend in PyTorch is also simple. Install and import PyTorch to your project and set your default device to mps. For more information on PyTorch, Metal backend, please refer to our video in WWDC22. Jax is a recent addition to the framework supported through a Metal backend. It supports features like just-in-time compilation and the Numpy-like interface to erase. You can install Jax metal and import JAX to your project to get started with the JAX metal backend. We cover more on Metal backend for JAX in our video in WWDC23. MLX is the newest framework supported through a Metal backend. MLX is designed and optimized for Apple Silicon. It supports features such as Numpy-like API, just-in-time compilation, distributed training and unified memory natively. The framework offers bindings for Python, Swift, C and C++. You can find examples for running machine learning tasks like fine-tuning transformer models, image generation and audio transcription in our code repository. Getting started with MLX is as easy as with the other frameworks. Just install the wheel in your Python environment and import it to your project. To find out more about the MLX framework, please check out our documentation and our code repository. Now that we have our Apple Silicon Prime for training, let me move on to the main topic for today. I would like to showcase some of the new features and improvements that have been made for two of the frameworks in particular. PyTorch and JAX. Let me first start with PyTorch. A year ago in WWDC23, the MPS backend advanced to beta stage in its development. Since then, support has been added for custom kernels, wider op coverage, and unified memory architecture. In addition, there has been numerous improvements and fixes to both performance and functionality. Many of them contributed from the open source community around PyTorch. The coverage for various networks has also improved through the year. For example, in the HuggingFace repository for state-of-the-art transformer models, PyTorch-MPS backend is now able to accelerate Top-50 most popular networks out of the box. This covers a large number of models that have become famous during the year, like Stable diffusion, Meta LlaMA models, and Gemma Speaking of improvements, I want to highlight three impacting transformer models in particular. Support for 8- and 4-bit integer quantization that allows even larger models to fit into device memory. Fused Scale dot product attention that improves performance of many common models. And unified memory support, which removes redundant copies of tensors when dispatching computations to the GPU. Next, let me tell you about each of these topics in more detail. Data formats like 32-bit floating point number or the 16-bit floating point number shown here are common when training models. One bit represents the sign of the value, 5 bits the exponent, and 10 bits the fraction. The precision is useful when updating the parameters during training. After training, a technique called quantization can be used to reduce the memory required for the parameters. By representing the same value as an 8-bit integer, we can cut the required memory in half. This gives benefits such a smaller memory footprint for models, better throughput for computations, and depending on the model, this can be achieved with only a small or no degradation in output accuracy. The Scaled dot product attention is at the heart of many transformer models. The operation starts with an input of tokenized text. The input is split into three tensors, referred to as the query, key, and value tensors. The three tensors are then manipulated through a series of matrix multiplication, Scaling, and Softmax operations. By fusing the series of operations into a single kernel call, the overhead from dispatching many small computations to the GPU can be avoided, improving the overall performance of many networks. And the last performance update I want to highlight is the benefit from the unified memory architecture on Apple devices. This allows for a tensor to simply exist in the main memory and be accessible to both the CPU and the GPU, without the need to copy the bits from one area of the memory to another. Next, I will wrap up the PyTorch discussion by showing an end-to-end workflow on how you can take a language model, customize it, fine-tuned to your use case, and run it on your device. I will start by importing torch and locking in my random seats for reproducible results. I'm using the popular transformers library to download and set up my model and tokenizer. It provides me a convenient method for getting models from the hugging face repository. I'm using the open llama version 2 with 3 billion parameters as my base model for my task. And I will also need the corresponding tokenizer the model has been trained with. To attach a fine-tuning adapter to my model, I'll use the peft library with the lora_config. I'll define the parameters for my adapter and then create a new PeftMode with the base model and the config. Now I am ready to send my model to our computing device, MPS. Next, I'll need to select data to use for my tuning. I'll use the the tiny shakespeare dataset from Andrej Karpathy as my training input. It consists of works of Shakespeare conveniently concatenated into a single file. Once the dataset has been loaded, I can load it into a dataset object, an instruct which tokenizer should be used for the date. I will need to set some training parameters for the tune. I can use the Trainer class and set arguments like batch size and number of training epochs to use. The data collector object forms the training batches for our objects for the trainer. Now I am ready to pass the model, arguments, data collector, and the training dataset to create my trainer object. Before I start the training, let me check what the model outputs before I fine tuned it. I'll add a small convenience function that will take an input text, tokenize it for the model to consume, generate the output, and de-tokenize it back into human readable text. I'll test it with some sentence from Shakespeare to see what kind of a response it invokes. Before the tuning it seems to me that the model is acting like some type of a dictionary entry where it starts off correctly attributing the quote and then rambles off to some discussion about head of family. It's not much fun to talk to a dictionary, so let's see if we can liven up the model a bit through my fine-tuning. I'll kick off the training with the trainer class. It will use the parameters I defined earlier to crunch through the dataset. After a while, the training is nearing its end after running 10 epochs over the dataset. Now let me try giving it the same input as before. An interesting take take from Menenius! Clearly our fine tuning has accomplished something. Now let me save the model for future use. I'll merge the adapter and base model into a single entity to make it easier to use, I'm also making sure to store the tokenizer with the model. Now that my model has been trained, I would like to try it out by deploying it to a device. For most networks, the preferred way to deploy would be to use CoreML to achieve this. This is discussed in more detail in our talk on deploying your models on device. In this case, I wanted to stay inside the PyTorch ecosystem, so I can use the new ExecuTorch framework to deploy the model. Executorch is aimed at deploying PyTorch models across various devices for inference. You can use any custom operations you defined in your PyTorch training seamlessly in deployment on ExecuTorch. ExecuTorch uses the MPS Partitioner to analyze the computational graph and accelerate recognized patterns using the MPS device. This is how you can set up ExecuTorch on your local device. First, I clone the repository to my machine. Then I update its sub modules. Finally, I'll run the installation script, passing the option for using MPS bindings when building Executorch. Now let me show you how to deploy a model in ExecuTorch. I'm following along the example from the ExecuTorch repository. I will use use Meta LLaMA2 model as my test model. The model has been converted to 4-bit integer data type using the groupwise quantization method. This makes it both smaller and faster. On my macOS, I will build the demo app for iOS in the repository and use my iPad Pro as the deployment target. Once the app is built, I will choose the model to use and the corresponding tokenizer that the model has been trained with. Next, I will ask if the model could tell me how to make lasagna. I am using the LLaMA2 prompt template for my query here, since this model has been fine tuned to act like a chatbot and it is expecting this format. Running locally on my iPad through Executorch, the model is giving me some pretty good ideas for dinner. Although I think I'll add some tomatoes and cheese to the list. Those are all the ways you can use the new features and improvements to speed up your PyTorch workflows. Next, I'll move on to discuss the new features we have added to another popular machine learning framework supported by the MPS Graph; JAX. JAX-Metal plugin was released to developers in last year's WWDC23. Since then, the plugin has continued to evolve and add more functionality and performance related updates. These include changes such as improved advanced array indexing, adoption of the CI runner workflow in the official JAX repository, and support for Mixed precision. I want to highlight a couple of users who have the JAX-Metal backend since our release. First one is MuJoCo. An open source framework for use cases where fast and accurate simulations are needed, such as robotics and biomechanics. JAX metal backend enables them to get the best possible performance for users running on Mac platforms. Second is AXLearn, a library for developing large-scale deep-learning models. The Metal backend enables rapid trading and testing of workflows on the local device. You can go check out both of these libraries and test how the JAX-Metal backend makes using them a great experience on your Mac devices. Next, let's go over some of the improvements added to JAX-Metal backend in more detail. I will cover mixed precision, NDArray indexing, and padding in JAX. One of the updates this year is the BFloat16 data type is now supported in the JAX-Metal framework. This data type represents a wide dynamic range of floating point values and is good for use cases like mixed precision training. You can use the new data type just like any other data type in JAX. For example, you can create a tensor with the new data type. Another improvement is that NDArray indexing and update support lets you manipulate arrays with a numpy-like syntax. For example, if you create a small array of two rows and two columns, you can divide the first column by 10 using the numpy indexing syntax. JAX allows you to define padding policies. These padding policies are now supported by the Jax-Metal backend. You can use this to add padding between elements known as dilation. To do this, call the pad function, which accepts three parameters for each dimension. You can also remove elements from the tensor with negative padding. Do this by passing negative values in the padding configuration. I want to conclude my JAX section with a small example on using JAX. I will be using the AXLearn library discussed earlier. I will pick the fuji 7 billion parameter model from there to run and I am using the BFloat16 data type for the model as discussed earlier. The script will create a small random input to pass to the model and asks the model to produce the next tokens. The output will give us logits and the output tokens. Once the prediction is done, I'll re-run the same script, but I will set JAX to run on the CPU instead with an environment variable. Here I can see that the CPU output agrees with the output from the Metal backend for our inference, concluding our demo. And that brings us to the end of our presentation for JAX and for this WWDC To recap what we discussed today: Apple Silicon has the advantage of a unified memory architecture that provides great benefits for many machine learning tasks. It both enables the use of larger models and larger batch sizes as well removes the need for copies between CPU and GPU since both have access to the same memory. You can use the powerful Apple Silicon GPUs through our Metal backends for the popular frameworks PyTorch, JAX, TensorFlow, and MLX. This year, a lot of new performance work has gone into supporting the very popular transformer class models. You can take advantage of these updates by making sure you are using an up-to-date release of the frameworks and remember to update your macOS. Thank you for your attention and we can't wait to see what you do with these new features. 
-