本文首发于知乎,现迁移至个人博客。
目录
本讲从张量开始,完整搭建模型、优化器与训练循环;重点掌握内存核算与算力核算,建立「效率优先」的大模型开发思维。
一、为什么要做资源核算?——两道餐巾纸速算题开场
效率是大模型开发的核心,每一笔算力和显存开销,最终都会直接换算成白花花的银子。而所有的成本估算,都可以用「餐巾纸速算(napkin math)」完成,先看两个问题:
问题1:在 1024 块 H100 显卡上,用 15 万亿个 Token 训练一个 700 亿参数的稠密 Transformer 模型需要多长时间?
- 训练总浮点运算量(FLOPs)= 6 × 参数总量 × Token总量(这个6是咋来的后面有讲解推导)
- 参考H100的理论峰值算力和模型浮点利用率(MFU)为0.5(行业常规水平)
- 计算1024张H100一天能提供的总算力,用总需求总算力 ÷ 单日算力,最终得到结果:约144天
问题2:如果不使用特别精巧的优化手段,在 8 块 H100 上用 AdamW 训练,最大能训多大的模型?
- 单张H100有80GB高带宽显存(HBM)
- AdamW训练时,单个参数需要占用16字节显存:参数本身、梯度、优化器的一阶矩+二阶矩,共4份,每份FP32占4字节,4×4=16字节
- 总参数量 = 总显存 ÷ 单参数字节数 = (8×80GB) ÷ 16B ≈ 400亿参数
P.S.: 这只是一个粗略估算,因为它没有计入激活值(activations),激活值的大小取决于批次大小和序列长度,这部分本讲不会细讲。
二、内存核算:从张量到浮点精度的全拆解
2.1 张量:深度学习的「原子」
张量是深度学习里存储所有数据的基础单元:参数、梯度、优化器状态、训练数据、激活值,全都是张量。 张量不是一个单纯的数学数组,可以理解为「指向显存块的指针 + 元数据(metadata)」:
- 张量底层是一段连续的内存/显存数组
- 元数据里的步长(stride),定义了每个维度的索引对应到底层数组的偏移量
- 多个张量可以共享同一块底层存储,不用复制数据,极大节省显存
2.2 视图(view)与连续(contiguous):张量的零成本形状变换
视图是 PyTorch 张量最常用的形状调整操作,也是显存优化的核心技巧,其底层逻辑和使用限制是新手最易踩坑的点,所有规则均基于张量的内存存储特性推导。 view 的核心设计目标是在不复制数据、不占用额外显存的前提下,修改张量的形状。PyTorch 中所有张量的底层都是连续的一维内存块,同时附带尺寸(shape)、步长(stride)等元数据,用于描述如何将一维内存解析为多维张量。view 操作仅修改张量的元数据,完全不触碰底层内存数据,因此属于零计算开销、零显存开销的免费操作。
- 数据共享:通过 view 生成的新张量,与原张量共用同一块底层内存,修改任意一个张量的值,另一个会同步变化;
- 无数据拷贝:不会创建新的内存空间,是最高效的形状变换方式;
- 严格限制:仅支持连续(contiguous) 张量使用,无法改变张量的总元素数量。
需要注意 transpose(转置)、permute(维度重排)等操作会改变张量的索引逻辑,导致张量变为非连续状态。此时张量的元数据与底层内存的映射关系被打乱,PyTorch 无法通过 view 重新解析形状,会直接抛出报错:view size is not compatible with input tensor's size and stride
这个时候用 contiguous () 会强制将非连续张量的底层数据重新排列为连续的内存块,并生成一个全新的张量:
- 执行后张量恢复连续状态,可以正常使用 view;
- 会触发数据拷贝,占用少量额外显存,不再与原张量共享内存;
- 非必要情况下不建议使用,避免额外开销。
2.3 浮点精度:显存占用与训练稳定性的核心权衡
所有张量的显存占用,都遵循:显存总占用 = 张量元素总数 × 单个元素的数据类型字节数 4×8的float32矩阵,元素总数32,单个元素4字节,总占用128字节;GPT-3前馈层的单个矩阵(12288×4 × 12288),显存占用直接达到2.3GB。
- **float32(FP32,单精度):**32位的构成:1位符号位 + 8位指数位 + 23位尾数位。是深度学习默认的数据类型。有些人把float32叫全精度,这很容易混淆。你跟搞科学计算的人说float32是全精度,他们会笑你——他们用的是float64甚至更高精度。但深度学习就是这么糙,float32基本就是我们需要的最高精度了。
- **float16(FP16,半精度):**16位的构成:1位符号位 + 5位指数位 + 10位尾数位。优势是显存直接减半,比FP32省一半空间。但是有个致命缺陷:动态范围极差,讲义验证:
1e-8会直接下溢成0;大模型训练时极易出现不稳定、上溢/下溢问题。现在基本不推荐在深度学习里用float16了。 - bfloat16(BF16,Brain Float)在2018年被提出,专门为深度学习设计的,16位的构成是:1位符号位 + 8位指数位 + 7位尾数位。他和FP16占用相同的显存,但动态范围和FP32完全一致,
1e-8不会下溢,完美解决了FP16的问题。Trade off :尾数位更少,数值精度更低,但深度学习对这个精度损失基本不敏感。一般来说,前向计算用BF16完全足够,但是参数、梯度、优化器状态,必须用FP32存储,不然训练会直接乱掉。 - **FP8(8位浮点数)**是2022年NVIDIA推出的,仅H100及以上显卡支持。他有两种标准变体: E4M3(范围[-448, 448],侧重精度)、E5M2([-57344, 57344],侧重范围)。现在业界在探索全程FP8训练,核心挑战是解决数值不稳定问题,需要配合数值控制技巧。
2.4 混合精度训练
混合精度训练就是在模型的不同环节,使用最低可行的精度,平衡显存、速度与稳定性。
- 前向传播(激活值)用BF16/FP8,参数、梯度、优化器状态用FP32
- PyTorch提供自动混合精度(AMP)工具,NVIDIA Transformer Engine支持线性层FP8加速
- 有一个通用的原则,长期累积更新的量(参数、优化器状态)用高精度FP32;临时的前向计算用低精度BF16/FP8
三、算力核算:从张量运算到FLOPs与MFU
3.1 张量的设备管理:CPU与GPU的鸿沟
PyTorch中张量是默认创建在CPU内存中的(x.device == cpu),**不用 GPU 的话训练速度会慢好几个量级。**有两种把张量从 CPU 转移到 GPU 的方式:
x.to("cuda:0")移动已有张量;- 直接创建GPU张量
torch.zeros(32,32, device="cuda:0")。
想精确控制计算和数据移动,写代码时必须时刻清楚手上的张量现在在CPU还是GPU,只看变量名和代码不一定能看出来,我们甚至可以在代码里加断言做校验。
3.2 张量运算:逐元素 + 矩阵乘法
张量运算可以分为两类,这两类操作在算力方面的开销差别很大。
- 逐元素操作:对每个元素独立计算(平方、开方、三角掩码等),FLOPs与张量大小呈线性关系,开销极低;
triu()上三角矩阵,专门用于生成因果注意力掩码,是Transformer必备操作 - 矩阵乘法:深度学习的算力核心,三次方复杂度开销,硬件专门为其优化
3.3 拯救代码可读性:einops 与 jaxtyping
课上在这里还提到了 Einops —— 一个用于简化张量操作的工具库,通过语义化维度命名替代 PyTorch 原生的负数字索引,解决维度易写错、难调试的问题,搭配 jaxtyping 标注维度,核心提供 einsum(矩阵运算)、reduce(归约)、rearrange(维度拆分合并)三个功能,可读性和维护性远优于原生写法。
1 | import torch |
- einsum(爱因斯坦求和):升级版矩阵乘法,用命名维度定义运算,自动处理广播、归约;
- rearrange:升级版
view(),优雅拆分/合并维度(如把hidden拆分为heads × hidden_dim); - reduce:优雅实现求和、均值、最大池化等归约操作;
- torch.compile会自动优化einsum,找最优的计算顺序,通常比手写的代码更高效
3.4 FLOPs
需要注意一点,小写s的FLOPs,代表的是 floating point operations(总浮点运算次数),衡量你做了多少计算;大写S的FLOPS 代表floating points per second(每秒浮点运算次数),衡量硬件的速度。这门课不用大写S,用/s表示每秒,避免混淆。
下面用几个直观的数字来感受一下这个概念:
- GPT-3训练总消耗约 3.14e23 FLOPs
- GPT-4训练总消耗(行业推测)约 2e25 FLOPs
- A100:FP32峰值 19.5TFLOP/s,BF16/FP16峰值 312 TFLOP/s
- H100:稠密矩阵峰值989.5 TFLOP/s(标称1979是稀疏优化,行业几乎不用)
H100的峰值带个星号,标注了with sparsity(带稀疏)。这个稀疏要求每4个元素里必须有2个是0,是特定的结构化稀疏,实际行业里根本没人用,基本就是营销用的数字。
3.5 矩阵乘法:深度学习的算力核心
一次矩阵乘法的总FLOPs = 2 × 左矩阵行数 × 公共维度 × 右矩阵列数
- 为什么是2?因为每个元素的计算,包含一次乘法 + 一次加法,各算1次FLOP
- 批量高维张量的矩阵乘法,会自动广播前序维度,算力公式依然成立
- 在深度学习中,对于足够大的矩阵,没有任何运算比矩阵乘法更贵。其他所有运算的开销都是线性的,只有矩阵乘法是三次方级别。
3.6 模型浮点利用率(MFU)
MFU = 实际有效FLOP/s ÷ 硬件理论峰值FLOP/s= (总FLOPs ÷ 实际耗时) ÷ 硬件标称峰值算力 行业参考标准:
- MFU>0.5:行业内公认的优秀水平
- MFU≈5%:非常差的水平,说明硬件完全没被利用起来
- BF16的MFU远高于FP32,低精度能大幅提升硬件利用率
一定要自己给代码跑基准测试,不要迷信宣传的峰值算力。标称值都太乐观了,比如BF16下你会发现,实际跑出来的MFU,往往比你预想的低。
四、反向传播的资源开销
前面我们只算了前向传播的算力,而训练的完整算力,必须包含反向传播的梯度计算。
模型结构:x(B×D) @ W1(D×D) → h1(B×D) @ W2(D×K) → 输出 → 损失
- 前向传播总FLOPs:2BDD + 2BDK = 2 × 总参数量 × 批次大小
- 反向传播总FLOPs(链式法则+矩阵乘法): 计算W2梯度:2BDK FLOPs 计算h1梯度:2BDK FLOPs 计算W1梯度:4BDD FLOPs 反向总FLOPs = 4 × 总参数量 × 批次大小
- 一次训练迭代总FLOPs:2×参数量×数据量 + 4×参数量×数据量 = 6 × 参数量 × 数据量
五、从零搭建训练全流程:从参数到训练循环
5.1 参数初始化
参数初始化分为通用张量创建和神经网络权重专用初始化两类。通用方法仅用于构造数据,模型权重必须使用专用方案,否则会出现梯度消失 / 爆炸,导致训练失效。 常用的张量初始化方法:
1 | x = torch.tensor([[1., 2, 3], [4, 5, 6]]) # 手动定义 |
如果直接用标准高斯分布初始化权重 → 输出值随维度增大爆炸,梯度不稳定。为了让输出值不依赖输入维度,保持恒定稳定,引入 Xavier 初始化(提出者是 Xavier Glorot,也叫 Glorot 初始化): ,将随机权重除以输入维度的平方根,让输出稳定在标准正态分布,避免数值爆炸。结合截断正态分布就是大模型权重初始化的最优方案,讲义里的实现代码为:
1 | w = nn.Parameter(nn.init.trunc_normal_(torch.empty(input_dim, output_dim), std=1 / np.sqrt(input_dim), a=-3, b=3)) |
5.2 随机性控制
深度学习训练过程中存在大量随机操作,若不加以约束,同一段代码在两次运行中会产生不同结果,导致实验无法复现、BUG 难以定位、模型效果无法对比。随机性控制的核心,是固定所有随机数生成源,让程序的运行结果完全确定。固定随机种子后,相同代码、相同硬件、相同数据,无论运行多少次,输出的损失值、参数更新、模型效果完全一致。这是调试代码、验证优化方案、复现实验结果的基础前提。
1 | seed = 0 |
5.3 数据加载
工业级训练数据(如 LLaMA 数据集)可达 2.8TB,远超单机内存上限,无法一次性全量加载到内存中。使用 **np.memmap 内存映射(对操作系统mmap 的封装)**来懒加载数据,将磁盘文件直接映射为内存变量,仅在访问指定数据片段时才从磁盘读取对应内容,全程不会占用全量数据的内存空间。
数据从 CPU 迁移至 GPU 时,还可以搭配 pin_memory() 固定内存 + non_blocking=True 实现异步数据传输,让 CPU 数据加载与 GPU 模型计算并行执行,避免数据传输阻塞训练流程,最大化硬件利用效率。
5.4 优化器:从 SGD 到 Adam,以及显存开销
优化器的核心作用是根据梯度更新模型权重,课程中实现的 SGD 为基础梯度下降算法,收敛速度较慢,AdaGrad 通过累积梯度平方实现自适应学习率,Adam 则结合了动量与 RMSProp 的优势,收敛速度快且训练稳定,是大模型训练的标配优化器。
- SGD:基础梯度下降
- AdaGrad:累积梯度平方,自适应学习率
- Adam = 动量 + RMSProp,大模型标配
我们用 FP32 精度训练时,一个数字占 4 个字节,显卡显存会被 4 类东西占满:
- 模型参数:就是我们初始化的权重 W,必须存 1 份;
- 梯度:反向传播算出来的错误值,必须存 1 份;
- 优化器状态:优化器要记住历史梯度,才能正常工作:AdaGrad 只需要记 1 份历史数据;Adam 需要记 2 份历史数据(所以 Adam 比 SGD 更占显存);
- 激活值:前向传播的中间计算结果,反向传播必须用到,占比最大。
所以有,总显存占用 = 4字节 × (参数总数 + 激活值总数 + 梯度总数 + 优化器状态总数)
激活值的存储是显存消耗的主要部分,朴素实现会存储全部激活值用于反向传播计算梯度,激活检查点作为主流优化方案,通过反向传播时重计算激活值,算力换显存,Lec5 中亦有记载。
5.5 标准训练循环
训练流程本质是一个循环执行的迭代过程,其中定义模型、定义优化器为仅执行一次的初始化准备工作,获取批次数据、前向传播计算损失、反向传播计算梯度、优化器更新参数、梯度清零为核心循环体,该循环体需要重复执行成千上万次,通过不断迭代修正模型参数,让模型逐步收敛至最优状态,因此整套流程被称为训练循环。
- 定义模型 → 迁移到GPU
- 定义优化器
- 循环获取批次数据
- 前向传播 → 计算损失
- 反向传播 → 计算梯度
- 优化器更新参数
- 梯度清零(
optimizer.zero_grad(set_to_none=True),释放显存)
检查点(Checkpoint):训练要跑很久,中途一定会崩,所以我们必须保存模型权重 + 优化器状态 + 迭代步数,此为 checkpoint。
六、总结与行业趋势
- 本讲核心总结:我们从张量的底层原理、创建方式讲起,覆盖了内存核算公式、浮点精度选型、算力FLOPs计算、MFU评估、反向传播6倍公式推导,最终用代码实现了参数初始化、自定义模型、数据懒加载、优化器、训练循环、检查点的全流程,核心就是建立「资源核算」的思维。
- 低精度训练的趋势:业界一直在探索更低的训练精度,已经有论文验证可以全程用FP8训练,甚至在探索int4训练。核心挑战是数值稳定性,需要配合模型设计和数值控制技巧。
- 训练与推理的精度差异:训练对低精度非常敏感,需要相对高的精度保证稳定;但训好的模型,在推理时可以做非常激进的量化,获得极大的显存和速度收益。
- 模型与硬件的协同设计:Transformer之所以长成现在的样子,就是为了适配GPU的特性。现在的模型设计,越来越受硬件特性的驱动。能适配低精度硬件的模型,会获得巨大的效率优势。