1. pytorch手写数字预测

article/2025/7/21 22:02:09

1. pytorch手写数字预测

  • 1.背景
  • 2.准备数据集
  • 2.定义模型
  • 3.dataloader和训练
  • 4.训练模型
  • 5.测试模型
  • 6.保存模型

1.背景

因为自身的研究方向是多模态目标跟踪,突然对其他的视觉方向产生了兴趣,所以心血来潮的回到最经典的视觉任务手写数字预测上来,所以这份教程并不是一份非常详尽的教程,是在一部分pytorch,深度学习基础上的教程,如果需要的是非常保姆级的教程建议看别的文章

2.准备数据集

这里我才用了直接导torchvision中的dataset包来下载Mnist数据集,也算是一个非常经典的数据集了

# 导入数据集
from torchvision.datasets import MNIST
import torch# 设置随机种子
torch.manual_seed(3306)# 数据预处理
from torchvision import transforms
# 定义数据转换
transform = transforms.Compose([transforms.ToTensor(),  # 转换为 Tensortransforms.Normalize((0.1307,), (0.3081,))  # 标准化
])# 下载 MNIST 数据集
mnist_train = MNIST(root='./dataset_file/mnist_raw', train=True, download=True,transform=transform)
mnist_test = MNIST(root='./dataset_file/mnist_raw', train=False, download=True,transform=transform)
# 查看数据集大小
print(f"MNIST train dataset size: {len(mnist_train)}")
print(f"MNIST test dataset size: {len(mnist_test)}")

其中,MNIST()中的root代表的是数据集存放的位置,download代表是如果当前位置没有数据集是否需要下载。
transformer则是对数据的处理方式,我这里采用了简单地转成tensor和简单地标准化。

不过这样子下载下来的数据集是二进制格式的,无法直接查看图片,当然,如果你需要查看图片,也有办法。

# 查看图片
import matplotlib.pyplot as pltdef show_image(id):img, label = mnist_train[id]img = img.squeeze().numpy()  # 去掉通道维度print(img.shape)# print(img)plt.imshow(img, cmap='gray')plt.title(f"Label: {label}")plt.axis('off')plt.show()show_image(1)

效果
在这里插入图片描述

又或者你想要下载的数据集是图片格式,我这里也准备了代码

代码是在别人的基础上改的,其中数据集存放路径是dataset_dir,如果需要修改自行打印然后修改位置就好了。

