Transformer模型:多头注意力机制深度解析

article/2025/8/22 6:15:09

在多头注意力机制里,输入的查询(Query)、键(Key)和值(Value)会被投影到多个子空间(头)进行并行计算,每个头关注输入序列的不同方面。在所有头的注意力计算完成后,需要将这些头的结果拼接起来,然后通过一个线性层进行变换,以整合多头的信息,使其能够适应模型后续的计算需求。

 

一、线性层的定义

在PyTorch中,nn.Linear 类用于实现线性变换。当创建一个 nn.Linear 层时,会自动初始化一个权重矩阵W和一个偏置向量b。在 MultiHeadedAttention 类中,self.linears 是一个包含 4 个线性层的列表,其中 self.linears[-1] 是最后一个线性层,用于最终的线性变换。

class MultiHeadedAttention(nn.Module):    def __init__(self, h, d_model, dropout=0.1):        # ...        #- 4 线性层的列表        #- 线性层的输入和输出的维度都是 d_model        self.linears = clones(nn.Linear(d_model, d_model), 4)        # ...
    def forward(self, query, key, value, mask=None):        # ...        x = (            x.transpose(1, 2)            .contiguous()            .view(nbatches, -1, self.h * self.d_k)        )        # ...        return self.linears[-1](x)  

nn.linear的解释:
1.用于实现线性变换(也称为全连接层)的基础模块,其底层实现基于张量操作和自动微分系统。
2. 线性变换的数学表达式为:y = xW^T + b。其中:
x 是输入张量,形状为 [..., in_features]
W 是可学习的权重矩阵,形状为 [out_features, in_features]
b 是可学习的偏置向量,形状为 [out_features]
y 是输出张量,形状为 [..., out_features]
3. 实现步骤:
(1) 参数初始化在创建nn.Linear(in_features, out_features) 时,会初始化:
权重矩阵 W:形状为 [out_features, in_features],通常用随机值初始化(如 Xavier 或 Kaiming 初始化)。
偏置向量 b:形状为 [out_features],通常初始化为零。 
(2) 前向传播前向传播时,输入张量x会与权重矩阵W相乘,并加上偏置 b。矩阵乘法:x dot W^;添加偏置:结果加上 b。

(3) 自动微分PyTorch 的自动微分系统会跟踪 W和b的梯度,以便在反向传播时更新参数。  

二、Q、K、V线性变换

 

“我”“爱”“AI” 这三个经过词嵌入和位置编码后的输入向量,会分别通过与三个不同的权重矩阵W^Q、W^K 和W^V相乘来得到查询(Query)、键(Key)和值(Value)。

在模型中,通过线性层(神经网络)完成的,每个线性层都相当于一个可学习的权重矩阵。

下面详细解释它们之间的关系:

1. 输入向量的生成

