深入探究 MNIST 数据集 - Fastai 第三部分

article/2025/7/3 8:58:59

在 fastai 第一部分中,我们学习了如何对 MNIST 数据集进行分类。在本教程中,我们将更深入地了解其底层原理。首先,我们将详细探索 MNIST 数据集。

数据探索

# 第一部分的代码
import torch
import random
from fastai.vision.all import *# 下载简单的 MNIST 数据集(稍后我们会下载完整数据集)
path = untar_data(URLs.MNIST_SAMPLE)
train_path = path/'train'
img_files = list((train_path/'7').ls())
img = PILImage.create(img_files[0])
img.show();

请添加图片描述

print(f"图像模式为: {img.mode}")
图像模式为: RGB

MNIST 图像本是灰度(单通道),但通过 PILImage 创建后会被转换为 RGB,这通常是为了模型兼容性。因此如上所示,图像模式为 RGB。灰度图像的模式应为 “L”(1 通道),而 RGB 为 3 通道。如果你想保持灰度,可以这样:

img = PILImage.create(img_files[0])
img.show();
print(f"图像模式为: {img.mode}")
arr = array(img)
print(arr.shape)
(28, 28, 3)

(28, 28, 3) 是一个三维 NumPy 数组,表示高、宽和通道数。28x28 像素,3 表示 RGB 通道。由于原图为灰度,三个通道的值通常相同。

print(np.unique(arr))
[  0   9  23  24  34  38  44  46  47  64  69  71  76  93  99 104 107 109111 115 128 137 138 139 145 146 149 151 154 161 168 174 176 180 184 185207 208 214 215 221 230 231 240 244 245 251 253 254 255]

np.unique(arr) 显示像素值范围为 0~255,0 为黑,255 为白,中间为灰度。

np.all(arr[:, :, 0] == arr[:, :, 1]) and np.all(arr[:, :, 1] == arr[:, :, 2])
np.True_

上面代码检查三个通道的像素值是否一致。由于是灰度图,结果为 True。

img_t = tensor(arr[:, :, 0])
print(img_t.shape)
df = pd.DataFrame(img_t)
df.style.set_properties(**{'font-size':'6pt'}).background_gradient('Greys')
torch.Size([28, 28])

在上述代码中:
arr[:,:,0] 是常见的三维数组索引方式,: 表示所有元素,0 表示第一个通道。

# 对比选择通道 0 与不选通道时的形状
print(arr[:,:,0].shape)
print(arr[:,:,:].shape)
(28, 28)
(28, 28, 3)
# 查看数组中的唯一值
print(np.unique(arr))
[  0   9  23  24  34  38  44  46  47  64  69  71  76  93  99 104 107 109111 115 128 137 138 139 145 146 149 151 154 161 168 174 176 180 184 185207 208 214 215 221 230 231 240 244 245 251 253 254 255]

这些值就是你看到的数字 7 的灰度图像。计算机就是用这些数字存储图像!手写数字识别的任务本质上就是在数值层面上比较图像与参考数字的相似性。我们将所有图片转为数值向量,分别对 3 和 7 求平均向量,作为数字的代表。这就是"向量空间模型",也是我们模型的基线。之后用它来预测新图片,看其更像 3 还是 7。这是最简单直观的机器学习方法。注意我们将 NumPy 数组转为 tensor(img_t=tensor(arr[:,:,0])),因为 fastai 基于 PyTorch,后者的模型都以 torch.Tensor 为输入输出。Tensor 支持自动微分和 GPU 加速,而 NumPy 仅限 CPU。

tensor(arr).permute(2,0,1).shape
torch.Size([3, 28, 28])

上述代码将 NumPy 数组转为 PyTorch tensor,并调整维度顺序。原始为 (28, 28, 3),调整后为 (3, 28, 28),这是 PyTorch 期望的图像格式(通道,高,宽)。

基线图像

堆叠
torch.stack(seven_tensors):seven_tensors 是数字 7 的二维张量列表,每个张量为 [高, 宽](28x28)。堆叠后变为三维张量 [图片数, 高, 宽]。.mean(0) 沿第 0 维(图片数)求均值,得到一张平均 7 的图片。

