-
Explore inferência e treinamento distribuído com o MLX
Amplie a capacidade das suas cargas de trabalho de aprendizado de máquina usando vários Macs com o MLX. Saiba como lidar com os desafios de eficiência de interconexão, inferência de grandes modelos, processamento de solicitações em lote e treinamento distribuído. Descubra como alguns Macs que você já tem podem substituir a onerosa infraestrutura de nuvem para cargas de trabalho de IA exigentes.
Capítulos
- 0:00 - Introduction
- 2:09 - Distributed communication
- 4:32 - Setting up your cluster
- 10:33 - Distributed inference and fine-tuning
- 13:35 - Model parallelism strategies
- 15:53 - Distributed fine-tuning
- 18:34 - CLI, Python, Swift, and C++ APIs
- 20:45 - Next steps
Recursos
- MLX Swift LM on GitHub
- MLX Swift Examples
- MLX Examples
- MLX Swift
- MLX LM - Python API
- MLX Explore - Python API
- MLX Framework
- MLX
Vídeos relacionados
WWDC26
WWDC25
-
Buscar neste vídeo...
-
-
8:31 - Hostfile format for a 4-node MLX cluster
[ { "ssh": "m3-ultra-0", "ips": ["192.168.1.10"], "rdma": [null, "rdma_en5", "rdma_en4", "rdma_en3"] }, { "ssh": "m3-ultra-1", "ips": ["192.168.1.11"], "rdma": ["rdma_en5", null, "rdma_en4", "rdma_en3"] }, { "ssh": "m3-ultra-2", "ips": ["192.168.1.12"], "rdma": ["rdma_en5", "rdma_en4", null, "rdma_en3"] }, { "ssh": "m3-ultra-3", "ips": ["192.168.1.13"], "rdma": ["rdma_en5", "rdma_en4", "rdma_en3", null] } ] -
8:56 - Generate the cluster hostfile with mlx.distributed_config
mlx.distributed_config \ --hosts m3-ultra-0,m3-ultra-1,m3-ultra-2,m3-ultra-3 \ --output "m3-ultra-jaccl.json" \ --env MLX_METAL_FAST_SYNCH=1 \ --auto-setup \ --backend jaccl -
11:04 - Run distributed LLM inference with mlx_lm.chat
# Single-device LLM inference mlx_lm.chat --model "Qwen/Qwen3.6-27B" --max-tokens 2048 # Distributed LLM inference across the cluster mlx.launch --hostfile "m3-ultra-jaccl.json" -- \ /remote/path/to/mlx_lm.chat --model "Qwen/Qwen3.6-27B" --max-tokens 2048 -
15:03 - Run distributed inference with pipeline parallelism
# Tensor parallelism (default) mlx.launch --hostfile "m3-ultra-jaccl.json" -- \ /remote/path/to/mlx_lm.chat --model "moonshotai/Kimi-K2.6" \ --max-tokens 2048 # Pipeline parallelism — append --pipeline flag mlx.launch --hostfile "m3-ultra-jaccl.json" -- \ /remote/path/to/mlx_lm.chat --model "moonshotai/Kimi-K2.6" \ --max-tokens 2048 \ --pipeline -
17:18 - Run distributed fine-tuning with mlx_lm.lora
# Single-device fine-tuning mlx_lm.lora --model "Qwen/Qwen3.5-9B" \ --data "mlx-community/wikisql" \ --train --batch-size 4 # Distributed fine-tuning (scale --batch-size by number of devices) mlx.launch --hostfile "hostfile.json" -- \ /remote/path/to/mlx_lm.lora --model "Qwen/Qwen3.5-9B" \ --data "mlx-community/wikisql" \ --train --batch-size 16 -
19:01 - Distributed inference with the MLX LM Python API
import mlx.core as mx from mlx_lm import stream_generate from mlx_lm.utils import sharded_load # Initialise distributed backend group = mx.distributed.init(strict=True, backend="jaccl") # Define parallelism tensor_group, pipeline_group = group, None # Shard the model model, tokenizer = sharded_load("moonshotai/Kimi-K2.6", pipeline_group, tensor_group) for response in stream_generate(model, tokenizer, prompt, max_tokens=1024): if group.rank() == 0: print(response.text, end="", flush=True) -
19:31 - Shard a layer with the MLX Python API
import mlx.core as mx import mlx.nn as nn # Initialise distributed backend group = mx.distributed.init(strict=True, backend="jaccl") # Define layer and shard it column-wise layer = nn.Linear(1024, 1024) sharded_layer = nn.layers.distributed.shard_linear( layer, strategy="all-to-sharded", group=group ) data = mx.random.normal((1, 1, 1024)) output = sharded_layer(data) mx.eval(output) -
19:47 - All-reduce across devices in Python, Swift, and C++
# Python import mlx.core as mx world = mx.distributed.init(strict=True, backend="jaccl") data = mx.full((4,), float(world.rank()), dtype=mx.float32) result = mx.distributed.all_sum(data, group=world) mx.eval(result) # Swift let group = try DistributedGroup(strict: .ring) let data = rank == 0 ? MLXArray(converting: [1.0, 2.0, 3.0]) : MLXArray(converting: [5.0, 6.0, 7.0]) let result = try group.allSum(data) // C++ namespace mx = mlx::core; auto world = mx::distributed::init(/* strict */ true, "jaccl"); mx::array data = mx::full({4}, static_cast<float>(world.rank()), mx::float32); mx::array result = mx::distributed::all_sum(data, world); mx::eval(result); -
20:06 - Standalone distributed sum with the JACCL C++ API
#include <jaccl/jaccl.h> #include <iostream> int main() { // Initialize JACCL group auto group = jaccl::init(); std::cout << "Rank " << group->rank() << " of " << group->size() << std::endl; // Perform all-reduce sum float data[10] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}; float output[10]; group->all_sum(data, output, sizeof(data), jaccl::Float32); std::cout << "Result: " << output[0] << std::endl; return 0; }
-
-
- 0:00 - Introduction
Overview of why distributed AI becomes necessary as models grow larger, and a preview of what the session covers: CLI tools, Python API, and Swift for embedding distributed workflows in your apps.
- 2:09 - Distributed communication
A walkthrough of the full hardware and software stack enabling distributed workloads on Apple silicon: RDMA over Thunderbolt 5 for low-latency data movement, JACCL (open-source collective communication library), and MLX as the ML framework that ties them together.
- 4:32 - Setting up your cluster
How to physically connect four M3 Ultras into a cluster — understanding latency vs. bandwidth trade-offs, choosing between mesh and ring topologies, enabling RDMA in System Settings, and using mlx.distributed_config and mlx.launch to configure and orchestrate the cluster.
- 10:33 - Distributed inference and fine-tuning
How to run distributed LLM inference with MLX LM using a single CLI command — wrapping mlx_lm.chat with mlx.launch to shard a 27B-parameter Qwen model across four M3 Ultras, achieving nearly 3x the token generation rate of a single machine.
- 13:35 - Model parallelism strategies
How MLX LM splits large models across machines using tensor parallelism (splitting by width for faster inference) and pipeline parallelism (splitting by depth for simpler communication) — including a demo running the 1-trillion-parameter Kimi 2.6 model across four Macs.
- 15:53 - Distributed fine-tuning
How data-parallel training accelerates fine-tuning by replicating the model across machines, processing different data batches in parallel, and averaging gradients — demonstrated fine-tuning Qwen 3.5 (9B) at over 3x throughput on the cluster versus a single M3 Ultra.
- 18:34 - CLI, Python, Swift, and C++ APIs
How to use MLX's fine-grained Python, Swift, and C++ APIs for distributed inference — initializing a distributed group, sharding models with tensor parallelism, using low-level all_reduce primitives, and leveraging JACCL standalone for non-ML distributed workloads.
- 20:45 - Next steps
Summary of the full distributed stack — from RDMA over Thunderbolt to MLX and MLX LM — and next steps including the companion session on local agentic AI, documentation on custom parallelism strategies, and the built-in MLX LM distributed server.