
-
开始使用适用于 Apple 芯片的 MLX
MLX 是一个灵活高效的阵列框架,适用于 Apple 芯片上的数字计算和机器学习。我们将探索统一内存、懒性计算和函数转换等基本功能。我们还将了解一些有关使用 Swift 和 Python API 来构建和加速支持不同 Apple 平台的机器学习模型的更高级技巧。
章节
资源
- MLX
- MLX LM - Python API
- MLX Examples
- MLX Explore - Python API
- MLX Framework
- MLX Llama Inference
- MLX Swift
- MLX Swift Examples
相关视频
WWDC25
-
搜索此视频…
大家好 我叫 Awni 今天 我很高兴 为大家介绍 MLX MLX 是专为 Apple 芯片构建的开源阵列框架 它高度灵活 既适用于基础数值计算 也能在 Apple 设备上运行 最前沿的大规模机器学习模型 如果要使用大语言模型生成文本 或使用最新的模型生成图像、音频 甚至视频 MLX 将是你的理想选择 你也可以直接在 Mac 上训练、微调 或自定机器学习模型 我首先会介绍 MLX 的基本概念 以及适用的场景 我还会讲解有关在 Python 中使用 MLX 的基础知识 包括安装步骤和基础的数组运算 接着 我会介绍 MLX 相较于 其他框架的 一些核心优势与特色功能 接下来 我会带你了解 MLX 提供的一些工具 帮助你在 Apple 芯片上 尽可能快速地运行各类机器学习任务
最后 我会简要介绍 MLX Swift 并演示如何快速上手 下面 我们将开始了解 MLX MLX 从一开始就是为 Apple 芯片量身打造的 具备卓越的性能表现 它可在 CPU 上运行 也可借助 GPU 实现高效加速 MLX 适用于各种应用场景 从小规模的数值计算到大规模的 机器学习任务 都能胜任 它在保持高速与高效的前提下 兼具易用性与灵活性
MLX 提供的核心 API 与 NumPy 的设计高度一致 通常可以用它直接替代 NumPy 用于加速大多数 数值计算任务 MLX 还内建了 机器学习所需的各种工具 涵盖自动微分机制和高阶功能库 这些高阶 API 类似于 PyTorch 和 JAX 如果你用过这些框架 MLX 用起来一点也不会感到陌生 甚至更容易上手 你可以直接在设备上使用 MLX 执行更高级的机器学习任务 例如 MLX 被广泛应用于 LM Studio 这是一款人气颇高的应用程序 可直接在 Mac 上 通过运行大语言模型来生成文本
你可以使用基于 MLX 构建的 MLX LM 软件包来生成文本 或对参数规模高达数千亿的语言模型 进行微调 要进一步了解更多相关内容 请观看讲座 “借助 MLX 在 Apple 芯片上探索大语言模型” MLX 提供功能完备的 Python API 非常适合快速开发原型 它还提供了一个 Swift API 其中包含构建和优化神经网络所需的 各类高阶工具包 MLX 还提供 C++ 和 C 的 API 可以使用上述任一语言 通过 MLX 在 Apple 芯片上运行最新的机器学习模型 其中包括 Mac、iPhone、 iPad 和 Apple Vision Pro 所有 MLX 软件均基于 宽松的 MIT 许可协议开源发布 MLX 的核心软件已在 GitHub 上提供 同时还有多个使用 Python 和 Swift API 构建的示例项目和软件包 MLX 在 Hugging Face 上 也拥有活跃的模型创作者社区 许多最新模型 已发布在 MLX 社区的 Hugging Face 组织中 并且每天都有新的模型持续上传 使用 Python 开始体验 MLX 最简单 的方式是通过 PyPi 进行安装 只需在终端中运行单行命令 pip3 install mlx 即可 MLX 简单易用 要开始进行数组计算 只需打开一个 Python 文件 并导入 MLX 即可 接下来就可以创建一些数组 并执行基础运算了 例如 这里我们对两个 整数数组进行相加运算
你也可以轻松查看数组的相关信息 例如它的形状和数据类型 我之前提到 MLX Python API 与 NumPy 十分相似 大多数操作的名称、参数签名以及 行为都与 NumPy 保持一致 如果用过 NumPy 或类似框架 MLX 用起来一点也不会感到陌生 并且容易上手 现在你已经了解了 MLX 的 基本概念、应用场景以及基础用法 接下来让我们看看它的一些核心特性 这些特性包括统一内存管理、 延迟计算、函数变换 以及用于构建和优化 神经网络的高阶软件包 MLX 的设计充分发挥了 Apple 芯片的优势 其中包括专为统一内存管理 设计的新型编程模型
大多数常用于机器学习的系统 采用的是独立 GPU 并拥有独立的内存 而 Apple 芯片 则采用统一内存架构 也就是说 CPU 和 GPU 共享 同一块物理内存
为了充分利用统一内存架构 MLX 的运作方式与传统框架有所不同 在传统框架中 计算通常是由数据驱动的 如果数组位于 CPU 内存中 计算就会在 CPU 上执行 如果数组位于 GPU 内存中 计算就会在 GPU 上执行
在 MLX 内 数组会被分配在统一内存中 无需将数组拷贝到任何其他地方 即可在所有受支持的 设备上直接使用
相反 要在设备上执行运算 只需为运算指定目标设备即可 例如 这里我们在 GPU 上 对 a 和 b 执行加法运算 而在 CPU 上对它们执行乘法运算 这些运算甚至可以并行执行 MLX 会在有依赖关系时 自动管理这些依赖关系 MLX 的另一项核心功能是延迟计算 为了提升执行效率 尤其是在处理大规模计算时 MLX 采用了延迟执行引擎
在对两个数组执行加法等运算时 系统并不会立刻进行实际计算 相反 系统会构建一个计算图 就像你现在看到的这样
这时数组 C 尚未计算 只有真正需要使用它时 系统才会触发计算 例如 当你打印 C 或将它从 MLX 数组 转换为 Python 列表时 系统才会真正执行计算 你也可以通过调用 mx.eval 来 显式执行计算图 强制完成计算 延迟计算有几个好处 通过对计算图的构建与执行进行解耦 MLX 可以在计算出结果之前对图形进行 转换与优化 此外 借助延迟计算机制 用户只需为实际使用的部分承担开销 函数变换 是 MLX 的另一项核心功能 有了它 MLX 不仅仅是一个 高性能数组框架 更成为训练与优化 机器学习模型的强大工具
函数变换是一类以函数作为输入 并返回新函数作为输出的转换操作
MLX 提供了多种函数变换机制 这些函数变换大致可分为两类 一类是用于自动微分的变换 另一类则用于优化计算图 例如 你可以通过函数变换自动 计算任意 MLX 函数的梯度
假设你有一个用于计算 输入值正弦的基本函数 要获取这个函数的梯度 只需使用 mx.grad 函数变换即可 mx.grad 会返回一个新函数 将数组传入这个函数后 它会返回对应的导数值 函数变换具有高度可组合性 可以通过对 mx.grad 的输出 再次使用 mx.grad 轻松计算二阶导数
这样得到的结果仍是一个函数 在传入数组时 它会返回这个函数的二阶导数值
MLX 还提供了两个高阶软件包 有助于更轻松地构建与训练神经网络
第一个是 mlx.nn 它是一个 模块化的库 用于构建神经网络结构 第二个是 mlx.optimizers 它是一个由常用优化算法构成的库 这两个软件包既可独立使用 也可无缝整合在一起
mlx.nn 软件包提供构建神经网络 所需的全部功能 nn.Module 是所有 层与容器继承的基类 它公开了一系列实用的方法 用于访问、加载和 保存参数等操作
nn 库还附带了一套标准且现成的层 如 nn.Linear 同时也可以通过继承 nn.Module 来构建自己的层
常用的损失函数和参数初始化方法 也已包含在 nn.losses 和 nn.init 子软件包中
下面我们来看一个示例 了解如何使用 mlx.nn 构建 简单的多层神经网络
第一步是创建一个从 nn.Module 继承的自定类 在这个示例中 我们会构建一个 带有单一隐藏层的简单神经网络 我们在模块的初始化方法中 创建线性层 接着我们实现调用函数 用于根据输入 计算这个模块的输出结果 调用函数先调用第一层线性层 接着应用 ReLU 激活函数 最后再调用第二层线性层 尽管 MLX 是专为 Apple 芯片的 统一内存架构设计与优化的 但它的高阶神经网络软件包 也借鉴了 PyTorch 等 常用机器学习框架的设计风格 便于快速上手
我们来对比一下使用 MLX 和 PyTorch 实现相同模型的方式 两者几乎完全相同 只有输出计算函数中 存在两处细微差别 如果你有使用 PyTorch 构建模型的经验 迁移到 MLX 会非常顺畅 现在你已经了解了 MLX 的大部分核心功能 接下来让我们来看看如何充分 发挥它的优势 让你的机器学习 工作负载更快、更高效 我会从函数编译开始 介绍如何通过 编译提升函数执行效率 然后 我会给大家介绍 mx.fast 子软件包 它提供了一套现成的高性能 常用机器学习操作 同时还提供 API 用于添加你自己的自定 Metal 内核 之后 我们将了解如何使用量化 降低内存占用、加快模型执行 最后 我会介绍 如何使用 MLX 在多台设备之间分布式执行计算任务 在 MLX 中 几乎所有 实际的计算都由 对数组进行一系列操作的函数构成 加速这类函数的一个简单方法 是使用 mx.compile 函数变换 假设你有一个函数 它执行了多个逐元素操作 例如如下所示的 GELU 激活函数 这个函数对应的计算图中 包含多个节点 每个节点在底层 都对应一次 GPU 内核调用
使用编译功能可以将这些分散的内核 融合为一个单一的内核 从而减少内存带宽消耗以及执行开销 并显著提升计算效率
使用 mx.compile 十分简单 只需为目标函数 添加 mx.compile 装饰器即可 编译通常效果很好 但对于一些更复杂的操作 特别是机器学习中的常见计算任务 使用 mx.fast 子软件包 可能会获得更出色的性能 例如 Transformer 模型中的许多核心构建模块 就使用了 mx.fast 中的操作 这些操作包括位置编码 归一化层以及缩放点积注意力机制等 mx.fast 中的操作 虽然更具专用性 但都经过深度优化 能够在 训练和推理阶段都实现高性能 此外 它们也具有高度可配置性 能灵活支持给定计算的多种变体 例如 缩放点积注意力操作可以 接收可选的 mask 参数 这个 mask 可以是 加性掩码、布尔掩码 也可以是一个字符串 用于指示掩码类型 让我们仔细看看 mx.fast 中的 RMSNorm 操作 RMSNorm 几乎应用于所有现代基于 Transformer 的大语言模型中 使用 MLX 操作的基础实现 通常会构建出较大的计算图
而使用 mx.fast.rms_norm 只需一行代码 即可完成相同功能 不仅代码更简洁 计算图也简化为一个节点 计算效率也会显著提升
MLX 提供了一个用于添加自定 Metal 内核的 API 适用于那些 尚未包含在 mx.fast 中 但通过定制实现可获得 更高性能的函数 只需编写自定的 Metal 内核代码 其余工作都由 MLX 自动处理 包括即时编译与调度执行 这些内核是用 Metal 编写的 它是 Apple 推出的编程语言和 API 专用于在 Apple GPU 上执行函数 只需提供 Metal 源代码字符串 以及关于输入和输出的相关信息 即可构建内核
通过指定线程网格大小以及输出的 形状与类型来调用内核 MLX 会将对内核的调用 视为普通操作的一部分 它会在计算图中创建一个节点 并采用延迟计算的方式执行
在优化机器学习任务 性能的众多工具中 量化是另一项重要方法 大型模型需占用大量内存与带宽 才能实现快速运行 而在很多情况下 推理所需精度 通常低于训练阶段 但依然能保持相似效果 降低精度可让模型占用更少的内存 从而容纳更大的模型并提升运行速度
如果你的模型使用的是 32 位浮点精度 可以使用 bfloat16 或 float16 作为第一步 将内存需求减少一半 当 16 位仍然过大时 MLX 提供了内建方法 可将数组进一步量化为更小的格式 并对它们执行运算 例如 你可以将每个元素量化为 4 位 以进一步降低内存需求
要对矩阵进行量化 可以使用 mx.quantize 你可以指定每个元素使用的 位数和分组大小 分组大小决定了量化矩阵中 共享同一缩放值和偏移值的元素数量 位数越少、分组越大 得到的结果 就越小巧 运行也越快 MLX 提供多种位数 和分组大小选项 尽可能为你提供最大的灵活性
可以使用 mx.quantized_matmul 将任意 未经量化的向量或矩阵 与量化矩阵相乘
使用 mx.dequantize 可近似还原为原始输入
MLX.mn 还提供了一个实用工具 只需一条命令即可量化整个模块 假设你有一个模型 由一个嵌入层和 和多个线性层堆叠而成 你可以使用 nn.quantize 对整个模型进行量化 Quantize 命令还支持 可选的回调函数 有助于更细致地控制要量化的层 以及为特定层选择使用的精度
在使用大型语言模型生成文本时 量化能显著降低内存占用 并提升每秒生成的 Token 数量 在某些情况下 仅靠一台机器是远远不够的 例如 你可能希望 使用一个无法完全装入 单台机器内存的大型模型来生成文本 或者你正在微调模型 或在大型数据集上进行评估 这两者都易于并行处理 使用多台机器可以显著加快速度
MLX 内建分布式计算能力 可轻松在多台机器上 执行各类计算任务 这些机器可以通过以太网或雷雳连接
可以使用 mx.distributed 子软件包 在多台机器之间分发计算任务 MX distributed 主要是一组通信操作
例如 all_Sum 会对所有机器上的输入数组进行求和 all_sum 的输出是所有输入的总和 并在每台机器上返回相同的结果 我们来详细看看如何在多台 机器之间对数组进行求和 使用 mx.distributed.init 初始化分布式后端 这一步是可选的 但如果需要访问通信组 就需要调用它 通信组中包含一些实用信息 例如进程总数和当前进程的索引
然后 在每个进程上创建 一个只包含单个值的数组 并调用 mx.distributed.all_sum 来对所有进程的数组进行求和
MLX 提供一个便捷的启动器 可用于在多台机器上运行 MLX 程序 要在 4 台机器上运行程序 可以使用 mlx.launch 并 传入这 4 台机器的 IP 地址 目前为止我们介绍的所有内容 都是在 Python MLX 在很多情况下 你可能更喜欢 Python 的灵活性和便捷性 而在某些情况下 你可能更偏好使用 Swift 因此 MLX 也提供了 功能完整的 Swift API 它构建于 Metal 之上 可在 macOS、iOS、 iPadOS、visionOS 等平台上高效运行 开始在 Swift 中使用 MLX 非常简单 只需将它作为软件包添加 到 Xcode 项目中即可 点击项目文件 然后在 “软件包依赖项”选项卡中 点击加号图标 然后输入 MLX Swift GitHub 仓库的链接地址 并点击“添加软件包”按钮 就这么简单 你就可以 开始使用 MLX Swift 开发了 为了尽可能简化 Python 与 Swift 之间的迁移 这两种语言的 API 设计刻意保持一致 下面是我们之前看到的 Python 代码片段 以及与它对应的 MLX Swift 之间的对比 在 Swift 中创建数组、 对数组执行运算 以及查看数组的元数据 与在 Python 中几乎相同 我们之前介绍的在 Python 中使用 MLX 的所有核心功能 以及“加速 MLX”部分 提到的各项优化在 MLX Swift 中同样适用 我们已经大致了解了 MLX 的多种关键功能 要想进一步了解这个框架 可以访问 MLX 网站 其中提供相关文档、示例代码 等丰富资源 Python 和 Swift 的 API 都配有示例代码库 涵盖了许多常见的机器学习应用场景 包括语言模型的训练与生成、 图像生成、语音识别等 这些示例是进一步了解 MLX 的好方法 也为你在自己的项目中使用 MLX 进行构建提供了良好的起点 谢谢观看 非常期待看到大家用 MLX 开发 的精彩作品
-
-
3:48 - Basics
import mlx.core as mx # Make an array a = mx.array([1, 2, 3]) # Make another array b = mx.array([4, 5, 6]) # Do an operation c = a + b # Access information about the array shape = c.shape dtype = c.dtype print(f"Result c: {c}") print(f"Shape: {shape}") print(f"Data type: {dtype}")
-
5:31 - Unified memory
import mlx.core as mx a = mx.array([1, 2, 3]) b = mx.array([4, 5, 6]) c = mx.add(a, b, stream=mx.gpu) d = mx.multiply(a, b, stream=mx.cpu) print(f"c computed on the GPU: {c}") print(f"d computed on the CPU: {d}")
-
6:20 - Lazy computation
import mlx.core as mx # Make an array a = mx.array([1, 2, 3]) # Make another array b = mx.array([4, 5, 6]) # Do an operation c = a + b # Evaluates c before printing it print(c) # Also evaluates c c_list = c.tolist() # Also evaluates c mx.eval(c) print(f"Evaluate c by converting to list: {c_list}") print(f"Evaluate c using print: {c}") print(f"Evaluate c using mx.eval(): {c}")
-
7:32 - Function transformation
import mlx.core as mx def sin(x): return mx.sin(x) dfdx = mx.grad(sin) def sin(x): return mx.sin(x) d2fdx2 = mx.grad(mx.grad(mx.sin)) # Computes the second derivative of sine at 1.0 d2fdx2(mx.array(1.0))
-
9:16 - Neural Networks in MLX
import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim class MLP(nn.Module): """A simple MLP.""" def __init__(self, dim, h_dim): super().__init__() self.linear1 = nn.Linear(dim, h_dim) self.linear2 = nn.Linear(h_dim, dim) def __call__(self, x): x = self.linear1(x) x = nn.relu(x) x = self.linear2(x) return x
-
9:57 - MLX and PyTorch
# MLX version import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim class MLP(nn.Module): """A simple MLP.""" def __init__(self, dim, h_dim): super().__init__() self.linear1 = nn.Linear(dim, h_dim) self.linear2 = nn.Linear(h_dim, dim) def __call__(self, x): x = self.linear1(x) x = nn.relu(x) x = self.linear2(x) return x # PyTorch version import torch import torch.nn as nn import torch.optim as optim class MLP(nn.Module): """A simple MLP.""" def __init__(self, dim, h_dim): super().__init__() self.linear1 = nn.Linear(dim, h_dim) self.linear2 = nn.Linear(h_dim, dim) def forward(self, x): x = self.linear1(x) x = x.relu() x = self.linear2(x) return x
-
11:35 - Compiling MLX functions
import mlx.core as mx import math def gelu(x): return x * (1 + mx.erf(x / math.sqrt(2))) / 2 @mx.compile def compiled_gelu(x): return x * (1 + mx.erf(x / math.sqrt(2))) / 2 x = mx.random.normal(shape=(4,)) out = gelu(x) compiled_out = compiled_gelu(x) print(f"gelu: {out}") print(f"compiled gelu: {compiled_out}")
-
12:32 - MLX Fast package
import mlx.core as mx import time def rms_norm(x, weight, eps=1e-5): y = x.astype(mx.float32) y = y * mx.rsqrt(mx.mean( mx.square(y), axis=-1, keepdims=True, ) + eps) return (weight * y).astype(x.dtype) batch_size = 8192 feature_dim = 4096 iterations = 1000 x = mx.random.normal([batch_size, feature_dim]) weight = mx.ones(feature_dim) bias = mx.zeros(feature_dim) start_time = time.perf_counter() for _ in range(iterations): y = rms_norm(x, weight, eps=1e-5) mx.eval(y) rms_norm_time = time.perf_counter() - start_time print(f"rms_norm execution: {gelu_time:0.4f} sec") start_time = time.perf_counter() for _ in range(iterations): mx.eval(mx.fast.rms_norm(x, weight, eps=1e-5)) fast_rms_norm_time = time.perf_counter() - start_time print(f"mx.fast.rms_norm execution: {compiled_gelu_time:0.4f} sec") print(f"mx.fast.rms_norm speedup: {rms_norm_time/fast_rms_norm_time:0.2f}x")
-
13:30 - Custom Metal kernel
import mlx.core as mx # Build the kernel source = """ uint elem = thread_position_in_grid.x; out[elem] = metal::exp(inp[elem]); """ kernel = mx.fast.metal_kernel( name="myexp", input_names=["inp"], output_names=["out"], source=source, ) # Call the kernel on a sample input x = mx.array([1.0, 2.0, 3.0]) out = kernel( inputs=[x], grid=(x.size, 1, 1), threadgroup=(256, 1, 1), output_shapes=[x.shape], output_dtypes=[x.dtype], )[0] print(out)
-
14:41 - Quantization
import mlx.core as mx x = mx.random.normal([1024]) weight = mx.random.normal([1024, 1024]) quantized_weight, scales, biases = mx.quantize( weight, bits=4, group_size=32, ) y = mx.quantized_matmul( x, quantized_weight, scales=scales, biases=biases, bits=4, group_size=32, ) w_orig = mx.dequantize( quantized_weight, scales=scales, biases=biases, bits=4, group_size=32, )
-
15:23 - Quantized models
import mlx.nn as nn model = nn.Sequential( nn.Embedding(100, 32), nn.Linear(32, 32), nn.Linear(32, 32), nn.Linear(32, 1), ) print(model) nn.quantize( model, bits=4, group_size=32, ) print(model)
-
16:50 - Distributed
import mlx.core as mx group = mx.distributed.init() world_size = group.size() rank = group.rank() x = mx.array([1.0]) x_sum = mx.distributed.all_sum(x) print(x_sum)
-
17:20 - Distributed launcher
mlx.launch --hosts ip1, ip2, ip3, ip4 my_script.py
-
18:20 - MLX Swift
// Swift import MLX // Make an array let a = MLXArray([1, 2, 3]) // Make another array let b = MLXArray([1, 2, 3]) // Do an operation let c = a + b // Access information about the array let shape = c.shape let dtype = c.dtype // Print results print("a: \(a)") print("b: \(b)") print("c = a + b: \(c)") print("shape: \(shape)") print("dtype: \(dtype)")
-
-
- 0:00 - 简介
MLX 是专为 Apple 芯片打造的开源阵列框架。它能够高效地执行机器学习任务,并允许你使用 Python 和 Swift 直接在设备上运行大型语言模型。
- 1:15 - MLX 概览
这个高性能机器学习框架针对 Apple 芯片进行了优化,可以在 CPU 和 GPU 上快速完成数值计算和机器学习任务。它拥有类似 NumPy 的核心 API,而高级 API 则类似于 PyTorch 和 JAX。 你可以在 LM Studio 等应用程序中使用它,通过大语言模型在设备端生成文本。MLX 提供了 Python、Swift、C++ 和 C 的 API。 MLX 在 MIT 许可协议下开源,并且在 Hugging Face 上拥有活跃的社区。
- 4:21 - 主要功能
MLX 非常高效,因为它针对 Apple 芯片量身定制,充分利用了芯片的统一内存架构。由于在 CPU 和 GPU 之间共享内存,因此无需拷贝数据;操作只需指定所需的设备即可。 MLX 不会立即执行计算,而是构建计算图,仅在需要获取结果时才执行。 通过函数变换,MLX 可以将函数作为输入并返回新的函数,便于进行自动微分和其他优化。 MLX 包括了用于构建和训练神经网络的高级软件包,以及常用机器学习操作。这些软件包采用模块化设计,类似 PyTorch 等热门框架,方便开发者轻松切换。
- 10:15 - 加速 MLX
使用“mx.compile”进行函数编译会将多个 GPU 内核启动融合到单个内核中,从而减少内存带宽和执行开销。 对于更复杂的操作,“mx.fast”子软件包提供了高度优化的常用机器学习操作 (例如 RMSNorm 和注意力机制) 专用实现。 MLX 支持量化,能以更少的内存开销实现更快、更高效的推理,从而在不过度影响质量的情况下降低精度。 大规模计算可以利用“mx.distributed”子软件包在多台机器之间分配任务。
- 17:30 - MLX Swift
MLX 还提供 Swift API,这个 API 实现了同样的效率提升,可在 Xcode 中跨 Apple 平台无缝开发。 要开始使用并了解详细信息,请访问 MLX 网站或下载示例代码库。