大模型面试必考:大模型显存计算全攻略,训练推理显存占用详解与并发优化!
本文详细解析大模型训练推理过程中的显存占用计算,包括模型权重、中间激活和KV Cache等。对比不同架构(Qwen3、DeepSeek V3)的显存差异,提供定量计算公式,分析批量大小影响,并给出不同规模模型的并发支持估算,助力开发者高效部署大模型。
简介
本文详细解析大模型训练推理过程中的显存占用计算,包括模型权重、中间激活和KV Cache等。对比不同架构(Qwen3、DeepSeek V3)的显存差异,提供定量计算公式,分析批量大小影响,并给出不同规模模型的并发支持估算,助力开发者高效部署大模型。
一、训练过程中显存占用量计算
中间激活:前向传播计算过程中,前一层的输出就是后一层的输入,相邻两层的中间结果也是需要 GPU 显存来保存的,中间结果变量也叫激活内存,值相对很小。
在模型训练过程中,设备内存中除了需要模型权重之外,还需要存储中间变量(激活)、梯度和优化器状态动量,后者显存占用量与 batch size 成正比
训练总内存 = 模型内存 + 优化器内存 + 中间激活内存 + 梯度内存
在模型训练过程中,存储前向传播的所有中间变量(激活)结果,称为 memory_activations,用以在反向传播过程中计算梯度时使用。而模型中梯度的数量通常等于中间变量的数量,所以
。
假设 是指存储模型所有参数所需的内存、 是优化器状态变量所需内存。综上,模型训练过程中,显存占用量的理论计算公式为:
值得注意的是,对于 LLM 训练而言,现代 GPU 通常受限于内存瓶颈,而不是算力。因此,激活重计算 (activation recomputation,或称为激活检查点 (activation checkpointing)) 就成为一种非常流行的 以计算换内存 的方法。
激活重计算主要的做法是重新计算某些层的激活而不是把它们存在 GPU 内存中,从而减少内存的使用量,内存的减少量取决于我们选择清除哪些层的激活。
二、推理过程中显存占用量计算
深度学习模型推理任务中,占用 GPU 显存的主要包括三个部分:模型权重、输入输出以及中间激活结果。(该结论来源论文[1])因此,LLM 显存占用可分为 3 部分:
2.1 模型权重显存占用
存储模型权重参数所需的显存计算公式(params 是模型参数量,参数类型为 fp16/bf16):
传统模型:
Qwen3模型:
计算示例:
对于Qwen3 0.6B(层,,使用bf16格式):
2.2 中间激活显存占用(额外开销)
和模型训练需要存储前向传播过程中的中间变量结果不同,模型推理过程中并不需要存储中间变量,因此推理过程中涉及到的中间结果内存会很小(中间结果用完就会释放掉),一般指相邻两层的中间结果或者算子内部的中间结果。
2.2.1 注意力层中间激活
Self-attention的完整计算流程包含多个步骤,每步都有中间激活:
步骤1:计算 + 输入输出形状: + 显存占用:
步骤2:Softmax + Attention×V计算 + Attention分数矩阵:,占用显存 + V矩阵:,占用显存 + 输出:,占用显存
峰值显存分析:
在标准实现中,需要同时存储attention分数矩阵和V矩阵进行相乘,因此峰值显存为:
峰值显存
其中。当s较大时,项占主导;当时,项更重要。
注意:现代GPU实现通常采用融合注意力(Fused Attention)操作,将、softmax和与V相乘合并执行,可以显著减少实际的峰值显存占用。
2.2.2 MLP层中间激活
传统MLP块:
- • 第一个线性层的输出结果形状为
- • 占用显存大小:
Qwen3的SwiGLU MLP块:
包含三个线性层,需要同时存储Gate和Up线性层的输出结果:
- • Gate Linear层输出:,占用显存
- • Up Linear层输出:,占用显存
- • 两者需要同时存在内存中进行SiLU激活和元素级乘法,所以峰值显存为
2.2.3 中间激活显存计算公式
# 伪代码
memory_intermediate of attention = 2 * batch_size * n_head * square_of(sequence_length) + 4 * batch_size * s * h
memory_intermediate of mlp:
+ Traditional: fc1 layer output = 2 * batch_size * s * 4h
+ Qwen3 SwiGLU: gate + up outputs = 2 * batch_size * s * 3h * 2
注意力与MLP中间激活对比:
- • Attention峰值显存:
- • MLP峰值显存:传统模型 ,Qwen3模型
当 时,MLP的中间激活通常仍然大于注意力层,但差距会缩小。当序列长度很大时,注意力层的 项可能成为主导。
中间激活显存总结:
考虑到注意力层和MLP层的峰值显存,每层的总中间激活为:
传统模型:
(当 时)
Qwen3模型:
(当 时)
计算示例:
对于Qwen3 0.6B,当 时:
- • 注意力峰值:
- • MLP峰值:
- • 总峰值:
注意:根据经验,在模型实际前向传播过程中产生的这些额外开销(中间激活)通常控制在总模型参数内存的 20% 以内(只有 80% 的有效利用率)。
这份完整版的大模型 AI 学习和面试资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费】
2.3 KV Cache显存占用
LLM 推理优化中 kv cache 是常见的方法,本质是用空间换时间。假设输入序列的长度为 ,输出序列的长度为 ,decoder layers 数目为 ,以 float16 来保存 KV cache。
2.3.1 传统MHA模型的KV Cache
2.3.2 Qwen3的GQA模型的KV Cache
对于使用GQA的Qwen3模型,由于键值头数量少于查询头数量:
上式,第一个 2 表示 K/V cache,第二个 2表示 float16 占 2 个 bytes
2.3.3 DeepSeek V3的MLA模型的KV Cache
DeepSeek V3采用Multi-Head Latent Attention (MLA)架构,通过压缩KV表示来减少KV Cache。根据MLA的设计原理:
MLA的KV Cache只需要存储两个向量:
1.压缩的KV表示:
- • 向量大小:维
- • 多头共享,每个position存储一个512维向量
- • 存储大小:
-
- RoPE相关的:
- • 向量大小:维
- • 多头共享,每个position存储一个64维向量
- • 存储大小:
MLA每层KV Cache:
DeepSeek V3总KV Cache:
存储量对比:
根据论文原文,对比MQA(每层有一个 维度的 和一个 维度的 ,共 个元素),MLA的存储量是MQA的2.25倍。 对比MHA的 ,由于 通常远大于2.25(比如说Qwen就是8 > 2.25),所以 MLA 相比 MHA 能显著减少缓存。
- • MQA:每层 个元素
- • MLA:每层 个元素
- • MLA是MQA的倍,但相比的个元素,节省了大量存储
2.3.4 每个Token的KV缓冲大小对比
传统模型:
Qwen3模型:
DeepSeek V3的MLA模型:
具体计算示例:
对于Qwen3 0.6B:每个token的kv缓冲大小 = 4,096 bytes × 28层 = 114,688 bytes ≈ 112KB
对于DeepSeek V3:每个token的kv缓冲大小 = 1152 bytes × 61层 = 70,272 bytes ≈ 69KB
其实从这里我们就可以看到MLA有多夸张,deepseek这么大的模型每个token的kv cache比Qwen0.6B还要小
当时:
- • Qwen3:
- • DeepSeek V3:
2.4 总显存消耗公式
综上分析可知,llm 推理时,GPU 显存占用主要是:模型权重 + KV Cache,总显存消耗计算如下:
传统模型:
Qwen3模型:
DeepSeek V3模型:
中间激活
注意:模型推理时,中间激活最大不会超过模型权重参数内存的 20%,当 h 较大时可忽略
2.5 显存占用示例分析
2.5.1 Qwen3 0.6B推理显存示例
配置:
- • 模型权重: 1.192GB
- • 中间激活: ≈6.3MB(可忽略)
- • KV Cache:
- • 总计: ≈1.31GB
2.5.2 DeepSeek V3推理显存示例
配置: (由于模型更大,使用更长序列)
- • 模型权重: ≈1342GB(fp16格式)
- • 中间激活: 可忽略相对于模型权重
- • KV Cache:
总计:
注意:DeepSeek V3由于参数量巨大(671B),通常需要使用模型并行、量化等技术来部署(比如我们大部分就是用的w8a8来测试的),实际显存占用会根据优化策略而变化。
2.5.3 批量大小对显存的影响
LLaMA-13B示例: 当 时
- • KV Cache显存占用
- • 是模型参数显存(26GB)的1.6倍
Qwen3 0.6B示例: 当 时
- • KV Cache显存占用
- • 已经大大超过了模型参数显存(1.2GB)
DeepSeek V3示例: 当 时
- • KV Cache显存占用
- • 相对于模型参数显存(1342GB)可忽略,但对于实际部署仍然需要考虑
结论: batch_size
的增加能带来近乎线性的 throughput
增加,LLM服务模块的调度策略就是动态调整批次大小,并尽可能让它最大。在 batch_size
> 某个阈值时,占推理显存大头的是 KV cache。
三、显存占用计算的定性分析和定量结论
1.模型推理阶段,当输入输出上下文长度之和比较小的时候,占用显存的大头主要是模型参数,但是当输入输出上下文长度之和很大的时候,占用显存的大头主要是 kv cache
。
2.每个 GPUkv cache
显存所消耗的量和 输入 + 输出序列长度 成正比,和 batch_size
也成正比。
3.不同规模模型的显存消耗对比:
- • 有文档[2]指出,13B 的 LLM 推理时,每个 token 大约消耗 1MB 的显存
- • 对于 Qwen3 0.6B,每个 token 大约消耗 112KB 的显存
- • 对于 DeepSeek V3,每个 token 大约消耗 69KB 的显存(得益于MLA架构)
以 A100-40G GPU 为例,LLaMA-13B 模型参数占用了 26GB,那么剩下的 14GB 显存中大约可以容纳 14,000 个 token。在部署项目中,如果将输入序列长度限制为 512,那么该硬件下最多只能同时处理大约 28 个序列。
以 RTX 4090-24G GPU 为例,Qwen3 0.6B 模型参数占用了 1.2GB,那么剩下的 22.8GB 显存中大约可以容纳 203,500 个 token(按 112KB/token 计算)。在部署项目中,如果将输入序列长度限制为 1024,那么该硬件下最多只能同时处理大约 199 个序列。
四、LLM 并发支持估算
4.1 估算场景设定
以集群上的单节点 8 卡 V100 机器运行 llama-13b 模型为例,估算极端情况下聊天系统同时服务 10000 人并发所需要的节点数量。这里的极端情况是指每个请求的输入长度为 512、输出长度为 1536(即上下文长度为 2048)且没有 latency 要求。
LLaMA 系列模型配置文件中 “max_sequence_length”: 2048, 即代表预训练的 LLaMA 模型的最大 Context Window 只有 2048。另外,V100 卡的显存为 32GB。
结合前面的显存分析章节可知,K、V Cache 优化中对于每个 token 需要存储的字节数为 。
4.2 LLaMA-13B并发支持分析
对 llama-13b 模型推理而言,使用 fp16 推理时,其模型参数实际占用显存量为 24.6GB(略小于理论参数显存 26),每个 token 大约消耗 1MB 的显存(其实是 kv cache 占用的缓冲),对于输入输出上下文长度(512+1536)和为 2048 的请求,其每个请求需要的显存是 2GB。这里对每个请求所需要显存的估算是没有计算推理中间结果所消耗显存(其比较小,可忽略),另外不同框架支持张量并行所需要的额外显存也各不相同,这里暂时也忽略不计。
- • 在模型权重为 float16 的情况下,支持的理论 batch 上限为 ()。
- • 在模型权重为 int8 的情况下,支持的理论 batch 上限为 ()。
以上是理论值即上限值,float16 权重的实际 batch 数量会小于 115.7,目前的 deepspeed 框架运行模型推理时实测 batch 数量只可以达到 50 左右。
10000/50 = 200 (台 8 卡 V100 服务器)。
实际场景中的并发请求具有稀疏性,不可能每个请求都是 2048 这么长的上下文长度,因此实际上 200 台 8 卡 V100 服务器能服务的并发请求数目应该远多于 10000,可能是几倍。
4.3 Qwen3 0.6B并发支持对比
对比Qwen3 0.6B模型,以单卡 RTX 4090-24G
运行为例:
对 Qwen3 0.6B 模型推理而言,使用 bf16 推理时,其模型参数实际占用显存量为 1.2GB,每个 token 大约消耗 112KB 的显存,对于输入输出上下文长度为 2048 的请求,其每个请求需要的显存是 229MB。
- • 在模型权重为 bf16 的情况下,支持的理论 batch 上限为 ()。
单卡RTX 4090理论上就能支持约100个并发请求,相比13B模型需要200台8卡V100服务器,Qwen3 0.6B大大降低了部署成本。
4.4 DeepSeek V3并发支持分析
对于 DeepSeek V3 模型而言,其推理时,每个 token 大约消耗 69KB(MLA架构优化后的KV Cache占用)的显存,因此,极限情况下每个请求(序列长度2048)需要的显存是 141MB。
考虑到DeepSeek V3的671B参数量,使用fp16格式需要约1342GB显存,通常需要多卡部署:
- • 8卡A100-80G部署:总显存640GB,无法完整加载模型参数,需要使用量化技术
- • 16卡A100-80G部署:总显存1280GB,仍需要量化(如W8A8)才能加载
- • 使用W8A8量化:模型参数占用约671GB,在16卡A100上可部署
以16卡A100-80G + W8A8量化为例:
- • 模型参数显存占用:671GB + 剩余可用显存:1280GB - 671GB = 609GB
- • 支持的理论batch上限为:609GB / 0.141GB = 4,319
DeepSeek V3得益于MLA架构,在超大参数量下仍能支持较高的并发请求数。
4.5 并发估算通用公式
这里的并发估算的请求长度是按照模型支持的最大上下文长度来估算的,且没有 latency ,模型显存占用也没有用1.2倍来要求,所以实际会比这小。
最大并发数
单卡显存卡数模型参数显存占用量化前后单个请求最大显存占用
单个请求最大显存占用 = 每个 token 的 kv cache 占用显存 × max_sequence_length
- • Llama-13B:每个token消耗1MB显存,单个2048长度请求需要2GB,需要200台8卡V100服务器支持10000并发
- • Qwen3 0.6B:每个token消耗112KB显存,单个2048长度请求需要229MB,单卡RTX 4090支持100个并发
- • DeepSeek V3:每个token消耗69KB显存,单个2048长度请求需要141MB,16卡A100+量化支持4319个并发
五、结论
对于典型自回归 llm,假设 decoder layers 层数为 n,隐藏层大小(Embedding 向量维度)为,输入数据形状为 。当隐藏维度比较大,且远大于序列长度 时,则参数量和计算量的估算都可以忽略一次项,则有以下关于参数量、计算量和显存占用计算分析结论。
这三篇文章的一些定性结论:
1.参数量和输入序列长度无关。传统模型:;模型:。
2.每个 token 对应的计算量:传统模型 ,Qwen3模型 ,计算量随序列长度呈线性增长。其中 (传统)();每轮 decode 的计算量
(传统)()。
3.每个 token 的 kv cache 占用显存大小:传统模型是 4nh,Qwen3模型是 114,688字节(由于GQA优化),kv cache 显存占用量随(输入 + 输出序列长度)以及批量大小 batch_size 呈线性增长。kv cache 显存占用量 (传统)(Qwen3),单位为字节 byte。
4.self-attention 的内存和计算复杂度随序列长度 s 呈二次方增长。注意力输出矩阵要求 的 FLOPs,并且除了输入和输出内存之外,需要额外的 内存。
定量结论(近似估算):
1.一次迭代训练中,对于每个 token 和每个模型参数:传统模型需要进行 6 次浮点数运算,Qwen3模型也需要进行 6 次浮点数运算。
2.随着模型变大,MLP 和 Attention 层参数量占比越来越大:传统模型分别接近 66% 和 33%,Qwen3模型分别为 60% 和 40%。
3.不同规模模型的显存消耗对比:
- • 13B 的 LLM 推理时,每个 token 大约消耗 1MB 的显存
- • Qwen3 0.6B 推理时,每个 token 大约消耗 112KB 的显存
- • DeepSeek V3 推理时,每个 token 大约消耗 69KB 的显存(MLA架构优势)
- • 小模型在部署成本和并发能力上具有显著优势
六、AI大模型学习和面试资源
我在一线互联网企业工作十余年里,指导过不少同行后辈。帮助很多人得到了学习和成长。
我意识到有很多经验和知识值得分享给大家,也可以通过我们的能力和经验解答大家在人工智能学习中的很多困惑,所以在工作繁忙的情况下还是坚持各种整理和分享。但苦于知识传播途径有限,很多互联网行业朋友无法获得正确的资料得到学习提升,故此将并将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。
这份完整版的大模型 AI 学习和面试资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费】
第一阶段: 从大模型系统设计入手,讲解大模型的主要方法;
第二阶段: 在通过大模型提示词工程从Prompts角度入手更好发挥模型的作用;
第三阶段: 大模型平台应用开发借助阿里云PAI平台构建电商领域虚拟试衣系统;
第四阶段: 大模型知识库应用开发以LangChain框架为例,构建物流行业咨询智能问答系统;
第五阶段: 大模型微调开发借助以大健康、新零售、新媒体领域构建适合当前领域大模型;
第六阶段: 以SD多模态大模型为主,搭建了文生图小程序案例;
第七阶段: 以大模型平台应用与开发为主,通过星火大模型,文心大模型等成熟大模型构建大模型行业应用。
👉学会后的收获:👈
• 基于大模型全栈工程实现(前端、后端、产品经理、设计、数据分析等),通过这门课可获得不同能力;
• 能够利用大模型解决相关实际项目需求: 大数据时代,越来越多的企业和机构需要处理海量数据,利用大模型技术可以更好地处理这些数据,提高数据分析和决策的准确性。因此,掌握大模型应用开发技能,可以让程序员更好地应对实际项目需求;
• 基于大模型和企业数据AI应用开发,实现大模型理论、掌握GPU算力、硬件、LangChain开发框架和项目实战技能, 学会Fine-tuning垂直训练大模型(数据准备、数据蒸馏、大模型部署)一站式掌握;
• 能够完成时下热门大模型垂直领域模型训练能力,提高程序员的编码能力: 大模型应用开发需要掌握机器学习算法、深度学习框架等技术,这些技术的掌握可以提高程序员的编码能力和分析能力,让程序员更加熟练地编写高质量的代码。
1.AI大模型学习路线图
2.100套AI大模型商业化落地方案
3.100集大模型视频教程
4.200本大模型PDF书籍
5.LLM面试题合集
6.AI产品经理资源合集
👉获取方式:
😝有需要的小伙伴,可以保存图片到wx扫描二v码免费领取【保证100%免费】🆓
更多推荐
所有评论(0)