【深度学习】实验四 卷积神经网络CNN

article/2025/6/25 17:45:41

 实验四  卷积神经网络CNN

一、实验学时: 2学时

二、实验目的

  1. 掌握卷积神经网络CNN的基本结构;
  2. 掌握数据预处理、模型构建、训练与调参;
  3. 探索CNN在MNIST数据集中的性能表现;

三、实验内容

实现深度神经网络CNN。

四、主要实验步骤及结果

1.搭建一个CNN网络,使用MNIST手写数字数据集进行训练与测试,并体现模型最终结果,CNN网络的具体框架可参考下图,也可自己设计:

图4-1 CNN架构图

(1)该图表示输入层为28*28*1的尺寸,符合MNIST数据集的标准尺寸。

(2)第一个卷积层,使用5*5卷积核,32个滤波器,填充(Padding)为2。输出尺寸为28*28*32。

(3)第一个池化层,使用2*2池化窗口,步长(stride)为2。输出尺寸为14*14*32。

(4)第二个卷积层,使用5*5卷积核,64个滤波器,填充(Padding)为2。输出尺寸为14*14*64。

(5)第二个池化层,使用2*2池化窗口,步长(stride)为2。输出尺寸为7*7*64。

(6)全连接层包含1024个神经元,输出尺寸为1*1*1024。

(7)Dropout层用于防止过拟合。

(8)输出层包含10个神经元,对应手写数字的0-9。输出尺寸为1*1*10。

模型实现:

以该架构图搭建CNN网络,使用MNIST手写数字数据集进行训练与测试,训练和测试结果如图4-2所示:

图4-2 CNN测试结果

2.尝试使用不同的数据增强方法、优化器、损失函数、学习率、batch size和迭代次数来进行训练,记录训练过程,评估模型性能,保存最佳模型。

编号

batch size

训练轮次

学习率

数据增强方法

优化器

实验结果

1

32

2

1e-4

Adam

98.62%

2

64

2

1e-4

Adam

98.56%

3

64

4

1e-4

Adam

99.08%

4

64

4

3e-4

Adam

99.08%

5

64

4

3e-4

旋转+平移

Adam

98.90%

5

64

4

3e-4

Adam(L2正则化)

99.23%

6

64

4

1e-4

SGD+momentum

97.30%

其中数据增强方法采用随机旋转和平移吗,原始代码中包含ToTensor()和Normalize(),给原始代码添加随机旋转10度和随机平移10%,代码如下:

