-
Explora la inferencia y el entrenamiento distribuidos con MLX
Amplía las cargas de trabajo de aprendizaje automático en varias Mac con MLX. Descubre cómo abordar los retos relacionados con la eficiencia de las interconexiones, la inferencia de modelos de gran tamaño, el procesamiento por lotes de solicitudes y el entrenamiento distribuido. Descubre cómo unas pocas Mac en el escritorio pueden sustituir a una costosa infraestructura en la nube para cargas de trabajo de IA exigentes.
Capítulos
- 0:00 - Introduction
- 2:09 - Distributed communication
- 4:32 - Setting up your cluster
- 10:33 - Distributed inference and fine-tuning
- 13:35 - Model parallelism strategies
- 15:53 - Distributed fine-tuning
- 18:34 - CLI, Python, Swift, and C++ APIs
- 20:45 - Next steps
Recursos
- MLX Swift LM on GitHub
- MLX Swift Examples
- MLX Examples
- MLX Swift
- MLX LM - Python API
- MLX Explore - Python API
- MLX Framework
- MLX
Videos relacionados
WWDC26
WWDC25
-
Buscar este video…
-
-
8:31 - Hostfile format for a 4-node MLX cluster
[ { "ssh": "m3-ultra-0", "ips": ["192.168.1.10"], "rdma": [null, "rdma_en5", "rdma_en4", "rdma_en3"] }, { "ssh": "m3-ultra-1", "ips": ["192.168.1.11"], "rdma": ["rdma_en5", null, "rdma_en4", "rdma_en3"] }, { "ssh": "m3-ultra-2", "ips": ["192.168.1.12"], "rdma": ["rdma_en5", "rdma_en4", null, "rdma_en3"] }, { "ssh": "m3-ultra-3", "ips": ["192.168.1.13"], "rdma": ["rdma_en5", "rdma_en4", "rdma_en3", null] } ] -
8:56 - Generate the cluster hostfile with mlx.distributed_config
mlx.distributed_config \ --hosts m3-ultra-0,m3-ultra-1,m3-ultra-2,m3-ultra-3 \ --output "m3-ultra-jaccl.json" \ --env MLX_METAL_FAST_SYNCH=1 \ --auto-setup \ --backend jaccl -
11:04 - Run distributed LLM inference with mlx_lm.chat
# Single-device LLM inference mlx_lm.chat --model "Qwen/Qwen3.6-27B" --max-tokens 2048 # Distributed LLM inference across the cluster mlx.launch --hostfile "m3-ultra-jaccl.json" -- \ /remote/path/to/mlx_lm.chat --model "Qwen/Qwen3.6-27B" --max-tokens 2048 -
15:03 - Run distributed inference with pipeline parallelism
# Tensor parallelism (default) mlx.launch --hostfile "m3-ultra-jaccl.json" -- \ /remote/path/to/mlx_lm.chat --model "moonshotai/Kimi-K2.6" \ --max-tokens 2048 # Pipeline parallelism — append --pipeline flag mlx.launch --hostfile "m3-ultra-jaccl.json" -- \ /remote/path/to/mlx_lm.chat --model "moonshotai/Kimi-K2.6" \ --max-tokens 2048 \ --pipeline -
17:18 - Run distributed fine-tuning with mlx_lm.lora
# Single-device fine-tuning mlx_lm.lora --model "Qwen/Qwen3.5-9B" \ --data "mlx-community/wikisql" \ --train --batch-size 4 # Distributed fine-tuning (scale --batch-size by number of devices) mlx.launch --hostfile "hostfile.json" -- \ /remote/path/to/mlx_lm.lora --model "Qwen/Qwen3.5-9B" \ --data "mlx-community/wikisql" \ --train --batch-size 16 -
19:01 - Distributed inference with the MLX LM Python API
import mlx.core as mx from mlx_lm import stream_generate from mlx_lm.utils import sharded_load # Initialise distributed backend group = mx.distributed.init(strict=True, backend="jaccl") # Define parallelism tensor_group, pipeline_group = group, None # Shard the model model, tokenizer = sharded_load("moonshotai/Kimi-K2.6", pipeline_group, tensor_group) for response in stream_generate(model, tokenizer, prompt, max_tokens=1024): if group.rank() == 0: print(response.text, end="", flush=True) -
19:31 - Shard a layer with the MLX Python API
import mlx.core as mx import mlx.nn as nn # Initialise distributed backend group = mx.distributed.init(strict=True, backend="jaccl") # Define layer and shard it column-wise layer = nn.Linear(1024, 1024) sharded_layer = nn.layers.distributed.shard_linear( layer, strategy="all-to-sharded", group=group ) data = mx.random.normal((1, 1, 1024)) output = sharded_layer(data) mx.eval(output) -
19:47 - All-reduce across devices in Python, Swift, and C++
# Python import mlx.core as mx world = mx.distributed.init(strict=True, backend="jaccl") data = mx.full((4,), float(world.rank()), dtype=mx.float32) result = mx.distributed.all_sum(data, group=world) mx.eval(result) # Swift let group = try DistributedGroup(strict: .ring) let data = rank == 0 ? MLXArray(converting: [1.0, 2.0, 3.0]) : MLXArray(converting: [5.0, 6.0, 7.0]) let result = try group.allSum(data) // C++ namespace mx = mlx::core; auto world = mx::distributed::init(/* strict */ true, "jaccl"); mx::array data = mx::full({4}, static_cast<float>(world.rank()), mx::float32); mx::array result = mx::distributed::all_sum(data, world); mx::eval(result); -
20:06 - Standalone distributed sum with the JACCL C++ API
#include <jaccl/jaccl.h> #include <iostream> int main() { // Initialize JACCL group auto group = jaccl::init(); std::cout << "Rank " << group->rank() << " of " << group->size() << std::endl; // Perform all-reduce sum float data[10] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}; float output[10]; group->all_sum(data, output, sizeof(data), jaccl::Float32); std::cout << "Result: " << output[0] << std::endl; return 0; }
-
-
- 0:00 - Introduction
Overview of why distributed AI becomes necessary as models grow larger, and a preview of what the session covers: CLI tools, Python API, and Swift for embedding distributed workflows in your apps.
- 2:09 - Distributed communication
A walkthrough of the full hardware and software stack enabling distributed workloads on Apple silicon: RDMA over Thunderbolt 5 for low-latency data movement, JACCL (open-source collective communication library), and MLX as the ML framework that ties them together.
- 4:32 - Setting up your cluster
How to physically connect four M3 Ultras into a cluster — understanding latency vs. bandwidth trade-offs, choosing between mesh and ring topologies, enabling RDMA in System Settings, and using mlx.distributed_config and mlx.launch to configure and orchestrate the cluster.
- 10:33 - Distributed inference and fine-tuning
How to run distributed LLM inference with MLX LM using a single CLI command — wrapping mlx_lm.chat with mlx.launch to shard a 27B-parameter Qwen model across four M3 Ultras, achieving nearly 3x the token generation rate of a single machine.
- 13:35 - Model parallelism strategies
How MLX LM splits large models across machines using tensor parallelism (splitting by width for faster inference) and pipeline parallelism (splitting by depth for simpler communication) — including a demo running the 1-trillion-parameter Kimi 2.6 model across four Macs.
- 15:53 - Distributed fine-tuning
How data-parallel training accelerates fine-tuning by replicating the model across machines, processing different data batches in parallel, and averaging gradients — demonstrated fine-tuning Qwen 3.5 (9B) at over 3x throughput on the cluster versus a single M3 Ultra.
- 18:34 - CLI, Python, Swift, and C++ APIs
How to use MLX's fine-grained Python, Swift, and C++ APIs for distributed inference — initializing a distributed group, sharding models with tensor parallelism, using low-level all_reduce primitives, and leveraging JACCL standalone for non-ML distributed workloads.
- 20:45 - Next steps
Summary of the full distributed stack — from RDMA over Thunderbolt to MLX and MLX LM — and next steps including the companion session on local agentic AI, documentation on custom parallelism strategies, and the built-in MLX LM distributed server.