Flash Attention完全指南:从原理到代码,解决大模型显存优化难题
本文详解Flash Attention如何通过分块计算降低注意力机制的时空复杂度,将QK^T大矩阵分块处理,减少GPU HBM访问量并利用更快的SRAM。文章包含标准注意力机制回顾、分块计算数学原理推导、完整的PyTorch代码实现,并提及PyTorch 2.0+已支持原生Flash Attention算法,有效解决大模型训练中的显存占用问题。01我自己的理解是通过将 QK^T 这个大矩阵。
本文详解Flash Attention如何通过分块计算降低注意力机制的时空复杂度,将QK^T大矩阵分块处理,减少GPU HBM访问量并利用更快的SRAM。文章包含标准注意力机制回顾、分块计算数学原理推导、完整的PyTorch代码实现,并提及PyTorch 2.0+已支持原生Flash Attention算法,有效解决大模型训练中的显存占用问题。
01
Flash Attention 降低时空复杂度原理
我自己的理解是通过将 QK^T 这个大矩阵分块,一方面降低了空间复杂度,另一方面可以让 torch 从 GPU 的 HBM(HighBandwidthMemory)进入到 SRAM(StaticRandom-AccessMemory)。
众所周知 SRAM 具有更快的访问速度和更低的延迟,同时可以降低 HBM 的访问量,达到提升速度的效果。
PS:果然当知道了原理,我就不会被 flash 这个单词迷惑,认为他单纯只是加速而没有降低空间复杂度了。
02
Flash Attention 计算原理
- 标准注意力机制回顾
对于查询向量 q(Q 的一行),键矩阵 K 和值矩阵 V,标准注意力的输出为:

2.分块( tiling)计算
分块计算里面有两个问题,一个是 softmax 函数是按整行计算的,所以如果切成两个小块分别计算再拼起来是不准确的(因为分母不一样),所以必须想到一种办法可以让不断的更新全局最大值。
这自然而然就想到了类似动态规划的计算方法,每次都记录最大值并且更新输出,只要更新到最后,我的 output 和原始的使用整个矩阵相乘是一样的,下面我们一步一步来看。
先将 K,V 切成 B 个 blocks:

我们对于每个块 i 维护三个状态量:
- m(i):处理前 i 个块后的全局最大值。
- l(i):处理前 i 个块后的调整指数和(基于 m(i))。
- o(i):处理前 i 个块后的部分输出(未归一化,相当于调整后的分子)。
初始化时:

对于每个块 i=1 到 B:

计算调整因子:

更新指数和:

更新输出:

处理完所有块后,最终输出为:

3. 通过数学归纳法证明并计算一下
我们证明对于任意 i ,以下关系成立:

i=0 时显然成立,上面两个量都为 0。
假设对 i-1 成立,则对 i 有:

因此:

03
Flash Attention torch 代码
import torch
import math
def flash_attention_forward(Q, K, V, block_size=256):
"""
参数:
Q: 查询张量, shape (batch_size, seq_len, d_k)
K: 键张量, shape (batch_size, seq_len, d_k)
V: 值张量, shape (batch_size, seq_len, d_v)
block_size: 分块大小, 控制内存使用
返回:
O: 注意力输出, shape (batch_size, seq_len, d_v)
"""
batch_size, seq_len, d_k = Q.shape
d_v = V.shape[-1]
O = torch.zeros_like(V) # 输出张量, 初始化为零
L = torch.zeros(batch_size, seq_len, device=Q.device) # 指数和统计量
M = torch.full((batch_size, seq_len), -torch.inf, device=Q.device) # 最大值统计量
# 计算缩放因子 (1/sqrt(d_k))
scale = 1.0 / math.sqrt(d_k)
# 外层循环: 按块处理键/值序列
for j in range(0, seq_len, block_size):
# 获取当前键块和值块
K_block = K[:, j:j+block_size, :] # [batch_size, block_size, d_k]
V_block = V[:, j:j+block_size, :] # [batch_size, block_size, d_v]
# 计算查询与键块的相似度分数
S_j = torch.matmul(Q, K_block.transpose(-2, -1)) * scale # [batch_size, seq_len, block_size]
# 计算当前块的最大值 (按最后一个维度)
m_j = torch.max(S_j, dim=-1).values # [batch_size, seq_len]
# 更新全局最大值
M_new = torch.maximum(M, m_j) # [batch_size, seq_len]
# 计算调整因子 (指数部分)
# 旧部分的调整因子: exp(M_prev - M_new)
exp_old = torch.exp(M - M_new) # [batch_size, seq_len]
# 新块的调整因子: exp(m_j - M_new)
exp_new = torch.exp(m_j - M_new) # [batch_size, seq_len]
# 使用当前块的局部最大值计算指数 (数值稳定性)
# P_j = exp(S_j - m_j)
P_j = torch.exp(S_j - m_j.unsqueeze(-1)) # [batch_size, seq_len, block_size]
# 计算当前块的指数和
l_j = P_j.sum(dim=-1) # [batch_size, seq_len]
# 计算当前块的加权和输出
# O_j = P_j @ V_block
O_j = torch.matmul(P_j, V_block) # [batch_size, seq_len, d_v]
# 调整旧输出: O_prev * exp(M_prev - M_new)
O = O * exp_old.unsqueeze(-1) # 保持维度一致
# 调整新块输出: O_j * exp(m_j - M_new)
O_j_adjusted = O_j * exp_new.unsqueeze(-1)
# 更新输出: O = O_prev_adjusted + O_j_adjusted
O = O + O_j_adjusted
# 调整旧指数和: L_prev * exp(M_prev - M_new)
L = L * exp_old
# 调整新块指数和: l_j * exp(m_j - M_new)
l_j_adjusted = l_j * exp_new
# 更新指数和: L = L_prev_adjusted + l_j_adjusted
L = L + l_j_adjusted
# 更新全局最大值
M = M_new
# 最终归一化: O = O / L
O = O / L.unsqueeze(-1)
return O
不过在 pytorch2.0 以上的版本已经支持原生的 Flash Attention 算法:
import torch.nn.functional as F
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
Flash Attention 通过数学和工程上的简单优化,达到了即降低运行时间,又降低显存占用,可谓十分优雅,确实是太强了。
AI时代,未来的就业机会在哪里?
答案就藏在大模型的浪潮里。从ChatGPT、DeepSeek等日常工具,到自然语言处理、计算机视觉、多模态等核心领域,技术普惠化、应用垂直化与生态开源化正催生Prompt工程师、自然语言处理、计算机视觉工程师、大模型算法工程师、AI应用产品经理等AI岗位。