# 查看 3 和 7 的图片
three_path = train_path/'3'
seven_path = train_path/'7'three_tensors = [tensor(Image.open(o)) for o in three_path.ls()]
seven_tensors = [tensor(Image.open(o)) for o in seven_path.ls()]# 堆叠所有 3 和 7
stacked_threes = torch.stack(three_tensors).float()/255
stacked_sevens = torch.stack(seven_tensors).float()/255# 计算所有 3 和 7 的均值
mean3 = stacked_threes.mean(0)
mean7 = stacked_sevens.mean(0)# 显示均值图片
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(mean3)
plt.title('3 的均值')
plt.subplot(1, 2, 2)
plt.imshow(mean7)
plt.title('7 的均值')
plt.tight_layout()

在这里插入图片描述

上述代码计算了所有 3 和 7 的均值。均值图片显示了该数字所有图片的平均像素值,给出了典型 3 和 7 的模板。

# 计算 3 和 7 的均值差异
diff = mean3 - mean7
plt.imshow(diff)

在这里插入图片描述

差异图片显示了两者差异最大的区域。亮区表示 3 的像素值高于 7,暗区则相反。

# 计算每个 3 与均值 3、均值 7 的相似度
three_similarity = [((t - mean3)**2).mean().item() for t in stacked_threes]
three_to_seven_similarity = [((t - mean7)**2).mean().item() for t in stacked_threes]# 计算每个 7 与均值 7、均值 3 的相似度
seven_similarity = [((t - mean7)**2).mean().item() for t in stacked_sevens]
seven_to_three_similarity = [((t - mean3)**2).mean().item() for t in stacked_sevens]# 绘制相似度
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(three_similarity, label='3 到均值 3')
plt.plot(three_to_seven_similarity, label='3 到均值 7')
plt.legend()
plt.title('3: 与均值 3 和 7 的相似度')plt.subplot(1, 2, 2)
plt.plot(seven_similarity, label='7 到均值 7')
plt.plot(seven_to_three_similarity, label='7 到均值 3')
plt.legend()
plt.title('7: 与均值 7 和 3 的相似度')
plt.tight_layout()

在这里插入图片描述

相似度计算
((tensor(Image.open(o)).float() - seven_avg)**2).mean() 计算图片 o 与平均 7 的均方误差(MSE),MSE 越低相似度越高。MSE 是衡量预测与实际差异的常用指标。

图示

  • y 轴为 MSE。
  • x 轴为图片索引。
  • 可以看到每个 3 与均值 3 的 MSE 通常低于与均值 7。

在完整 MNIST 上训练神经网络

现在我们用完整的 MNIST 数据集训练神经网络。

# 下载完整 MNIST 数据集
path = untar_data(URLs.MNIST)
path

在这里插入图片描述

# 创建 DataBlock
mnist = DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(train_name='training', valid_name='testing'),get_y=parent_label,batch_tfms=Normalize()
)

DataBlock

fastai 的 DataBlock 是数据处理的高级 API。它定义了数据获取、标签、变换、划分和输入输出类型。

  1. blocks=(ImageBlock(cls=PILImageBW), CategoryBlock):输入为黑白图片,输出为类别。
  2. get_items=get_image_files:获取所有图片文件。
  3. splitter=GrandparentSplitter(train_name='training', valid_name='testing'):按上级文件夹划分训练/验证集。
  4. get_y=parent_label:标签为父文件夹名。
  5. batch_tfms=Normalize():归一化图片。
# 创建 DataLoaders
dls = mnist.dataloaders(path, bs=64)
# 显示一批图片
dls.show_batch(max_n=9, figsize=(4,4))

在这里插入图片描述

现在我们创建并训练一个卷积神经网络(CNN)用于数字分类。

