记我的第一个深度学习模型尝试——MNIST手写数字识别

article/2025/6/9 15:26:14

种一棵树最好的时间是十年前,其次是现在。

目录

前言

一、数据准备

二、构建模型

三、模型精度检验


前言

最近又空闲下来,终于有时间把之前荒废的学习计划给重拾起来了!今天做的是MNIST手写数字识别项目。这可以说是深度学习的“Hello World”级项目了。在AI的帮助下,也是成功的完成了这个项目。记录下来,其中如有不正确的地方,欢迎指正。


一、数据准备

做项目最重要的是什么?数据!因此,我们首先把数据准备好。

我使用的框架是Pytorch,在Pytorch中有现成的方法直接下载。直接通过 torchvision.datasets 模块提供的接口完成。首先需要安装torchvision。

pip install torchvision

下载数据,代码如下。运行后会直接下载到data文件夹,如果没有会直接在当前文件路径新建一个。

from torchvision import datasets, transforms# 定义数据预处理(这里仅做归一化,将像素值从 [0,255] 转为 [0,1])
transform = transforms.Compose([transforms.ToTensor(),  # 转为 PyTorch 张量(形状:[1,28,28])transforms.Normalize((0.1307,), (0.3081,))  # MNIST 全局均值和标准差(经验值)
])# 下载训练集(6万张图)
train_dataset = datasets.MNIST(root='./data',  # 数据集存储路径(当前目录下的 data 文件夹)train=True,     # 是否为训练集(True:训练集,False:测试集)download=True,  # 若本地无数据则下载transform=transform  # 应用预处理
)# 下载测试集(1万张图)
test_dataset = datasets.MNIST(root='./data',train=False,download=True,transform=transform
)

当然,如果使用的是tensorflow框架的话,也是有现成的方法,但是tensorflow使用起来要比Pytorch稍微难上手一点。除此之外,也可以选择直接去官方网站下载

下载完后,我们查看数据集大小,以及对各数字类别分布做一个统计,这样做的目的是为了对这个数据集有更多的了解。机器学习非常依赖数据,所以在进行模型训练前,我们应该对数据集有尽可能多的了解。代码及运行结果如下。

from torch.utils.data import DataLoadertrain_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)  # 训练集批量加载(打乱顺序)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)   # 测试集批量加载(不打乱)
# 训练集和测试集的图片数量
print(f"训练集图片数量: {len(train_dataset)} 张")  # 输出:60000 张
print(f"测试集图片数量: {len(test_dataset)} 张")   # 输出:10000 张
import numpy as np# 统计训练集标签分布
train_labels = [label for _, label in train_dataset]
train_label_counts = np.bincount(train_labels)  # 统计0-9每个数字的出现次数# 统计测试集标签分布
test_labels = [label for _, label in test_dataset]
test_label_counts = np.bincount(test_labels)# 绘制柱状图
plt.figure(figsize=(12, 5))# 训练集子图
plt.subplot(1, 2, 1)
plt.bar(range(10), train_label_counts)
plt.title("distribution of categories in train set")
plt.xlabel("label")
plt.ylabel("number")# 测试集子图
plt.subplot(1, 2, 2)
plt.bar(range(10), test_label_counts)
plt.title("distribution of categories in test set")
plt.xlabel("label")
plt.ylabel("number")plt.tight_layout()
plt.show()

二、构建模型

导入相关库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

构建模型,我这里选择搭建了一个三层全连通感知机模型。

class ThreeLayerPerceptron(nn.Module):def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim):super(ThreeLayerPerceptron, self).__init__()# 第一层全连接:输入层 -> 隐藏层1self.fc1 = nn.Linear(input_dim, hidden_dim1)# 第二层全连接:隐藏层1 -> 隐藏层2self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)# 第三层全连接:隐藏层2 -> 输出层self.fc3 = nn.Linear(hidden_dim2, output_dim)def forward(self, x):# 输入数据展平(如果是图像等多维输入需要此操作)x = x.view(x.size(0), -1)# 第一层:线性变换 + ReLU激活x = F.relu(self.fc1(x))# 第二层:线性变换 + ReLU激活x = F.relu(self.fc2(x))# 第三层:线性变换(输出层通常不接激活函数,用于分类时后续接softmax)x = self.fc3(x)return x

