Python_day43

article/2025/7/1 13:15:10

DAY 43 复习日

作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

进阶:并拆分成多个文件

关于 Dataset

从谷歌图片中抓取了 1000 多张猫和狗的图片。问题陈述是构建一个模型,该模型可以尽可能准确地在图像中的猫和狗之间进行分类。

图像大小范围从大约 100x100 像素到 2000x1000 像素。

图像格式为 jpeg。

已删除重复项。

猫狗图像分类 --- Cats and Dogs image classification

步骤

导入所需的模块

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Subset
from torchvision import transforms, datasets
import random
import os

数据准备和预处理

# 设置随机种子确保可复现
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)# 设置设备(优先使用GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# --- 关键修改1:调整为本地绝对路径并检查目录存在性 ---
data_dir = "d:\\code\\trae\\python_60\\Cat_and_Dog"  # 你的本地项目根目录
train_dir = os.path.join(data_dir, "train")  # 指向你的实际训练数据目录(需包含类别子文件夹)# 检查训练目录是否存在
if not os.path.isdir(train_dir):raise FileNotFoundError(f"训练目录不存在: {train_dir}\n""请按以下结构准备数据:\n"f"{data_dir}\n""└── train\n""    ├── cat\n"   # 类别子文件夹1(如猫)"    └── dog\n"   # 类别子文件夹2(如狗)"(每个子文件夹存放对应类别的图片)")# --- 关键修改2:优化数据划分逻辑(修正索引生成问题) ---
proportion = 0.2    # 验证集比例
batch_size = 32     # 批量大小# 加载数据集(使用训练目录)
data = datasets.ImageFolder(root=train_dir, transform=transforms.Compose([transforms.Resize(256),                 # 缩放到256x256transforms.CenterCrop(224),             # 中心裁剪224x224transforms.RandomHorizontalFlip(p=0.5), # 50%概率水平翻转transforms.ColorJitter(                 # 颜色抖动增强brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize(                   # ImageNet标准化参数mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
]))n_total = len(data)  # 总样本数
all_indices = list(range(n_total))  # 生成0~n_total-1的索引(修正原range(1,n)的0索引遗漏问题)
random.shuffle(all_indices)         # 打乱索引确保随机划分# 按比例分割训练集和验证集
n_val = int(proportion * n_total)
val_indices = all_indices[:n_val]       # 前n_val个作为验证集
train_indices = all_indices[n_val:]     # 剩余作为训练集train_set = Subset(data, train_indices)
val_set = Subset(data, val_indices)# 数据加载器(补充num_workers提升加载效率)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4)

定义卷积神经网络模型

实例化模型并移至计算设备(GPU或CPU)

定义损失函数和优化器(调整学习率和权重衰减)

学习率调度(移除不兼容的verbose参数)

# 定义卷积神经网络模型(优化版)
class SimpleCNN(nn.Module):def __init__(self, dropout_rate=0.5):super().__init__()# 卷积特征提取模块(含残差连接)self.conv_layers = nn.Sequential(# 第一层:输入3通道(RGB)→16通道nn.Conv2d(3, 16, kernel_size=3, padding=1),nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(2),  # 224x224 → 112x112# 第二层:16→32通道 + 残差连接nn.Conv2d(16, 32, kernel_size=3, padding=1),nn.BatchNorm2d(32),nn.ReLU(),nn.Conv2d(32, 32, kernel_size=3, padding=1),  # 残差分支nn.BatchNorm2d(32),nn.MaxPool2d(2),  # 112x112 → 56x56# 第三层:32→64通道nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.Dropout2d(0.1),nn.MaxPool2d(2),  # 56x56 → 28x28# 第四层:64→128通道nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128),nn.ReLU(),nn.Dropout2d(0.1),nn.MaxPool2d(2)   # 28x28 → 14x14(与原计算一致))# 动态计算全连接层输入维度(避免硬编码错误)with torch.no_grad():  # 虚拟输入计算特征尺寸dummy_input = torch.randn(1, 3, 224, 224)  # 输入尺寸与数据预处理一致dummy_output = self.conv_layers(dummy_input)self.feature_size = dummy_output.view(1, -1).size(1)# 全连接分类模块(增加正则化)self.fc_layers = nn.Sequential(nn.Linear(self.feature_size, 512),nn.BatchNorm1d(512),nn.ReLU(),nn.Dropout(dropout_rate),nn.Linear(512, 256),nn.BatchNorm1d(256),nn.ReLU(),nn.Dropout(dropout_rate),nn.Linear(256, 2)  # 修正:二分类输出维度为2)def forward(self, x):x = self.conv_layers(x)x = x.view(x.size(0), -1)  # 展平特征x = self.fc_layers(x)return x# 实例化模型并移至计算设备(GPU或CPU)
model = SimpleCNN(dropout_rate=0.3).to(device)  # 调整Dropout率(0.3比0.5更温和)# 定义损失函数和优化器(调整学习率和权重衰减)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)  # 学习率降至0.001,权重衰减微调# 学习率调度(移除不兼容的verbose参数)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max',  # 监控指标为验证准确率(越大越好)factor=0.5,    # 学习率衰减因子patience=2     # 等待2个epoch无提升再衰减
)