# 数据加载(归一化)
transform = torchvision.transforms.Compose([torchvision.transforms.RandomRotation(10),  # 随机旋转10度torchvision.transforms.RandomAffine(0, translate=(0.1, 0.1)),  # 随机平移10%torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

优化器选择方面使用SGD+momentum(0.9)替代原Adam优化器,

# 使用SGD+momentum
optimizer = torch.optim.SGD(model.parameters(), lr=LEARN_RATE, momentum=0.9)

根据训练过程记录的数据,最佳模型尊却绿为99.23%,最佳模型代码如下:

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoaderBATCH_SIZE = 64
EPOCHS = 4
LEARN_RATE = 3e-4
DROPOUT_RATE = 0.5device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# 数据加载(归一化)
transform = torchvision.transforms.Compose([# torchvision.transforms.RandomRotation(10),  # 随机旋转10度# torchvision.transforms.RandomAffine(0, translate=(0.1, 0.1)),  # 随机平移10%torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))
])train_data = torchvision.datasets.MNIST(root='./mnist',train=True,download=True,transform=transform
)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)test_data = torchvision.datasets.MNIST(root='./mnist',train=False,transform=transform
)
test_loader = DataLoader(test_data, batch_size=1000, shuffle=False)class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_layers = nn.Sequential(# 第一层卷积:5x5 卷积核,32 个过滤器,padding=2nn.Conv2d(1, 32, kernel_size=5, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),  # 池化后 14x14x32# 第二层卷积:5x5 卷积核,64 个过滤器,padding=2nn.Conv2d(32, 64, kernel_size=5, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2)  # 池化后 7x7x64)self.fc_layers = nn.Sequential(nn.Linear(64 * 7 * 7, 1024),  # 全连接层:7x7x64 → 1024nn.ReLU(),nn.Dropout(DROPOUT_RATE),  # Dropout层nn.Linear(1024, 10)  # 输出层:1024 → 10)self._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')nn.init.constant_(m.bias, 0)def forward(self, x):x = self.conv_layers(x)x = x.view(x.size(0), -1)  # 展平操作x = self.fc_layers(x)return xmodel = CNN().to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE, weight_decay=1e-5)
# optimizer = torch.optim.SGD(model.parameters(), lr=LEARN_RATE, momentum=0.9)  # 使用SGD+momentum
# 训练循环
for epoch in range(EPOCHS):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = loss_fn(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Epoch {epoch + 1} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.4f}')# 测试
model.eval()
correct = 0
with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)pred = output.argmax(dim=1)correct += pred.eq(target).sum().item()print(f'Test Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.2f}%)')

3.使用画图工具将自己的学号逐个写出,使用保存的最佳模型对每个数字进行推理,比较模型对每个数字的准确率预测,也可以尝试实现一个实时识别手写数字的demo。
(1)使用画图工具将自己的学号逐个写出,进行反色处理,并将图片命名为“x_001.png”格式。

图4-3手写数字

(2)在训练代码(CNN.py)中添加模型保存代码。

torch.save(model.state_dict(), 'mnist_cnn.pth')

(3)编写推理代码读取img文件夹中的手写图片并预测,预测代码如下所示:

import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import os# 定义模型结构(需与训练代码一致)
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, kernel_size=5, padding=2),nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(32, 64, kernel_size=5, padding=2),nn.ReLU(),nn.MaxPool2d(2, 2))self.fc_layers = nn.Sequential(nn.Linear(64 * 7 * 7, 1024),nn.ReLU(),nn.Dropout(0.5),nn.Linear(1024, 10))def forward(self, x):x = self.conv_layers(x)x = x.view(x.size(0), -1)x = self.fc_layers(x)return x# 加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNN().to(device)
model.load_state_dict(torch.load('mnist_cnn.pth', map_location=device))
model.eval()# 定义预处理(与训练一致)
transform = transforms.Compose([transforms.Resize((28, 28)),  # 确保输入为28x28transforms.Grayscale(num_output_channels=1),  # 转换为单通道transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])# 遍历img文件夹中的图片并推理
img_dir = 'img'
digit_stats = {str(i): {'correct': 0, 'total': 0} for i in range(10)}for filename in os.listdir(img_dir):if filename.lower().endswith(('.png', '.jpg', '.jpeg')):# 从文件名中提取真实标签(假设文件名为 "label_xxx.png")try:true_label = filename.split('_')[0]  # 例如文件名 "3_001.png" → 标签为3true_label = int(true_label)if true_label < 0 or true_label > 9:continueexcept:print(f"跳过文件 {filename}(文件名格式错误)")continue# 加载并预处理图像img_path = os.path.join(img_dir, filename)image = Image.open(img_path)image = transform(image).unsqueeze(0).to(device)  # 添加batch维度# 推理with torch.no_grad():output = model(image)pred = output.argmax(dim=1).item()# 统计结果digit_stats[str(true_label)]['total'] += 1if pred == true_label:digit_stats[str(true_label)]['correct'] += 1print(f"图片 {filename} 真实标签: {true_label}, 预测: {pred} → {'正确' if pred == true_label else '错误'}")# 计算每个数字的准确率
accuracies = {}
for digit in digit_stats:if digit_stats[digit]['total'] > 0:acc = digit_stats[digit]['correct'] / digit_stats[digit]['total']accuracies[digit] = accprint(f"数字 {digit} 的准确率: {acc:.2%}")

预测结果如图4-4所示:

图4-4预测结果

预测结果显示“1”和“4”预测结果错误,其他均正确。

五、实验小结(包括问题和解决办法、心得体会、意见与建议等)

1.问题和解决办法:

问题1:RuntimeError: Dataset not found. You can use download=True to download it。

解决方法:添加下载训练集的参数download=True。

问题2:使用SGD+momentum优化器后,准确率反而下降了。

解决方法:因为SGD对学习率比较敏感,学习率没有适配,使用StepLR梯度衰减,另外也可以增加训练轮次。

问题3:预测结果全部错误。

解决方法:图片要像素28*28,且黑色背景,白色笔迹,对Windows画图的图片反色处理即可。

2.心得体会:通过本次CNN手写数字识别实验的完整实践,我深刻体会到深度学习模型性能的提升是一个系统工程,需要从数据、模型、训练策略到结果分析的全流程精细化把控,尝试使用不同的数据增强方法、优化器、损失函数、学习率、batch size和迭代次数来进行训练,迭代出最佳模型,再手写数字进行测试。通过以上的学习和实践,我对神经网络的原理和应用有了更深入的理解。神经网络的发展给人工智能带来了巨大的影响,它在图像识别、自然语言处理等领域发挥着重要的作用。我相信,随着技术的进步,神经网络将会有更广泛的应用。


http://www.hkcw.cn/article/yUIEjnVqWw.shtml

相关文章

通俗理解“高内聚,低耦合”

在软件开发中&#xff0c;良好的架构设计能够大幅降低系统的复杂度&#xff0c;提高代码的可维护性。而“高内聚&#xff0c;低耦合”正是指导我们如何合理组织代码的核心原则之一。本文将从通俗的角度解释这一概念&#xff0c;并结合实际案例说明其重要性。 一&#xff0c;高…

Unity + HybirdCLR热更新 入门篇

官方文档 HybridCLR | HybridCLRhttps://hybridclr.doc.code-philosophy.com/docs/intro 什么是HybirdCLR? HybridCLR&#xff08;原名 huatuo&#xff09;是一个专为 Unity 项目设计的C#热更新解决方案&#xff0c;它通过扩展 IL2CPP 运行时&#xff0c;使其支持动态加载和…

Python基础:人生重开模拟器(小游戏)

引言 手把手带你速通Python 实现人生重开模拟器&#xff08;小游戏&#xff09;的意义&#xff1a;增强对条件语句&#xff0c;循环语句的运用&#xff0c;增加写代码的乐趣。 一、 游戏介绍 网页版的人生重开模拟器&#xff1a; 人生重开模拟器-重来-重启 (aizhancloud.cn) …

【Elasticsearch】ILM(Index Lifecycle Management)策略详解

ILM&#xff08;Index Lifecycle Management&#xff09;策略详解 1.什么是 ILM 策略&#xff1f;2.ILM 解决的核心业务问题3.ILM 生命周期阶段3.1 Hot&#xff08;热阶段&#xff09;3.2 Warm&#xff08;温阶段&#xff09;3.3 Cold&#xff08;冷阶段&#xff09;3.4 Delete…

【存储基础】数据存储基础知识

文章目录 1. 概述&#xff1a;数据存储基础知识2. 存储物理介质3. 数据存储的分类3.1按存储架构分类DAS 直连存储SAN 存储区域网络NAS 网络附加存储分布式存储四种架构之间的核心区别 3.2 按数据模型分类块存储文件存储对象存储 4. 数据存储的关键技术方案和核心机制冗余与容错…

【Part 3 Unity VR眼镜端播放器开发与优化】第二节|VR眼镜端的开发适配与交互设计

文章目录 《VR 360全景视频开发》专栏Part 3&#xff5c;Unity VR眼镜端播放器开发与优化第一节&#xff5c;基于Unity的360全景视频播放实现方案第二节&#xff5c;VR眼镜端的开发适配与交互设计一、Unity XR开发环境与设备适配1.1 启用XR Plugin Management1.2 配置OpenXR与平…

小米YU7还有5款颜色即将发布 更多色彩敬请期待

6月1日,小米在5月22日的发布会上公布了YU7的四款颜色:钛金属色、宝石绿、熔岩橙和寒武岩灰。官方透露,除了这四款已发布的颜色外,还有五款新颜色即将推出,每一种都设计得非常经典。回顾之前的小米SU7,在刚推出时就提供了9种颜色选择,涵盖了跑车色系、时尚色系、豪华色系…

老人被甩客执法人员送其回家 温情护送获赠枇杷

日前,重庆交通执法总队轨道交通支队三大队在重庆西站巡查时发现一名老人误乘“黑车”。考虑到她年近九旬行动不便,执法人员开车将其安全护送回家。老人感激地拿出自己种的枇杷送给执法人员以示感谢。5月27日上午,执法人员在巡查过程中发现一辆渝A籍车辆正在下客,随即上前检…

2025最新 MacBook Pro苹果电脑M系列芯片安装zsh教程方法大全

2025最新 MacBook Pro苹果电脑M系列芯片安装zsh教程方法大全 本文面向对 macOS 环境和终端操作尚不熟悉的“小白”用户。我们将从最基础的概念讲起&#xff0c;结合实际操作步骤&#xff0c;帮助你在 2025 年最新 MacBook Pro&#xff08;搭载苹果 M 系列芯片&#xff09;的环境…

女子多次上门骚扰邻居 持刀砍门引发恐慌

近日,辽宁大连有网友发布视频称,疑似患有精神疾病的邻居多次持刀上门砍其家门。当事人刘女士向媒体透露,楼下60多岁的邻居自去年10月搬家入住后,就反复上门找事,声称刘女士一家是脑控组织,想要入侵她的大脑。刘女士表示自己是外地人,去年刚搬进来,为方便孩子上学才购买…

遥控器竟牵出10亿元大案 数据篡改揭秘

涉及河南、四川、浙江等16省市,涉案交易金额达10.3亿元的合同诈骗案成功告破。内蒙古自治区鄂尔多斯市杭锦旗警方通过流量计调节流量和篡改数据的方式,将27名犯罪嫌疑人全部抓获。2024年3月,一封匿名举报信揭露了某石油工程有限公司通过更改流量计数据窃取国家能源的行为。杭…

弗朗西斯卡说樊振东加盟像做梦一样 莫大荣耀与期待

6月2日,据外媒报道,队长弗朗西斯卡在接受采访时谈到夺冠以及樊振东加盟表示:“感觉像做梦一样,樊振东加盟是我们莫大的荣耀。”决赛周末俱乐部宣布了轰动消息,奥运冠军兼前世界冠军樊振东将代表球队出战全部三项赛事。对于下赛季谁能击败他们的问题,弗朗西斯卡表示这取决…

董宇辉在陕西汉江赛龙舟夺冠 体验传统民俗乐趣

5月31日,陕西安康第25届汉江龙舟节开幕。当日,龙舟方阵展演,龙舟横渡汉江,抢鸭子、摸鲤鱼等传统环节亮相开幕式,节日氛围浓厚。今年有27支队伍600多名选手参加龙舟竞渡比赛。董宇辉现身安康龙舟文化园,与现场市民、游客热情互动,齐喊端午安康。他和团队成员在汉江边进行…

租客退房现垃圾场 下水道都堵了 屋内堆满垃圾几乎无处下脚

山东潍坊一名租客租住半年退房时,房东发现屋内堆满生活垃圾。发布视频的当地民宿工作人员称租客是一名年轻女生,房东已报警。辖区派出所表示正在处理此事。5月31日,抖音实名认证的潍坊潍城区怀夏民宿发布了一段54秒的现场视频,显示屋内一片狼藉,堆满了生活垃圾,几乎无处下…

上迪情侣和一家三口扭打 拍照冲突引发热议

5月31日,有网友发布视频称,在上海迪士尼乐园内一对情侣与一家三口发生冲突并扭打在一起,此事引发广泛关注。视频中可以看到双方在现场互相推搡,周围游客纷纷上前劝阻。据权威人士透露,事件发生在5月31日,地点并非排队区域,而是游客自由拍照的点位。情侣和一家三口因拍照…

北京大兴警方严查炸街摩托 夜查行动见效

大兴公安分局针对南海子公园南环路牡丹园南广场区域夜间改装摩托车聚集扰民问题,开展了专项整治行动。这些非法改装的摩托车不仅存在安全隐患,其巨大的噪音也严重干扰了周边居民的生活。5月30日晚,大兴交通支队旧宫中队在“炸街车”夜间活动高发时段,科学部署警力,采取定点…

AI矢量软件|Illustrator 2025网盘下载与安装教程指南

说起AI&#xff0c;很多人第一印象可能是AI人工智能&#xff0c;是与Python相关。实际上&#xff0c;本文要讲的AI&#xff0c;是Adobe Illustrator的缩写&#xff0c;它是一款基于矢量的图形制作软件&#xff0c;主要应用于插画、包装、印刷出版、书籍排版、动画和网页制作等领…

【Spring】RAG 知识库基础

1. RAG 基础概念 1.1 什么是 RAG&#xff1f; RAG&#xff08;Retrieval-Augmented Generation&#xff0c;检索增强生成&#xff09;是一种将检索技术与人工智能生成技术相结合的混合架构&#xff0c;用于解决大模型时效性限制与幻觉问题 你可以这样理解&#xff1a;RAG 技…

NLP学习路线图(十七):主题模型(LDA)

在浩瀚的文本海洋中航行&#xff0c;人类大脑天然具备发现主题的能力——翻阅几份报纸&#xff0c;我们迅速辨别出"政治"、"体育"、"科技"等板块&#xff1b;浏览社交媒体&#xff0c;我们下意识区分出美食分享、旅行见闻或科技测评。但机器如何…

信息安全管理与评估山东卷无线部分答案

配置解析 配置解析 配置解析 radio 1工作在2.4g频段下 radio 2工作在5.0g频段下 配置解析 station-isolation配置关联在同一个VAP下的用户无法互通,但是可以和其他VAP下关联的用户互通,这里的隔离功能类似于交换的端口隔离功能。 arp-suppression开启该功能后则自动使能ARP…