进行模型训练。我们这里是训练了5个epoch,意味着整个数据集经历了五次前向传播和反向传播。其实迭代很少了,但是这个任务比较简单,所以虽然只是经过了简单的训练,但是最后的效果还行。

# 模型参数(以MNIST为例)
input_dim = 784       # 28x28图像展平后的维度
hidden_dim1 = 256     # 第一个隐藏层神经元数
hidden_dim2 = 128     # 第二个隐藏层神经元数
output_dim = 10       # 10类数字# 初始化模型(自动适配CPU/GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ThreeLayerPerceptron(input_dim, hidden_dim1, hidden_dim2, output_dim).to(device)# 定义损失函数(分类任务用交叉熵)和优化器(Adam)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
def train_model(model, train_loader, criterion, optimizer, epochs=10):model.train()  # 切换训练模式(启用Dropout等)for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (images, labels) in enumerate(train_loader):# 数据移动到目标设备(CPU/GPU)images, labels = images.to(device), labels.to(device)# 前向传播 + 计算损失outputs = model(images)loss = criterion(outputs, labels)# 反向传播 + 优化参数optimizer.zero_grad()  # 清空梯度loss.backward()        # 反向传播optimizer.step()       # 更新参数# 统计训练指标running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)  # 取概率最大的类别total += labels.size(0)correct += (predicted == labels).sum().item()# 每100个批量打印一次进度if (batch_idx+1) % 100 == 0:print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx+1}/{len(train_loader)}], "f"Loss: {running_loss/100:.4f}, Acc: {100*correct/total:.2f}%")running_loss = 0.0  # 重置累计损失print("训练完成!")# 开始训练(建议先试3-5轮,观察准确率是否提升)
train_model(model, train_loader, criterion, optimizer, epochs=5)

三、模型精度检验

测试集精度验证

ef test_model(model, test_loader):model.eval()  # 切换测试模式(禁用Dropout等)correct = 0total = 0with torch.no_grad():  # 不计算梯度(加速测试)for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"测试准确率: {100 * correct / total:.2f}%")# 运行测试
test_model(model, test_loader)

测试集准确率为0.9762,结合之前的训练集准确率为0.9876,可以看到效果还是不错的。

接下来进行混淆矩阵热图可视化。几乎都集中在对角线,模型性能不错。

from sklearn.metrics import confusion_matrix
import seaborn as snsdef plot_confusion_matrix(model, test_loader):model.eval()all_labels = []all_preds = []with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, preds = torch.max(outputs, 1)all_labels.extend(labels.cpu().numpy())all_preds.extend(preds.cpu().numpy())# 计算混淆矩阵cm = confusion_matrix(all_labels, all_preds)# 可视化plt.figure(figsize=(10, 8))sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=range(10), yticklabels=range(10))plt.xlabel('predicted')plt.ylabel('true')plt.title('Confusion Matrix')plt.show()# 调用函数(需已定义model和test_loader)
plot_confusion_matrix(model, test_loader)

错误样本可视化,展示模型分类错误的样本,分析误分类原因

def plot_wrong_samples(model, test_loader, num_samples=9):model.eval()wrong_images = []wrong_labels = []wrong_preds = []with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, preds = torch.max(outputs, 1)# 筛选错误样本mask = (preds != labels)if mask.any():wrong_images.extend(images[mask].cpu())wrong_labels.extend(labels[mask].cpu().numpy())wrong_preds.extend(preds[mask].cpu().numpy())if len(wrong_images) >= num_samples:break# 可视化前9个错误样本plt.figure(figsize=(12, 12))for i in range(num_samples):image = wrong_images[i].squeeze()  # 移除通道维度true_label = wrong_labels[i]pred_label = wrong_preds[i]plt.subplot(3, 3, i+1)plt.imshow(image, cmap='gray')plt.title(f'true: {true_label}, :pred {pred_label}', color='red')plt.axis('off')plt.tight_layout()plt.show()