# 自定义 MNIST CNN 模型
class MnistCNN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(kernel_size=2)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2)self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.relu3 = nn.ReLU()self.pool3 = nn.MaxPool2d(kernel_size=2)self.fc1 = nn.Linear(64 * 3 * 3, 128)self.relu4 = nn.ReLU()self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.pool1(self.relu1(self.conv1(x)))x = self.pool2(self.relu2(self.conv2(x)))x = self.pool3(self.relu3(self.conv3(x)))x = x.view(x.size(0), -1)x = self.relu4(self.fc1(x))x = self.fc2(x)return xmodel = MnistCNN()
learn = Learner(dls, model, loss_func=nn.CrossEntropyLoss(), metrics=accuracy)
print("模型结构:")
print(learn.model)
模型结构:
MnistCNN((conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(relu1): ReLU()(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(relu2): ReLU()(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(relu3): ReLU()(pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(fc1): Linear(in_features=576, out_features=128, bias=True)(relu4): ReLU()(fc2): Linear(in_features=128, out_features=10, bias=True)
)
# 训练模型 1 轮
dls = mnist.dataloaders(path, bs=64)
learn = Learner(dls, model, loss_func=nn.CrossEntropyLoss(), metrics=accuracy)
learn.fine_tune(1)

模型评估

让我们在验证集上评估模型表现:

# 获取预测
preds, targets = learn.get_preds()
pred_classes = preds.argmax(dim=1)# 混淆矩阵
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as snscm = confusion_matrix(targets, pred_classes)
plt.figure(figsize=(5, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('预测')
plt.ylabel('实际')
plt.title('混淆矩阵')
plt.show()

在这里插入图片描述

预测

用训练好的模型对测试图片进行预测:

# 获取测试图片
test_files = get_image_files(path/'testing')
random_test_files = random.sample(test_files, 10)
test_dl = learn.dls.test_dl(random_test_files)
# 预测
preds, _ = learn.get_preds(dl=test_dl)
pred_classes = preds.argmax(dim=1)
# 显示图片和预测
fig, axes = plt.subplots(2, 5, figsize=(6, 3))
axes = axes.flatten()
for i, (img_file, pred) in enumerate(zip(random_test_files, pred_classes)):img = PILImage.create(img_file)axes[i].imshow(img, cmap='gray')axes[i].set_title(f"预测: {pred.item()}")axes[i].axis('off')
plt.tight_layout()
plt.show()

在这里插入图片描述

与模板匹配法对比

前面我们用均方误差(MSE)模板匹配区分 3 和 7。现在对比神经网络的表现:

# 获取所有测试集中的 3 和 7
test_3_files = get_image_files(path/'testing'/'3')
test_7_files = get_image_files(path/'testing'/'7')# 创建只包含 3 和 7 的测试集
test_files_3_7 = test_3_files[:50] + test_7_files[:50]
test_dl_3_7 = learn.dls.test_dl(test_files_3_7)# 预测
preds, _ = learn.get_preds(dl=test_dl_3_7)
pred_classes = preds.argmax(dim=1)# 计算 3 和 7 的准确率
true_labels = torch.tensor([3] * 50 + [7] * 50)
correct = (pred_classes == true_labels).float().mean()
print(f"神经网络在 3 和 7 上的准确率: {correct.item():.4f}")
神经网络在 3 和 7 上的准确率: 0.9900

可视化特征图

让我们通过查看第一层卷积的激活来可视化我们的 CNN 学到的特征:

# 获取一批图片
x, y = dls.one_batch()# 获取自定义模型的第一层卷积
conv1 = learn.model.conv1# 应用第一层卷积获取激活
with torch.no_grad():activations = conv1(x)# 可视化第一张图片的激活
# 我们的自定义模型第一层有 16 个滤波器
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
axes = axes.flatten()# 显示原图
axes[0].imshow(x[0][0].cpu(), cmap='gray')
axes[0].set_title(f"原图: {y[0].item()}")
axes[0].axis('off')# 显示前 15 个滤波器的激活图
for i in range(1, 16):axes[i].imshow(activations[0, i-1].detach().cpu(), cmap='viridis')axes[i].set_title(f"滤波器 {i}")axes[i].axis('off')plt.tight_layout()
plt.show()# 同时可视化滤波器权重
weights = conv1.weight.data.cpu()
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
axes = axes.flatten()# 再次显示原图
axes[0].imshow(x[0][0].cpu(), cmap='gray')
axes[0].set_title(f"原图: {y[0].item()}")
axes[0].axis('off')# 显示前 15 个滤波器的权重
for i in range(1, 16):# 每个滤波器只有一个输入通道(灰度)axes[i].imshow(weights[i-1, 0], cmap='viridis')axes[i].set_title(f"滤波器 {i} 权重")axes[i].axis('off')plt.tight_layout()
plt.show()

在这里插入图片描述

在这里插入图片描述

结论

在本教程中,我们深入探索了 MNIST 数据集,并训练了一个卷积神经网络来对手写数字进行分类。我们看到了模型的表现,并可视化了其部分内部表示。

主要收获:

  1. 神经网络可以在数字分类任务上取得很高的准确率。
  2. CNN 的第一层会学习到边缘和纹理等简单特征。
  3. 我们之前的模板匹配方法虽然更简单,但准确率远不如完整的神经网络。
  4. fastai 让构建、训练和解释深度学习模型变得非常容易。

http://www.hkcw.cn/article/LAZWREtfwi.shtml

相关文章

5. 算法与分析 (2)

本节主要介绍算法时间复杂度的具体求法和空间复杂度 本文部分ppt、视频截图来自:[青岛大学-王卓老师的个人空间-王卓老师个人主页-哔哩哔哩视频] 1. 算法分析 1.1 分析算法时间复杂度的基本方法 定理1.1 即忽略所有低次幂项和最高次幂系数,体现出增长…

RCU初步分析

RCU初步分析 背景知识RCU介绍名词定义RCU的基本执行过程基本过程基本思想示意图基本流程示例代码并发执行示意图 RCU特性简易RCU实现基于spinlock的实现基于计数器的实现基于线程变量的实现 Linux内核中经典RCU实现介绍 背景知识 随着硬件晶体管的尺寸越来越小,CPU的频率上限基…

多名中国公民被印度拘捕 中使馆提醒 边境风险需警惕

中国驻尼泊尔使馆近期发布消息,提醒旅尼中国公民避免前往尼印边境地区。尽管多次发出警告,仍有部分中国公民未听从劝告,执意前往该区域,导致多起被捕事件。尼泊尔和印度之间的边界开放,两国公民可凭身份证件自由往来,但外国公民不能免签证经尼泊尔进入印度。中国公民在尼…

解锁AI超级能力:30+款MCP服务器全景指南

MCP服务器是当前AI领域的热门话题,几乎每个人都渴望参与其中。简单来说,MCP(模型上下文协议,Model Context Protocol)服务器是一种REST API服务器,充当大型语言模型(LLM)与各种外部系…

【cpp-httplib】 安装与使用

cpp-httplib 1. 介绍2. 安装3. 类与接口3.1 httplib请求3.2 httplib响应3.3 httplib服务端3.4 httplib客户端 4. 使用4.1 服务端4.2 客户端 1. 介绍 C HTTP 库(cpp-httplib)是一个轻量级的 C HTTP 客户端/服务器库,它提供了简单的 API 来创建…

HPE推出全新分布式服务交换机及有线无线产品组合,全面赋能AI与高性能计算需求

HPE Aruba Networking将分布式服务交换机性能全面升级,实现能力翻倍 休斯顿-2025年5月29日-慧与科技(NYSE:HPE)日前宣布全面扩展HPE Aruba Networking有线及无线网络产品组合,并重磅推出全新HPE Aruba Networking CX 10K分布式服务交换机。该系列交换机搭载AMD Pensando可编程…

烟草工业数字化转型:科技领航,重塑传统产业新生态

在科技浪潮席卷各行业的当下,烟草工业这一传统产业也迎来了深刻变革。《烟草工业数字化转型:科技领航,重塑传统产业新生态》这一主题,精准揭示了数字化技术如何在具有独特生产工艺与严格监管要求的烟草工业中,发挥关键…

单依纯《歌手》被吐槽像吃了跳跳糖 转型争议不断

2025年开春,单依纯这个名字在娱乐圈频繁出现。这个被称为“00后王菲”的女孩,在《歌手2025》的舞台上表现亮眼,但随之而来的争议也越来越多。有人称赞她的唱功,有人批评她“卖肉博眼球”。这些争论背后反映出两个问题:女艺人应该怎样生活,观众到底想看什么。回顾单依纯的…

断眉袭榜单依纯成功 青春DNA狂飙舞台

歌手2025第三期是袭榜赛,断眉作为袭榜者开场演唱了大热单曲《See You Again》,前奏一响即勾起观众的集体回忆,被评价为“青春DNA狂飙”的舞台。尽管部分观众认为其高音表现稍显吃力,但整体感染力仍获认可。接下来是歌手2025第三期出场顺序及淘汰名单:GAI周延第一个出场,演…

二维平面点集相似问题思考及优化

欢迎关注更多精彩 关注我,学习常用算法与数据结构,一题多解,降维打击。 问题描述 如果两个点集可以通过平移,X轴对称,Y轴对称,中心对称得到相同的点集,则移两个点集相似。 给定多个点集&…

AI炼丹日志-23 - MCP 自动操作 自动进行联网检索 扩展MCP能力

点一下关注吧!!!非常感谢!!持续更新!!! Java篇: MyBatis 更新完毕目前开始更新 Spring,一起深入浅出! 大数据篇 300: Hadoop&…

WebFuture:设置不自动删除操作日志

问题描述: 客户要求保留系统操作日期为1年 或者不删除 问题处理: 在平台安全配置中 将自动清理后台操作日志功能 选择为否,或者设置自动清理的时间为365天

国产高安全芯片在供应链自主可控中的综合优势与案例分析

摘要:本文深入探讨了国产高安全芯片在实现供应链自主可控中的关键作用,通过分析国科安芯的 AS32A601、ASM1042、ASP3605 和 ASP4644 芯片的技术特性,结合其在工业控制、汽车电子、航天航空和电力系统等领域的应用场景,系统阐述了国…

Sigma-Aldrich3D细胞培养支架有哪些类型?

体内生长的哺乳动物细胞处于复杂的三维(3D)环境中。围绕在细胞周围的细胞外基质(ECM)的形状和化学组成能够决定许多生理反应。传统的细胞培养技术和实验方案在通常由玻璃或聚苯乙烯制成的二维(2D)表面上进行…

【开发心得】AstrBot对接飞书失败的问题探究

飞书与AstrBot的集成使用中,偶尔出现连接不稳定的现象。尽管不影响核心功能,但为深入探究技术细节并推动后续优化,需系统性记录该问题。先从底层通信机制入手,分析连接建立的逻辑与数据交互流程。基于实际现象,明确问题…

低功耗可编程RTU 在供水管网监控中的应用

1.1 智慧水务之管网 供水管网监控系统适用于供水企业实施供水管网的远程监测,工作人员在调度中心远程监测供水管网的压力及流量情况,可以对远程现场的运行设备进行监控,以实现管道压力、水流量的数据传送及阀门开关的自动管制&#xff…

RK3568 OH5.1 编译运行程序hellworld

编写helloworld 代码根目录创建sample子系统文件夹在子系统目录下创建hello部件文件夹hello文件夹中创建hello源码目录及源码 sample/hello/src/helloworld.c&#xff1a; #include <stdio.h> #include "helloworld.h"void hello_oh(void);int main(int arg…

LangChain-结合魔塔社区modelscope的embeddings实现搜索

首先要安装modelscope pip install modelscope 安装完成后测试 from langchain_community.embeddings import ModelScopeEmbeddingsembeddings ModelScopeEmbeddings(model_id"iic/nlp_gte_sentence-embedding_chinese-base")text "这是一个测试句子"…

千库/六图素材下载工具

—————【下 载 地 址】——————— 【​本章下载一】&#xff1a;https://pan.xunlei.com/s/VORW9TbxC9Lmz8gCynFrgdBzA1?pwdxiut# 【​本章下载二】&#xff1a;https://pan.quark.cn/s/829e2a4085d3 【百款黑科技】&#xff1a;https://ucnygalh6wle.feishu.cn/wiki/…

老板发百万让员工带薪收麦子 暖心福利获赞

5月23日,河南长垣的河南省矿山起重机有限公司内举行了一场特别的“三夏”生产动员暨表彰大会。公司董事长崔培军在会上宣布了一项暖心措施:他现场拿出360万元现金,每位员工都收到了700元现金、一袋小米和4箱啤酒作为“三夏”福利。崔培军表示,正值“三夏”大忙季节,考虑到…