大模型-attention汇总解析之-MHA

article/2025/7/22 11:52:41

一、MHA(Multi-Head Attention)

1.1 MHA 原理

MHA(Multi-Head Attention)称为多头注意力,开山之作所提出的一种 Attention 计算形式,它是当前主流 LLM 的基础工作。在数学原理上,多头注意力 MHA 等价于多个独立的单头注意力的拼接, MHA 可以形式地记为:

公式展开下如下:

Attention 的计算公式如下: 

Attention 计算模型结构和MHA的模型结构示意图:

在实践中,为了减少计算复杂度和内存占用,通常会设置 ,其中 d 是模型的维度,h 是缩放因子(也称为头数,即多头注意力中的头的数量)。对于 LLaMA2-7b 模型:模型维度 d = 4096,多头数 h = 32, 因此,d_k = d_v = 128(即 4096 / 32.

这里我们只考虑了主流的自回归 LLM 所用的 Causal Attention,因此在 token by token 递归生成时,新预测出来的第 i+1个 token,并不会影响到已经算好的 前面的i个K, V的值,因此这部分K, V结果我们可以缓存下来供后续生成调用,避免不必要的重复计算,这就是所谓的 KV Cache。下面是kv cache的示意图。

 

 多头MHA的Kv cache 的简单实现:

import torch
import torch.nn as nn
import mathclass CachedAttention(nn.Module):def __init__(self, d_model, num_heads):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.head_dim = d_model // num_heads# 定义线性变换层,将输入映射到Query、Key和Value空间self.q_proj = nn.Linear(d_model, d_model)self.k_proj = nn.Linear(d_model, d_model)self.v_proj = nn.Linear(d_model, d_model)# 定义输出线性变换层,将注意力计算结果映射回原维度self.out_proj = nn.Linear(d_model, d_model)def forward(self, x, kv_cache=None):b, t, c = x.shape# 将输入x通过线性变换得到Query,并调整形状和维度q = self.q_proj(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)# 将输入x通过线性变换得到Key,并调整形状和维度k = self.k_proj(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)# 将输入x通过线性变换得到Value,并调整形状和维度v = self.v_proj(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)if kv_cache is not None:cached_k, cached_v = kv_cache# 将缓存中的Key和当前计算的Key拼接起来k = torch.cat((cached_k, k), dim=2)# 将缓存中的Value和当前计算的Value拼接起来v = torch.cat((cached_v, v), dim=2)# 计算注意力分数,这里除以根号下head_dim是为了缩放attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))# 对注意力分数进行softmax归一化attn = attn.softmax(dim=-1)# 根据注意力分数对Value进行加权求和y = (attn @ v).transpose(1, 2).contiguous().view(b, t, c)# 通过输出线性变换层得到最终输出y = self.out_proj(y)return y, (k, v)

 

1.2 存在的问题

看下attention计算的公式:

 

从上面的可以知道:

  1. attention2的计算和Q2, K1, K2, V1, V2有关系。

  2. 如果我们把之前已经计算好的K1, V1 保存起来,那么这一步的计算量就节省了,从而可以使用空间换时间,加快计算速度。

  3. 人们总是不断的追求极致, 那么能不能再进一步的节省空间,减少KV cache的同时,保证计算的效果还能达到要求呢。

所以后续就出现了一系列的attention的优化方法。这里先上一张简洁明了的示意图。后续再聊


http://www.hkcw.cn/article/TwIAgHnKbg.shtml

相关文章

历年上海交通大学计算机保研上机真题

2025上海交通大学计算机保研上机真题 2024上海交通大学计算机保研上机真题 2023上海交通大学计算机保研上机真题 在线测评链接:https://pgcode.cn/school String Match 题目描述 Finding all occurrences of a pattern in a text is a problem that arises freq…

DeepSeek-R1-0528-Qwen3-8B 本地ollama离线运行使用和llamafactory lora微调

参考: https://huggingface.co/deepseek-ai/DeepSeek-R1-0528-Qwen3-8B 量化版本: https://huggingface.co/unsloth/DeepSeek-R1-0528-Qwen3-8B-GGUF https://docs.unsloth.ai/basics/deepseek-r1-0528-how-to-run-locally 1、ollama运行 升级ollama版本到0.9.0 支持直接…

数字人革新教育:开启智慧教学新时代

随着人工智能技术的迅猛发展,数字人正在逐步走进教育领域,成为传统教学模式的颠覆者。广州深声科技有限公司(以下简称“深声科技”)凭借其在智能语音、数字人及多模态交互等核心技术上的深厚积累,推出了一系列创新性产…

Linux操作系统之进程(四):命令行参数与环境变量

目录 前言: 什么是命令行参数 什么是环境变量 认识环境变量 PATH环境变量 HOME USER OLDPWD 本地变量 本地变量与环境变量的差异 核心要点回顾 结语: 前言: 大家好,今天给大家带来的是一个非常简单,但也十…

IDA dumpdex经典脚本(记录)

一个dumpdex的IDA插件 毕业了,暂时用不着了,存起来 import idaapi import structdef dumpdex(start, len, target):rawdex idaapi.dbg_read_memory(start, len)fd open(target, wb)fd.write(rawdex)fd.close()def getdexlen(start):pos start 0x20mem idaapi.dbg_read_mem…

第2期:APM32微控制器键盘PCB设计实战教程

第2期:APM32微控制器键盘PCB设计实战教程 一、APM32小系统介绍 使用apm32键盘小系统开源工程操作 APM32是一款与STM32兼容的微控制器,可以直接替代STM32进行使用。本教程基于之前开源的APM32小系统,链接将放在录播评论区中供大家参考。 1…

Redis的安装与使用

网址:Spring Data Redis 安装包:Releases tporadowski/redis GitHub 解压后 在安装目录中打开cmd 打开服务(注意:每次客户端连接都有先打开服务!!!) 按ctrlC退出服务 客户端连接…

Redis 难懂命令-- ZINTERSTORE

**背景:**学习的过程中 常用的redis命令都能快速通过官方文档理解 但是还是有一些比较难懂的命令 **目的:**写博客记录一下(当然也可以使用AI搜索) 在Redis中,ZINTERSTORE 是一个用于计算多个有序集合(So…

7.atlas安装

1.服务器规划 软件版本参考: https://cloud.google.com/dataproc/docs/concepts/versioning/dataproc-release-2.2?hlzh-cn 由于hive3.1.3不完全支持jdk8,所以将hive的版本调整成4.0.1。这个版本没有验证过,需要读者自己抉择。 所有的软件都安装再/op…

RabbitMQ和MQTT区别与应用

RabbitMQ与MQTT深度解析:协议、代理、差异与应用场景 I. 引言 消息队列与物联网通信的重要性 在现代分布式系统和物联网(IoT)生态中,高效、可靠的通信机制是构建稳健、可扩展应用的核心。消息队列(Message Queues&am…

【技能篇】RabbitMQ消息中间件面试专题

1. RabbitMQ 中的 broker 是指什么?cluster 又是指什么? 2. 什么是元数据?元数据分为哪些类型?包括哪些内容?与 cluster 相关的元数据有哪些?元数据是如何保存的?元数据在 cluster 中是如何分布…

[3D GISMesh]三角网格模型中的孔洞修补算法

📐 三维网格模型空洞修复技术详解 三维网格模型在扫描、重建或传输过程中常因遮挡、噪声或数据丢失产生空洞(即边界非闭合区域),影响模型的完整性与可用性。空洞修复(Hole Filling)是计算机图形学和几何处…

基于Spring Boot+Vue 网上书城管理系统设计与实现(源码+文档+部署讲解)

技术范围:SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、小程序、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:免费功能设计、开题报告、任务书、中期检查PPT、系统功能实现、代码编写、论文编写和辅导、论文…

[ctfshow web入门] web81

信息收集 新增过滤:,伪协议都有:,这意味着伪协议不能用了 if(isset($_GET[file])){$file $_GET[file];$file str_replace("php", "???", $file);$file str_replace("data", "???", $file);$file st…

2025年应用心理学与社会环境国际会议(ICAPSE 2025)

2025年应用心理学与社会环境国际会议(ICAPSE 2025) 2025 International Conference on Applied Psychology and Social Environment 一、大会信息 会议简称:ICAPSE 2025 大会地点:中国北京 审稿通知:投稿后2-3日内通…

Windows 11 家庭版 安装Docker教程

Windows 家庭版需要通过脚本手动安装 Hyper-V 一、前置检查 1、查看系统 快捷键【winR】,输入“control” 【控制面板】—>【系统和安全】—>【系统】 2、确认虚拟化 【任务管理器】—【性能】 二、安装Hyper-V 1、创建并运行安装脚本 在桌面新建一个 .…

Redis 数据恢复的月光宝盒,闪回到任意指定时间

在数据库的运维工作中,DBA 应该选择哪一种方案,确保 Redis 数据库崩溃后可以对数据进行回档,恢复业务运行? 一般情况下,DBA 可以通过 Redis 原生的持久化机制,如 RDB 快照持久化或者 AOF 日志持久化的方案…

鸿蒙 HarmonyOS - SideBarContainer 组件自学指南

在日常开发中,如果你有类似「左侧导航 右侧内容」的布局需求,比如后台管理界面、文件管理器、设置页等,​​SideBarContainer​​ 是非常值得掌握的组件。它自带侧边栏和主内容区的分离机制,还支持折叠、拖拽、控制按钮和多种显示…

一个Mybatisplus组件扫描不当引起的bug:弄巧成拙,认真的锅,自我怀疑

在我们系统基建层的业务组件包 sby-biz-component 中,最初,我写了两个业务组件,一个是 通道错误码组件,一个是 审核流水组件。 这两个业务组件都要依赖Mybatisplus来操作数据。 com.sby.bizcomponent├── auditflow│ └── A…

t015-预报名管理系统设计与实现 【含源码!!!】

项目演示地址 摘 要 传统办法管理信息首先需要花费的时间比较多,其次数据出错率比较高,而且对错误的数据进行更改也比较困难,最后,检索数据费事费力。因此,在计算机上安装预报名管理系统软件来发挥其高效地信息处理的…