【ConvLSTM第二期】模拟视频帧的时序建模(Python代码实现)

article/2025/8/27 19:07:47

目录

  • 1 准备工作:python库包安装
    • 1.1 安装必要库
  • 案例说明:模拟视频帧的时序建模
    • ConvLSTM概述
    • 损失函数说明
    • (python全代码)
  • 参考

ConvLSTM的原理说明可参见另一博客-【ConvLSTM第一期】ConvLSTM原理。

1 准备工作:python库包安装

1.1 安装必要库

pip install torch torchvision matplotlib numpy

案例说明:模拟视频帧的时序建模

🎯 目标:给定一个人工生成的动态图像序列(例如移动的方块),使用 ConvLSTM 对其进行建模,输出预测结果,并查看输出的维度和特征变化。

ConvLSTM概述

ConvLSTM 的基本结构,包括:

  • ConvLSTMCell:实现了一个时间步的 ConvLSTM 单元,类似于一个“时刻”的神经元。
  • ConvLSTM:实现了多层ConvLSTM结构,能够处理一整个时间序列的视频帧数据。

损失函数说明

MSE(均方误差) 衡量预测值和真实值之间的平均平方差。
在这里插入图片描述

关于训练终止条件:
可以根据 MSE是否达到某个阈值(如 < 0.001)提前终止训练,这是所谓的 “Early Stopping(提前停止)策略”。

(python全代码)

MSE损失函数曲线如下:可知MSE一直在下降,虽然存在振荡
在这里插入图片描述

前9帧图像及预测的第十帧图像得到的动图如下:
在这里插入图片描述

python完整代码如下:

import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image# 设置字体
plt.rcParams['font.family'] = 'Times New Roman'# 创建保存图像目录
os.makedirs("./Figures", exist_ok=True)# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# ====================================
# 一、ConvLSTM 模型结构
# ====================================class ConvLSTMCell(nn.Module):def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):super(ConvLSTMCell, self).__init__()padding = kernel_size // 2self.input_channels = input_channelsself.hidden_channels = hidden_channelsself.conv = nn.Conv2d(input_channels + hidden_channels, 4 * hidden_channels, kernel_size, padding=padding, bias=bias)def forward(self, x, h_prev, c_prev):combined = torch.cat([x, h_prev], dim=1)conv_output = self.conv(combined)cc_i, cc_f, cc_o, cc_g = torch.chunk(conv_output, 4, dim=1)i = torch.sigmoid(cc_i)f = torch.sigmoid(cc_f)o = torch.sigmoid(cc_o)g = torch.tanh(cc_g)c = f * c_prev + i * gh = o * torch.tanh(c)return h, cclass ConvLSTM(nn.Module):def __init__(self, input_channels, hidden_channels, kernel_size, num_layers):super(ConvLSTM, self).__init__()self.num_layers = num_layerslayers = []for i in range(num_layers):in_channels = input_channels if i == 0 else hidden_channelslayers.append(ConvLSTMCell(in_channels, hidden_channels, kernel_size))self.layers = nn.ModuleList(layers)def forward(self, input_seq):b, t, c, h, w = input_seq.size()h_t = [torch.zeros(b, layer.hidden_channels, h, w).to(input_seq.device) for layer in self.layers]c_t = [torch.zeros(b, layer.hidden_channels, h, w).to(input_seq.device) for layer in self.layers]for time in range(t):x = input_seq[:, time]for i, layer in enumerate(self.layers):h_t[i], c_t[i] = layer(x, h_t[i], c_t[i])x = h_t[i]return h_t[-1]  # 返回最后一层最后一帧的隐藏状态# ====================================
# 二、生成移动方块序列数据
# ====================================def generate_moving_square_sequence(batch_size, time_steps, height, width):data = torch.zeros((batch_size, time_steps, 1, height, width))for b in range(batch_size):dx = np.random.randint(1, 3)dy = np.random.randint(1, 3)x = np.random.randint(0, width - 6)y = np.random.randint(0, height - 6)for t in range(time_steps):data[b, t, 0, y:y+5, x:x+5] = 1.0x = (x + dx) % (width - 5)y = (y + dy) % (height - 5)return data# ====================================
# 三、模型、损失、优化器
# ====================================class ConvLSTM_Predictor(nn.Module):def __init__(self):super().__init__()self.convlstm = ConvLSTM(input_channels=1, hidden_channels=16, kernel_size=3, num_layers=1)self.decoder = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1)def forward(self, input_seq):hidden = self.convlstm(input_seq)pred = self.decoder(hidden)return predmodel = ConvLSTM_Predictor().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)# ====================================
# 四、训练过程
# ====================================mse_list = []
max_epochs = 100
mse_threshold = 0.001
height, width = 64, 64for epoch in range(max_epochs):model.train()seq = generate_moving_square_sequence(8, 10, height, width).to(device)input_seq = seq[:, :9]target_frame = seq[:, 9, 0].unsqueeze(1)optimizer.zero_grad()output = model(input_seq)loss = criterion(output, target_frame)loss.backward()optimizer.step()mse = loss.item()mse_list.append(mse)print(f"Epoch {epoch+1}/{max_epochs}, MSE: {mse:.6f}")# 提前停止条件if mse < mse_threshold:print(f"✅ 提前停止:MSE 已达到阈值 {mse_threshold}")break# ====================================
# 五、测试与可视化结果
# ====================================model.eval()
with torch.no_grad():test_seq = generate_moving_square_sequence(1, 10, height, width).to(device)input_seq = test_seq[:, :9]true_frame = test_seq[:, 9, 0]pred_frame = model(input_seq)[0, 0].cpu().numpy()# 保存输入帧
for t in range(9):frame = input_seq[0, t, 0].cpu().numpy()plt.imshow(frame, cmap='gray')plt.title(f"Input Frame t={t}")plt.colorbar()plt.savefig(f"./Figures/input_frame_{t}.png")plt.close()# 保存 Ground Truth
plt.imshow(true_frame[0].cpu().numpy(), cmap='gray')
plt.title("Ground Truth Frame t=9")
plt.colorbar()
plt.savefig("./Figures/ground_truth_t9.png")
plt.close()# 保存预测帧
plt.imshow(pred_frame, cmap='gray')
plt.title("Predicted Frame t=9")
plt.colorbar()
plt.savefig("./Figures/predicted_t9.png")
plt.close()# 保存 MSE 曲线图
plt.plot(mse_list)
plt.title("Training MSE Loss")
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.grid(True)
plt.savefig("./Figures/mse_curve.png")
plt.close()# ---------------- 生成动图 ----------------frames = []# 添加前9帧输入
for t in range(9):img = Image.open(f"./Figures/input_frame_{t}.png")frames.append(img.copy())# 添加预测帧
img = Image.open("./Figures/predicted_t9.png")
frames.append(img.copy())# 保存动图
frames[0].save("./Figures/sequence.gif", save_all=True, append_images=frames[1:], duration=500, loop=0)
print("✅ 所有图像和动图已保存至 ./Figures 文件夹")

参考


http://www.hkcw.cn/article/GoDpQrZJgf.shtml

相关文章

clickhouse如何查看操作记录,从日志来查看写入是否成功

背景 插入表数据后&#xff0c;因为原本表中就有数据&#xff0c;一时间没想到怎么查看插入是否成功&#xff0c;因为对数据源没有很多的了解&#xff0c;这时候就想怎么查看下插入是否成功呢&#xff0c;于是就有了以下方法 具体方法 根据操作类型查找&#xff0c;比如inse…

【GESP真题解析】第 15 集 GESP 二级 2024 年 6 月编程题 2:计数