掌握大模型技能,就是把握高薪未来。
那么,普通人如何抓住大模型风口?
AI技术的普及对个人能力提出了新的要求,在AI时代,持续学习和适应新技术变得尤为重要。无论是企业还是个人,都需要不断更新知识体系,提升与AI协作的能力,以适应不断变化的工作环境。
因此,这里给大家整理了一份《2025最新大模型全套学习资源》,包括2025最新大模型学习路线、大模型书籍、视频教程、项目实战、最新行业报告、面试题等,带你从零基础入门到精通,快速掌握大模型技术!
由于篇幅有限,有需要的小伙伴可以扫码获取!

1. 成长路线图&学习规划
要学习一门新的技术,作为新手一定要先学习成长路线图,方向不对,努力白费。这里,我们为新手和想要进一步提升的专业人士准备了一份详细的学习成长路线图和规划。
2. 大模型经典PDF书籍
书籍和学习文档资料是学习大模型过程中必不可少的,我们精选了一系列深入探讨大模型技术的书籍和学习文档,它们由领域内的顶尖专家撰写,内容全面、深入、详尽,为你学习大模型提供坚实的理论基础。(书籍含电子版PDF)

3. 大模型视频教程
对于很多自学或者没有基础的同学来说,书籍这些纯文字类的学习教材会觉得比较晦涩难以理解,因此,我们提供了丰富的大模型视频教程,以动态、形象的方式展示技术概念,帮助你更快、更轻松地掌握核心知识。

4. 大模型项目实战
学以致用 ,当你的理论知识积累到一定程度,就需要通过项目实战,在实际操作中检验和巩固你所学到的知识,同时为你找工作和职业发展打下坚实的基础。

5. 大模型行业报告
行业分析主要包括对不同行业的现状、趋势、问题、机会等进行系统地调研和评估,以了解哪些行业更适合引入大模型的技术和应用,以及在哪些方面可以发挥大模型的优势。

6. 大模型面试题
面试不仅是技术的较量,更需要充分的准备。
在你已经掌握了大模型技术之后,就需要开始准备面试,我们将提供精心整理的大模型面试题库,涵盖当前面试中可能遇到的各种技术问题,让你在面试中游刃有余。

为什么大家都在学AI大模型?
随着AI技术的发展,企业对人才的需求从“单一技术”转向 “AI+行业”双背景。企业对人才的需求从“单一技术”转向 “AI+行业”双背景。金融+AI、制造+AI、医疗+AI等跨界岗位薪资涨幅达30%-50%。
同时很多人面临优化裁员,近期科技巨头英特尔裁员2万人,传统岗位不断缩减,因此转行AI势在必行!

这些资料有用吗?
这份资料由我们和鲁为民博士(北京清华大学学士和美国加州理工学院博士)共同整理,现任上海殷泊信息科技CEO,其创立的MoPaaS云平台获Forrester全球’强劲表现者’认证,服务航天科工、国家电网等1000+企业,以第一作者在IEEE Transactions发表论文50+篇,获NASA JPL火星探测系统强化学习专利等35项中美专利。本套AI大模型课程由清华大学-加州理工双料博士、吴文俊人工智能奖得主鲁为民教授领衔研发。
资料内容涵盖了从入门到进阶的各类视频教程和实战项目,无论你是小白还是有些技术基础的技术人员,这份资料都绝对能帮助你提升薪资待遇,转行大模型岗位。


大模型全套学习资料已整理打包,有需要的小伙伴可以
微信扫描下方CSDN官方认证二维码,免费领取【保证100%免费】

更多推荐



所有评论(0)