-
Explora la inferencia y el entrenamiento distribuidos con MLX
Amplía las cargas de trabajo de aprendizaje automático en varias Mac con MLX. Descubre cómo abordar los retos relacionados con la eficiencia de las interconexiones, la inferencia de modelos de gran tamaño, el procesamiento por lotes de solicitudes y el entrenamiento distribuido. Descubre cómo unas pocas Mac en el escritorio pueden sustituir a una costosa infraestructura en la nube para cargas de trabajo de IA exigentes.
Capítulos
- 0:00 - Introducción
- 2:09 - Comunicación distribuida
- 4:32 - Configuración del clúster
- 10:33 - Inferencia distribuida y ajuste fino
- 13:35 - Estrategias de paralelismo de modelos
- 15:53 - Ajuste fino distribuido
- 18:34 - API de CLI, Python, Swift y C++
- 20:45 - Próximos pasos
Recursos
- MLX Swift LM on GitHub
- MLX Swift Examples
- MLX Examples
- MLX Swift
- MLX LM - Python API
- MLX Explore - Python API
- MLX Framework
- MLX
Videos relacionados
WWDC26
WWDC25
-
Buscar este video…
-
-
8:31 - Hostfile format for a 4-node MLX cluster
[ { "ssh": "m3-ultra-0", "ips": ["192.168.1.10"], "rdma": [null, "rdma_en5", "rdma_en4", "rdma_en3"] }, { "ssh": "m3-ultra-1", "ips": ["192.168.1.11"], "rdma": ["rdma_en5", null, "rdma_en4", "rdma_en3"] }, { "ssh": "m3-ultra-2", "ips": ["192.168.1.12"], "rdma": ["rdma_en5", "rdma_en4", null, "rdma_en3"] }, { "ssh": "m3-ultra-3", "ips": ["192.168.1.13"], "rdma": ["rdma_en5", "rdma_en4", "rdma_en3", null] } ] -
8:56 - Generate the cluster hostfile with mlx.distributed_config
mlx.distributed_config \ --hosts m3-ultra-0,m3-ultra-1,m3-ultra-2,m3-ultra-3 \ --output "m3-ultra-jaccl.json" \ --env MLX_METAL_FAST_SYNCH=1 \ --auto-setup \ --backend jaccl -
11:04 - Run distributed LLM inference with mlx_lm.chat
# Single-device LLM inference mlx_lm.chat --model "Qwen/Qwen3.6-27B" --max-tokens 2048 # Distributed LLM inference across the cluster mlx.launch --hostfile "m3-ultra-jaccl.json" -- \ /remote/path/to/mlx_lm.chat --model "Qwen/Qwen3.6-27B" --max-tokens 2048 -
15:03 - Run distributed inference with pipeline parallelism
# Tensor parallelism (default) mlx.launch --hostfile "m3-ultra-jaccl.json" -- \ /remote/path/to/mlx_lm.chat --model "moonshotai/Kimi-K2.6" \ --max-tokens 2048 # Pipeline parallelism — append --pipeline flag mlx.launch --hostfile "m3-ultra-jaccl.json" -- \ /remote/path/to/mlx_lm.chat --model "moonshotai/Kimi-K2.6" \ --max-tokens 2048 \ --pipeline -
17:18 - Run distributed fine-tuning with mlx_lm.lora
# Single-device fine-tuning mlx_lm.lora --model "Qwen/Qwen3.5-9B" \ --data "mlx-community/wikisql" \ --train --batch-size 4 # Distributed fine-tuning (scale --batch-size by number of devices) mlx.launch --hostfile "hostfile.json" -- \ /remote/path/to/mlx_lm.lora --model "Qwen/Qwen3.5-9B" \ --data "mlx-community/wikisql" \ --train --batch-size 16 -
19:01 - Distributed inference with the MLX LM Python API
import mlx.core as mx from mlx_lm import stream_generate from mlx_lm.utils import sharded_load # Initialise distributed backend group = mx.distributed.init(strict=True, backend="jaccl") # Define parallelism tensor_group, pipeline_group = group, None # Shard the model model, tokenizer = sharded_load("moonshotai/Kimi-K2.6", pipeline_group, tensor_group) for response in stream_generate(model, tokenizer, prompt, max_tokens=1024): if group.rank() == 0: print(response.text, end="", flush=True) -
19:31 - Shard a layer with the MLX Python API
import mlx.core as mx import mlx.nn as nn # Initialise distributed backend group = mx.distributed.init(strict=True, backend="jaccl") # Define layer and shard it column-wise layer = nn.Linear(1024, 1024) sharded_layer = nn.layers.distributed.shard_linear( layer, strategy="all-to-sharded", group=group ) data = mx.random.normal((1, 1, 1024)) output = sharded_layer(data) mx.eval(output) -
19:47 - All-reduce across devices in Python, Swift, and C++
# Python import mlx.core as mx world = mx.distributed.init(strict=True, backend="jaccl") data = mx.full((4,), float(world.rank()), dtype=mx.float32) result = mx.distributed.all_sum(data, group=world) mx.eval(result) # Swift let group = try DistributedGroup(strict: .ring) let data = rank == 0 ? MLXArray(converting: [1.0, 2.0, 3.0]) : MLXArray(converting: [5.0, 6.0, 7.0]) let result = try group.allSum(data) // C++ namespace mx = mlx::core; auto world = mx::distributed::init(/* strict */ true, "jaccl"); mx::array data = mx::full({4}, static_cast<float>(world.rank()), mx::float32); mx::array result = mx::distributed::all_sum(data, world); mx::eval(result); -
20:06 - Standalone distributed sum with the JACCL C++ API
#include <jaccl/jaccl.h> #include <iostream> int main() { // Initialize JACCL group auto group = jaccl::init(); std::cout << "Rank " << group->rank() << " of " << group->size() << std::endl; // Perform all-reduce sum float data[10] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}; float output[10]; group->all_sum(data, output, sizeof(data), jaccl::Float32); std::cout << "Result: " << output[0] << std::endl; return 0; }
-