Day 35 训练

article/2025/7/27 13:34:13

Day 35 训练

环境准备

确保已安装 PyTorch、sklearn、matplotlib、tqdm 等必要的 Python 库。若未安装,可通过以下命令安装:

pip install torch scikit-learn matplotlib tqdm -i https://pypi.tuna.tsinghua.edu.cn/simple

由于网络原因,使用清华大学的镜像源来加速安装过程。如果链接无法访问,请检查网络连接,或更换其他可用的 PyPI 镜像源。

模型训练

数据准备与预处理

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import time
import matplotlib.pyplot as plt# 设置GPU设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 加载鸢尾花数据集
iris = load_iris()
X = iris.data  # 特征数据
y = iris.target  # 标签数据# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 归一化数据
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 将数据转换为PyTorch张量并移至GPU
X_train = torch.FloatTensor(X_train).to(device)
y_train = torch.LongTensor(y_train).to(device)
X_test = torch.FloatTensor(X_test).to(device)
y_test = torch.LongTensor(y_test).to(device)

这里我们使用 torch.device 来检测并使用可用的 GPU 设备,以加速模型训练过程。加载鸢尾花数据集后,将其划分为训练集和测试集,并对特征数据进行归一化处理,使数据分布在 [0, 1] 范围内,有助于模型更快地收敛。最后将数据转换为 PyTorch 张量格式,并传输到指定设备(GPU 或 CPU)上。

定义模型

class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(4, 10)  # 输入层到隐藏层self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 3)  # 隐藏层到输出层def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# 实例化模型并移至GPU
model = MLP().to(device)

定义了一个简单的多层感知机(MLP)模型,包含一个输入层、一个隐藏层和一个输出层。输入层接收 4 个特征维度的数据,隐藏层有 10 个神经元并使用 ReLU 激活函数,输出层有 3 个神经元,用于对应鸢尾花的 3 个类别。通过 model.to(device) 将模型放置在指定设备上。

损失函数与优化器

# 分类问题使用交叉熵损失函数
criterion = nn.CrossEntropyLoss()# 使用随机梯度下降优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)

对于分类任务,选择交叉熵损失函数(CrossEntropyLoss)作为损失函数,它可以结合 softmax 操作,衡量模型预测结果与真实标签之间的差异。采用随机梯度下降(SGD)优化器来更新模型参数,学习率(lr)设置为 0.01。

模型训练过程

# 训练模型
num_epochs = 20000  # 训练的轮数# 用于存储每100个epoch的损失值和对应的epoch数
losses = []start_time = time.time()  # 记录开始时间for epoch in range(num_epochs):# 前向传播outputs = model(X_train)  # 隐式调用forward函数loss = criterion(outputs, y_train)# 反向传播和优化optimizer.zero_grad() #梯度清零loss.backward() #  反向传播计算梯度optimizer.step() # 更新参数# 记录损失值if (epoch + 1) % 200 == 0:losses.append(loss.item()) # item()方法返回一个Python数值,loss是一个标量张量print(f'Epoch [{epoch+1}/{num_epochs}, Loss: {loss.item():.4f}')# 打印训练信息if (epoch + 1) % 100 == 0: # 每100个epoch打印一次print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')time_all = time.time() - start_time  # 计算训练时间
print(f'Training time: {time_all:.2f} seconds')

设置训练轮数为 20000 次。在每次训练迭代中,首先进行前向传播计算模型输出和损失值,然后执行反向传播计算梯度,并利用优化器更新模型参数。每 100 个 epoch 打印一次训练信息,每 200 个 epoch 记录一次损失值用于后续可视化。训练完成后,输出整个训练过程所花费的时间。

模型结构可视化

使用 torch.nn.Module 内置功能

# 直接输出模型结构
print(model)# 输出模型的可训练参数迭代器
for name, param in model.named_parameters():print(f"Parameter name: {name}, Shape: {param.shape}")

直接打印模型对象可以查看模型的结构,显示各层的名称和参数信息。通过 model.named_parameters() 可以获取模型中可训练参数的名称和形状,有助于我们了解模型的参数细节。
在这里插入图片描述
在这里插入图片描述

使用 torchsummary

from torchsummary import summary# 打印模型摘要
summary(model, input_size=(4,))

需先安装 torchsummary 库。summary 函数会输出模型的详细摘要,包括每层的输出形状、参数数量等信息。input_size 参数指定输入数据的形状,以便推断模型各层的信息。

使用 torchinfo

from torchinfo import summary# 打印模型摘要
summary(model, input_size=(4, ))

torchinfo 提供了比 torchsummary 更为详细的模型信息,包括每层的输入输出形状、参数数量、计算量等。

进度条功能实现

使用 tqdm 库可以方便地在训练过程中添加进度条,实时显示训练进度。

