循环神经网络(RNN)模型

article/2025/8/27 4:43:23

一、概述

  循环神经网络(Recurrent Neural Network, RNN)是一种专门设计用于处理序列数据(如文本、语音、时间序列等)的神经网络模型。其核心思想是通过引入时间上的循环连接,使网络能够保留历史信息并影响当前输出。

二、模型原理

  RNN的关键特点是隐藏状态的循环传递,即当前时刻的输出不仅依赖于当前输入,还依赖于之前所有时刻的信息,这种机制使RNN能够建模序列的时序依赖性。一个隐含层神经元的结构示意图如下

在这里插入图片描述
  对于时间步 t t t,有

h t = ( W x x t + W h h t − 1 + b h ) h_t=\left( W_xx_t+W_hh_{t-1}+b_h \right) ht=(Wxxt+Whht1+bh)

y t = g ( W y h t + b y ) y_t=g\left( W_yh_t+b_y \right) yt=g(Wyht+by)

其中, h t h_t ht 是当前隐含状态, x t x_t xt 是当前输入, y t y_t yt 是当前输出, W x , W h , W y W_x,W_h,W_y Wx,Wh,Wy 是权重矩阵, f , g f,g f,g 是激活函数。

  RNN在时间步上展开后,可视为多个共享参数的重复模块链式连接。序列结构过程如图所示

在这里插入图片描述

三、优势与局限性

1. 主要优势

参数共享:所有时间步共享同一组权重,大幅减少参数量。
记忆能力:隐藏状态能够“记忆”,存储历史信息。
灵活输入输出:支持多种序列任务(如一对一、一对多、多对多)。

2. 局限性

梯度问题:传统RNN难以训练长序列(梯度消失/爆炸)。
计算效率:无法并行处理序列(因时间步需顺序计算)。

四、应用场景

自然语言处理:语言模型(如 GPT 早期基于 RNN)、文本生成、机器翻译、情感分析。
语音处理:语音识别(如结合 CTC 损失函数)、语音合成。
时间序列分析:股票价格预测、传感器数据异常检测、天气预测。
视频处理:视频内容理解(如动作识别,结合 CNN 提取空间特征)。

五、Python实现示例

(环境:Python 3.11,PyTorch 2.4.0)

import matplotlib
matplotlib.use('TkAgg')import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt# 设置matplotlib的字体
plt.rcParams['font.sans-serif'] = ['SimHei']  # 'SimHei' 是黑体,也可设置 'Microsoft YaHei' 等
plt.rcParams['axes.unicode_minus'] = False  # 正确显示负号# 设置随机种子以确保结果可复现
torch.manual_seed(42)
np.random.seed(42)# 生成示例时间序列数据
def generate_data(n_samples=1000, seq_length=20):"""生成简单的正弦波时间序列数据"""x = np.linspace(0, 10 * np.pi, n_samples + seq_length)y = np.sin(x)# 创建序列和目标sequences = []targets = []for i in range(n_samples):sequences.append(y[i:i + seq_length])targets.append(y[i + seq_length])# 转换为PyTorch张量sequences = torch.FloatTensor(sequences).unsqueeze(2)  # [样本数, 序列长度, 特征数]targets = torch.FloatTensor(targets).unsqueeze(1)  # [样本数, 1]# 分割训练集和测试集train_size = int(0.8 * n_samples)train_data = TensorDataset(sequences[:train_size], targets[:train_size])test_data = TensorDataset(sequences[train_size:], targets[train_size:])return train_data, test_data# 定义RNN模型
class SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, num_layers, output_size):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layers# RNN层self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)# 全连接输出层self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):# 初始化隐藏状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)# 前向传播RNNout, _ = self.rnn(x, h0)# 我们只需要最后一个时间步的输出out = self.fc(out[:, -1, :])return out# 训练函数
def train_model(model, train_loader, criterion, optimizer, device, epochs=100):model.train()for epoch in range(epochs):total_loss = 0for inputs, targets in train_loader:inputs, targets = inputs.to(device), targets.to(device)# 清零梯度optimizer.zero_grad()# 前向传播outputs = model(inputs)loss = criterion(outputs, targets)# 反向传播和优化loss.backward()optimizer.step()total_loss += loss.item()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}')# 评估函数
def evaluate_model(model, test_loader, device):model.eval()predictions = []actuals = []with torch.no_grad():for inputs, targets in test_loader:inputs, targets = inputs.to(device), targets.to(device)outputs = model(inputs)predictions.extend(outputs.cpu().numpy())actuals.extend(targets.cpu().numpy())return np.array(predictions), np.array(actuals)# 主函数
def main():# 设备配置device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 超参数input_size = 1  # 输入特征维度hidden_size = 64  # 隐藏层维度num_layers = 1  # RNN层数output_size = 1  # 输出维度seq_length = 20  # 序列长度batch_size = 32  # 批次大小learning_rate = 0.001  # 学习率# 生成数据train_data, test_data = generate_data(seq_length=seq_length)train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)test_loader = DataLoader(test_data, batch_size=batch_size)# 初始化模型model = SimpleRNN(input_size, hidden_size, num_layers, output_size).to(device)# 定义损失函数和优化器criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练模型print("开始训练模型...")train_model(model, train_loader, criterion, optimizer, device)# 评估模型print("评估模型...")predictions, actuals = evaluate_model(model, test_loader, device)# 可视化结果plt.figure(figsize=(10, 6))plt.plot(actuals[:50], label='实际值')plt.plot(predictions[:50], label='预测值')plt.title('RNN模型预测结果')plt.xlabel('样本')plt.ylabel('值')plt.legend()plt.show()if __name__ == "__main__":main()

