DAY43打卡

article/2025/7/14 20:37:58

@浙大疏锦行

kaggle找到一个图像数据cnn网络进行训练并且grad-cam可视化

进阶并拆分成多个文件

fruit_cnn_project/
├─ data/                # 存放数据集(需手动创建,后续放入图片)
│  ├─ train/            # 训练集图像
│  └─ val/              # 验证集图像
├─ models/              # 模型定义
│  └─ cnn_model.py      # CNN网络结构
├─ utils/               # 工具函数
│  ├─ dataset_utils.py  # 数据加载与预处理
│  ├─ grad_cam.py       # Grad-CAM可视化
│  └─ train_utils.py    # 训练与评估
├─ main.py              # 主程序
└─ requirements.txt     # 依赖列表(可选)
# 第一部分:导入库
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline# 第二部分:数据加载与预处理
def load_data():data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])train_dataset = datasets.ImageFolder(root='data/train', transform=data_transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)test_dataset = datasets.ImageFolder(root='data/test', transform=data_transform)test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)return train_loader, test_loader# 第三部分:模型定义
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)self.fc1 = nn.Linear(32 * 56 * 56, 128)self.fc2 = nn.Linear(128, 2)def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 32 * 56 * 56)x = self.relu(self.fc1(x))x = self.fc2(x)return x# 第四部分:模型训练
train_loader, _ = load_data()
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10for epoch in range(num_epochs):running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')torch.save(model.state_dict(), 'trained_model.pth')# 第五部分:模型测试
_, test_loader = load_data()
model = SimpleCNN()
model.load_state_dict(torch.load('trained_model.pth'))
model.eval()
correct = 0
total = 0with torch.no_grad():for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the network on the test images: {100 * correct / total}%')# 第六部分:Grad-CAM可视化(修复版)
def get_activation():activation = {}def hook(model, input, output):activation['target_layer'] = output.detach()return hook, activationdef grad_cam(model, image, target_class_index):hook, activation = get_activation()target_layer = model.conv2target_layer.register_forward_hook(hook)model.eval()image = image.unsqueeze(0)image.requires_grad_(True)output = model(image)one_hot = torch.zeros(1, output.size()[-1]).to(image.device)one_hot[0][target_class_index] = 1output.backward(gradient=one_hot, retain_graph=True)gradients = image.grad[0].cpu().numpy()# 从activation字典中获取激活图activation_map = activation['target_layer'].cpu().numpy()[0]weights = np.mean(gradients, axis=(1, 2))cam = np.zeros(activation_map.shape[1:], dtype=np.float32)for i, w in enumerate(weights):cam += w * activation_map[i]cam = np.maximum(cam, 0)cam = F.interpolate(torch.from_numpy(cam).unsqueeze(0).unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False)[0][0].numpy()cam = (cam - cam.min()) / (cam.max() - cam.min())return cam# 可视化前几张测试图片
dataiter = iter(test_loader)
images, labels = dataiter.next()for i in range(5):  # 可视化前5张图片image = images[i]label = labels[i].item()cam = grad_cam(model, image, label)plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.imshow(image.permute(1, 2, 0).numpy())plt.title(f'Original Image (Class: {label})')plt.axis('off')plt.subplot(1, 2, 2)plt.imshow(image.permute(1, 2, 0).numpy())plt.imshow(cam, cmap='jet', alpha=0.5)plt.title('Grad-CAM Visualization')plt.axis('off')plt.tight_layout()plt.show()


http://www.hkcw.cn/article/vTXCJISXJk.shtml

相关文章

蓝天影院订票网站的设计V3

1 绪 论 1.1 本课题研究背景 20世纪90年代中期以来,随着以Internet为代表的计算机技术,网络技术和信息技术的迅速发展,影院订票也逐渐转移到网络上[1][2]。伴随着我国计算机信息产业的飞速进步,计算机的开发应用已经遍布生活…

Python----目标检测(《YOLO9000: Better, Faster, Stronger》和YOLO-V2的原理与网络结构)