构建深度学习模型

训练主模型

# 训练模型主函数(优化版)
def train_model(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler._LRScheduler, epochs: int
) -> tuple[list[float], list[float], list[float], list[float]]:# 初始化训练和验证过程中的监控指标train_losses: list[float] = []  # 存储每个epoch的训练损失val_losses: list[float] = []    # 存储每个epoch的验证损失train_accuracies: list[float] = []  # 存储每个epoch的训练准确率val_accuracies: list[float] = []    # 存储每个epoch的验证准确率# 新增:早停相关变量(可选)best_val_loss: float = float('inf')early_stop_counter: int = 0early_stop_patience: int = 5  # 连续5个epoch无提升则停止# 主训练循环 - 遍历指定轮数for epoch in range(epochs):# 设置模型为训练模式(启用Dropout和BatchNorm等训练特定层)model.train()train_loss: float = 0.0  # 累积训练损失correct: int = 0         # 正确预测的样本数total: int = 0           # 总样本数# 批次训练循环 - 遍历训练数据加载器中的所有批次for inputs, targets in train_loader:# 将数据移至计算设备(GPU或CPU)inputs, targets = inputs.to(device), targets.to(device)# 梯度清零 - 防止梯度累积(每个批次独立计算梯度)optimizer.zero_grad()# 前向传播 - 通过模型获取预测结果outputs = model(inputs)# 计算损失 - 使用预定义的损失函数(如交叉熵)loss = criterion(outputs, targets)# 反向传播 - 计算梯度loss.backward()# 参数更新 - 根据优化器(如Adam)更新模型权重optimizer.step()# 统计训练指标train_loss += loss.item()  # 累积批次损失_, predicted = outputs.max(1)  # 获取预测类别total += targets.size(0)  # 累积总样本数correct += predicted.eq(targets).sum().item()  # 累积正确预测数# 计算当前epoch的平均训练损失和准确率train_loss /= len(train_loader)  # 平均批次损失train_accuracy = 100.0 * correct / total  # 计算准确率百分比train_losses.append(train_loss)  # 记录损失train_accuracies.append(train_accuracy)  # 记录准确率# 模型验证部分model.eval()  # 设置模型为评估模式(禁用Dropout等)val_loss: float = 0.0  # 累积验证损失correct = 0   # 正确预测的样本数total = 0     # 总样本数# 禁用梯度计算 - 验证过程不需要计算梯度,节省内存和计算资源with torch.no_grad():# 遍历验证数据加载器中的所有批次for inputs, targets in val_loader:# 将数据移至计算设备inputs, targets = inputs.to(device), targets.to(device)# 前向传播 - 获取验证预测结果outputs = model(inputs)# 计算验证损失loss = criterion(outputs, targets)# 统计验证指标val_loss += loss.item()  # 累积验证损失_, predicted = outputs.max(1)  # 获取预测类别total += targets.size(0)  # 累积总样本数correct += predicted.eq(targets).sum().item()  # 累积正确预测数# 计算当前epoch的平均验证损失和准确率val_loss /= len(val_loader)  # 平均验证损失val_accuracy = 100.0 * correct / total  # 计算验证准确率val_losses.append(val_loss)  # 记录验证损失val_accuracies.append(val_accuracy)  # 记录验证准确率# 打印当前epoch的训练和验证指标print(f'Epoch {epoch+1}/{epochs}')print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_accuracy:.2f}%')print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.2f}%')print('-' * 50)# 更新学习率调度器(修正mode为min,匹配验证损失)scheduler.step(val_loss)  # 传入验证损失,mode='min'# 新增:早停逻辑(可选)if val_loss < best_val_loss:best_val_loss = val_lossearly_stop_counter = 0# 可选:保存最佳模型权重torch.save(model.state_dict(), 'best_model.pth')else:early_stop_counter += 1if early_stop_counter >= early_stop_patience:print(f"Early stopping at epoch {epoch+1}")break# 返回训练和验证过程中的所有指标,用于后续分析和可视化return train_losses, val_losses, train_accuracies, val_accuracies# 训练模型(保持调用方式不变)
epochs = 20  
train_losses, val_losses, train_accuracies, val_accuracies = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs
)# 可视化训练过程(保持原函数不变)
def plot_training(train_losses, val_losses, train_accuracies, val_accuracies):plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(train_losses, label='Train Loss')plt.plot(val_losses, label='Validation Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.title('Training and Validation Loss')plt.subplot(1, 2, 2)plt.plot(train_accuracies, label='Train Accuracy')plt.plot(val_accuracies, label='Validation Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.legend()plt.title('Training and Validation Accuracy')plt.tight_layout()plt.show()plot_training(train_losses, val_losses, train_accuracies, val_accuracies)

模型评估结构

获取预测

@浙大疏锦行


http://www.hkcw.cn/article/cYGdpcrfOw.shtml

相关文章

【Quest开发】bug记录——Link界面无音频选项

此方法适用于这个不见了的情况 打开设备管理器&#xff0c;点击卸载 再到Oculus\Support\oculus-drivers找到oculus-driver点击重装驱动&#xff0c;重启电脑即可修复

汇编语言学习(二)——寄存器

目录 一、通用寄存器 二、数据存储 三、汇编指令 四、物理地址 五 、段寄存器 一、通用寄存器 在8086 CPU中&#xff0c;通用寄存器共有四个&#xff0c;分别是 AX、BX、CX 和 DX&#xff0c;它们通常用于存放一般性的数据&#xff0c;均为 16 位寄存器&#xff0c;可以存…

Error creating bean with name *.PageHelperAutoConfiguration 异常解析

一、问题报错 微服务安装成功&#xff0c;启动失败&#xff0c;报错如下&#xff1a; 二、 Spring Boot应用启动错误分析 错误概述 这是一个Spring Boot应用启动过程中出现的Bean创建错误。根据错误堆栈&#xff0c;主要问题在于无法创建PageHelper分页插件的自动配置类。 …

【Zephyr 系列 3】多线程与调度机制:让你的 MCU 同时干多件事

好的,下面是Zephyr 系列第 3 篇:聚焦 多线程与调度机制的实践应用,继续面向你这样的 Ubuntu + 真板实战开发者,代码清晰、讲解通俗、结构规范,符合 CSDN 高质量博客标准。 🧠关键词:Zephyr、线程调度、k_thread、k_sleep、RTOS、BluePill 📌适合人群:想从裸机开发进…

AI万能写作v1.0.12

AI万能写作是一款高度自动化、智能化、个性化的AI智能软件&#xff0c;旨在通过人工智能技术进行内容整合创作&#xff0c;为用户提供便捷高效的写作辅助。这款APP能够一键生成各类素材内容&#xff0c;帮助用户快速获取思路和灵感&#xff0c;成为写作、学习、工作以及日常生活…

【Linux网络篇】:HTTP协议深度解析---基础概念与简单的HTTP服务器实现

✨感谢您阅读本篇文章&#xff0c;文章内容是个人学习笔记的整理&#xff0c;如果哪里有误的话还请您指正噢✨ ✨ 个人主页&#xff1a;余辉zmh–CSDN博客 ✨ 文章所属专栏&#xff1a;Linux篇–CSDN博客 文章目录 一.三个预备知识认识域名认识URL认识URL编码和解码 二.http请求…

【JAVA后端入门基础001】Tomcat 是什么?通俗易懂讲清楚!

&#x1f4da;博客主页&#xff1a;代码探秘者 ✨专栏&#xff1a;《JavaSe》 其他更新ing… ❤️感谢大家点赞&#x1f44d;&#x1f3fb;收藏⭐评论✍&#x1f3fb;&#xff0c;您的三连就是我持续更新的动力❤️ &#x1f64f;作者水平有限&#xff0c;欢迎各位大佬指点&…

系统思考:成长与投资不足

最近认识了一位95后年轻创业者&#xff0c;短短2年时间&#xff0c;他的公司从十几个人发展到几百人&#xff0c;规模迅速扩大。随着团队壮大&#xff0c;用户池也在持续扩大&#xff0c;但令人困惑的是&#xff0c;业绩增长却没有明显提升&#xff0c;甚至人效持续下滑。尽管公…

PHP7+MySQL5.6 查立得轻量级公交查询系统

# PHP7MySQL5.6 查立得轻量级公交查询系统 ## 系统简介 本系统是一个基于PHP7和MySQL5.6的轻量级公交查询系统(40KB级)&#xff0c;支持线路查询、站点查询和换乘查询功能。系统采用原生PHPMySQL开发&#xff0c;无需第三方框架&#xff0c;适合手机端访问。 首发版本&#x…

【笔记】Windows系统部署suna基于 MSYS2的Poetry 虚拟环境backedn后端包编译失败处理

基于 MSYS2&#xff08;MINGW64&#xff09;中 Python 的 Poetry 虚拟环境包编译失败处理笔记 一、背景 在基于 MSYS2&#xff08;MINGW64&#xff09;中 Python 创建的 Poetry 虚拟环境里&#xff0c;安装 Suna 开源项目相关包时编译失败&#xff0c;阻碍项目正常部署。 后端…

docker可视化工具

一、portainer&#xff08;不常用&#xff09; 1、安装portainer [rootlocalhost /]# docker run -d -p 8088:9000 --name portainer --restartalways -v /var/run/docker.sock:/var/run/docker.sock -v portainer_data:/data --privilegedtrue portainer/portainer-c…

#16 学习日志软件测试

#16 #13布置的任务都没有wanc 反思一下 一个是贪玩 一个是懒 还有一个原因是学习方式 单看视频容易困 然后是一个进度宝贝 java ai 编程 完 挑着看的 廖雪峰教程 完 速看 很多过时 javaweb ai笔记 见到13.aop 小林coding 看到4.并发 java guide 还没开始 若依框架 笔…

【数据集】NCAR CESM Global Bias-Corrected CMIP5 Output to Support WRF/MPAS Research

目录 数据概述🔍 数据集简介:🧪 数据处理方法:📅 时间范围(Temporal Coverage):📈 模拟情景(Scenarios):🌡️ 关键变量(Variables):📏 垂直层级(Vertical Levels):💾 数据格式与获取方式:数据下载及处理参考🌍 数据集名称: NCAR CESM Global B…

如何用AI写作?

过去半年&#xff0c;我如何用AI高效写作&#xff0c;节省数倍时间 过去六个月&#xff0c;我几乎所有文章都用AI辅助完成。我的朋友——大多是文字工作者&#xff0c;对语言极为敏感——都说看不出我的文章是AI写的还是亲手创作的。 我的AI写作灵感部分来自丘吉尔。这位英国…

dvwa4——File Inclusion

LOW: 先随便点开一个文件&#xff0c;可以观察到url栏变成这样&#xff0c;说明?page是dvwa当前关卡用来加载文件的参数 http://10.24.8.35/DVWA/vulnerabilities/fi/?pagefile1.php 我们查看源码 &#xff0c;没有什么过滤&#xff0c;直接尝试访问其他文件 在url栏的pag…

mysql数据库实现分库分表,读写分离中间件sharding-sphere

一 概述 1.1 sharding-sphere 作用&#xff1a; 定位关系型数据库的中间件&#xff0c;合理在分布式环境下使用关系型数据库操作&#xff0c;目前有三个产品 1.sharding-jdbc&#xff0c;sharding-proxy 1.2 sharding-proxy实现读写分离的api版本 4.x版本 5.x版本 1.3 说明…

Doris环境部署与应用开发

部署的方式有几种,可以下载github上的源码编译,这里直接下载官方最新的二进制包,差不多有4G。 wget -c https://apache-doris-releases.oss-accelerate.aliyuncs.com/apache-doris-3.0.5-bin-x64.tar.gz tar -zxvf apache-doris-3.0.5-bin-x64.tar.gz mv apache-doris-3.0.…

Dify在Windows 11上的安装实战

一、引言 随着人工智能技术的飞速发展&#xff0c;大语言模型&#xff08;LLM&#xff09;的应用场景日益丰富&#xff0c;从智能客服到内容生成&#xff0c;再到复杂的数据分析&#xff0c;LLM正逐步渗透到各行各业。Dify&#xff0c;作为一个专注于AI应用开发的开源平台&…

C++之动态数组vector

Vector 一、什么是 std::vector&#xff1f;二、std::vector 的基本特性&#xff08;一&#xff09;动态扩展&#xff08;二&#xff09;随机访问&#xff08;三&#xff09;内存管理 三、std::vector 的基本操作&#xff08;一&#xff09;定义和初始化&#xff08;二&#xf…

Spring Boot Starter 自动装配原理全解析:从概念到实践

Spring Boot Starter 自动装配原理全解析&#xff1a;从概念到实践 在Spring Boot开发中&#xff0c;Starter和自动装配是两个核心概念&#xff0c;它们共同构成了“开箱即用”的开发体验。通过引入一个Starter依赖&#xff0c;开发者可以快速集成第三方组件&#xff08;如Red…