在这里插入图片描述
在这里插入图片描述

示例实现过程包括以下几个部分:

  数据生成:创建了一个简单的正弦波时间序列数据集,用于训练和测试模型。
  模型架构:定义了一个简单的 RNN 模型,包含一个 RNN 层处理序列输入、一个全连接层将 RNN 的输出映射到预测值
  训练流程:实现了完整的训练循环,包括前向传播、计算损失、反向传播和参数更新。
  评估和可视化:训练完成后,模型在测试数据上进行评估,并可视化预测结果与实际值的对比。

  示例展示了 RNN 在时间序列预测任务中的基本用法。可以通过调整超参数(如隐藏层大小、学习率、RNN 层数等)来优化模型性能,也可将此框架应用到其他序列数据相关的预测任务中。


End.


http://www.hkcw.cn/article/NCmucPdQtU.shtml

相关文章

【stm32开发板】原理图设计(电源部分)附:设计PCB流程

一、PCB设计流程 二、操作步骤 1.新建工程 文件→新建→工程 2.命名工程 保存后进入该页面 自生成了一个原理图和PCB 3.新建图页及重命名 右键第一个图页,选择新建图页 右键选择重命名可以为图页改名 4.取消设计规则的22项 5.原理图尺寸调整 如果觉得原理图框的…

MCP入门实战(极简案例)

MCP简介 MCP(Model Context Protocol,模型上下文协议)2024年11月底由 Antbropic 推出的一种开放标准,旨在统一大型语言模型(LLM)与外部数据源和工具之间的通信协议。 Function Calling是AI模型调用函数的机制,MCP是一个标准协议,使AI模型与API无缝交互,而Al Agent是一个…

SCL语言两台电机正反转控制程序从选型、安装到调试全过程的详细步骤指南(下)

