从认识AI开始-----解密LSTM:RNN的进化之路

article/2025/7/27 7:20:46

前言

我在上一篇文章中介绍了 RNN,它是一个隐变量模型,主要通过隐藏状态连接时间序列,实现了序列信息的记忆与建模。然而,RNN在实践中面临严重的“梯度消失”与“长期依赖建模困难”问题:

  • 难以捕捉相隔很远的时间步之间的关系
  • 隐状态在不断更新中容易遗忘早期信息。

为了解决这些问题,LSTM(Long Short-Term Memory) 网络于 1997 年被 Hochreiter等人提出,该模型是对RNN的一次重大改进。


一、LSTM相比RNN的核心改进

接下来,我们通过对比RNN、LSTM,来看一下具体的改进:

模型特点优势缺点
RNN单一隐藏转态,时间步传递结构简答容易造成梯度消失/爆炸,对长期依赖差
LSTM多门控机制 + 单独的“记忆单元”解决长距离依赖问题,保留长期信息结构复杂,计算开销大

通过对比,我们可以发现,其实LSTM的核心思想是:引入了一个专门的“记忆单元”,在通过门控机制对信息进行有选择的保留、遗忘与更新


二、LSTM的核心结构

LSTM的核心结构如下图所示:

 如图可以轻松的看出,LSTM主要由门控机制和候选记忆单元组成,对于每个时间步,LSTM都会进行以下操作:

1. 忘记门

忘记门F_t主要的作用是:控制保留多少之前的记忆:

F_t=\sigma(X_t@W_{xf}+H_{t-1}@W_{hf}+b_f)

2. 输入门

输入门I_t主要的作用是:决定当前输入中哪些信息信息被写入记忆:

I_t=\sigma(X_t@W_{xi}+H_{t-1}@W_{hi}+b_i)

3. 候选记忆单元

\tilde C_t=tanh(X_t@W_{xc}+H_{t-1}@W_{hc}+b_c)

4. 输出门

输出门O_t的作用是:决定是是否使用隐状态:

O_t=\sigma(X_t@W_{xo}+H_{t-1}@W_{ho}+b_o)

5. 真正记忆单元

记忆单元( C_t )用于长期存储信息,解决RNN容易遗忘的问题:

C_t=F_t*C_{t-1}+I_t*\tilde C_{t}

7. 隐藏转态

H_t=O_t*tanh(C_t)

LSTM引入了专门的记忆单元 C_t  ,长期存储信息,解决了传统RNN容易遗忘的问题。


三、手写LSTM

通过上面的介绍,我们现在已经知道了LSTM的实现原理,现在,我们试着手写一个LSTM核心层:

首先,初始化需要训练的参数:

import torch
import torch.nn as nn
import torch.nn.functional as Fdef params(input_size, output_size, hidden_size):W_xi, W_hi, b_i = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_xf, W_hf, b_f = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_xo, W_ho, b_o = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_xc, W_hc, b_c = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_hq = torch.randn(hidden_size, output_size) * 0.1b_q = torch.zeros(output_size)params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q]for param in params:param.requires_grad = Truereturn params

接着,我们需要初始化0时刻的隐藏转态:

import torchdef init_state(batch_size, hidden_size):return (torch.zeros((batch_size, hidden_size)), torch.zeros((batch_size, hidden_size)))

然后, 就是LSTM的核心操作:

import torch
import torch.nn as nn
def lstm(X, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params(H, C) = stateoutputs = []for x in X:I = torch.sigmoid(torch.mm(x, W_xi) + torch.mm(H, W_hi) + b_i)F = torch.sigmoid(torch.mm(x, W_xf) + torch.mm(H, W_hf) + b_f)O = torch.sigmoid(torch.mm(x, W_xo) + torch.mm(H, W_ho) + b_o)C_tilde = torch.tanh(torch.mm(x, W_xc) + torch.mm(H, W_hc) + b_c)C = F * C + I * C_tildeH = O * torch.tanh(C)Y = torch.mm(H, W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=1), (H, C)

四、使用Pytroch实现简单的LSTM

在Pytroch中,已经内置了lstm函数,我们只需要调用就可以实现上述操作:

import torch
import torch.nn as nnclass mylstm(nn.Module):def __init__(self, input_size, output_size, hidden_size):super(mylstm, self).__init__()self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x, h0, c0):out, (hn, cn) = self.lstm(x, h0, c0)out = self.fc(out)return out, (hn, cn)# 示例
input_size = 10
hidden_size = 20
output_size = 10
batch_size = 1
seq_len = 5
num_layer = 1 # lstm堆叠层数h0 = torch.zeros(num_layer, batch_size, hidden_size)
c0 = torch.randn(num_layer, batch_size, hidden_size)
x = torch.randn(batch_size, seq_len, hidden_size)model = mylstm(input_size=input_size, hidden_size=hidden_size, output_size=output_size)out, _ = model(x, (h0, c0))
print(out.shape)