结果如下:

计算类别级准确率,查看每个类别的分类准确率。

def plot_class_accuracy(model, test_loader):model.eval()class_correct = [0] * 10class_total = [0] * 10with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, preds = torch.max(outputs, 1)for label, pred in zip(labels, preds):if label == pred:class_correct[label] += 1class_total[label] += 1# 计算每个类别的准确率class_acc = [100 * class_correct[i]/class_total[i] for i in range(10)]# 绘制柱状图plt.figure(figsize=(10, 6))plt.bar(range(10), class_acc)plt.xticks(range(10))plt.xlabel('label')plt.ylabel('accuracy(%)')# plt.title('各类别分类准确率')plt.ylim(80, 100)  # MNIST模型通常准确率较高,调整Y轴范围plt.show()# 调用函数
plot_class_accuracy(model, test_loader)

其实从几个样本的结果,还有柱状图,可以看出模型对9这个数字的识别明显不如其他类别。



http://www.hkcw.cn/article/xFxxQbwQba.shtml

相关文章

杭州白塔岭画室怎么样?和燕壹画室哪个好?

杭州作为全国美术艺考集训的核心区域,汇聚了众多实力强劲的画室,其中白塔岭画室和燕壹画室备受美术生关注。对于怀揣艺术梦想的考生而言,选择一所契合自身需求的画室,对未来的艺术之路影响深远。接下来,我们将从多个维…

AI与区块链:数据确权与模型共享的未来

AI与区块链:数据确权与模型共享的未来 系统化学习人工智能网站(收藏):https://www.captainbed.cn/flu 文章目录 AI与区块链:数据确权与模型共享的未来摘要引言技术路线对比1. 数据确权:从中心化存储到分布…

【T2I】Decouple-Then-Merge: Finetune Diffusion Models as Multi-Task Learning

CODE: CVPR 2025 GitHub - MqLeet/DeMe: [CVPR2025] Official implementation of "Decouple-Then-Merge: Finetune Diffusion Models as Multi-Task Learning" Abstract 扩散模型是通过学习一系列模型来训练的,这些模型可以逆转噪声衰减的每一步。通常&…

二分查找的边界艺术:LeetCode 34 题深度解析

文章目录 一、问题引入:寻找区间的边界二、二分的核心:二段性三、左边界的查找逻辑(找第一个 ≥ target 的位置)四、右边界的查找逻辑(找最后一个 ≤ target 的位置)五、代码实现六、二分边界模板总结结语 …

系统思考:短期利益与长期系统影响

一个决策难题:一家公司接到了一个大订单,客户提出了10%的降价要求,而企业的产能还无法满足客户的需求。你会选择增加产能,接受这个订单,还是拒绝?从系统思考的角度来看,这个决策不仅仅是一个简单…

【数据结构 -- B树】

目录 一、前言二、B树示例定义查找数据插入数据删除数据 一、前言 前面我们已经学习了二叉搜索树和AVL树,它们的查找、插入、删除数据效率都很高,我们首先需要了解它们是怎么操作数据的 首先将所有数据一次性调到内存中,再在内存中进行处理…

新手小白使用VMware创建虚拟机练习Linux

新手小白想要练习linux,找不到合适的地方,可以先创建一个虚拟机,在自己创建的虚拟机里面进行练习,接下来我给大家接受一下创建虚拟机的步骤。 VMware选择创建新的虚拟机 选择自定义 硬件兼容性选择第一个,不同的版本&a…

C++ Vector算法精讲与底层探秘:从经典例题到性能优化全解析

