Attention 层:GQA / MHA 标准路径

【免费下载链接】cannbot-skills CANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体,本仓库为其提供可复用的 Skills 模块。 【免费下载链接】cannbot-skills 项目地址: https://gitcode.com/cann/cannbot-skills

参考模型cann-recipes-infer/models/qwen3_moe/(MoE)、cann-recipes-infer/models/gpt_oss/(Dense)

核心特征:标准多头 / 分组查询注意力,KV Cache 存完整的 K/V。Prefill 和 Decode 走不同 FA 参数

Prefill 链路

# ─── Pre-Norm ───
hidden_states, residual = npu_add_rms_norm(residual, hidden_states, weight, eps)
#   首层无 residual 时退化为 npu_rms_norm

# ─── QKV 投影 ───
q, k, v = qkv_proj(hidden_states).split(...)

# ─── QK Head Norm(部分模型有,如 qwen3-moe)───
q = npu_rms_norm(q, ...)
k = npu_rms_norm(k, ...)

# ─── RoPE ───
q, k = npu_apply_rotary_pos_emb(q, k, cos, sin, layout='BSH')

# ─── KV Cache 写入 ───
scatter_update_(past_key, kv_len, k, dim=-2)
scatter_update_(past_val, kv_len, v, dim=-2)

# ─── Flash Attention(用当前 batch 的 k/v)───
output = npu_fused_infer_attention_score(q, k, v, sparse_mode=3, ...)

# ─── O 投影 ───
output = o_proj(output)

Decode 链路

# ─── Pre-Norm ───
hidden_states, residual = npu_add_rms_norm(residual, hidden_states, weight, eps)

# ─── QKV 投影 ───
q, k, v = qkv_proj(hidden_states).split(...)

# ─── QK Head Norm ───
q = npu_rms_norm(q, ...)
k = npu_rms_norm(k, ...)

# ─── RoPE ───
q, k = npu_apply_rotary_pos_emb(q, k, cos, sin, layout='BSH')

# ─── KV Cache 写入 ───
scatter_update_(past_key, kv_len, k, dim=-2)
scatter_update_(past_val, kv_len, v, dim=-2)

# ─── Flash Attention(用完整 cache)───
output = npu_fused_infer_attention_score(q, past_key, past_val, actual_seq_lengths_kv=..., ...)

# ─── O 投影 ───
output = o_proj(output)

Prefill vs Decode 关键差异

环节 Prefill Decode
FA 的 KV 输入 当前 batch 的 k/v(非 cache) 完整 past_key/past_value(cache)
FA 参数 sparse_mode=3(causal mask,推荐) sparse_mode,传 actual_seq_lengths_kv

【免费下载链接】cannbot-skills CANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体,本仓库为其提供可复用的 Skills 模块。 【免费下载链接】cannbot-skills 项目地址: https://gitcode.com/cann/cannbot-skills

Logo

小龙虾开发者社区是 CSDN 旗下专注 OpenClaw 生态的官方阵地,聚焦技能开发、插件实践与部署教程,为开发者提供可直接落地的方案、工具与交流平台,助力高效构建与落地 AI 应用

更多推荐