Python 训练营打卡 Day 33-神经网络

article/2025/6/29 13:20:47
简单神经网络的流程

1.数据预处理(归一化、转换成张量)
2.模型的定义
    继承nn.Module类
    定义每一个层
    定义前向传播流程

3.定义损失函数和优化器
4.定义训练过程
5.可视化loss过程

预处理补充:
分类任务中,若标签是整数(如 0/1/2 类别),需转为long类型(对应 PyTorch 的torch.long),否则交叉熵损失函数会报错
回归任务中,标签需转为float类型(如torch.float32)

数据的准备

以4特征,3分类的鸢尾花数据集作为我们今天的数据集
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import numpy as np# 加载鸢尾花数据集
iris = load_iris()
X = iris.data  # 特征数据
y = iris.target  # 标签数据
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 打印下尺寸
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

# 归一化数据,神经网络对于输入数据的尺寸敏感,归一化是最常见的处理方式
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test) #确保训练集和测试集是相同的缩放# 将数据转换为 PyTorch 张量,因为 PyTorch 使用张量进行训练
# y_train和y_test是整数,所以需要转化为long类型,如果是float32,会输出1.0 0.0
X_train = torch.FloatTensor(X_train)
y_train = torch.LongTensor(y_train)
X_test = torch.FloatTensor(X_test)
y_test = torch.LongTensor(y_test)

模型架构定义

定义一个简单的全连接神经网络模型,包含一个输入层、一个隐藏层和一个输出层
定义层数+定义前向传播顺序

class MLP(nn.Module): # 定义一个多层感知机(MLP)模型,继承父类nn.Moduledef __init__(self): # 初始化函数super(MLP, self).__init__() # 调用父类的初始化函数# 前三行是八股文,后面的是自定义的self.fc1 = nn.Linear(4, 10)  # 输入层到隐藏层self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 3)  # 隐藏层到输出层
# 输出层不需要激活函数,因为后面会用到交叉熵函数cross_entropy,交叉熵函数内部有softmax函数,会把输出转化为概率def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# 实例化模型
model = MLP()

这个网络结构非常简单:
输入层:4个特征
隐藏层:10个神经元,使用ReLU激活
输出层:3个神经元(适合3分类问题)
没有dropout或batch normalization等复杂结构,这是一个典型的前馈神经网络,适用于简单的分类或回归任务

模型训练

定义损失函数和优化器

# 分类问题使用交叉熵损失函数
criterion = nn.CrossEntropyLoss()# 使用随机梯度下降优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)# # 使用自适应学习率的化器
# optimizer = optim.Adam(model.parameters(), lr=0.001)

使用交叉熵损失函数(CrossEntropyLoss),适用于多分类问题
会自动对输出进行softmax处理并计算损失
常用于分类任务,特别是当输出是类别概率时

使用随机梯度下降(SGD)优化器
优化对象是模型的所有可训练参数( model.parameters() )
学习率(lr)设置为0.01

这个配置是训练神经网络的标准设置:
交叉熵损失适用于分类任务
SGD是最基础的优化算法
学习率0.01是一个常用的初始值

循环训练

# 训练模型
num_epochs = 20000 # 训练的轮数# 用于存储每个 epoch 的损失值
losses = []for epoch in range(num_epochs): # range是从0开始,所以epoch是从0开始# 前向传播outputs = model.forward(X_train)   # 显式调用forward函数# outputs = model(X_train)  # 常见写法隐式调用forward函数,其实是用了model类的__call__方法loss = criterion(outputs, y_train) # output是模型预测值,y_train是真实标签# 反向传播和优化optimizer.zero_grad() #梯度清零,因为PyTorch会累积梯度,所以每次迭代需要清零,梯度累计是那种小的bitchsize模拟大的bitchsizeloss.backward() # 反向传播计算梯度optimizer.step() # 更新参数# 记录损失值losses.append(loss.item())# 打印训练信息if (epoch + 1) % 100 == 0: # range是从0开始,所以epoch+1是从当前epoch开始,每100个epoch打印一次print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