一、YOLO9000: Better, Faster, Stronger 1.1、基本信息 标题: YOLO9000: Better, Faster, Stronger 作者: Joseph Redmon, Ali Farhadi 机构: 华盛顿大学1, 艾伦人工智能研究所2 发布时间: 2016年(根据arXiv编号1612.08242推断) 论文链接: [1612.0…

力扣HOT100之动态规划:32. 最长有效括号

这道题放在动态规划里属实是有点难为人了,感觉用动态规划来做反而更难理解了,这道题用索引栈来做相当好理解,这里先讲下索引栈的思路。 索引栈做法 我们定义一个存放整数的栈,定义一个全局变量result来记录最长有效子串的长度&a…

操作系统:文件系统笔记

文件系统 参考资料: 12.10 虚拟文件系统_哔哩哔哩_bilibili7.1 文件系统全家桶 | 小林coding 基本组成 文件系统是操作系统中负责管理持久数据的子系统,说简单点,就是负责把用户的文件存到磁盘硬件中,因为即使计算机断电了&#…

Docker 安装 Redis 容器

系列文章目录 文章目录 系列文章目录前言1 获取redis镜像2 创建和部署redis容器3 查看redis是否启动成功4 使用Redis客户端验证连接总结 前言 搭建环境: ubuntu22.04.05 docker redis: 7.0.10 测试环境: windows: win11 Redis测试客户端:Ti…

Spring Boot 3.X 下Redis缓存的尝试(二):自动注解实现自动化缓存操作

前言 上文我们做了在Spring Boot下对Redis的基本操作,如果频繁对Redis进行操作而写对应的方法显示使用注释更会更高效; 比如: 依之前操作对一个业务进行定入缓存需要把数据拉取到后再定入; 而今天我们可以通过注释的方式不需要额外…

【Linux】Ubuntu 20.04 英文系统显示中文字体异常

英文系统显示中文字体异常 新安装的 Ubuntu 20.04 英文系统,显示中文字体有些奇怪,比如在谷歌浏览器中中文字体显示效果如下 参考 英文版ubuntu默认中文显示很奇怪 解决方案 - dbxxx - 博客园 编辑文件 sudo gedit /etc/fonts/conf.avail/64-languag…

每天总结一个html标签——a标签

文章目录 一、定义与使用说明二、支持的属性三、支持的事件四、默认样式五、常见用法1. 文本链接2. 图片链接3. 导航栏 在前端开发中,a标签(锚点标签)是最常用的HTML标签之一,主要用于创建超链接,实现页面间的跳转或下…

Day10

1. ArrayList和LinkedList的区别? 底层结构:ArrayList 是基于动态数组实现,支持索引快速访问;LinkedList 是基于双向链表实现,依赖指针访问前后元素。插入与删除效率:在尾部操作时,两者性能相近…

Flask + Celery 应用

目录 Flask Celery 应用项目结构1. 创建app.py2. 创建tasks.py3. 创建celery_worker.py4. 创建templates目录和index.html运行应用测试文件 Flask Celery 应用 对于Flask与Celery结合的例子,需要创建几个文件。首先安装必要的依赖: pip install flas…

鸿蒙电脑会在国内逐渐取代windows电脑吗?

点击上方关注 “终端研发部” 设为“星标”,和你一起掌握更多数据库知识 10年内应该不会 用Windows、MacOS操作系统的后果是你的个人信息可能会被美国FBI看到,但绝大多数人的信息FBI没兴趣去看 你用某家公司的电脑系统,那就得做好被某些人监视的下场,相信…

安全态势感知中的告警误报思考

如果说2020年还是小打小闹,那么2021年无疑是杀疯了,我带着我的误报识别系统和事件专家系统在流量态势感知领域杀疯了。先说说这个误报识别系统,我创新性的定义了灰色事件,认为告警不是非黑即白的,而是需要持续评估的&a…

电脑wifi显示已禁用怎么点都无法启用

一、重启路由器与电脑 有时候,简单的重启可以解决很多小故障。试着先断开电源让路由器休息一会儿再接通;对于电脑,则可选择重启系统看看情况是否有改善。 二、检查驱动程序 无线网卡驱动程序的问题也是导致WiFi无法启用的常见原因之一。我…

MAC电脑怎么通过触摸屏打开右键

在Mac电脑上,通过触摸屏打开右键菜单的方法如下: 法1:双指轻点:在触控板上同时用两根手指轻点,即可触发右键菜单。这是Mac上常用的右键操作方法。 法2:自定义触控板角落:可以设置触控板的右下角或左下角作为右键区域…

【音视频】 FFmpeg 硬件(AMD)解码H264

参考链接:https://trac.ffmpeg.org/wiki/HWAccelIntro 硬件编解码的概念 硬件编解码是⾮CPU通过烧写运⾏视频加速功能对⾼清视频流进⾏编解码,其中⾮CPU可包括GPU、FPGA或者ASIC等独⽴硬件模块,把CPU⾼使⽤率的视频解码⼯作从CPU⾥分离出来&…

【笔记】为 Python 项目安装图像处理与科学计算依赖(MINGW64 环境)

📝 为 Python 项目安装图像处理与科学计算依赖(MINGW64 环境) 🎯 安装目的说明 本次安装是为了在 MSYS2 的 MINGW64 工具链环境中,搭建一个完整的 Python 图像处理和科学计算开发环境。 主要目的是支持以下类型的 Pyth…

软件测评师教程 第2章 软件测试基础 笔记

第2章 软件测试基础 笔记 25.03.18 2.1 软件测试的基本概念 2.1.1 什么是软件测试 软件测试的定义: IEEE 使用人工或自动手段来运行或测定某个系统的过程,其目的在于检验它是否满足规定的需求 或是 弄清预期结果与实际结果之间的差异。 软件测试应尽…

[网页五子棋][匹配对战]落子实现思路、发送落子请求、处理落子响应

文章目录 落子实现思路发送落子请求处理落子响应两种棋盘的区别实现 handleTextMessage实现对弈功能控制台打印棋盘完善前端逻辑 落子实现思路 先来实现:点击棋盘,能发送落子请求 客户端 1 点击了棋盘位置,先不着急画子,而是给服…

【QT控件】QWidget 常用核心属性介绍 -- 万字详解

目录 一、控件概述 二、QWidget 核心属性 2.1 核心属性概览 2.2 enabled ​编辑 2.3 geometry 2.4 windowTitle 2.5 windowIcon 使用qrc文件管理资源 2.6 windowOpacity 2.7 cursor 2.8 font ​编辑 2.9 toolTip 2.10 focusPolicy 2.11 styleSheet QT专栏&…

【Java EE初阶】计算机是如何⼯作的

计算机是如何⼯作的 计算机发展史冯诺依曼体系(Von Neumann Architecture)CPU指令(Instruction)CPU 是如何执行指令的(重点) 操作系统(Operating System)进程(process) 进程 PCB 中的…