前引:在C标准模板库(STL)中,vector作为动态数组的实现,既是算法题解的基石,也是性能优化的关键战场。其连续内存布局、动态扩容机制和丰富的成员函数,使其在面试高频题(如LeetCode、…

【macbook】触控板手势

在 MacBook 上,你可以使用「触控板手势」或快捷键来实现在多个窗口/应用间切换,以下是几种方式: ✅ 1. 三指或四指左右滑动:切换“全屏应用”或“桌面”空间 **操作方式:**三指或四指在触控板上左右滑动。**适用场景&…

帝可得 - 策略管理

一. 需求说明 策略管理主要涉及到二个功能模块,业务流程如下: 新增策略: 允许管理员定义新的策略,包括策略的具体内容和参数(如折扣率) 策略分配: 将策略分配给一个或多个售货机。 graph TDA[登录系统] A --> B…

立志成为一名优秀测试开发工程师(第十一天)—Postman动态参数/变量、文件上传、断言策略、批量执行及CSV/JSON数据驱动测试

目录 一、Postman接口关联与正则表达式应用 1.正则表达式解析 2.提取鉴权码。 二、Postman内置动态参数以及自定义动态参数 1.常见内置动态参数: 2.自定义动态参数: 3.“编辑”接口练习 三、图片上传 1.文件的上传 2.上传后内容的验证 四、po…

学习路之PHP--easyswoole使用视图和模板

学习路之PHP--easyswoole使用视图和模板 一、安装依赖插件二、 实现渲染引擎三、注册渲染引擎四、测试调用写的模板五、优化六、最后补充 一、安装依赖插件 composer require easyswoole/template:1.1.* composer require topthink/think-template相关版本: "…

【C++高并发内存池篇】性能卷王养成记:C++ 定长内存池,让内存分配快到飞起!

📝本篇摘要 在本篇将介绍C定长内存池的概念及实现问题,引入内存池技术,通过实现一个简单的定长内存池部分,体会奥妙所在,进而为之后实现整体的内存池做铺垫! 🏠欢迎拜访🏠&#xff…

前端验证下跨域问题(npm验证)

文章目录 一、背景二、效果展示三、代码展示3.1)index.html3.2)package.json3.3) service.js3.4)service2.js 四、使用说明4.1)安装依赖4.2)启动服务器4.3)访问前端页面 五、跨域解决方案说明六…

nginx+Tomcat负载均衡群集

目录 一. LVS,HAProxy,Nginx的区别 1. 核心区别 2. 负载均衡算法对比 2. 1 LVS 负载均衡算法 2.2 HAProxy 负载均衡算法 2.3 Nginx 负载均衡算法 2.4 总结 二. 案例分析 1. 案例概述 (1) Tomcat 简介 (2)应用场景 2. 案例环境 3. 案例实施 …

WSL安装及使用 (适用于 Linux 的 Windows 子系统)

WSL简介 WSL:适用于 Linux 的 Windows 子系统,有1和2两个版本,1是windows重新实现了linux接口,2是原生linux内核。目前 WSL2 为默认模式,兼容性和性能更好。 wsl中文官网 安装 确保以下功能开启: 控制面…

JavaSec | SpringAOP 链学习分析

目录: 链子分析 反向分析 正向分析 poc 构造 总结 链子分析 反向分析 依赖于 Spring-AOP 和 aspectjweaver 两个包,在我们 springboot 中的 spring-boot-starter-aop 自带包含这俩类,所以也可以说是 spring boot 的原生反序化链了,调用…

PV操作的C++代码示例讲解

文章目录 一、PV操作基本概念(一)信号量(二)P操作(三)V操作 二、PV操作的意义三、C中实现PV操作的方法(一)使用信号量实现PV操作代码解释: (二)使…

医疗内窥镜影像工作站技术方案(续)——EFISH-SCB-RK3588国产化替代技术深化解析

一、异构计算架构的医疗场景适配 ‌多核任务调度优化‌ ‌A76/A55协同计算‌:4Cortex-A762.4GHz负责AI推理(如息肉识别算法),4Cortex-A551.8GHz处理DICOM影像传输协议,多任务负载效率比赛扬N系列提升80%‌NPU加速矩阵…

HCIP-Datacom Core Technology V1.0_3 OSPF基础

动态路由协议简介 静态路由相比较动态路由有什么优点呢。 静态路由协议,当网络发生故障或者网络拓扑发生变更,它需要管理员手工配置去干预静态路由配置,但是动态路由协议,它能够及时自己感应网络拓扑变化,不路由选择…