阶段三:PLC 编程 (SCL 语言)(为了学会结构体和I/O映射可能看着有点复杂,多电机控制及维护好修改) 程序结构思路: 1. 定义清晰的数据结构 (STRUCT) 来管理每台电机的所有变量(输入、输出、状态、互锁条件&…

apptrace 的优势以及对 App 的价值

官网地址:AppTrace - 专业的移动应用推广追踪平台 apptrace 的优势以及对 App 的价值​ App 拉起作为移动端深度链接技术的关键应用,能实现从 H5 网页到 App 的无缝跳转,并精准定位到 App 内指定页面。apptrace 凭借专业的技术与丰富的经验…

西门子嵌入式学习笔记---(1)裸机和调度器开发

🌈个人主页: 羽晨同学 💫个人格言:“成为自己未来的主人~” 裸机和调度器开发的对比 嵌入式开发是为了特定目的而设计的计算系统编写软件的过程,这些系统通常会具有受限的资源(处理能力,、内存、能源等&…

Rust使用Cargo构建项目

文章目录 你好,Cargo!验证Cargo安装使用Cargo创建项目新建项目配置文件解析默认代码结构 Cargo工作流常用命令速查表详细使用说明1. 编译项目2. 运行程序3.快速检查4. 发布版本构建 Cargo的设计哲学约定优于配置工程化优势 开发建议1. 新项目初始化​2. …

Python自动化之selenium语句——元素点击、输入、清空和八大元素定位方法

目录 一、元素定位配置 1.导包 2.查找元素 二、元素交互操作 1.点击 2.输入 3.清空 三、元素定位方法 1.ID 2.NAME 3.CLASS_NAME 4.TAG_NAME 5.LINK_TEXT 6.PARTIAL_LINK_TEXT 7.CSS_SELECTOR 8.XPATH 本节讲解元素定位相关知识 一、元素定位配置 1.导包 2.查…

C++并集查找

前言 C图论 C算法与数据结构 本博文代码打包下载 基本概念 并查集(Union-Find)是一种用于处理动态连通性(直接或间接相连)的数据结构,主要支持两种操作:union 和 find。通过这两个基本操作,可…

DeepSeek - 尝试一下GitHub Models中的DeepSeek

1.简单介绍 当前DeepSeek使用的人很多,各大AI平台中也快速引入了DeekSeek,比如Azure AI Foundary(以前名字是Azure AI Studio)中的Model Catalog, HuggingFace, GitHub Models等。同时也出现了一些支持DeepSeek的.NET类库。微软的Semantic Kernel也支持…

2025年人文发展与教育心理学国际会议(ICHDEP 2025)

2025年人文发展与教育心理学国际会议(ICHDEP 2025) 2025 International Conference on Humanistic Development and Educational Psychology 一、大会信息 会议简称:ICHDEP 2025 大会地点:中国广州 审稿通知:投稿后2…

实测,大模型谁更懂数据可视化?

大家好,我是 Ai 学习的老章 看论文时,经常看到漂亮的图表,很多不知道是用什么工具绘制的,或者很想复刻类似图表。 实测,大模型 LaTeX 公式识别,出乎预料 前文,我用 Kimi、Qwen-3-235B-A22B、…

MySQL高可用方案:Keepalived+双主库架构深度解析与实战指南

MySQL高可用方案:Keepalived+双主库架构深度解析与实战指南 一、方案概述 MySQL双主+Keepalived架构通过双节点互为主从模式结合VRRP协议,实现数据库服务的高可用与自动故障转移。该方案具备以下核心优势: 双活写入能力:两节点均可处理读写请求,通过双向复制保持数据强一…

【MySQL】联合查询(下)

目录 一. 子查询 单行子查询 多行子查询 多列子查询 在from子句中使用子查询 二. 合并查询 union all union 三.插入查询结果 上期我们讲了内连接、外连接、自连接查询,今天我们继续讲其他联合查询,没看过的之前的可以先去看看上期博客&#xff1…

unity—特效闪光衣服的设置

模型设置两个材质球,一个基础色,一个闪光色 闪光层设置 基础色设置

lvs-keepalived高可用群集

目录 1.Keepalived 概述及安装 1.1 Keepalived 的热备方式 1.2 keepalived的安装与服务控制 (1)安装keep alived (2)控制 Keepalived 服务DNF 安装 keepalived 后,执行以下命令将keepalived 服务设置为开机启动。 2.使用 Keepalived 实现双机热备 …

多端 API 兼容性设计:如何统一 iOS / Android / Web 接口规范?

在移动互联网时代,一个后台服务往往需要同时支撑 iOS、Android 和 Web 三端业务。当某电商App在Android端出现支付接口返回结构不一致导致崩溃,而iOS端却正常运行时;当某个Web端新功能因接口版本问题延期上线时——多端API的兼容性问题已成为…

Linux的SHELL脚本中的常用命令

一、设置主机名称 1.文件的方式 注:修改完毕文件后在当前的shell中是不生效的,如果需要看到效果,关闭当前shell后重新开启新的shell 2.通过命令更改主机名 注:hostnamectl hostname后加上你要改的主机名,即改即生效&…

ultraiso制作U盘镜像 针对win2012及win2016等需要特殊处理

1.按照正常操作步骤制作U盘镜像 以管理员方式运行软碟通2.正常制作镜像 3.由于磁盘格式,大于4G的文件是写不进去的 手动拷贝资源文件,右键将镜像挂载到电脑上 4.转换U盘格式 convert H:/fs:NTFS 执行该命令 此次需要保证U盘不被占用 这个时候就能存储…

【AI News | 20250529】每日AI进展

AI Repos 1、WebAgent 阿里巴巴通义实验室近日发布了WebDancer,一款旨在实现自主信息搜索的原生智能体搜索推理模型。WebDancer采用ReAct框架,通过分阶段训练范式,包括浏览数据构建、轨迹采样、监督微调和强化学习,赋予智能体自主…

【Python】3.函数与列表

文章目录 一、函数1、函数是什么?2、语法格式3、函数参数4、函数返回值5、变量作用域6、函数执行过程7、链式调用8、嵌套调用9、函数递归10、参数默认值11、关键字参数小结 二、列表和元组1、列表是什么,元组是什么?2、创建列表3、访问下标4、…