from tqdm import tqdm# 创建tqdm进度条
with tqdm(total=num_epochs, desc="训练进度", unit="epoch") as pbar:# 训练模型for epoch in range(num_epochs):# 前向传播outputs = model(X_train)  loss = criterion(outputs, y_train)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()# 记录损失值并更新进度条if (epoch + 1) % 200 == 0:losses.append(loss.item())epochs.append(epoch + 1)# 更新进度条的描述信息pbar.set_postfix({'Loss': f'{loss.item():.4f}'})# 每1000个epoch更新一次进度条if (epoch + 1) % 1000 == 0:pbar.update(1000)  # 更新进度条# 确保进度条达到100%if pbar.n < num_epochs:pbar.update(num_epochs - pbar.n)  # 计算剩余的进度并更新

通过 tqdm 创建一个进度条对象,并在训练循环中根据设定的步长更新进度条。set_postfix 方法用于在进度条右侧显示实时数据,如当前损失值。这种方式能够更直观地观察模型训练的进展和效果。

模型推理

在完成模型训练后,我们可以在测试集上进行模型推理,评估模型的性能。

# 在测试集上评估模型
model.eval() # 设置模型为评估模式
with torch.no_grad(): # 禁用梯度计算,提高推理速度outputs = model(X_test)  # 对测试数据进行前向传播,获得预测结果_, predicted = torch.max(outputs, 1) # 获取预测类别correct = (predicted == y_test).sum().item() # 计算预测正确的样本数accuracy = correct / y_test.size(0)print(f'测试集准确率: {accuracy * 100:.2f}%')

将模型设置为评估模式,此时模型会关闭一些训练时特有的操作(如 dropout 等),以确保输出结果的稳定性和一致性。在推理过程中禁用梯度计算可以提高效率。通过计算预测结果与真实标签的匹配程度,得出模型在测试集上的准确率。

以上就是基于 PyTorch 实现鸢尾花分类模型的完整实践过程,涵盖了数据准备、模型构建、训练、可视化、进度条功能以及推理评估等各个环节,希望对大家学习和应用深度学习技术有所帮助。
浙大疏锦行


http://www.hkcw.cn/article/iACdTnrVNQ.shtml

相关文章

C++实现伽罗华域生成及四则运算(三)

目录 原来的求行列式代码改进后的求行列式代码GF(256)域幻方4阶矩阵求逆致敬 上一篇文章介绍了用C实现伽罗华域 G F ( 2 n ) GF(2^n) GF(2n)数的次幂、矩阵数乘、方阵次幂、求行列式、求伴随矩阵、逆矩阵的函数&#xff0c;本篇继续&#xff0c;对求行列式的函数进行改进&#…

【AIGC专栏】微软Bing ImageCreator 创建图片

图像创建器是一款产品,可帮助用户使用 DALLE 3 生成 AI 图像。得到文本提示后,AI 将生成一组与该提示匹配的图像。 当前的Bing Image Creator 依然处在预览周期,所以很多区域无法访问。特别是位于中国大陆和香港、澳门区域都无法正常使用这些服务。 另外还有两个条件: 注…

2025年05月总结及随笔之家电询价及以旧换新

1. 回头看 日更坚持了882天。 读《数据自助服务实践指南&#xff1a;数据开放与洞察提效》更新完成读《数据科学伦理&#xff1a;概念、技术和警世故事》开更并更新完成读《红蓝攻防&#xff1a;技术与策略》开更并持续更新 2023年至2025年05月底累计码字2356176字&#xff…

【IC】chip binning是什么

你刚买了一块新的CPU或显卡&#xff0c;在电脑上启动了它。它运行起来感觉很酷&#xff0c;于是你尝试了一下超频。频率越来越高&#xff0c;看起来你似乎得到了一些特别的东西。这肯定不应该吧&#xff1f; 于是你冲到互联网上分享你中了硅片大奖的兴奋之情&#xff0c;没过几…

mysql核心知识点

Server 层负责建立连接、分析和执行 SQL。MySQL 大多数的核心功能模块都在这实现&#xff0c;主要包括连接器&#xff0c;查询缓存、解析器、预处理器、优化器、执行器等。另外&#xff0c;所有的内置函数&#xff08;如日期、时间、数学和加密函数等&#xff09;和所有跨存储引…

可暂停Windows更新的便捷工具

软件介绍 今天给大家介绍一款能实现Windows暂停更新的实用工具。 吾爱jiedeng按照这个思路开发了一个小工具。该工具能将Windows更新暂停的上限设为7000天&#xff0c;这几乎相当于永久暂停了。 软件使用起来很简单&#xff0c;输入要暂停的天数后&#xff0c;点击【一键修改…

lidar和imu的标定(二)GRIL-Calib

直接看IV. PREPROCESSING这一部分&#xff0c; 首先是提取了地平面。 然后看这个公式&#xff0c;表示地平面和[0,0,1]向量之间构成的旋转矩阵&#xff0c;当然是使用角轴的方式构建的。 B. LiDAR Odometry Utilizing the Ground Plane Residual 使用平面约束的的激光里程计约…

怎么更改cursor chat中的字体大小

