【Day40】

article/2025/7/19 17:41:42

DAY 40 训练和测试的规范写法

知识点回顾

  1. 彩色灰度图片测试训练规范写法封装在函数中
  2. 操作第一个维度batchsize全部展
  3. dropout操作训练阶段随机丢弃神经元测试阶段eval模式关闭dropout

作业仔细学习测试和训练代码逻辑这是基础这个代码框架后续会一直沿用后续重点慢慢就是转向模型定义阶段

"""
DAY 40 训练和测试的规范写法本节介绍深度学习中训练和测试的规范写法,包括:
1. 训练和测试函数的封装
2. 展平操作
3. dropout的使用
4. 训练过程可视化
"""import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader# 设置中文字体(解决中文显示问题)
plt.rcParams['font.sans-serif'] = ['SimHei']  # Windows系统常用黑体字体
plt.rcParams['axes.unicode_minus'] = False    # 正常显示负号# 设置随机种子,确保结果可复现
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#====================== 1. 数据加载 ======================def load_data(batch_size=64, is_train=True):"""加载CIFAR-10数据集Args:batch_size: 批次大小is_train: 是否为训练集Returns:dataloader: 数据加载器"""transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])dataset = torchvision.datasets.CIFAR10(root='./data', train=is_train,download=True,transform=transform)dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=is_train,  # 训练集打乱,测试集不打乱num_workers=2)return dataloader#====================== 2. 模型定义 ======================class SimpleNet(nn.Module):def __init__(self, dropout_rate=0.5):super(SimpleNet, self).__init__()# 修改第一层卷积的输入通道为3(彩色图像)self.conv1 = nn.Conv2d(3, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout1 = nn.Dropout2d(dropout_rate)  # 2D dropout用于卷积层self.dropout2 = nn.Dropout(dropout_rate)    # 1D dropout用于全连接层# 展平后的特征图大小计算:# 原始图像: 32x32# conv1: (32-3+1)x(32-3+1) = 30x30# maxpool: 15x15# conv2: (15-3+1)x(15-3+1) = 13x13# maxpool: 6x6# 因此全连接层输入大小为: 64 * 6 * 6self.fc1 = nn.Linear(64 * 6 * 6, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = F.max_pool2d(x, 2)x = self.conv2(x)x = F.relu(x)x = F.max_pool2d(x, 2)x = self.dropout1(x)  # 训练时随机丢弃,测试时自动关闭# 展平操作:保留batch_size维度,其余维度展平x = torch.flatten(x, 1)  # 等价于 x.view(x.size(0), -1)x = self.fc1(x)x = F.relu(x)x = self.dropout2(x)x = self.fc2(x)return F.log_softmax(x, dim=1)#====================== 3. 训练函数 ======================def train(model, train_loader, optimizer, epoch, history):"""训练一个epochArgs:model: 模型train_loader: 训练数据加载器optimizer: 优化器epoch: 当前epoch数history: 记录训练历史的字典Returns:epoch_loss: 当前epoch的平均损失epoch_acc: 当前epoch的准确率"""model.train()  # 设置为训练模式,启用dropouttrain_loss = 0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()  # 清空梯度output = model(data)   # 前向传播loss = F.nll_loss(output, target)  # 计算损失loss.backward()        # 反向传播optimizer.step()       # 更新参数train_loss += loss.item()pred = output.max(1, keepdim=True)[1]  # 获取最大概率的索引correct += pred.eq(target.view_as(pred)).sum().item()total += target.size(0)if batch_idx % 100 == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\t'f'Loss: {loss.item():.6f}\t'f'Accuracy: {100. * correct / total:.2f}%')# 计算epoch的平均损失和准确率epoch_loss = train_loss / len(train_loader)epoch_acc = 100. * correct / total# 记录训练历史history['train_loss'].append(epoch_loss)history['train_acc'].append(epoch_acc)return epoch_loss, epoch_acc#====================== 4. 测试函数 ======================def test(model, test_loader, history):"""在测试集上评估模型Args:model: 模型test_loader: 测试数据加载器history: 记录训练历史的字典Returns:test_loss: 测试集上的平均损失accuracy: 测试集上的准确率"""model.eval()  # 设置为评估模式,关闭dropouttest_loss = 0correct = 0with torch.no_grad():  # 测试时不需要计算梯度for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += F.nll_loss(output, target, reduction='sum').item()pred = output.max(1, keepdim=True)[1]correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)# 记录测试历史history['test_loss'].append(test_loss)history['test_acc'].append(accuracy)print(f'\nTest set: Average loss: {test_loss:.4f}, 'f'Accuracy: {correct}/{len(test_loader.dataset)} 'f'({accuracy:.2f}%)\n')return test_loss, accuracy#====================== 5. 可视化函数 ======================def plot_training_history(history):"""绘制训练历史曲线Args:history: 包含训练和测试历史数据的字典"""epochs = range(1, len(history['train_loss']) + 1)# 创建一个包含两个子图的图表plt.figure(figsize=(12, 4))# 绘制损失曲线plt.subplot(1, 2, 1)plt.plot(epochs, history['train_loss'], 'b-', label='训练损失')plt.plot(epochs, history['test_loss'], 'r-', label='测试损失')plt.title('训练和测试损失')plt.xlabel('Epoch')plt.ylabel('损失')plt.legend()plt.grid(True)# 绘制准确率曲线plt.subplot(1, 2, 2)plt.plot(epochs, history['train_acc'], 'b-', label='训练准确率')plt.plot(epochs, history['test_acc'], 'r-', label='测试准确率')plt.title('训练和测试准确率')plt.xlabel('Epoch')plt.ylabel('准确率 (%)')plt.legend()plt.grid(True)plt.tight_layout()plt.show()def visualize_predictions(model, test_loader, num_samples=5):"""可视化模型预测结果Args:model: 训练好的模型test_loader: 测试数据加载器num_samples: 要显示的样本数量"""model.eval()# 获取一批数据dataiter = iter(test_loader)images, labels = next(dataiter)# 获取预测结果with torch.no_grad():outputs = model(images.to(device))_, predicted = torch.max(outputs, 1)# 显示图像和预测结果fig = plt.figure(figsize=(12, 3))for idx in range(num_samples):ax = fig.add_subplot(1, num_samples, idx + 1)img = images[idx] / 2 + 0.5  # 反标准化npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))ax.set_title(f'预测: {classes[predicted[idx]]}\n实际: {classes[labels[idx]]}',color=('green' if predicted[idx] == labels[idx] else 'red'))plt.axis('off')plt.tight_layout()plt.show()#====================== 6. 主函数 ======================# CIFAR-10数据集的类别
classes = ('飞机', '汽车', '鸟', '猫', '鹿', '狗', '青蛙', '马', '船', '卡车')def main():# 超参数设置batch_size = 64epochs = 9lr = 0.01dropout_rate = 0.5# 初始化训练历史记录history = {'train_loss': [],'train_acc': [],'test_loss': [],'test_acc': []}# 加载数据print("正在加载训练集...")train_loader = load_data(batch_size, is_train=True)print("正在加载测试集...")test_loader = load_data(batch_size, is_train=False)# 创建模型model = SimpleNet(dropout_rate=dropout_rate).to(device)optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)# 训练和测试print(f"开始训练,使用设备: {device}")for epoch in range(1, epochs + 1):train_loss, train_acc = train(model, train_loader, optimizer, epoch, history)test_loss, test_acc = test(model, test_loader, history)# 可视化训练过程print("训练完成,绘制训练历史...")plot_training_history(history)# 可视化预测结果print("可视化模型预测结果...")visualize_predictions(model, test_loader)if __name__ == '__main__':main()"""
重点说明:1. 训练和测试的区别:- 训练时:model.train(),启用dropout- 测试时:model.eval(),关闭dropout2. 展平操作:- torch.flatten(x, 1) 或 x.view(x.size(0), -1)- 保留第一维度(batch_size),其余维度展平3. dropout的使用:- 训练阶段:随机丢弃神经元,防止过拟合- 测试阶段:自动关闭dropout,使用完整网络4. 规范写法的优点:- 代码结构清晰,便于维护- 功能模块化,易于复用- 训练过程可控,便于调试- 适用于不同的数据集和模型
"""
Train Epoch: 1 [19200/50000 (38%)]      Loss: 1.878432  Accuracy: 24.59%
Train Epoch: 1 [25600/50000 (51%)]      Loss: 1.737842  Accuracy: 27.03%
Train Epoch: 1 [32000/50000 (64%)]      Loss: 1.608304  Accuracy: 29.29%
Train Epoch: 1 [38400/50000 (77%)]      Loss: 1.654722  Accuracy: 30.90%
Train Epoch: 1 [44800/50000 (90%)]      Loss: 1.781868  Accuracy: 32.24%Test set: Average loss: 1.4125, Accuracy: 4879/10000 (48.79%)Train Epoch: 2 [0/50000 (0%)]   Loss: 1.725113  Accuracy: 31.25%
Train Epoch: 2 [6400/50000 (13%)]       Loss: 1.371717  Accuracy: 43.70%
Train Epoch: 2 [12800/50000 (26%)]      Loss: 1.377221  Accuracy: 43.85%
Train Epoch: 2 [19200/50000 (38%)]      Loss: 1.497515  Accuracy: 44.32%
Train Epoch: 2 [25600/50000 (51%)]      Loss: 1.509949  Accuracy: 44.92%
Train Epoch: 2 [32000/50000 (64%)]      Loss: 1.322219  Accuracy: 45.19%
Train Epoch: 2 [38400/50000 (77%)]      Loss: 1.451519  Accuracy: 45.65%
Train Epoch: 2 [44800/50000 (90%)]      Loss: 1.284523  Accuracy: 46.09%Test set: Average loss: 1.2420, Accuracy: 5596/10000 (55.96%)Train Epoch: 3 [0/50000 (0%)]   Loss: 1.457208  Accuracy: 57.81%
Train Epoch: 3 [6400/50000 (13%)]       Loss: 1.411661  Accuracy: 49.80%
Train Epoch: 3 [12800/50000 (26%)]      Loss: 1.251750  Accuracy: 49.25%
Train Epoch: 3 [19200/50000 (38%)]      Loss: 1.485202  Accuracy: 49.98%
Train Epoch: 3 [25600/50000 (51%)]      Loss: 1.219448  Accuracy: 50.09%
Train Epoch: 3 [32000/50000 (64%)]      Loss: 1.319644  Accuracy: 50.40%
Train Epoch: 3 [38400/50000 (77%)]      Loss: 1.431417  Accuracy: 50.58%
Train Epoch: 3 [44800/50000 (90%)]      Loss: 1.321420  Accuracy: 51.04%Test set: Average loss: 1.1419, Accuracy: 6067/10000 (60.67%)Train Epoch: 4 [0/50000 (0%)]   Loss: 1.274258  Accuracy: 54.69%
Train Epoch: 4 [6400/50000 (13%)]       Loss: 1.455593  Accuracy: 53.57%
Train Epoch: 4 [12800/50000 (26%)]      Loss: 1.439796  Accuracy: 53.95%
Train Epoch: 4 [19200/50000 (38%)]      Loss: 1.333504  Accuracy: 54.18%
Train Epoch: 4 [25600/50000 (51%)]      Loss: 1.127613  Accuracy: 54.53%
Train Epoch: 4 [32000/50000 (64%)]      Loss: 1.197434  Accuracy: 54.76%
Train Epoch: 4 [38400/50000 (77%)]      Loss: 1.217459  Accuracy: 54.58%
Train Epoch: 4 [44800/50000 (90%)]      Loss: 1.249435  Accuracy: 54.67%Test set: Average loss: 1.0938, Accuracy: 6156/10000 (61.56%)Train Epoch: 5 [0/50000 (0%)]   Loss: 1.200900  Accuracy: 54.69%
Train Epoch: 5 [6400/50000 (13%)]       Loss: 1.200518  Accuracy: 55.96%
Train Epoch: 5 [12800/50000 (26%)]      Loss: 1.267728  Accuracy: 56.58%
Train Epoch: 5 [19200/50000 (38%)]      Loss: 1.501915  Accuracy: 56.76%
Train Epoch: 5 [25600/50000 (51%)]      Loss: 1.248580  Accuracy: 56.72%
Train Epoch: 5 [32000/50000 (64%)]      Loss: 1.385589  Accuracy: 56.64%
Train Epoch: 5 [38400/50000 (77%)]      Loss: 1.377769  Accuracy: 56.59%
Train Epoch: 5 [44800/50000 (90%)]      Loss: 1.355240  Accuracy: 56.62%Test set: Average loss: 1.0414, Accuracy: 6448/10000 (64.48%)Train Epoch: 6 [0/50000 (0%)]   Loss: 1.194540  Accuracy: 64.06%
Train Epoch: 6 [6400/50000 (13%)]       Loss: 1.255205  Accuracy: 59.00%
Train Epoch: 6 [12800/50000 (26%)]      Loss: 1.216109  Accuracy: 58.45%
Train Epoch: 6 [19200/50000 (38%)]      Loss: 0.916238  Accuracy: 58.74%
Train Epoch: 6 [25600/50000 (51%)]      Loss: 1.081454  Accuracy: 58.52%
Train Epoch: 6 [32000/50000 (64%)]      Loss: 1.170482  Accuracy: 58.42%
Train Epoch: 6 [38400/50000 (77%)]      Loss: 1.263351  Accuracy: 58.43%
Train Epoch: 6 [44800/50000 (90%)]      Loss: 1.197278  Accuracy: 58.45%Test set: Average loss: 0.9976, Accuracy: 6609/10000 (66.09%)Train Epoch: 7 [0/50000 (0%)]   Loss: 1.296109  Accuracy: 51.56%
Train Epoch: 7 [6400/50000 (13%)]       Loss: 1.194998  Accuracy: 59.25%
Train Epoch: 7 [12800/50000 (26%)]      Loss: 1.045425  Accuracy: 58.80%
Train Epoch: 7 [19200/50000 (38%)]      Loss: 1.096962  Accuracy: 59.35%
Train Epoch: 7 [25600/50000 (51%)]      Loss: 1.002581  Accuracy: 59.48%
Train Epoch: 7 [32000/50000 (64%)]      Loss: 1.101984  Accuracy: 59.45%
Train Epoch: 7 [38400/50000 (77%)]      Loss: 0.934384  Accuracy: 59.56%
Train Epoch: 7 [44800/50000 (90%)]      Loss: 1.025743  Accuracy: 59.56%Test set: Average loss: 0.9824, Accuracy: 6663/10000 (66.63%)Train Epoch: 8 [0/50000 (0%)]   Loss: 1.121836  Accuracy: 60.94%
Train Epoch: 8 [6400/50000 (13%)]       Loss: 1.057686  Accuracy: 60.47%
Train Epoch: 8 [12800/50000 (26%)]      Loss: 1.132846  Accuracy: 60.13%
Train Epoch: 8 [19200/50000 (38%)]      Loss: 1.094760  Accuracy: 59.88%
Train Epoch: 8 [25600/50000 (51%)]      Loss: 1.392307  Accuracy: 59.98%
Train Epoch: 8 [32000/50000 (64%)]      Loss: 0.905305  Accuracy: 60.01%
Train Epoch: 8 [38400/50000 (77%)]      Loss: 1.293327  Accuracy: 60.11%
Train Epoch: 8 [44800/50000 (90%)]      Loss: 1.154168  Accuracy: 60.13%Test set: Average loss: 0.9402, Accuracy: 6824/10000 (68.24%)Train Epoch: 9 [0/50000 (0%)]   Loss: 0.742247  Accuracy: 70.31%
Train Epoch: 9 [6400/50000 (13%)]       Loss: 0.880693  Accuracy: 60.89%
Train Epoch: 9 [12800/50000 (26%)]      Loss: 1.063176  Accuracy: 61.19%
Train Epoch: 9 [19200/50000 (38%)]      Loss: 1.462891  Accuracy: 61.12%
Train Epoch: 9 [25600/50000 (51%)]      Loss: 1.227893  Accuracy: 61.29%
Train Epoch: 9 [32000/50000 (64%)]      Loss: 0.829324  Accuracy: 61.12%
Train Epoch: 9 [38400/50000 (77%)]      Loss: 1.199507  Accuracy: 61.10%
Train Epoch: 9 [44800/50000 (90%)]      Loss: 1.242885  Accuracy: 61.04%Test set: Average loss: 0.9322, Accuracy: 6954/10000 (69.54%)

训练完成,绘制训练历史...

 

 可视化模型预测结果...

浙大疏锦行 


http://www.hkcw.cn/article/mKkIRcrhpv.shtml

相关文章

GEARS以及与基础模型结合

理解基因扰动的反应是众多生物医学应用的核心。然而,可能的多基因扰动组合数量呈指数级增长,严重限制了实验探究的范围。在此,图增强基因激活与抑制模拟器(GEARS),将深度学习与基因-基因关系知识图谱相结合…

【C++】入门基础知识(1.5w字详解)

本篇博客给大家带来的是一些C基础知识! 🐟🐟文章专栏:C 🚀🚀若有问题评论区下讨论,我会及时回答 ❤❤欢迎大家点赞、收藏、分享! 今日思想:微事不通,粗事不能…

[SWPUCTF 2023 秋季新生赛]Classical Cipher203分古典密码Base家族栅栏密码

下载附件解压得到txt文件 得到信息 U2FsdGVkX19aQNEomnRqmmLlI9qJkzr0pFMeMBF99ZDKTF3CojpkTzHxLcu/ZNCYeeAV3/NEoHhpP5QUCK5AcHJlZBMGdKDYwko5sAATQ 用在线网站进行解密 解密得到 TGhmYlMlXXNwX2BTb3NoQWcye1VweSRfcXEGdmBheDx0I1BkMXdfXG0ldzdbGBy 栅栏密码用在线网站进行…

Unity 中实现首尾无限循环的 ListView

之前已经实现过: Unity 中实现可复用的 ListView-CSDN博客文章浏览阅读5.6k次,点赞2次,收藏27次。源码已放入我的 github,地址:Unity-ListView前言实现一个列表组件,表现方面最核心的部分就是重写布局&…

【提升工作效率的小工具】截图软件Snipaste

1.F1截图,F3钉在桌面上 2.小技巧 纯文本复制后,F3钉在桌面上,鼠标右键,点击复制纯文本,可以直接再次复制 shift鼠标双击截图,可以缩略显示不占位置,重复操作就是展开截图。 隐藏/显示所有贴图…

极刻云搜2.0-强大的蓝奏搜索引擎以及智能网址导航

【🎉 重磅发布】极刻云搜 2.0 正式上线! 🚀 核心升级: ✨ 界面全面焕新: 视觉更美观,操作更流畅,体验显著提升。 🔍 搜索能力升级: 在原有站内搜索(覆盖全站…

无人机视角海上漂浮物检测与人员救援检测数据集VOC+YOLO格式2903张6类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):2903 标注数量(xml文件个数):2903 标注数量(txt文件个数):2903 …

cocosCreator 1.8 升级到 2.4

现在负责的一个运营中的商业项目,使用的是 cocosCreator1.8,之前没有做好设计,所以东西都是直接加载在内存中的,到了现在性能问题逐渐暴露出来,讨论之后想进行引擎升级,升级到cocosCreator 2.4。 官方的升…

【递归、搜索与回溯算法】综合练习(二)

📝前言说明: 本专栏主要记录本人递归,搜索与回溯算法的学习以及LeetCode刷题记录,按专题划分每题主要记录:(1)本人解法 本人屎山代码;(2)优质解法 优质代码…

65.AI流式回答后再次修改同一界面的消息不在同一对话中bug

问题背景 在实现AI对话应用的流式响应功能后,我发现一个关键问题:当用户对AI的回答进行修改或重新生成时,有时会导致新的回答不在原对话上下文中,而是创建了一个新的独立对话。这种bug会严重影响用户体验和对话的连贯性。 问题现…

YOLOv8目标检测实战-(TensorRT原生API搭建网络和使用Parser搭建网络)

文章目录 一、原理篇1)Trt基础知识2)Trt plugin3)int8量化算法和原理4)cuda编程5)onnx基础知识6)yolov8网络架构6.1 yolov5网络架构图6.2 yolov8s网络架构 二、TensorRT原生API搭建网络1)window…

【IC】ASIC 设计流程:什么是 ASIC 设计?

什么是 ASIC? ASIC(专用集成电路)是一种经过精心设计的专用集成电路,用于在电子系统中执行特定功能或功能集。与微波炉或电视盒等日常电子设备中的通用微处理器不同,ASIC 是为特定应用量身定制的,可提供无…

TKdownloader V5.5 抖音批量下载工具

目前能找到的仅存的免费抖音批量下载软件,有win版和mac版。 但是软件的运行需要一点点电脑知识,不太复杂,按着说明一步一步走,也能正常安装使用。 项目功能 下载抖音无水印视频/图集 下载抖音无水印实况/动图 下载最高画质视频文件…

Rust 编程实现猜数字游戏

文章目录 编程实现猜数字游戏游戏规则创建新项目默认代码处理用户输入代码解析 生成随机数添加依赖生成逻辑 比较猜测值与目标值类型转换 循环与错误处理优化添加循环优雅处理非法输入​ 最终完整代码核心概念总结 编程实现猜数字游戏 我们使用cargo和rust实现一个经典编程练习…

苏州SAP代理公司排名:工业园区企业推荐的服务商

目录 一、SAP实施商选择标准体系 1、行业经验维度 2、实施方法论维度 3、资质认证维度 4、团队实力维度 二、SAP苏州实施商工博科技 1、SAP双重认证,高等院校支持 2、以SAP ERP为核心,助力企业数字化转型 三、苏州使用SAP的企业 苏州是中国工业…

2505软考高项第一、二批真题终极汇总

第一批2025.05综合题(75道选择题) 1、2025 年中央一号文件对进一步深化农村改革的各项任务作出全面部署。“推进农业科技力量协同攻关”的相关措施不包括()。 A.强化农业科研资源力量统筹,培育农业科技领军企业 B.发挥农业科研平台作用&…

微深节能 堆取料机动作综合检测系统 格雷母线

精准定位,高效运行——微深节能格雷母线堆取料机动作综合检测系统 在现代工业自动化领域,精准的位置检测是保障设备高效运行的关键。武汉市微深节能科技有限公司推出的格雷母线高精度位移测量系统,凭借其卓越的性能和可靠性,成为…

Android Native 之 adbd进程分析

目录 1、adbd守护进程 2、adbd权限降级 3、adbd命令解析 1)adb shell 2)adb root 3)adb reboot 4、案例 1)案例之实现不需要执行adb root命令自动具有root权限 2)案例之实现不需要RSA认证直接能够使用adb she…

wireshark分析国标rtp ps流

1.将抓到的tcp或者udp视频流使用decode as 转为rtp包 2.电话->RTP->RTP播放器 选择Export 里面的Payload 就可以导出原始PS流

next.js 如何做中英文切换(详解)

最近开发的项目涉及到了 react, 因为之前没用过 next.js, 发现文档比较乱,所以也是花了点时间,这里做个记录。 前提依赖:App 文件夹路由 {"next": "14.2.22","react-i18next": "^15.5.1","i1…