“我爱 AI” 经过分词得到 “我”“爱”“AI”,对这些词进行词嵌入操作,将每个词映射为一个固定维度的向量。为了让模型能够感知词的位置信息,还会对这些词嵌入向量添加位置编码,最终得到 “我”“爱”“AI” 对应的三个输入向量。假设这些输入向量的维度为d_model(通常在 Transformer 中其值为512。

2. 线性投影的作用

在多头注意力机制中,为了让模型能够从不同的子空间关注输入序列的不同方面,需要将输入向量分别投影到查询、键和值的空间中。这是通过与三个不同的权重矩阵W^Q、W^K 和W^V相乘来实现的。

3. 具体的线性投影过程

在 the_annotated_transformer.py 文件中的 MultiHeadedAttention 类里,线性投影的实现如下:

class MultiHeadedAttention(nn.Module):    def __init__(self, h, d_model, dropout=0.1):        # ... 初始化代码 ...        self.linears = clones(nn.Linear(d_model, d_model), 4)        # ... 其他代码 ...
    def forward(self, query, key, value, mask=None):        # ... 其他代码 ...        query, key, value = [            lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)            for lin, x in zip(self.linears, (query, key, value))        ]        # ... 其他代码 ...

这里的 self.linears 包含了用于线性投影的线性层,其中前三个线性层分别对应 W^Q、W^K 和W^V。具体步骤如下: • 对 “我”“爱”“AI” 的输入向量进行投影: ◦ 假设 “我”“爱”“AI” 对应的输入向量分别为x1、x2、x3,它们的维度都是 d_model。 ◦ 对于查询向量 Q,将输入向量 x1、x2、x3 与 W^Q相乘,得到对应的查询向量q1、q2、q3。 ◦ 对于键向量 K,将输入向量x1、x2、x3 与 W^K相乘,得到对应的键向量 k1、k2、k3。 ◦ 对于值向量 V,将输入向量x1、x2、x3与 W^V相乘,得到对应的值向量v1、v2、v3。 • 多头处理:在得到查询、键和值向量后,还会将它们拆分为多个头(在 Transformer 中通常为 8 个头),以便并行计算。

小结

“我”“爱”“AI” 这三个经过词嵌入和位置编码后的输入向量,会分别与W^Q、W^K 和W^V相乘,得到对应的查询、键和值向量,用于后续的多头注意力计算。这样做可以让模型从不同的子空间关注输入序列的不同方面,提高模型的表达能力。

三、拼接后线性变换

 

在多头注意力机制里,Concat(拼接)操作的目的是将多个头的注意力结果合并成一个张量,之后再通过一个线性层进行变换。  在 MultiHeadedAttention 类的 forward 方法里,Concat 操作是通过形状重塑和转置达成的。以下是相关代码:

class MultiHeadedAttention(nn.Module):    def __init__(self, h, d_model, dropout=0.1):        # ... 初始化代码 ...        self.linears = clones(nn.Linear(d_model, d_model), 4)        # ... 其他代码 ...
    def forward(self, query, key, value, mask=None):        # ... 其他代码 ...        # 1) 进行线性投影        query, key, value = [            lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)            for lin, x in zip(self.linears, (query, key, value))        ]
        # 2) 应用注意力机制        x, self.attn = attention(            query, key, value, mask=mask, dropout=self.dropout        )
        # 3) "Concat" 操作        x = (            x.transpose(1, 2)            .contiguous()            .view(nbatches, -1, self.h * self.d_k)        )        del query        del key        del value        return self.linears[-1](x)
 *******************        拼接操作步骤 ******************* 1. 转置维度:x = x.transpose(1, 2)    #-将头的维度和序列长度维度进行交换,使x的形状变为 (nbatches, seq_len, self.h, self.d_k)。 2. 确保内存连续:x = x.contiguous()   3. 重塑形状:x = x.view(nbatches, -1, self.h * self.d_k)    #-将多个头的结果拼接在一起,x的形状变为 (nbatches, seq_len, self.h * self.d_k)    #-self.h 代表头(head)的数量,self.d_k 是每个头的维度。    #-self.h * self.d_k 这一操作是为了算出所有头拼接后的总维度。    #-这个维度实际上等同于模型的总维度d_model
*********************线性变换*********************   return self.linears[-1](x)#-线性层 self.linears[-1] 用于对拼接后的结果进行线性变换#-self.linears[-1](x) 把拼接后的结果 x 传入最后一个线性层#-线性层中的可学习权重矩阵 W
*********************权重矩阵*********************1. 含义:   权重矩阵W是一个可学习的参数,它的作用是将拼接后的结果   从一个 d_model 维的向量空间映射到另一个 d_model 维的向量空间。   在训练过程中,模型会根据输入数据自动调整 W 的值,以学习到最优的映射关系。2. 线性层的工作原理示例代码import torchimport torch.nn as nn
# 假设 d_model = 512d_model = 512# 创建一个线性层linear_layer = nn.Linear(d_model, d_model)
# 输入张量,形状为 (batch_size, seq_len, d_model)batch_size = 32seq_len = 10x = torch.randn(batch_size, seq_len, d_model)
# 进行线性变换y = linear_layer(x)
# 查看输出形状print("输入形状:", x.shape)  # 输出: torch.Size([32, 10, 512])print("输出形状:", y.shape)  # 输出: torch.Size([32, 10, 512])
# 查看权重矩阵 W 和偏置向量 b 的形状W = linear_layer.weightb = linear_layer.biasprint("权重矩阵 W 的形状:", W.shape)  # 输出: torch.Size([512, 512])print("偏置向量 b 的形状:", b.shape)  # 输出: torch.Size([512])

拼接(Concat)后进行线性变换的主要目的是让模型能够学习如何整合不同头的信息,并将其映射到一个更有意义的表示空间。这一步骤是多头注意力设计的核心,下面从原理和代码两方面详细解释。

1. 拼接操作的局限性在多头注意力中,输入会被投影到多个子空间(头),每个头关注输入的不同方面。例如: 

  • 一个头可能关注主语和谓语的关系。

  • 另一个头可能关注实体之间的语义关联。当所有头的计算完成后,直接拼接这些结果只是简单地将不同视角的信息堆叠在一起,但并没有让模型学习如何融合这些信息。此时的输出只是多个子空间表示的罗列,缺乏整体的语义整合。

2. 线性变换的作用拼接后的线性变换(即代码中的 self.linears[-1])通过一个可学习的权重矩阵 W 和偏置向量 b,让模型能够:

(1)整合多头信息:学习不同头之间的关联和权重,将分散的子空间表示融合为一个统一的表示。

(2)增加模型表达能力:线性变换引入了额外的参数,使模型能够学习更复杂的映射关系。

(3)保持维度一致性:确保输出维度与输入维度相同(即 d_model),便于后续层的处理。

 


http://www.hkcw.cn/article/QYxVHskUil.shtml

相关文章

刑拘!男子在家自学制售假币还收徒 网络“发财”梦破灭

七星关公安分局经侦大队民警在洪山街道虎踞路将涉嫌制售假币的余某抓获。在余某住处,警方收缴了464张假币以及电脑、打印机等制作工具。据余某交代,他在某APP上浏览时收到一条陌生人私信,得知有制作假币这条“发财”路。经过详细了解后,余某根据对方教授的步骤,在某软件上…

02.K8S核心概念

服务的分类 有状态服务:会对本地环境产生依赖,例如需要把数据存储到本地磁盘,如mysql、redis; 无状态服务:不会对本地环境产生任何依赖,例如不会存储数据到本地磁盘,如nginx、apache&#xff…

搭建MQTT服务器

搭建MQTT服务器 安装EMQX命令配置 EMQX Apt 源:安装 EMQX启动 EMQX 卸载EMQX登录EMQX控制台开放端口打开测试MQTT通信 MQTT客户端测试添加客户端认证配置 客户端授权配置API接口说明安装MySQL数据库1. 下载 MySQL APT 配置包2. 安装仓库配置包3. 更新系统包索引4. 安…

【博客系统】博客系统第十一弹:部署博客系统项目到 Linux 系统

搭建 Java 部署环境 apt apt(Advanced Packaging Tool)是 Linux 软件包管理工具,用于在 Ubuntu、Debian 和相关 Linux 发行版上安装、更新、删除和管理 deb 软件包。 大多数 apt 命令必须以具有 sudo 权限的用户身份运行。 apt 常用命令 列出…

如何利用categraf的exec插件实现对Linux主机系统用户及密码有效期进行监控及告警?

需求描述 Categraf作为夜莺监控平台的数据采集工具,为了保障Linux主机的安全,需要实现对系统用户密码有效期的监控,并在密码即将到期时及时告警,以提醒运维人员更改密码。本章将详细介绍如何利用Categraf的exec插件来实现这一功能…

Houdini POP入门学习02

本篇继续随教程学习POP,并附带学习一些wrangle知识点等。 1.新建空项目,添加Geometry sphere小球。 2.连接popnet,现在粒子随小球形态发射 3.双击进入popnet,在wire_pops_into_here处连接popwind,添加风力 4.设置Wind…

《藏海传》平津侯被斩首!着实让人恨的牙痒痒

《藏海传》平津侯被斩首。藏海传演到今天,目前最大的反派就是平津侯,他霸道强势,杀人如麻,掌控许多人的命运,又有实力派演员黄觉演绎,着实让人恨的牙痒痒。平津侯名字庄芦隐,战功赫赫,他一副正义凛然不信鬼神之说的样子,其实并不是。他逼杀藏海父母,都知道是为了癸玺…

哪吒汽车总部LOGO被连夜拆除?公司回应!原CEO张勇名下超4000万股权被冻结 搬迁与股权冻结引关注

哪吒汽车上海总部外墙的“哪吒汽车”LOGO已被拆除,一同被拆除的还有位于总部的哪吒体验中心标志。据透露,拆除原因是场地到期,公司即将搬家。具体的新办公室地址尚未公布。哪吒汽车原CEO张勇名下股权被冻结,金额为4050万元,冻结期限从2025年5月13日至2028年5月12日。张勇是…

特朗普政府请求上诉法院暂停关税裁决 裁决暂时搁置

5月29日,美国联邦巡回上诉法院批准了特朗普政府的请求,暂时搁置了美国国际贸易法院此前做出的禁止执行依据《国际紧急经济权力法》对多国加征关税措施的裁决。联邦巡回上诉法院表示,在审议相关动议文件期间,美国国际贸易法院在这些案件中作出的判决和永久性禁令将暂时中止,…

禁招国际生案哈佛再获胜 美政府改立场

禁招国际生案哈佛再获胜 美政府改立场提出“30天限期”当地时间29日,美国马萨诸塞州联邦地区法院一名法官批准了哈佛大学提出的发布初步禁令请求,“叫停”特朗普政府取消哈佛大学招收外国学生资质的政策。该法院法官艾莉森伯勒斯29日就该案举行听证会。法院网站最新信息显示,…

中国对沙特等4国试行免签!欢迎说走就走的中国行

5月28日,外交部发言人毛宁主持例行记者会。有记者提问称,中方在东盟-中国-海合会峰会期间宣布对沙特等四国试行免签政策,希望了解具体情况。毛宁表示,为便利中外人员往来,中方决定扩大免签国家范围。自2025年6月9日至2026年6月8日,对沙特、阿曼、科威特、巴林持普通护照人…

【Unity基础】Unity新手实战教程:用ScriptableObject控制Cube颜色

目录 项目概述🛠️ 完整操作步骤(10分钟内完成)步骤1:创建ScriptableObject类步骤2:创建颜色配置资产步骤3:创建Cube控制器步骤4:设置场景和Cube步骤5:添加简单UI提示步骤6&#xff…

美宣布撤销中国留学生签证 我使馆:已提出严正交涉

美方宣布撤销中国留学生签证 我使馆:已提出严正交涉关于美国务院发表声明称将撤销有关中国在美留学生签证一事,中国驻美使馆发言人5月29日在回答媒体提问时表示,中方坚决反对美方这一政治性、歧视性做法。中国驻美使馆表示,美方此举将严重损害中国在美留学人员正当合法权益…

硅基计划2.0 学习总结 伍 类的继承 初稿

文章目录 一、 继承1. 为什么要继承2. 如何继承3. 情况一:子父类成员变量重名4. 情况二:子父类成员方法重名5. 子父类构造方法问题6. 继承中代码块调用顺序7. protected关键字7. 继承方式8. final关键字9. 继承和组合 一、 继承 1. 为什么要继承 假设一…

长安链合约操作 查询合约命令解析

这个命令使用 ChainMaker 的 cmc 客户端工具查询智能合约 fact 的 find_by_file_hash 方法,通过文件哈希值检索链上存储的数据。以下是详细解析: 命令功能 调用合约 fact 的 查询方法 find_by_file_hash,根据文件哈希值 ab3456df5799b87c77…

嵌入式开发之STM32学习笔记day15

STM32F103C8T6 USART串口协议 1 通信接口 通信的目的:将一个设备的数据传送到另一个设备,扩展硬件系统通信协议:制定通信的规则,通信双方按照协议规则进行数据收发 名称 引脚 双工 时钟 电平 设备 USART TX、RX 全双工 …

Java版本的VPN(wlcn)1.3.1-JDK17-SNAPSHOT

项目介绍 wu-lazy-cloud-network 是一款基于(wu-framework-parent)孵化出的项目,内部使用Lazy ORM操作数据库,主要功能是网络穿透,对于没有公网IP的服务进行公网IP映射 使用环境JDK17 Spring Boot 3.0.2 版本更新 1…

javaweb 前言

Web的发展历史 Web的诞生 (1989-1991): 1989年,欧洲核子研究组织(CERN)的蒂姆伯纳斯-李提出了World Wide Web的概念,并发明了统一资源定位符(URL)、超文本传输协议(HTTP&#xff09…

<el-date-picker>配置禁用指定日期之前的时间选择(Vue2+Vue3)

今天突然接受到一个离谱的需求&#xff1a;有一个需要配置定时任务开始执行时间的组件&#xff0c;之前的做法都是用<el-form>的rules定义校验规则&#xff0c;也能实现效果&#xff0c;但是今天产品突发奇想&#xff1a;不能选的时间就置灰&#xff08;就是我们说的禁用…

Redis 主从节点

Redis 主从节点的核心区别 特性主节点 (Master)从节点 (Slave/Replica)读写权限可读可写只读&#xff08;默认配置&#xff09;数据流向数据来源从主节点同步数据连接关系可连接多个从节点只能连接一个主节点故障切换故障时需要手动/自动提升从节点可被提升为新的主节点命令执…