在前面的文章里,Flash Attention 这个名字反复出现:
第 2 篇讲 attention 时提到它是「现代推理框架的标配」
第 5 篇讲长上下文时把它列为「四大攻坚维度」之一
第 11 篇讲推理优化时它是 prefill 阶段的核心加速器
这一篇我们正式把它讲透。
为什么 Flash Attention 值得单独一篇?因为它代表了深度学习系统优化的一个里程碑思路——它没有改变任何数学(计算结果完全等价),只通过重新设计数据在显存中的搬运方式,把 attention 的速度提升了 2-4 倍,把显存占用从 O(n²) 降到 O(n)。
如果你做过相关工作,下面这些问题应该不陌生:
为什么 vLLM、SGLang、TensorRT-LLM 都默认用 Flash Attention?
为什么把
attn_implementation="flash_attention_2"加上模型就能跑得快很多?Flash Attention 的"分块"、"在线 softmax"到底是什么?
H100 的 Flash Attention v3 比 v2 快多少?
端到端训练用了 Flash Attention 后能省多少显存?
读完本文你将能:
理解 GPU 显存层级(HBM vs SRAM)—— 这是 Flash Attention 的物理基础
理解 Flash Attention 的两个核心技巧(Tiling + Online Softmax)
知道 v1 / v2 / v3 之间的演进,针对你的硬件选对版本
用 PyTorch / HuggingFace / vLLM 三种方式启用 Flash Attention
判断什么场景 Flash Attention 不适用
我们开始。
一、为什么 Attention 需要专门优化
1.1 一个被忽视的事实:GPU 不是只有算力
很多人对 GPU 的认知停留在「TFLOPS 多少」——比如 H100 SXM 是 989 TFLOPS(FP16)。
但 GPU 还有一个同等重要的指标:显存带宽。
GPU | FP16 算力 | 显存带宽 | 算力/带宽比 |
V100 | 125 TFLOPS | 0.9 TB/s | 139 |
A100 80G | 312 TFLOPS | 2.0 TB/s | 156 |
H100 SXM | 989 TFLOPS | 3.35 TB/s | 295 |
H200 | 989 TFLOPS | 4.8 TB/s | 206 |
B200 | 2250 TFLOPS | 8 TB/s | 281 |
注意「算力/带宽比」——越高表示单位带宽对应的算力越多。
关键认知:
GPU 算力增长比显存带宽增长快得多。
从 V100 到 H100,算力翻了 8 倍,带宽只翻了 3.7 倍。
这意味着「IO 瓶颈」越来越严重。
1.2 GPU 显存的三层结构
我们再深入一层——GPU 内部其实有多级存储:
HBM (High Bandwidth Memory) 80 GB, 3.3 TB/s ↑ ↑ "显存",所有数据默认在这里 非常大,相对慢 L2 Cache 50 MB, ~12 TB/s ↑ 中间层 SRAM (Shared Memory + Registers) ~228 KB / SM ~19 TB/s ↑ "片上",极快但极小 H100 有 132 个 SM,总共也才 30 MB简化版:
[ 80 GB ] HBM ← 慢 ↕↕↕↕↕↕ 数据搬运 [ 30 MB ] SRAM ← 极快 ↑ 计算实际发生的地方核心矛盾:
数据默认在 HBM(80 GB 富余)
但计算必须在 SRAM 进行
每次计算都要把数据从 HBM 搬到 SRAM
HBM 带宽(3.3 TB/s)远低于 SRAM(19 TB/s)
这就是为什么 IO 成了瓶颈——GPU 算力再强,数据搬不进来也没用。
1.3 传统 Attention 的 IO 噩梦
回顾 attention 计算:
S = Q · K^T # [n, n] 矩阵 P = softmax(S) # [n, n] 矩阵 O = P · V # [n, d] 矩阵朴素实现把每一步的中间结果写回 HBM,然后下一步再读回来:
1. 读 Q, K 到 SRAM 2. 计算 S = QK^T 3. 把 S 写回 HBM ← O(n²) 写 4. 读 S 回 SRAM 5. 计算 P = softmax(S) 6. 把 P 写回 HBM ← O(n²) 写 7. 读 P, V 到 SRAM 8. 计算 O = PV 9. 写 O 到 HBM问题:第 3、6 步要把n × n大小的矩阵在 HBM 和 SRAM 之间来回搬。
对于 n=8K 序列:
S 矩阵显存:8K × 8K × 4 bytes =256 MB
这 256 MB 反复在 HBM ↔ SRAM 间来回搬
实测数据(A100 上 attention 计算):
实际算力消耗:约 5% 的 GPU 算力
实际 IO 消耗:约 95% 的 GPU 时间
也就是说,95% 的时间在搬数据,5% 的时间在算——这是工程优化的巨大空间。
1.4 Flash Attention 的「Aha Moment」
Flash Attention 论文(Tri Dao, 2022)的一句话总结了它的核心思想:
能不能让 attention 计算"不要"物化中间矩阵 S 和 P?
如果可以,那么:
IO 量从 O(n²) 降到 O(n)
显存占用从 O(n²) 降到 O(n)
速度提升 2-4×(算力终于能跑满)
但难点在于:softmax 需要看到整行才能归一化——你不知道总和之前,怎么知道每个元素的归一化值?
Flash Attention 的天才之处在于:它用一种叫"在线 softmax"的算法,让 softmax 可以流式计算。
二、Flash Attention v1 原理深入
2.1 核心技巧 1:Tiling(分块)
Flash Attention 不一次计算整个 attention,而是按块计算。
把 Q、K、V 切成 block:
Q : [n, d] → 切成 Tr 块,每块 [Br, d] K : [n, d] → 切成 Tc 块,每块 [Bc, d] V : [n, d] → 切成 Tc 块,每块 [Bc, d]Br、Bc 设计成能装进 SRAM(典型值 128)。
然后双层循环:
for j in range(Tc): # 外层循环 K, V 块 把 Kj, Vj 加载到 SRAM for i in range(Tr): # 内层循环 Q 块 把 Qi 加载到 SRAM 在 SRAM 中计算 Qi · Kj^T → Sij (小矩阵) 在 SRAM 中应用 softmax → Pij 在 SRAM 中计算 Pij · Vj → 输出累积 把累积结果写回 HBM关键:
整个
n × n大矩阵 S 从未物化在 HBM只有小的
Br × Bc块在 SRAM 里HBM ↔ SRAM 的数据搬运量从 O(n²) 降到 O(n²/M)(M 是 SRAM 大小)
2.2 核心技巧 2:在线 Softmax
但 softmax 是个全局操作——它需要先看到整行才能归一化:
softmax(x) = exp(x_i) / Σ exp(x_j) ↑ 需要总和!Flash Attention 用在线 softmax解决:
# 增量计算 softmax # 假设我们已经处理了前 i 个 block # m_i = 前 i 个 block 的最大值 # s_i = 前 i 个 block 的 exp 总和 新来一个 block,计算它的 softmax: m_new = max(m_i, max(new_block)) s_new = exp(m_i - m_new) * s_i + exp(m_new - m_new) * sum(exp(new_block - m_new)) 输出 = 用 m_new 和 s_new 重新归一化所有已处理的部分这个算法的核心数学技巧:
exp(a) + exp(b) = exp(max) * [exp(a - max) + exp(b - max)] ↑ 防止 overflow + 可流式合并直观上:
每个 block 自己算 softmax(用本地 max 防 overflow)
处理完后保存 (max, sum) 两个状态
来新 block 时,用两个 max 之间的"换算因子"调整之前的累积
这个算法数学上完全等价于一次性 softmax——没有任何精度损失。
2.3 完整伪代码
def flash_attention(Q, K, V): n, d = Q.shape M = SRAM_SIZE # SRAM 大小,约 100 KB Br, Bc = derive_block_size(M, d) # 通常 128 Tr, Tc = n // Br, n // Bc # 初始化输出和状态 O = zeros((n, d), in_hbm=True) l = zeros(n, in_hbm=True) # 累积的 sum m = full(n, -inf, in_hbm=True) # 累积的 max for j inrange(Tc): Kj = load_to_sram(K[j*Bc:(j+1)*Bc]) Vj = load_to_sram(V[j*Bc:(j+1)*Bc]) for i inrange(Tr): Qi = load_to_sram(Q[i*Br:(i+1)*Br]) Oi = load_to_sram(O[i*Br:(i+1)*Br]) li = load_to_sram(l[i*Br:(i+1)*Br]) mi = load_to_sram(m[i*Br:(i+1)*Br]) # 在 SRAM 内计算 Sij = Qi @ Kj.T / sqrt(d) # [Br, Bc] mij = row_max(Sij) # [Br] Pij = exp(Sij - mij[:, None]) # [Br, Bc] lij = row_sum(Pij) # [Br] # 在线 softmax 合并 m_new = max(mi, mij) l_new = exp(mi - m_new) * li + exp(mij - m_new) * lij # 更新输出 Oi_new = ( (li * exp(mi - m_new))[:, None] * Oi + exp(mij - m_new)[:, None] * (Pij @ Vj) ) / l_new[:, None] # 写回 HBM write_to_hbm(O[i*Br:(i+1)*Br], Oi_new) write_to_hbm(l[i*Br:(i+1)*Br], l_new) write_to_hbm(m[i*Br:(i+1)*Br], m_new) return O整体效果:
数学等价于标准 attention
中间矩阵 S, P 从未离开过 SRAM
HBM IO 量降为原来的 1/M(M = SRAM 大小,约 100 KB)
2.4 Flash Attention v1 的实际收益
序列长度 | 朴素 Attention | Flash Attention | 速度提升 |
512 | 1× | 1.2× | 1.2× |
1024 | 1× | 1.8× | 1.8× |
4096 | 1× | 2.7× | 2.7× |
16384 | 1× | 3.5× | 3.5× |
结论:序列越长,Flash Attention 越赚。
显存占用:
序列长度 | 朴素 (n² 矩阵) | Flash Attention |
8K | 256 MB | < 1 MB |
32K | 4 GB | < 4 MB |
128K | 64 GB | < 16 MB |
这就是为什么没有 Flash Attention 根本搞不动长上下文——光 attention 矩阵就把显存吃光了。
三、Flash Attention v2 / v3 的演进
3.1 v2 (2023.07):进一步加速
Flash Attention v2 的改进点:
改进 1:减少非矩阵乘法的开销
v1 中有不少 "rescale"、"max compare" 等非 matmul 操作,这些操作虽然简单但累积起来不少。v2 重新设计算法,把它们减少到最少。
改进 2:更好的并行化
v1 内层循环只在 Q 上并行。v2 把外层循环也并行化,更充分利用 GPU 的多个 SM。
改进 3:分配更好的 warp
把 SRAM 分配给更细粒度的 warp,进一步提升计算密度。
实测:
比 v1 快~2×
在 A100 上达到 50-70% 的理论算力
在长序列下尤其明显
3.2 v3 (2024.07):H100 时代的飞跃
Flash Attention v3 专为 H100 设计,引入了 H100 的特殊功能:
特性 1:异步加载(async TMA)
H100 引入了TMA(Tensor Memory Accelerator)——可以异步搬运数据,让计算和数据搬运 overlap。
v3 充分利用这个:
计算 block 1 ── 同时加载 block 2 计算 block 2 ── 同时加载 block 3 计算 block 3 ── 同时加载 block 4 ...特性 2:FP8 支持
v3 第一次支持 FP8 attention:
• 精度:约 0.1% 掉点
• 速度:比 FP16 再快 2×
特性 3:Warpgroup 异步矩阵乘法
H100 的WGMMA(Warpgroup MMA)让矩阵乘法本身就是异步的。v3 充分利用这个,让算力打满。
实测:
v3 在 H100 上达到75% 的理论算力(vs v2 的 35%)
FP8 模式下接近1.5 PFLOPS
3.3 三个版本性能对比
测试设置:H100 SXM,序列长度 8K,d=128,BF16:
版本 | TFLOPS | 利用率 |
标准 PyTorch | 11 | 1.1% |
| Flash v1 | 195 | 19.7% |
| Flash v2 | 348 | 35.2% |
| Flash v3 | 740 | 74.8% |
| Flash v3 FP8 | 1417 | 71.6% (vs FP8 ceil) |
结论:
标准 PyTorch → Flash v3:67× 加速
v2 → v3:约2× 加速(H100 专属)
3.4 哪个版本配哪个硬件
GPU | 推荐 Flash 版本 |
V100 / T4 | v1(v2/v3 不一定支持) |
A100 / L40 | v2 (v3 部分支持但优化不到位) |
H100 / H200 | v3 |
B200 | v3(v4 即将出,专门为 Blackwell 优化) |
四、工程实战:怎么用上 Flash Attention
4.1 用 HuggingFace Transformers 自动启用
最简单的方式:
from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen3-32B-Instruct", attn_implementation="flash_attention_2", # ← 关键 torch_dtype="auto", device_map="auto", )支持的选项:
attn_implementation = "eager" # 朴素,PyTorch 实现,慢 attn_implementation = "sdpa" # PyTorch 2.0 内置,使用 backend attn_implementation = "flash_attention_2" # Flash v2 attn_implementation = "flash_attention_3" # Flash v3 (Transformers 4.46+ 支持)Tip:
"sdpa"是 PyTorch 内置的scaled_dot_product_attention,它在底层会自动选择 Flash 或 Memory-Efficient 实现——很多情况下这就够用"flash_attention_2"/"flash_attention_3"需要pip install flash-attn
4.2 用 PyTorch SDPA(最通用)
PyTorch 2.0+ 内置了scaled_dot_product_attention,会自动用 Flash Attention 后端:
import torch.nn.functional as F def my_attention(q, k, v, mask=None): # 自动用 Flash Attention 如果可用 output = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=True, ) return output控制后端:
from torch.nn.attention import SDPBackend, sdpa_kernel with sdpa_kernel(SDPBackend.FLASH_ATTENTION): output = F.scaled_dot_product_attention(q, k, v, is_causal=True)可选 backend:
•
FLASH_ATTENTION── Flash Attention 实现•
EFFICIENT_ATTENTION── Memory Efficient Attention•
MATH── 标准实现(fallback)•
CUDNN_ATTENTION── cuDNN 实现(新)
4.3 在 vLLM 中
vLLM默认就用 Flash Attention,你什么都不用做:
vllm serve Qwen/Qwen3-32B-Instruct # 自动用 Flash Attention v2/v3(看硬件)强制版本:
# vLLM 0.6+ 支持 VLLM_ATTENTION_BACKEND=FLASH_ATTN vllm serve ... VLLM_ATTENTION_BACKEND=FLASH_ATTN_3 vllm serve ...4.4 训练时的 Flash Attention
训练阶段 Flash Attention 收益更明显——因为序列更长、需要反向传播。
from transformers import AutoModelForCausalLM, TrainingArguments model = AutoModelForCausalLM.from_pretrained( "...", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, ) training_args = TrainingArguments( ..., bf16=True, # 必须用低精度才能用 Flash Attention gradient_checkpointing=True, # 配合用,节省更多显存 )实测:
训练 70B 模型 + 8K context
不开 Flash Attention:每 step 4.2 秒,显存 78 GB
开 Flash Attention v2:每 step 1.8 秒,显存 42 GB
2× 加速 + 47% 显存节省。这就是为什么训练大模型必须用 Flash Attention。
4.5 安装
# Flash Attention v2 pip install flash-attn --no-build-isolation # Flash Attention v3(H100 only,目前仍在 hopper 分支) pip install git+https://github.com/Dao-AILab/flash-attention.git@hopper常见安装坑:
CUDA 版本要匹配(建议 12.x+)
编译时间长(首次安装 30-60 分钟)
需要 ≥ 8 GB 内存编译
没有预编译 wheel 时编译失败 → 安装 ninja 试试
五、扩展话题:Flash 家族还在演进
5.1 Flash Decoding(推理专用)
Flash Attention v2 主要为训练优化(长 seq、batch 大)。推理有不同的瓶颈:
Decode 阶段每次只处理 1 个 token
KV Cache 上的 attention 是 1 × N 矩阵(不是 N × N)
真正的瓶颈是并行度不足
Flash Decoding(Dao 2023.10)专门解决这个:
把 KV 序列也切到多个 SM上并行
每个 SM 处理 KV 的一部分
最后用 log-sum-exp 合并
效果:
Decode 速度提升2-8×(看 batch 和 seq)
长上下文场景尤其明显(128K decode 提升 5×+)
当下地位:vLLM、SGLang 等推理框架都已集成。
5.2 Ring Attention(跨卡 Flash)
第 5 篇我们讲过 Ring Attention——它本质上就是 Flash Attention 的分布式版本:
把 KV 切到多张卡
每张卡持有局部 KV
KV 在卡间环形传递,每张卡轮流和其他卡的 KV 做 Flash Attention
这是训练 / 推理 1M+ 上下文的基础。
5.3 Triton 实现
Flash Attention 原生用 CUDA 写,但Triton 版本越来越流行:
Triton 是 OpenAI 开源的 GPU kernel DSL
比 CUDA 简单
性能接近 CUDA(v2 大概 90%,v3 仍在追赶)
可读性极强——你可以读 Triton 版的 Flash Attention 来理解算法
vLLM 部分 backend 就是 Triton 实现。
5.4 什么时候 Flash Attention 不适用
虽然 Flash Attention 是"标配",但有一些场景不适用或收益有限:
场景 | 原因 |
序列极短(< 256) | IO 占比不大,传统 attention 反而更快 |
自定义 attention(如 ALiBi 老版) | Flash 默认不支持任意 mask,要专门修改 |
FP32 训练 | Flash v1/v2 仅支持 FP16/BF16,v3 加 FP8 |
老 GPU(Pascal / Volta) | Flash 需要 Ampere+ 架构 |
极特殊 attention 模式(局部 + 全局混合) | 需要专门定制 |
但 95% 的场景,Flash Attention 都是无脑选项。
六、Flash Attention 给工程师的启示
6.1 算法 + 硬件 = 真正的优化
Flash Attention 的成功不是算法创新(softmax 还是那个 softmax),也不是新硬件(GPU 没变),而是两者结合:
理解算法的数学结构
理解硬件的物理特性
重新设计两者的接口
这是大模型系统优化的核心方法论:不要只看算法,也不要只看硬件,而是两者协同。
6.2 IO 优化的普适性
Flash Attention 的"分块 + 流式合并"思路在很多地方都能用:
量化:W4 + FP16 也用类似思想分块
MoE:专家计算和数据搬运的 overlap
分布式训练:通信和计算的 overlap
训练 checkpointing:分块保存激活
如果你做系统优化,多想想"能不能不物化中间结果"——这是个屡试不爽的优化方向。
6.3 不要害怕底层
Flash Attention 的实现要写 CUDA / Triton kernel,这让很多工程师望而却步。但理解它的原理并不要求你能从零写——理解 Tiling、在线 softmax、IO/compute 平衡这些概念,已经足够你做正确的部署决策。
七、结语:Flash Attention 是大模型时代的基础设施
读完本文你应该明白:
GPU 算力增长比带宽快,IO 是大模型的主要瓶颈
Flash Attention 用 Tiling + Online Softmax 把 attention IO 量从 O(n²) 降到 O(n)
v1 / v2 / v3 演进:v1 开创、v2 优化、v3 适配 H100 + FP8
使用方式:HuggingFace 加
attn_implementation、PyTorch 用SDPA、vLLM 默认启用训练比推理收益更大:训练 70B + 8K 上下文,2× 加速 + 47% 显存节省
Flash Decoding / Ring Attention:Flash 家族在持续演进
参考文献:
13.Flash Attention 原理与实践:让 Attention 重新成为算力游戏