大家好,我是莫小特。 这篇文章给大家分享 GESP 二级 2024 年 6 月编程题第 2 题:计数。 题目链接 洛谷链接:B4007 计数 一、完成输入 根据输入格式描述,输入两行,正整数 n 和正整数 k,数据范围: 1 < = n < = 1000 , 1 < = k < = 9 1<=n<=1000,1&…

NumPy 2.x 完全指南【二十一】元素重排操作

文章目录 1. 翻转1.1 fliplr1.2 fliplr1.3 flipud 2. 滚动2.1 roll2.2 rot90 1. 翻转 1.1 fliplr numpy.flip&#xff1a; 沿指定轴翻转数组元素顺序&#xff0c;返回视图&#xff0c;共享原数组内存。 函数定义&#xff1a; def flip(m, axisNone)参数说明&#xff1a; m…

彻底卸载安装的虚拟机VMware Workstation软件

文章目录 前言一、结束“任务管理器”中的相关任务二、停止“服务”中的相关服务三、卸载vmware软件四、删除vmware相关文件五、删除vmware相关注册表 前言 VMware Workstation 是 VMware 推出的桌面虚拟计算机软件&#xff0c;支持在单台物理机上运行多个操作系统。它提供强大…

Python 进阶【三】:Excel操作

1. 概述与库介绍 1.1 Excel自动化的重要性 在数据处理领域&#xff0c;Excel是最常用的工具之一。手动操作Excel对于小规模数据和简单任务尚可&#xff0c;但当面对&#xff1a; 大规模数据集重复性操作复杂计算和分析 时&#xff0c;手动操作效率低下且容易出错。Python提供…

Oracle RMAN自动恢复测试脚本

说明 此恢复测试脚本&#xff0c;基于rman备份脚本文章使用的fullbak.sh做的备份。 数据库将被恢复到RESTORE_LO参数设置的位置。 在恢复完成后&#xff0c;执行一个测试sql,确认数据库恢复完成&#xff0c;数据库备份是好的。恢复测试数据库的参数&#xff0c;比如SGA大小都…

亚马逊桌布运营中的利润核算与优化:从成本管控到决策升级

在亚马逊电商市场&#xff0c;卖家运营面临利润核算与决策难题。​ 一、卖家运营核心痛点 &#xff08;一&#xff09;利润核算复杂性 亚马逊费用体系复杂&#xff1a;平台销售佣金因类目而异&#xff0c;FBA 费用包含仓储、配送等项目&#xff0c;且随淡旺季、仓储时长动态…

C# Costura.Fody 排除多个指定dll

按照网上的说在 FodyWeavers.xml 里修改 然后需要注意的是 指定多个排除项 不是加 | 是换行 一个换行 就排除一项 我测试的 <?xml version"1.0" encoding"utf-8"?> <Weavers xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance&quo…

设计模式-发布订阅

文章目录 发布订阅概念发布订阅 vs 监听者例子代码 发布订阅概念 发布/订阅者模式最大的特点就是实现了松耦合&#xff0c;也就是说你可以让发布者发布消息、订阅者接受消息&#xff0c;而不是寻找一种方式把两个分离 的系统连接在一起。当然这种松耦合也是发布/订阅者模式最大…

算法第32天|509. 斐波那契数、70. 爬楼梯、746. 使用最小花费爬楼梯

