从认识AI开始-----解密门控循环单元(GRU):对LSTM的再优化

article/2025/7/15 17:10:58

前言

在此之前,我已经详细介绍了RNN和LSTM,RNN虽然在处理序列数据中发挥了重要的作用,但它在实际使用中存在长期依赖问题,处理不了长序列,因为RNN对信息的保存只依赖一个隐藏状态,当序列过长,隐藏转态保存的东西过多时,它对于前面的信息的抽取就会变得困难。为了解决这个问题,LSTM被提出,它通过设计复杂的门控机制以及记忆单元,实现了对信息重要性的提取:因为在现实中,对于一个序列来说,并不是序列中所有的信息都是同等重要的,这就意味着模型可以只记住相关的观测信息即可,但LSTM因为过多的门控机制与记忆单元,导致参数过多,训练速度慢。而GRU则是对LSTM的进一步优化,它的结构简单,训练更高效,并且性能同样出色


一、GRU诞生背景:RNN与LSTM的局限性

1. RNN的问题

RNN 依赖于隐藏单元循环结构来记忆序列信息,但在面对较长序列时会遇到:

  • 梯度消失/爆炸问题
  • 长期依赖问题
  • 训练效率低下

2. LSTM的改进

LSTM 通过设计输入门、遗忘门、输出门,以及单独的记忆单元,有效控制信息流,解决了上述问题。但 LSTM 的结构较为复杂,参数量大,训练慢。

二、GRU:结构更简单性能同样优秀的门控循环单元

GRU在2014年被提出来,其思想来源于LSTM的设计,但是对LSTM的进一步简化:

  • 没有单独的记忆单元,只有一个隐藏转态
  • 将LSTM的输入门和忘记门合并为一个更新门
  • 另有一个重置门控制新信息与历史信息的融合程度

其具体结构如下:

如图可以看出,GRU由重置门、更新门、隐藏状态组成,对于每个时间步,GRU都会进行以下操作:

1. 重置门

重置门(R_t)的作用是:决定遗忘多少过去的信息

R_t=\sigma(X_t @ W_{xr} + H_{t-1} @ W_{hr} + b_r)

2. 更新门

更新门(Z_t)的作用是:决定保留多少过去的信息

Z_t=\sigma(X @ W_{xz} + H @ W_{hz} + b_z)

3. 候选隐藏转态

候选隐藏转态(\tilde {H})能控制对前面信息的遗忘程度,因为 R_t 经过 Sigmoid 后的值在 [0,1] 之间,当 R_t 趋近于 0 时,则表示要遗忘之前的信息,趋近于 1 时,要记住前面的信息

\tilde{H}=tanh(X_t@ W_{xh} + (R_t * H_{t-1}) @ W_{hh} + b_h)

4. 真正的隐藏转态

当 Z_t 为 1 时,忽略当前的候选隐藏转态,直接用前面的隐藏转态 H_{t-1} 作为当前的隐藏转态,当 Z_t 为 0 时,GRU就相当于退化成RNN了。

H_t=Z_t*H_{t-1}+(1-Z_t)*\tilde {H}


三、GRU与RNN/LSTM的比较

特性RNNLSTMGRU
是否解决长期依赖
参数量较少
门控机制输入、输出、遗忘重置、更新
记忆单元无(仅隐藏转态)
训练速度快、但性能差
表现一般类似甚至优于LSTM

GRU相比LSTM来说,结构简洁,参数少,训练更快,在多数任务上性能媲美甚至优于LSTM。更少的参数对过拟合更友好。但由于简化了部分结构,缺少了记忆单元的独立控制,无法像LSTM一样分开控制信息流


 四、手写GRU

通过上面的介绍,我们现在已经知道了GRU的实现原理,现在,我们试着手写一个GRU核心层:

首先,与RNN、LSTM一样,我们先初始化所需要的参数:

import torch
import torch.nn as nn
import torch.nn.functional as Fdef params(input_size, output_size, hidden_size):W_xz, W_hz, b_z = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_xr, W_hr, b_r = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_xh, W_hh, b_h = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_hq = torch.randn(hidden_size, output_size) * 0.1b_q = torch.zeros(output_size)params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad = Truereturn params

然后,定义初始隐藏转态: 

import torchdef init_state(batch_size, hidden_size):return (torch.zeros((batch_size, hidden_size)), )

最后,是GRU的核心操作:

import torch
import torch.nn as nn
def gru(X, state, params):[W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q] = params(H, C) = stateoutputs = []for x in X:Z = torch.sigmoid(torch.mm(x, W_xz) + torch.mm(H, W_hz) + b_z)R = torch.sigmoid(torch.mm(x, W_xr) + torch.mm(H, W_hr) + b_r)H_tilde = torch.tanh(torch.mm(x, W_xh) + torch.mm((R * H), W_hh) + b_h)H = Z * H + (1 - Z) * H_tildeY = torch.mm(H, W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=1), (H,)

四、使用Pytroch实现简单的LSTM

在Pytroch中,已经内置了gru函数,我们只需要调用就可以实现上述操作:

import torch
import torch.nn as nnclass mygru(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(mygru, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.gru = nn.GRU(input_size, hidden_size, num_layers=num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x, h0):out, hn = self.gru(x, h0)out = self.fc(out)return out, hn# 示例
# 参数定义
input_size = 10
hidden_size = 20
output_size = 10
seq_len = 5
batch_size = 1
num_layers = 1model = mygru(input_size, hidden_size, output_size, num_layers)
inputs = torch.randn(batch_size, seq_len, input_size)
h0 = torch.zeros(num_layers, batch_size, hidden_size)output, hn = model(inputs, h0)
print(output.shape)

总结

以上就是本文的全部内容,算上本篇,我们已经系统性的讲述了RNN、RNN的进化版 LSTM、LSTM的优化版 GRU,相信小伙伴们已经对序列模型有了相当深刻的认识。GRU是一种比LSTM更轻量的门控循环单元,保留了长距离依赖建模能力,同时减少了参数量和计算复杂度。对于大多数NLP和时间序列任务来说,GRU提供了一个在性能与效率之间平衡良好的选择。


如果小伙伴们觉得本文对各位有帮助,欢迎:👍点赞 | ⭐ 收藏 |  🔔 关注。我将持续在专栏《人工智能》中更新人工智能知识,帮助各位小伙伴们打好扎实的理论与操作基础,欢迎🔔订阅本专栏,向AI工程师进阶!


http://www.hkcw.cn/article/ceboFXBWba.shtml

相关文章

历年西北工业大学计算机保研上机真题

2025西北工业大学计算机保研上机真题 2024西北工业大学计算机保研上机真题 2023西北工业大学计算机保研上机真题 在线测评链接:https://pgcode.cn/school 计算整数乘积 题目描述 给定 n n n 组数,每组两个整数,输出这两个整数的乘积。 …

ansible-playbook 进阶 接上一章内容

1.异常中断 做法1:强制正常 编写 nginx 的 playbook 文件 01-zuofa .yml - hosts : web remote_user : root tasks : - name : create new user user : name nginx-test system yes uid 82 shell / sbin / nologin - name : test new user shell : gete…

基于cornerstone3D的dicom影像浏览器 第二十七章 设置vr相机,复位视图

文章目录 前言一、VR视图设置相机位置1. 相机位置参数2. 修改mprvr.js3. 调用流程1) 修改Toolbar3D.vue2) 修改View3d.vue3) 修改DisplayerArea3D.vue 二、所有视图复位1.复位流程说明2. 调用流程1) Toolbar3D中添加"复位"按钮,发送reset事件2) View3d.vu…

以色列防长:哈马斯要么接受美方提案 要么面临毁灭

当地时间5月30日,以色列国防部长卡茨通过其个人社交媒体账号发表声明称,在以军强大的军事压力之下,巴勒斯坦伊斯兰抵抗运动(哈马斯)将被迫接受选择:接受美方提出加沙停火提案,或者被以色列消灭。△以色列国防部长卡茨(资料图)卡茨在声明中表示,当前以军正全力在加沙地…

古巴外交部召见美国临时代办 抗议其无礼行为

△古巴哈瓦那(资料图)当时间5月30日,古巴外交部召见了美国驻古巴临时代办迈克哈默(Mike Hammer)并表示,迈克哈默自2024年11月抵达古巴以来,对古巴表现出的不友好行为,既不符合他外交官的身份,也表现了对古巴人民的不尊重。古巴外交部美国双边事务总司主任加西亚向迈克…

Java处理动态的属性:字段不固定、需要动态扩展的 JSON 数据结构

引言 应用场景: 签名测试接口、表单配置项、参数列表、插件信息等。技术实现:JSONObject 接收、使用json格式的字符串,或者@JsonAnySetter/@JsonAnyGetter注解方法来处理动态的属性。I JSONObject 接收和返回 例子:表单配置 接口对应的表单配置信息 JSONObject 接收和返回…

leetcode1201. 丑数 III -medium

1 题目:1201. 丑数 III. 官方标定难度:中 丑数是可以被 a 或 b 或 c 整除的 正整数 。 给你四个整数:n 、a 、b 、c ,请你设计一个算法来找出第 n 个丑数。 示例 1: 输入:n 3, a 2, b 3, c 5 输出…