总结

在现实中,LSTM的实际应用场景很多,比如语言模型、文本生成、时间序列预测、情感分析等长序列任务重,这是因为相比于RNN而言,LSTM能够更高地捕捉长期依赖,而且也更好的缓解了梯度消失问题;但是由于LSTM引入了三个门控机制,导致参数量比RNN要多,训练慢。

总的来说,LSTM是对传统RNN的一次革命性升级,引入门控机制和记忆单元,使模型能够选择性地记忆与遗忘,从而有效地捕捉长距离依赖。尽管LSTM近年来Transformer所取代,但LSTM依然是理解深度学习序列模型不可绕开的一环,有时在其他任务上甚至优于Transformer。


如果小伙伴们觉得本文对各位有帮助,欢迎:👍点赞 | ⭐ 收藏 |  🔔 关注。我将持续在专栏《人工智能》中更新人工智能知识,帮助各位小伙伴们打好扎实的理论与操作基础,欢迎🔔订阅本专栏,向AI工程师进阶!


http://www.hkcw.cn/article/KMEiuKvpcs.shtml

相关文章

基于javaweb的JSP+Servlet家政服务系统设计与实现(源码+文档+部署讲解)

技术范围:SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、小程序、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:免费功能设计、开题报告、任务书、中期检查PPT、系统功能实现、代码编写、论文编写和辅导、论文…

银行数字化应用解决方案

行业背景 银行业数字化转型已成为大势所趋,新技术浪潮为银行带来巨大的创新机遇,移动互联催生全新的用户体验需求,大数据应用带来深刻的客户洞察。但错综复杂的业务场景、严格的监管和合规要求、复合型人才的匮乏等问题,严重制约…

视频加密技术和防翻录技术有哪些?

更新:问答播放器的截图效果。 摘要:视频加密技术通过分布式编码、切片加密、动态水印等手段保护内容安全,防止盗录和二次分发。主流方案包括:1)VRM12采用混合算法加密与密码本混淆;2)H5优化HLS机…

嵌入式开发学习日志(linux系统编程--进程(4)——线程锁)Day30