使用 ctrl 【Ctrl键和加号键一起按】增加所有窗口的字体大小 然后打开 VS Code 设置并减小文本编辑器字体大小即可

JMeter 直连数据库

1.直连数据库的使用场景 1.1 参数化&#xff0c;例如登录使用的账户名密码都可以从数据库中取得 1.2 断言&#xff0c;查看实际结果和数据库中的预期结果是否一致 1.3 清理垃圾数据&#xff0c;例如插入一个用户&#xff0c;它的ID不能相同&#xff0c;在测试插入功能后将数据删…

【火山引擎 大模型批量推理数据教程---详细讲解一篇过!】

0. 相关的文档 &#xff01;&#xff01;先注册火山引擎账号第一步&#xff01;&#xff01; 批量推理文档网页对象存储网页提交批量处理网页费用接口网页 1. 准备jsonl数据集 官网地址样例&#xff0c;需要根据你自己的数据进行需改 import json## 你的数据&#xff0c;自…

测量3D翼片的距离与角度

1&#xff0c;目的。 测量3D翼片的距离与角度。说明&#xff1a; 标注A 红色框选的区域即为翼片&#xff0c;本示例的3D 对象共有3个翼片待测。L1与L2的距离、L1与L2的角度即为所求的翼片距离与角度。 2&#xff0c;原理。 使用线结构光模型&#xff08;标定模式&#xff0…

单元测试-概述入门

目录 main方法测试缺点&#xff1a; 在pom.xm中&#xff0c;引入junit的依赖。,在test/java目录下&#xff0c;创建测试类&#xff0c;并编写对应的测试方法&#xff0c;并在方法上声明test注解。 练习&#xff1a;验证身份证合法性 测试成功 测试失败 main方法测试缺点&am…

模块联邦:更快的微前端方式!

什么是模块联邦 在前端项目中&#xff0c;不同团队之间的业务模块可能有耦合&#xff0c;比如A团队的页面里有一个富文本模块&#xff08;组件&#xff09;&#xff0c;而B团队 的页面恰好也需要使用这个富文本模块。 传统模式下&#xff0c;B团队只能去抄A团队的代码&#x…

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 时间事件处理部分)

揭秘高效存储模型与数据结构底层实现 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 时间事件&#xff1a;serverCron函数更新服务器时间缓存更新LRU时钟-lruclock更新服务器每秒执行命令次…

ZIP Cracker版本更新了

废话不多说(也不能多说&#xff0c;原因都懂吧)&#xff0c;上图&#xff0c;阿修大佬已经更新了新的版本 参考原文&#xff1a;https://mp.weixin.qq.com/s/7ptu8tLR_2huivLJdcFBzQ

云南独龙江乡全部通信网络已抢通 紧急抢修保畅通

近日,受持续强降雨影响,怒江傈僳族自治州贡山县独龙江乡遭遇山洪和滑坡等自然灾害,导致通信网络严重受损。5月31日上午10时37分,全乡通信网络站点大面积中断,中国移动云南公司怒江分公司使用卫星传输基站保障独龙江乡政府所在地的通信正常。怒江移动分公司迅速启动防汛应急…

跨越时空的科学对话:现代科学解160年前的遗传学密码 科学家精神熠熠生辉

点滴故事中,领略科学家精神的熠熠光辉。通过讲述一个个科学家的故事,展现他们的风采,记录科技事业的发展历程,弘扬科学家的精神内涵。2025年5月31日是端午节,传统文化中有纪念屈原的习俗。两千三百年前,屈原在汨罗江畔仰观宇宙,以《天问》叩击苍穹:“日月安属?列星安陈…

美国民众开始不愿意花钱了 对现有经济存“潜在焦虑情绪” 多重经济压力交织

近期,一系列数据和调查显示,美国民众对本国经济前景的信心正处于低谷。美国密歇根大学公布的5月消费者信心指数初值降至50.8,连续第五个月下降,为2022年6月以来的最低水平。这种悲观情绪反映出美国经济深层次的矛盾与挑战。通货膨胀一直是困扰美国民众的主要问题。尽管美联…

【GESP真题解析】第 4 集 GESP 三级 2023 年 6 月编程题 1:春游

大家好,我是莫小特。 这篇文章给大家分享 GESP 三级 2023 年 6 月编程题第 1 题:春游。 题目链接 洛谷链接:B3842 春游 一、完成输入 根据输入格式的描述,输入包括两个正整数 N 和 M,N 是 N 位同学,M 是 M 次报出编号,数据范围: 2 ≤ N , M ≤ 1000 2\le N,M \le 10…

遭邻居多次持刀砍门当事人发声 精神疾病患者惹争议

近日,有大连网友在社交平台发布视频称,5月1日和5月28日,疑似患有精神疾病的邻居两次持刀上门,用刀砍其家门,并进行踢踹。网传视频截图显示了这一情况。该网友表示,他们一家才搬来一年,与这名邻居素不相识,没有正面交流过。记者多次尝试联系该网友,但未获回复。6月1日,…