【Oracle】DML语言

个人主页:Guiat 归属专栏:Oracle 文章目录 1. DML概述1.1 什么是DML?1.2 DML的核心功能 2. INSERT语句详解2.1 基础插入操作2.2 子查询插入2.3 多表插入2.4 批量插入优化 3. UPDATE语句详解3.1 基础更新操作3.2 关联更新3.3 批量更新优化 4. …

安装启动Mosquitto以及问题error: cjson/cJSON.h: No such file or directory解决

安装Mosquitto 在官方下载地址:https://mosquitto.org/files/source/ 选择版本下载 安装环境是linux centos7,上传 mosquitto-2.0.18.tar.gz 文件到 /mqtt 文件夹下 tar -xvf mosquitto-2.0.18.tar.gz #解压 cd mosquitto-2.0.18/ #切换到解压目录下…

附件上传唯一性校验

1. Overridepublic String uploadFile(MultipartFile file, String id, String funNo, String ctType) {//TODO 附件重复判断// 计算文件哈希值// 将MultipartFile转换为临时File对象String fileHash "";try {File tempFile convertMultipartFileToFile(file);// …

正点原子AU15开发板!板载40G QSFP、PCIe3.0x8和FMC LPC等接口,性能强悍!

正点原子AU15开发板!板载40G QSFP、PCIe3.0x8和FMC LPC等接口,性能强悍! 正点原子AU15开发板搭载Xilinx Artix UltraScale 系列FPGA,核心板主控芯片的型号是XCAU15P-FFVB676-2I。开发板由核心板+底板组成,外…

Attention-> flashAttention材料参考

1、 一文看懂 Attention(本质原理3大优点5大类型)_attention结构-CSDN博客2​​​​​​​2https://blog.csdn.net/haima1998/article/details/107845549 2、 一文看懂 NLP 里的模型框架 Encoder-Decoder 和 Seq2Seq (easyai.tech) 3、 详解深度学习…

MySQL高可用集群

https://dev.mysql.com/doc/mysql-shell/8.4/en/mysql-innodb-cluster.html 1 什么是MySQL高可用集群 MySQL高可用集群:MySQL InnoDB ClusterInnoDB Cluster是MySQL官方实现高可用读写分离的架构方案,包含以下组件 MySQL Group Replication:简…

山洪灾害声光电监测预警解决方案

一、方案背景 我国是一个多山的国家,山丘区面积约占国土面积的三分之二。每年汛期,受暴雨等因素影响,极易引发山洪和泥石流。山洪、泥石流地质灾害具有突发性、流速快、流量大、物质容量大和破坏力强等特点,一旦发生,将…

2025年最新工程项目管理系统应该具备哪些模块?

随着数字化转型浪潮席卷工程行业,工程项目管理系统的作用愈发凸显。2025年,工程项目管理系统的核心目标不仅是提升项目效率,更在于通过智能化、集成化技术实现全生命周期的精细化管理。基于行业趋势和企业实际需求,结合金众诚工程…

unity入门:同一文本不同颜色显示

unity入门:同一文本不同颜色显示 同一文本不同颜色显示#RRGGBBAA(带透明度)用法 同一文本不同颜色显示 在Unity中,如果想让文本中的某一部分显示不同的颜色,可以使用富文本(Rich Text)标记,在字符串中插入…

128、STM32H723ZGT6实现串口IAP

Bootloader程序通过串口接收*.bin文件数据,写入到内部flash区域,然后跳转APP应用程序 flash读写数据参考我的博客:127、stm32h743XI内部flash 注意:H723系列flash必须32字节写入,并且擦除时别重启|断电,不然…

【Netty系列】Reactor 模式 2

目录 流程图说明 关键流程 以下是 Reactor 模式流程图,结合 Netty 的主从多线程模型,帮助你直观理解事件驱动和线程分工: 流程图说明 Clients(客户端) 多个客户端(Client 1~N)向服务端发起连…

接口自动化测试用例的编写方法

🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 phpunit 接口自动化测试系列 Post接口自动化测试用例 Post方式的接口是上传接口,需要对接口头部进行封装,所以没有办法在浏览器下直接调…

2025030给荣品PRO-RK3566开发板单独升级Android13的boot.img

./build.sh init ./build.sh -K ./build.sh kernel 【导入配置文件】 Z:\Android13.0\rockdev\Image-rk3566_t\config.cfg 【更新的内核】 Z:\Android13.0\rockdev\Image-rk3566_t\boot.img 【导入分区表,使用原始的config.cfg会出错的^_】 Z:\Android13.0\rockdev\…