可视化结果

import matplotlib.pyplot as plt
# 可视化损失曲线
plt.plot(range(num_epochs), losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()

@浙大疏锦行


http://www.hkcw.cn/article/XpVqHeDRjZ.shtml

相关文章

TDengine 的 AI 应用实战——电力需求预测

作者: derekchen Demo数据集准备 我们使用公开的UTSD数据集里面的电力需求数据,作为预测算法的数据来源,基于历史数据预测未来若干小时的电力需求。数据集的采集频次为30分钟,单位与时间戳未提供。为了方便演示,按…

【03】完整开发腾讯云播放器SDK的UniApp官方UTS插件——优雅草上架插件市场-卓伊凡

【03】完整开发腾讯云播放器SDK的UniApp官方UTS插件——优雅草上架插件市场-卓伊凡 一、项目背景与转型原因 1.1 原定计划的变更 本系列教程最初规划是开发即构美颜SDK的UTS插件,但由于甲方公司内部战略调整,原项目被迫中止。考虑到: 技术…

(aaai2024) Omni-Kernel Network for Image Restoration

代码:https://github.com/c-yn/OKNet 研究动机:作者认为Transformer模型计算复杂度太高,因此提出了 omni-kernel module (OKM),可以有效的学习局部到全局的特征表示。该模块包括:全局、大分支、…

useMemo useCallback 自定义hook

useMemo & useCallback & 自定义hook useMemo 仅当依赖项发生变化的时候,才去重新计算;其他状态变化时则不去做不必要的计算。 useCallback 缓存函数。但是使用注意📢 ,useCallback没有特别明显的优化。 *合适的场景——父…

android binder(二)应用层编程实例

一、binder驱动浅析 从上图看出,binder的通讯主要涉及三个步骤。 在 Binder Server 端定义好服务,然后向 ServiceManager 注册服务在 Binder Client 中向 ServiceManager 获取到服务发起远程调用,调用 Binder Server 中定义好的服务 整个流…

GESP2024年3月认证C++二级( 第三部分编程题(2)小杨的日字矩阵)

参考程序&#xff1a; #include <iostream> using namespace std;int main() {int n;cin >> n; // 读入奇数 n// 外层循环控制每一行for (int i 0; i < n; i) {// 内层循环控制每一列for (int j 0; j < n; j) {char ch;// 如果当前列是最左或最右&#x…

BUUCTF[ACTF2020 新生赛]Exec 1题解

BUUCTF[ACTF2020 新生赛]Exec 1题解 分析解题过程总结: 分析 先分析题目&#xff1a;exc()是一个内部调用shell命令的函数&#xff0c;同样的函数还有system(), 创建靶机&#xff0c;打开网址&#xff0c;是一个和PING相关的网页&#xff0c;查看源代码&#xff0c;没有提示&a…

NX869NX874美光固态颗粒NX877NX883

NX869NX874美光固态颗粒NX877NX883 美光固态硬盘颗粒技术解析与市场展望 近年来&#xff0c;固态硬盘&#xff08;SSD&#xff09;市场呈现出蓬勃发展的态势&#xff0c;而作为核心组件的存储颗粒&#xff0c;其技术进展与市场动态自然吸引了众多关注。在众多品牌中&#xff…

CodeTop100 Day20

58、翻转字符串中的数字 class Solution {public String reverseWords(String s) {s s.trim(); int j s.length() - 1, i j;StringBuilder res new StringBuilder();while (i > 0) {while (i > 0 && s.charAt(i) ! ) i--…

重温经典算法——快速排序

版权声明 本文原创作者&#xff1a;谷哥的小弟作者博客地址&#xff1a;http://blog.csdn.net/lfdfhl 基本原理 快速排序基于分治思想&#xff0c;通过选取基准元素将数组划分为两个子数组&#xff08;小于基准和大于基准&#xff09;&#xff0c;递归排序子数组。平均时间复…

【机器学习】集成学习与梯度提升决策树

目录 一、引言 二、自举聚合与随机森林 三、集成学习器 四、提升算法 五、Python代码实现集成学习与梯度提升决策树的实验 六、总结 一、引言 在机器学习的广阔领域中,集成学习(Ensemble Learning)犹如一座闪耀的明星,它通过组合多个基本学习器的力量,创造出…

Python量化交易:K线形态识别与技术分析可视化

引言 在量化交易领域&#xff0c;K线形态识别是一种重要的技术分析方法&#xff0c;可以帮助投资者预测市场趋势并制定交易策略。本文将介绍如何使用Python实现K线形态的自动识别与可视化分析&#xff0c;无需依赖复杂的第三方库如TA-Lib&#xff0c;完全使用纯Python实现。通…

前端自动化测试利器:Playwright 全面介绍

目录 &#x1f9ea; 前端自动化测试利器&#xff1a;Playwright 全面介绍 ✨ 为什么选择 Playwright&#xff1f; 1. 跨浏览器支持 2. 多语言支持 3. 自动等待机制 4. 强大的页面交互能力 &#x1f527; Playwright 快速上手 &#x1f4f8; 更强的调试体验 &#x1f9…

华为云Flexus+DeepSeek征文|华为云 Dify 打造智慧水果分析助手,实现“知识库 + 大模型”精准赋能

前言 本文聚焦基于华为云平台部署的智慧水果分析助手 AI Agent&#xff0c;通过 Dify 平台集成 Embedding、Rerank 及 DeepSeek 模型&#xff0c;构建工作流&#xff0c;实现提问内容驱动的 “知识库 大模型” 与 “联网搜索 大模型” 智能切换。 ECS控制台&#xff1a;https…

【算法设计与分析】实验——改写二分搜索算法,众数问题(算法分析:主要算法思路),有重复元素的排列问题,整数因子分解问题(算法实现:过程,分析,小结)

说明&#xff1a;博主是大学生&#xff0c;有一门课是算法设计与分析&#xff0c;这是博主记录课程实验报告的内容&#xff0c;题目是老师给的&#xff0c;其他内容和代码均为原创&#xff0c;可以参考学习&#xff0c;转载和搬运需评论吱声并注明出处哦。 要求&#xff1a;2.…

MCP协议学习

MCP协议出现的背景 MCP&#xff08;Model Context Protocol&#xff0c;模型上下文协议&#xff09;由Anthropic公司于2024年11月推出&#xff0c;旨在解决大语言模型&#xff08;LLM&#xff09;与外部数据源、工具和服务之间的标准化交互问题。例如某金融科技公司需开发一款…

【笔记】Windows 部署 Suna 开源项目完整流程记录

#工作记录 因篇幅有限&#xff0c;所有涉及处理步骤的详细处理办法请参考文末资料。 Microsoft Windows [Version 10.0.27868.1000] (c) Microsoft Corporation. All rights reserved.(suna-py3.12) F:\PythonProjects\suna>python setup.py --admin███████╗██╗…

SQL Views(视图)

目录 Views Declaring Views Example: View Definition Example: Accessing a View Advantages of Views Triggers on Views Interpreting a View Insertion&#xff08;视图插入操作的解释&#xff09; The Trigger Views A view is a relation defined in terms of…

MySQL指令个人笔记

MySQL学习&#xff0c;SQL语言笔记 一、MySQL 1.1 启动、停止 启动 net start mysql83停止 net stop mysql831.2 连接、断开 连接 mysql -h localhost -P 3306 -u root -p断开 exit或者ctrlc 二、DDL 2.1 库管理 2.1.1 直接创建库 使用默认字符集和排序方式&#xf…

【redis实战篇】第七天

摘要&#xff1a; 本文介绍了黑马点评中点赞、关注和推送功能的实现方案。点赞功能采用Redis的ZSET结构存储用户点赞数据&#xff0c;实现点赞状态查询、热门博客排行和点赞用户展示。关注功能通过关系表和Redis集合实现用户关注关系管理&#xff0c;包含共同关注查询。推送功能…