扩:typedef三种用法(简化代码编写) 一、线程的控制——互斥和同步 (一)实例引入 1、示例: 运行结果: 两个线程都在运行,出现问题原因:资源竞争(对全局变量都…

从图像处理到深度学习:直播美颜SDK的人脸美型算法详解

在直播的镜头前,每一位主播都希望自己“光彩照人”。但在高清摄像头无死角的审视下,哪怕是天生丽质,也难免需要一点技术加持。于是,美颜SDK应运而生,成为直播平台提升用户粘性和视觉体验的重要工具。 尤其是在“人脸美…

编译rustdesk,使用flutter、hwcodec硬件编解码

目录 安装相应的环境安装visual studio安装vpkg安装rust开发环境安装llvm和clang编译源码下载源码使用Sciter作为UI的(已弃用)使用flutter作为UI的(主流)下载flutter sdk桥接静默安装最近某desk免费的限制越来越多,实在没办法,平时远程控制用的比较多,只能用rustdesk了,…

Dynamics 365 Business Central EC Sales List 欧洲共同体 (EC) 销售列表

什么是EC Sales List? 是在欧盟境内 开立的具有增值税主体公司的一项报告义务,提供欧盟国家/地区企业之间的跨境交易记录。ESL 的目的是确保这些交易中的所有相关方都支付和申报了适当金额的增值税。 随着出海企业越来越多的在欧州开展业务,此项报告需…

将图片存为二进制流到数据库并展示到前端的实现

使用图片直接存储到数据库中可能会出现以下问题: 1.图片的存储太多了占用数据库的存储空间 2.图片占用内存较大在传输和渲染的情况下会影响应用性能 3.一般情况下是将图片上传云服务器然后数据库存地址,这里讲解的情况只适合图片较少的情景 这里使用…

pikachu通关教程-RCE

目录 RCE(remote command/code execute)概述: exec "ping" 管道符 乱码问题 RCE(remote command/code execute)概述: RCE漏洞,可以让攻击者直接向后台服务器远程注入操作系统命令或者代码,从而控制后台系统 分为远程代码和远程命令两种.当…

合合信息首批通过中国信通院文档图像篡改检测平台能力完备性测评

随着 AIGC 技术的迅速发展,图像篡改手段日益多样化和隐蔽化,给各行业带来了严峻挑战。虚假证照、伪造合同等文档不仅威胁企业的运营安全,也对社会诚信体系构成冲击。在中国信通院最新开展的文档图像篡改检测平台能力完备性测评中,…

线程池的详细知识(含有工厂模式)

前言 下午学习了线程池的知识。重点探究了ThreadPoolExecutor里面的各种参数的含义。我详细了解了这部分的知识。其中有一个参数涉及工厂模式,我将这一部分知识分享给大家~ 线程池的详细介绍(含工厂模式) 结语 分享到此结束啦。byebye~

力扣HOT100之动态规划:279. 完全平方数

这道题之前在刷代码随想录的时候做过,但是现在给忘干净了,现在甚至都不记得这是一个背包问题。。。又反过头去看代码随想录的视频才做出来的。这道题就是一个背包问题,这个问题可以抽象为:对于容量为j的背包,要计算出恰…

Pytorch Geometric官方例程pytorch_geometric/examples/link_pred.py环境安装教程及图数据集制作

最近需要训练图卷积神经网络(Graph Convolution Neural Network, GCNN),在配置GCNN环境上总结了一些经验。 我觉得对于初学者而言,图神经网络的训练会有2个难点: ①环境配置 ②数据集制作 一、环境配置 我最初光想…

AI预测3D新模型百十个定位预测+胆码预测+去和尾2025年5月30日第93弹

从今天开始,咱们还是暂时基于旧的模型进行预测,好了,废话不多说,按照老办法,重点8-9码定位,配合三胆下1或下2,杀1-2个和尾,再杀4-5个和值,可以做到100-300注左右。 (1)定…

架构加速-深度学习教程

由于RK、jetson nano和电脑的GPU不相同,对应的pytorch也不同,因此不能直接将电脑训练好的模型丢到板端运行,因为训练的模型框架不同。就像你torch1.13和torch2.0都不一定支持,更何况不同平台上的torch。因此需要进行onnx模型转化&…

顶会新热门:机器学习可解释性

🧀机器学习模型的可解释性一直是研究的热点和挑战之一,同样也是近两年各大顶会的投稿热门。 🧀这是因为模型的决策过程不仅需要高准确性,还需要能被我们理解,不然我们很难将它迁移到其它的问题中,也很难进…

MicroPython+L298N+ESP32控制电机转速

要使用MicroPython控制L298N电机驱动板来控制电机的转速,你可以通过PWM(脉冲宽度调制)信号来调节电机速度。L298N是一个双H桥驱动器,可以同时控制两个电机的正反转和速度。 硬件准备: 1. L298N 电机控制板 2. ESP32…

Chainlink:连接 Web2 与 Web3 的去中心化桥梁

区块链技术通过智能合约实现了去中心化的自动执行,但智能合约无法直接访问链下数据,限制了其在现实世界的应用。Chainlink 作为去中心化预言机网络,以信任最小化的方式解决了这一问题,成为连接传统互联网(Web2&#xf…

杨传辉:构建 Data × AI 能力,打造 AI 时代的一体化数据底座|OceanBase 开发者大会实录

5 月 17 日,OceanBase 在广州举办第三届开发者大会。主论坛环节,OceanBase CTO 杨传辉系统阐述了 Data AI 战略,并正式推出三大产品:PowerRAG、共享存储 及OceanBase桌面版。 杨传辉指出,数据与AI模型的一体化融合&a…

AU6825集成音频DSP的2x32W数字型ClaSSD音频功率放大器(替代TAS5825)

1.特性 ● 输出配置 - 立体声 2.0: 2 x 32W (8Ω,24V,THD N 10%) - 立体声 2.0: 2 x 26W (8Ω,21V,THD N 1%) ● 供电电压范围 - PVDD:4.5V -26.4V - DVDD: 1.8V 或者 3.3V ● 静态功耗 - 37mA at PVDD12V ● 音频性能指标 - THDN ≤ 0.02% at 1W,1kHz - SNR ≥ 107dB (A-wei…