#!/usr/bin/env python3
# -*- encoding utf-8 -*-'''
@File: save_mnist_to_jpg.py
@Date: 2024-08-23
@Author: KRISNAT
@Version: 0.0.0
@Email: ****
@Copyright: (C)Copyright 2024, KRISNAT
@Desc:1. 通过 torchvision.datasets.MNIST 下载、解压和读取 MNIST 数据集;2. 使用 PIL.Image.save 将 MNIST 数据集中的灰度图片以 JPEG 格式保存。
'''import sys, os
sys.path.insert(0, os.getcwd())from torchvision.datasets import MNIST
import PIL
from tqdm import tqdmif __name__ == "__main__":home_dir = os.path.abspath('.')root = os.path.abspath(os.path.join(home_dir, '../dataset_file'))print(root)# exit(0)# 图片保存路径dataset_dir = os.path.join(root, 'mnist_jpg')if not os.path.exists(dataset_dir):os.makedirs(dataset_dir)# 从网络上下载或从本地加载MNIST数据集# 训练集60K、测试集10K# torchvision.datasets.MNIST接口下载的数据一组元组# 每个元组的结构是: (PIL.Image.Image image model=L size=28x28, 标签数字 int)training_dataset = MNIST(root='mnist',train=True,download=True,)test_dataset = MNIST(root='mnist',train=False,download=True,)# 保存训练集图片with tqdm(total=len(training_dataset), ncols=150) as pro_bar:for idx, (X, y) in enumerate(training_dataset):f = dataset_dir + "/" + "training_" + str(idx) + \"_" + str(training_dataset[idx][1] ) + ".jpg"  # 文件路径training_dataset[idx][0].save(f)pro_bar.update(n=1)# 保存测试集图片with tqdm(total=len(test_dataset), ncols=150) as pro_bar:for idx, (X, y) in enumerate(test_dataset):f = dataset_dir + "/" + "test_" + str(idx) + \"_" + str(test_dataset[idx][1] ) + ".jpg"  # 文件路径test_dataset[idx][0].save(f)pro_bar.update(n=1)

2.定义模型

这里我准备了两个模型,一个MLP模型和一个简单地CNN模型,其中MLP模型参数量1M,CNN模型参数量大概8M,当然这俩模型也没有很仔细的规划

import torch
import torch.nn as nnclass DigitLinear(nn.Module):def __init__(self):super(DigitLinear, self).__init__()self.fc1 = nn.Linear(28 * 28, 1000)self.fc2 = nn.Linear(1000, 500)self.dropout = nn.Dropout(0.3)self.fc3 = nn.Linear(500, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = self.fc1(x)x = torch.relu(x)x = self.dropout(x)x = self.fc2(x)x = torch.relu(x)x = self.fc3(x)return xclass DigitCNN(nn.Module):def __init__(self):super(DigitCNN,self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64*28*28, 128)self.dropout = nn.Dropout(0.1)self.fc2 = nn.Linear(128, 10)def forward(self, x):# print("x.shape:", x.shape)B,N,H,W = x.shapex = self.conv1(x)x = torch.relu(x)x = self.conv2(x)x = torch.relu(x)x = x.view(B, -1)  # 展平x = self.fc1(x)x = torch.relu(x)x = self.dropout(x)x = self.fc2(x)return x

3.dataloader和训练

这里的代码就很简单了,就是一些参数的选择,例如epoch,batchsize。其中的训练函数我写的买有很全面,只是勉强满足了训练功能,还有好多可以优化的点,比如打印fps,断点续训练啥的,不过这个任务提不起劲去干这事,大家可以自行优化。

# 数据加载器
from torch.utils.data import DataLoader
from lib.model.DigitModel import DigitLinear,DigitCNN
# 定义数据加载器
batch_size = 256
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)epoch = 50
# 训练模型
net = DigitLinear() # 参数量1M 97.50%
# net = DigitCNN() # 参数量8M 98.81%
net.cuda()# 定义损失函数和优化器
import torch.optim as optim
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 训练函数def train_model(model, train_loader, criterion, optimizer, num_epochs=10):model.train()  # 设置模型为训练模式for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0for i, (inputs, labels) in enumerate(train_loader):inputs= inputs.cuda()y = torch.tensor(torch.zeros((inputs.shape[0],10), dtype=torch.float)).cuda()y[torch.arange(inputs.shape[0]), labels] = 1optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, y)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels.cuda()).sum().item()epoch_loss = running_loss / len(train_loader)epoch_acc = 100. * correct / totalprint(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')# 训练模型
train_model(net, train_loader, criterion, optimizer, num_epochs=epoch)

4.训练模型

有了上面的代码就可以开始训练了,我这里训练的截图是我的MLP模型,效果不是很好,CNN的效果稍微好一点,比MLP高1%,但是图忘记截了。反正够用了,因为本身MNIST的数据就不是很完美,有很多类似于噪声的数据例如:
在这里插入图片描述
这些数字我人眼都分不出是什么玩意。

训练效果如下
在这里插入图片描述

5.测试模型

训练完当然是测试了
最后我的MLP模型跑了97.50%的准确率

代码如下

# 测试模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.eval()
correct = 0
total = 0
with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device).float(), labels.to(device).float()outputs = net(inputs)_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels.cuda()).sum().item()# print(f"Predicted: {predicted}, Ground Truth: {targets}")print(f"Accuracy: {correct / total * 100:.4f} %")

在这里插入图片描述

6.保存模型

保存模型代码就更简单了

# 保存模型
torch.save(net.state_dict(), './digit_model.pth')

http://www.hkcw.cn/article/JLSNEleStI.shtml

相关文章

武警智能兵器室系统架构设计与关键技术解析

在现代化武警部队建设中,武器弹药的安全管理与快速响应能力直接影响部队战斗力。本文基于某实战化智能兵器室建设案例,深入解析其系统架构设计、关键技术实现及创新管理机制,为安防领域提供可借鉴的解决方案。 整体拓扑结构 系统采用分层分布…

HTML5 列表、表格与媒体元素、页面结构分析

1. 列表 无序列表 有序列表 定义列表 列表对比 2. 表格 跨列 跨行 跨行和跨列 3. HTML5的媒体元素 视频元素 注意:autoplay现在很多浏览器不支持了! 音频元素 4. 页面结构分析 5. 总结

中文文本分析及词云生成

一、代码解析(按执行顺序) 1. 库导入 import jieba # 中文分词工具 from wordcloud import WordCloud # 词云生成器 from collections import Counter # 词频统计 import matplotlib.pyplot as plt # 可视化 import numpy as np # 图像矩阵处理 f…

芯片手册解读

一: 1.这是一款差分转单端的芯片: 2.给出了逻辑高低的识别门限:并不是大于100mv和小于-100mv就识别不到了——而是大于100mv和小于-100mv都可以识别到,手册的意思仅仅代表门限节点而已,完全可以在进入门限后的其他电…

LangChain-Tool和Agent结合智谱AI大模型应用实例2

1.Tool(工具) 定义与功能 单一功能模块:Tool是完成特定任务的独立工具,每个工具专注于一项具体的操作,例如:搜索、计算、API调用等 无决策能力:工具本身不决定何时被调用,仅在被触发时执行预设操作 输入输出明确:每个工具需明确定义输入、输出参数及格式 2.Agent(…

专业级图片分割解决方案

在日常处理图片的过程中,我们常常会遇到需要将一张图分割成多个小图的情况。这一款高效又实用的图片分割工具——它操作简单、功能强大,关键是完全免费开源,适合所有有图像处理需求的朋友! 在使用之前,先花几分钟把它…

Re--题

一.[NSSCTF 2022 Spring Recruit]easy C 直接看for循环,异或 写代码 就得到了flag easy_Re 二.[SWPUCTF 2021 新生赛]非常简单的逻辑题 先对这段代码进行分析 flag xxxxxxxxxxxxxxxxxxxxx s wesyvbniazxchjko1973652048$-&*&l…

iOS 集成网易云信IM

云信官方文档在这 看官方文档的时候&#xff0c;版本选择最新的V10。 1、CocoPods集成 pod NIMSDK_LITE 2、AppDelegate.m添加头文件 #import <NIMSDK/NIMSDK.h> 3、初始化 NIMSDKOption *mrnn_option [NIMSDKOption optionWithAppKey:"6f6568e354026d2d658a…

边缘计算网关支撑医院供暖系统高效运维的本地化计算与边缘决策

一、项目背景 医院作为人员密集的特殊场所&#xff0c;对供暖系统的稳定性和高效性有着极高的要求。其供暖换热站传统的人工现场监控方式存在诸多弊端&#xff0c;如人员值守成本高、数据记录不及时不准确、故障发现和处理滞后、能耗难以有效监测和控制等&#xff0c;难以满足…

Google Earth Pro 7.3 中文绿色版 - 谷歌地球专业版(精准定位,清晰查看您家位置)

谷歌卫星高清地图 下载链接&#xff1a;https://pan.quark.cn/s/c6069864c9f3 Google Earth Pro-7.3.6.9796-x64 Google Earth WinMac安装版 GoogleEarthProPortable googleearthpromac-intel-7.3.6.10155 GoogleEarthProWin-7.3.6.10155 GoogleEarthProWin-x64-7.3.6.10…

【工作笔记】 WSL开启报错

【工作笔记】 WSL开启报错 时间&#xff1a;2025年5月30日16:50:42 1.现象 Installing, this may take a few minutes... WslRegisterDistribution failed with error: 0x80370114 Error: 0x80370114 ??????????????????Press any key to continue...

《TCP/IP 详解 卷1:协议》第3章:链路层

以太网和IEEE802局域网/城域网标准 IEEE802局域网/城域网标准 IEEE 802 是一组由 IEEE&#xff08;电气与电子工程师协会&#xff09;定义的局域网和城域网通信标准系列&#xff0c;涵盖了从物理层到链路层的多个网络技术。其中&#xff1a; IEEE 802.3 定义的是传统的以太网…

【定昌linux开发板】设置密码的有效时间

查看密码策略命令 sudo chage -l 用户名 先查询下默认情况下&#xff0c;密码策略 结果&#xff1a; 可以看出&#xff0c;密码没什么策略 那么我要设置30天后过期&#xff0c;并且七天前要进行提醒 sudo chage -M 30 用户名 再进行查询&#xff0c;结果如下 显示密码的有…

Vue-数据监听

数据监听 基础信息 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><title>数据监听</title><!-- 引入Vue --><script type"text/javascript" src"../js/vue.js&qu…

Java 注解式限流教程(使用 Redis + AOP)

Java 注解式限流教程&#xff08;使用 Redis AOP&#xff09; 在上一节中&#xff0c;我们已经实现了基于 Redis 的请求频率控制。现在我们将进一步升级功能&#xff0c;使用 Spring AOP 自定义注解 实现一个更优雅、可复用的限流方式 —— 即通过 RateLimiter 注解&#xf…

C++学习-入门到精通【10】面向对象编程:多态性

C学习-入门到精通【10】面向对象编程&#xff1a;多态性 目录 C学习-入门到精通【10】面向对象编程&#xff1a;多态性一、多态性介绍&#xff1a;多态电子游戏二、类继承层次中对象之间的关系1.从派生类对象调用基类函数2.将派生类指针指向基类对象3.通过基类指针调用派生类的…

基于springboot的医护人员排班系统设计与实现(源码+文档+部署讲解)

技术范围&#xff1a;SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、小程序、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容&#xff1a;免费功能设计、开题报告、任务书、中期检查PPT、系统功能实现、代码编写、论文编写和辅导、论文…

1、python代码实现与大模型的问答交互

一、基础知识 1.1导入库 torch 是一个深度学习框架&#xff0c;用于处理张量和神经网络。modelscope是由阿里巴巴达摩院推出的开源模型库。 AutoTokenizer 是ModelScope 库的类&#xff0c;分词器应用场景包括自然语言处理&#xff08;NLP&#xff09;中的文本分类、信息抽取…

再见Cursor!Trae Pro 登场

5 月 27 日&#xff0c;字节跳动旗下的 AI 编辑器 Trae 国际版正式推出了 Pro 订阅计划。长期以来&#xff0c;Trae 凭借免费使用和出色的编程体验&#xff0c;深受大家喜爱。不过&#xff0c;免费版在实际使用中&#xff0c;排队等待的情况时有发生&#xff0c;着实给用户带来…

【Docker 从入门到实战全攻略(一):核心概念 + 命令详解 + 部署案例】

1. 是什么 Docker 是一个用于开发、部署和运行应用程序的开源平台&#xff0c;它使用 容器化技术 将应用及其依赖打包成独立的容器&#xff0c;确保应用在不同环境中一致运行。 2. Docker与虚拟机 2.1 Docker&#xff08;容器化&#xff09; 容器化是一种轻量级的虚拟化技术…