509. 斐波那契数 题目 思路与解法 class Solution:def fib(self, n: int) -> int:fib [1] * nif n 0:return 0if n 1 or n 2 :return 1for i in range(2, n):fib[i] fib[i-1] fib[i-2]return fib[n-1]70. 爬楼梯 题目 思路与解法 class Solution:def climbStairs(…

涂鸦智能的TuyaOpen框架入门指南:智能插座实战

目录 引言 TuyaOpen框架简介 程序下载和编译 安装依赖 克隆仓库 设置与编译 step1. 设置环境变量 step2. 选择待编译项目 step3. 编译 step4. menuconfig 配置 在Ubuntu上测试示例程序Switch Demo 创建产品并获取产品的 PID 确认 TuyaOpen 授权码 运行程序 程序…

快速上手shell条件测试

一、命令执行结果判定 && 命令执行后如果没有任何报错时会执行符号后面的动作 || 在命令执行后如果命令有报错会执行符号后的动作 二、条件判断方法 条件测试语法说明示例test 测试表达式test命令和 测试表达式 之间至少有一个空格[ 测试表达式 ]该方法和test命令的…

每日刷题c++

快速幂 #include <iostream> using namespace std; #define int long long int power(int a, int b, int p) {int ans 1;while (b){if (b % 2){ans * a;ans % p; // 随时取模}a * a;a % p; // 随时取模b / 2;}return ans; } signed main() {int a, b, p;cin >> a …

什么是node.js、npm、vue

一、Node.js 是什么&#xff1f; &#x1f63a; 定义&#xff1a; Node.js 是一个基于 Chrome V8 引擎的 JavaScript 运行时环境&#xff0c;让你可以在浏览器之外运行 JavaScript 代码&#xff0c;主要用于服务端开发。 &#x1f63a;从计算机底层说&#xff1a;什么是“运…

华为OD机试真题——求最多可以派出多少支队伍(2025A卷:100分)Java/python/JavaScript/C/C++/GO最佳实现

2025 A卷 100分 题型 本专栏内全部题目均提供Java、python、JavaScript、C、C++、GO六种语言的最佳实现方式; 并且每种语言均涵盖详细的问题分析、解题思路、代码实现、代码详解、3个测试用例以及综合分析; 本文收录于专栏:《2025华为OD真题目录+全流程解析+备考攻略+经验分…

webrtc初了解

1. webrtc的简介 一、WebRTC 是什么&#xff1f; Web Real-Time Communication&#xff08;网页实时通信&#xff09;&#xff0c;是浏览器原生支持的实时音视频通信技术&#xff0c;无需安装插件或客户端&#xff0c;可直接在浏览器之间实现点对点&#xff08;P2P&#xff09…

【Deepseek 学网络互联】跨节点通信global 和节点内通信CLAN保序

Clan模式下的源端保序与Global类似&#xff0c;目的端保序则退化成通道保序&#xff0c;此时仅支持网络单路径保序。”这里的通道保序怎么理解&#xff1f; 用户可能正在阅读某种硬件架构文档&#xff08;比如NVIDIA的NVLink或InfiniBand规范&#xff09;&#xff0c;因为"…

​Windows 11 安装 Miniconda 与 Jupyter 全流程指南​

​一、Miniconda 安装与配置​ 1. 下载安装程序 ​访问官网​&#xff1a;打开 Miniconda 官网&#xff0c;下载 ​Python 3.x 版本的 Windows 64 位安装包​。​安装路径选择​&#xff1a; 推荐路径&#xff1a;D:\Miniconda3&#xff08;避免使用中文路径和空格&#xff0…

飞牛NAS+Docker技术搭建个人博客站:公网远程部署实战指南

文章目录 前言1. Docker下载源设置2. Docker下载WordPress3. Docker部署Mysql数据库4. WordPress 参数设置5. 飞牛云安装Cpolar工具6. 固定Cpolar公网地址7. 修改WordPress配置文件8. 公网域名访问WordPress总结 前言 在数字化浪潮中&#xff0c;传统网站搭建方式正面临前所未…

批目标灵活模拟!成都鼎讯雷达模拟器,打造沉浸式雷达对抗训练场景

在现代战争的电磁频谱博弈中&#xff0c;能否构建高度逼真的雷达干扰与目标环境&#xff0c;直接决定着雷达装备性能的上限与作战人员的实战能力。成都鼎讯凭借在数字射频存储&#xff08;DRFM&#xff09;、数字干扰调制&#xff08;DJS&#xff09;等前沿技术的深厚积累&…