python训练营打卡第42天

article/2025/7/6 22:29:45

Grad-CAM与Hook函数

知识点回顾

  1. 回调函数
  2. lambda函数
  3. hook函数的模块钩子和张量钩子
  4. Grad-CAM的示例

作业:理解下今天的代码即可

1.回调函数

def handle_result(result):"""处理计算结果的回调函数"""print(f"计算结果是: {result}")def with_callback(callback):"""装饰器工厂:创建一个将计算结果传递给回调函数的装饰器"""def decorator(func):"""实际的装饰器,用于包装目标函数"""def wrapper(a, b):"""被装饰后的函数,执行计算并调用回调"""result = func(a, b)  # 执行原始计算callback(result)     # 调用回调函数处理结果return result        # 返回计算结果(可选)return wrapperreturn decorator# 使用装饰器包装原始计算函数
@with_callback(handle_result)
def calculate(a, b):"""执行加法计算"""return a + b# 直接调用被装饰后的函数
calculate(3, 5) 

输出结果:

计算结果是: 8

2.lamda匿名函数

square = lambda x : x ** 2
print(square(5))

输出结果:

25

3.hook函数

import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)# 解决OpenMP运行时库冲突问题
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # 临时解决方案
# 或者限制OpenMP线程数(可选)
# os.environ['OMP_NUM_THREADS'] = '1'# 定义简单的卷积神经网络
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()# 卷积层:输入1通道,输出2通道,3x3卷积核,填充1保持尺寸self.conv = nn.Conv2d(1, 2, kernel_size=3, padding=1)self.relu = nn.ReLU()# 全连接层:输入2*4*4特征,输出10分类self.fc = nn.Linear(2 * 4 * 4, 10)def forward(self, x):x = self.conv(x)x = self.relu(x)x = x.view(-1, 2 * 4 * 4)x = self.fc(x)return xdef analyze_model_hooks():# 创建模型model = SimpleModel()# 存储中间层输出conv_outputs = []# 前向钩子函数def forward_hook(module, input, output):"""前向钩子函数,在模块前向传播后自动调用参数:module: 当前模块实例input: 输入张量元组output: 输出张量"""print(f"前向钩子被调用 - 模块类型: {type(module)}")print(f"输入形状: {input[0].shape}")print(f"输出形状: {output.shape}")# 保存卷积层输出用于分析conv_outputs.append(output.detach())# 注册前向钩子forward_hook_handle = model.conv.register_forward_hook(forward_hook)# 创建输入张量x = torch.randn(1, 1, 4, 4)# 执行前向传播output = model(x)# 释放前向钩子forward_hook_handle.remove()# 打印并可视化卷积层输出if conv_outputs:print(f"\n卷积层输出形状: {conv_outputs[0].shape}")print(f"卷积层第一个输出通道示例:\n{conv_outputs[0][0, 0, :, :]}")# 尝试可视化,如果环境支持try:plt.figure(figsize=(12, 4))# 输入图像plt.subplot(1, 3, 1)plt.title('输入图像')plt.imshow(x[0, 0].detach().numpy(), cmap='gray')# 第一个卷积核输出plt.subplot(1, 3, 2)plt.title('卷积核1输出')plt.imshow(conv_outputs[0][0, 0].detach().numpy(), cmap='gray')# 第二个卷积核输出plt.subplot(1, 3, 3)plt.title('卷积核2输出')plt.imshow(conv_outputs[0][0, 1].detach().numpy(), cmap='gray')plt.tight_layout()plt.show()except Exception as e:print(f"无法显示图像: {e}. 可能需要GUI环境。")# 存储梯度conv_gradients = []# 反向钩子函数def backward_hook(module, grad_input, grad_output):"""反向钩子函数,在模块反向传播时自动调用参数:module: 当前模块实例grad_input: 输入梯度元组grad_output: 输出梯度元组"""print(f"\n反向钩子被调用 - 模块类型: {type(module)}")print(f"输入梯度数量: {len(grad_input)}")print(f"输出梯度数量: {len(grad_output)}")# 保存梯度用于分析conv_gradients.append((grad_input, grad_output))# 注册反向钩子backward_hook_handle = model.conv.register_backward_hook(backward_hook)# 创建带梯度的输入并执行前向传播x = torch.randn(1, 1, 4, 4, requires_grad=True)output = model(x)# 定义损失函数并执行反向传播loss = output.sum()loss.backward()# 释放反向钩子backward_hook_handle.remove()# 张量钩子示例def demonstrate_tensor_hook():print("\n=== 张量钩子示例 ===")# 创建带梯度的张量x = torch.tensor([2.0], requires_grad=True)y = x ** 2z = y ** 3# 梯度修改钩子def tensor_hook(grad):print(f"原始梯度: {grad}")# 修改梯度(例如减半)return grad / 2# 注册张量钩子tensor_hook_handle = y.register_hook(tensor_hook)# 执行反向传播z.backward()print(f"修改后的梯度: {x.grad}")  # 应显示修改后的梯度# 释放张量钩子tensor_hook_handle.remove()# 运行张量钩子示例demonstrate_tensor_hook()if __name__ == "__main__":analyze_model_hooks()    

输出结果:

前向钩子被调用 - 模块类型: <class 'torch.nn.modules.conv.Conv2d'>
输入形状: torch.Size([1, 1, 4, 4])
输出形状: torch.Size([1, 2, 4, 4])卷积层输出形状: torch.Size([1, 2, 4, 4])     
卷积层第一个输出通道示例:
tensor([[-0.4173, -0.5642,  0.3407, -0.5395],[-0.0755,  0.8618,  0.5276, -0.2671],[-0.1973, -0.4461, -1.1223,  0.1371],[ 0.2291, -0.1226,  0.2020,  0.3061]])
反向钩子被调用 - 模块类型: <class 'torch.nn.modules.conv.Conv2d'>
输入梯度数量: 3
输出梯度数量: 1=== 张量钩子示例 ===
原始梯度: tensor([48.])
修改后的梯度: tensor([96.])

4.Grad-CAM

import os
# 解决OpenMP运行时库冲突问题
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # 允许重复加载OpenMP库
# 可选:限制OpenMP线程数以减少冲突
# os.environ['OMP_NUM_THREADS'] = '1'import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import warnings# 全局设置
warnings.filterwarnings("ignore")
plt.rcParams["font.family"] = ["SimHei"]  # 设置中文字体
plt.rcParams["axes.unicode_minus"] = False  # 解决负号显示问题
torch.manual_seed(42)  # 固定随机种子
np.random.seed(42)# 设备配置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# ---------------------- 数据处理模块 ---------------------- #
class DataLoader:"""数据加载与预处理类"""def __init__(self, data_root="./data"):self.data_root = data_rootself.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])def load_dataset(self, train=True):"""加载CIFAR-10数据集"""try:dataset = torchvision.datasets.CIFAR10(root=self.data_root,train=train,download=True,transform=self.transform)return datasetexcept Exception as e:raise RuntimeError(f"数据集加载失败: {e}")def create_dataloader(self, dataset, batch_size=64, shuffle=True, num_workers=2):"""创建数据加载器"""return torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,pin_memory=(device.type == "cuda"),  # GPU优化persistent_workers=(num_workers > 0))# ---------------------- 模型定义模块 ---------------------- #
class SimpleCNN(nn.Module):"""简单CNN模型"""def __init__(self):super(SimpleCNN, self).__init__()self.conv_layers = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2, 2))self.fc_layers = nn.Sequential(nn.Linear(128 * 4 * 4, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.conv_layers(x)x = x.view(x.size(0), -1)x = self.fc_layers(x)return x# ---------------------- 训练模块 ---------------------- #
class Trainer:"""模型训练类"""def __init__(self, model, device):self.model = model.to(device)self.device = deviceself.criterion = nn.CrossEntropyLoss()self.optimizer = torch.optim.Adam(model.parameters(), lr=0.001)def train_epoch(self, dataloader):"""单轮训练"""self.model.train()running_loss = 0.0for i, (inputs, labels) in enumerate(dataloader, 1):inputs, labels = inputs.to(self.device), labels.to(self.device)self.optimizer.zero_grad()outputs = self.model(inputs)loss = self.criterion(outputs, labels)loss.backward()self.optimizer.step()running_loss += loss.item()if i % 100 == 0:print(f"批次 {i:4d} 平均损失: {running_loss / 100:.3f}")running_loss = 0.0def train(self, dataloader, epochs=1):"""完整训练流程"""for epoch in range(1, epochs + 1):print(f"\nEpoch {epoch}/{epochs}")self.train_epoch(dataloader)print("训练完成")# ---------------------- Grad-CAM可视化模块 ---------------------- #
class GradCAM:"""Grad-CAM可视化类(修复梯度跟踪和OpenMP冲突)"""def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layerself.gradients = None  # 存储脱离梯度的梯度self.activations = None  # 存储脱离梯度的激活值self._register_hooks()  # 注册钩子def _register_hooks(self):"""注册前向/反向钩子函数,并脱离梯度"""def forward_hook(module, _, output):# 前向传播时保存脱离梯度的激活值self.activations = output.detach()def backward_hook(module, grad_input, grad_output):# 反向传播时保存脱离梯度的梯度self.gradients = grad_output[0].detach()self.target_layer.register_forward_hook(forward_hook)self.target_layer.register_backward_hook(backward_hook)def _generate_cam(self, input_tensor, target_class):"""生成类激活映射(CAM)"""# 确保模型处于训练模式(临时启用梯度计算)self.model.train()# 启用输入张量的梯度跟踪input_tensor.requires_grad_(True)# 前向传播获取输出outputs = self.model(input_tensor)# 构建目标类别的one-hot向量one_hot = torch.zeros_like(outputs)one_hot[0, target_class] = 1# 反向传播计算梯度(保留计算图)self.model.zero_grad()outputs.backward(gradient=one_hot, retain_graph=True)# 获取已脱离梯度的激活值和梯度activations = self.activations  # 已在钩子中detach()gradients = self.gradients      # 已在钩子中detach()# 计算通道权重(全局平均池化)weights = torch.mean(gradients, dim=(2, 3), keepdim=True)# 加权求和生成CAMcam = torch.sum(weights * activations, dim=1, keepdim=True)# ReLU激活(保留正贡献区域)cam = F.relu(cam)# 调整尺寸并归一化cam = F.interpolate(cam, size=32, mode="bilinear", align_corners=False)cam = (cam - cam.min()) / (cam.max() + 1e-8)  # 归一化到[0, 1]# 恢复模型为评估模式self.model.eval()# 转换为NumPy数组(已脱离梯度,安全转换)return cam.squeeze().numpy()def __call__(self, input_image, target_class=None):"""入口函数:生成热力图和预测类别"""input_tensor = input_image.unsqueeze(0).to(device)if target_class is None:# 不启用梯度模式下获取预测类别(节省计算)with torch.no_grad():target_class = self.model(input_tensor).argmax(dim=1).item()return self._generate_cam(input_tensor, target_class), target_class# ---------------------- 工具函数 ---------------------- #
def tensor_to_image(tensor):"""将Tensor转换为可视化图像(反归一化并脱离梯度)"""# 确保张量脱离梯度并转换为CPUtensor = tensor.detach().cpu()img = tensor.numpy().transpose(1, 2, 0)mean = np.array([0.5, 0.5, 0.5])std = np.array([0.5, 0.5, 0.5])return std * img + mean  # 反归一化处理def visualize_cam(image, heatmap, pred_class, true_class, save_path="grad_cam_result.png"):"""可视化热力图(支持无GUI环境)"""try:plt.figure(figsize=(12, 4))# 原始图像(处理梯度并转换)image_np = tensor_to_image(image)# 原始图像plt.subplot(1, 3, 1)plt.imshow(image_np)plt.title(f"真实类别: {classes[true_class]}")plt.axis("off")# 热力图plt.subplot(1, 3, 2)plt.imshow(heatmap, cmap="jet")plt.title(f"预测类别: {classes[pred_class]}")plt.axis("off")# 叠加图像plt.subplot(1, 3, 3)heatmap_colored = plt.cm.jet(heatmap)[:, :, :3]superimposed = image_np * 0.6 + heatmap_colored * 0.4plt.imshow(superimposed)plt.title("叠加热力图")plt.axis("off")plt.tight_layout()plt.savefig(save_path)print(f"结果已保存至 {save_path}")plt.show()except Exception as e:print(f"可视化失败: {e},可能缺少GUI环境,已跳过显示")plt.close()# ---------------------- 主流程 ---------------------- #
if __name__ == "__main__":classes = ("飞机", "汽车", "鸟", "猫", "鹿", "狗", "青蛙", "马", "船", "卡车")# 1. 初始化数据加载器data_loader = DataLoader()try:testset = data_loader.load_dataset(train=False)except RuntimeError as e:print(f"数据加载失败: {e},程序终止")exit(1)# 2. 初始化模型model = SimpleCNN()print("模型已创建")# 3. 加载或训练模型model_path = "cifar10_cnn.pth"try:model.load_state_dict(torch.load(model_path, map_location=device))print(f"成功加载预训练模型: {model_path}")except FileNotFoundError:print("未找到预训练模型,开始训练...")trainset = data_loader.load_dataset(train=True)trainloader = data_loader.create_dataloader(trainset)trainer = Trainer(model, device)trainer.train(trainloader, epochs=1)  # 可增加epochs参数torch.save(model.state_dict(), model_path)print(f"模型已保存至 {model_path}")except Exception as e:print(f"模型加载失败: {e},使用随机初始化模型")model = model.to(device)model.eval()  # 模型评估模式# 4. Grad-CAM可视化try:idx = 102  # 固定索引便于复现(可改为随机索引:np.random.randint(len(testset)))image, true_label = testset[idx]print(f"选择图像索引 {idx},真实类别: {classes[true_label]}")# 选择最后一个卷积层(conv3,即conv_layers中的第7层)target_layer = model.conv_layers[-3]  # 对应nn.Conv2d(64, 128, kernel_size=3, padding=1)grad_cam = GradCAM(model, target_layer)# 生成热力图(自动处理梯度跟踪)heatmap, pred_label = grad_cam(image, target_class=None)# 可视化结果(确保张量已脱离梯度)visualize_cam(image, heatmap, pred_label, true_label)except Exception as e:print(f"可视化失败: {e}")

输出结果:

​
使用设备: cpu
模型已创建
成功加载预训练模型: cifar10_cnn.pth
选择图像索引 102,真实类别: 青蛙
结果已保存至 grad_cam_result.png​

@浙大疏锦行


http://www.hkcw.cn/article/XoXvjHsSmr.shtml

相关文章

ISO18436-2 CATII级振动分析师能力矩阵

ISO18436-2021是当前针对针对分析师的一个标准&#xff0c;它对振动分析师的能力和知识体系做了4级分类&#xff0c;这里给出的是一家公司响应ISO18436的CATII级标准&#xff0c;做的一个专题培训的教学大纲。摘自&#xff1a; 【振動噪音產學技術聯盟】04/19-23 ISO 18436-2…

YARN应用日志查看

YARN应用日志查看 1、页面查看2、命令行查看1、页面查看 1.1、YARN ResourceManager Web UI Spark on YARN时,YARN的资源管理器(ResourceManager)和历史服务器(History Server)提供了强大的日志和监控功能,可以帮助用户查看和管理Spark作业 访问YARN ResourceManager的…

免费酒店管理系统+餐饮系统+小程序点餐——仙盟创梦IDE

酒店系统主屏幕 房间管理 酒店管理系统的房间管理&#xff0c;可实现对酒店所有房间的实时掌控。它能清晰显示房间状态&#xff0c;如已预订、已入住、空闲等&#xff0c;便于高效安排入住与退房&#xff0c;合理分配资源&#xff0c;提升服务效率&#xff0c;保障酒店运营有条…

29 C 语言内存管理与多文件编程详解:栈区、全局静态区、static 与 extern 深度解析

1 C 语言内存管理概述 1.1 内存分区模型解析 在 C 语言程序中&#xff0c;内存的合理管理是确保程序高效运行的核心。为了深入理解变量的作用域、生命周期及内存分配机制&#xff0c;我们需要先掌握内存分区模型。C 语言将内存划分为以下几个核心区域&#xff1a; 栈区&#…

JavaScript 性能优化实战:从原理到框架的全栈优化指南

在 Web 应用复杂度指数级增长的今天&#xff0c;JavaScript 性能优化已成为衡量前端工程质量的核心指标。本文将结合现代浏览器引擎特性与一线大厂实践经验&#xff0c;构建从基础原理到框架定制的完整优化体系&#xff0c;助你打造高性能 Web 应用。 一、性能优化基础&#x…

2025年十大AI幻灯片工具深度评测与推荐

我来告诉你一个好消息。 我们已经亲自测试和对比了市面上最优秀的AI幻灯片工具&#xff0c;让你无需再为选择而烦恼。 得益于AI技术的飞速发展&#xff0c;如今你可以快速制作出美观、专业的幻灯片。 这些智能平台的功能远不止于配色美化——它们能帮你头脑风暴、梳理思路、…

MATLAB 安装与使用详细教程

目录 第一部分&#xff1a;MATLAB 安装教程第二部分&#xff1a;MATLAB 界面介绍第三部分&#xff1a;MATLAB 基础使用第四部分&#xff1a;MATLAB 脚本编程第五部分&#xff1a;MATLAB 编程示例 第一部分&#xff1a;MATLAB 安装教程 1 下载 MATLAB 安装文件 访问 MathWor…

【C++进阶篇】C++11新特性(上篇)

&#x1f4a1; 解锁C11新技能&#xff1a;初始化、类型推导与智能指针的奥秘&#xff01; 一. C11简介1.1 C11发展历史 二. 初始化列表2.1 内置类型2.2 initializer_list详解 三. 简化声明3.1 auto 自动推导类型3.2.1 注意事项 3.3 decltype 获取推导类型3.3.1 没有括号3.3.2 有…

Unity中应对高速运动的物体,碰撞组件失效的问题?

尝试方法一:修改重力组件Rigidbody中的碰撞检测模式Collision Detection 把碰撞检测模式Collision Detection属性修改成Continuous Dynamic后,发现效果不是很明显,还会有碰撞组件失效的问题。 尝试方法二:射线检测替代物理碰撞 private Vector3 _prevPos;void Start() {…

高性能MYSQL(三):性能剖析

一、性能剖析概述 &#xff08;一&#xff09;关于性能优化 1.什么是性能&#xff1f; 我们将性能定义为完成某件任务所需要的时间度量&#xff0c;换句话说&#xff0c;性能即响应时间&#xff0c;这是一个非常重要的原则。 我们通过任务和时间而不是资源来测量性能。数据…

《深入解析SPI协议及其FPGA高效实现》-- 第二篇:SPI控制器FPGA架构设计

第二篇&#xff1a;SPI控制器FPGA架构设计 聚焦模块化设计、时序优化与资源管理 1. 系统级架构设计 1.1 模块化硬件架构 verilog module spi_controller (input wire clk, // 系统时钟 (100 MHz)input wire rst_n, // 异步复位// 配置接口…

rabbitmq Fanout交换机简介

给每个服务创建一个队列&#xff0c;然后每个业务订阅一个队列&#xff0c;进行消费。 如订单服务起个多个服务&#xff0c;代码是一样的&#xff0c;消费的也是同一个队列。加快了队列中的消息的消费速度。 可以看到两个消费者已经在消费了

Ⅱ.计算机二级选择题(运算符与表达式)

【注&#xff1a;重点题以及添加目录格式导航&#xff01;&#xff01;&#xff01;】 【重点题】&#xff08;第5题&#xff09; 【重点题】&#xff08;第18题&#xff09; 【重点题】&#xff08;第19题&#xff09; 【重点题】&#xff08;第35题&#xff09; 【重点题】&a…

使用Mathematica观察多形式根的分布随参数的变化

有两种方式观察多项式的根随着参数变化&#xff1a;&#xff08;1&#xff09;直接制作一个小的动态视频&#xff1b;&#xff08;2&#xff09;绘制所有根形成的痕迹&#xff08;locus&#xff09;。 制作动态视频&#xff1a; (*Arg-plane plotting routine with plotting …

腾答知识竞赛系统功能介绍

支持抢答题的局域网现场大屏知识竞赛抢答软件&#xff0c;无需网络只要有局域网或者WIFI就可以使用,现场大屏幕显示题目&#xff0c;支持基础题、抢答题、必答题、风险题等题目。 系统支持任何个人或者企业单位使用&#xff0c;使用无人员限制&#xff0c;可放心使用。 抢答时…

Python-matplotlib库之核心对象

matplotlib库之核心对象 FigureFigure作用Figure常用属性Figure常用方法Figure对象的创建隐式创建&#xff08;通过 pyplot&#xff09;显式创建使用subplots()一次性创建 Figure 和 Axes Axes&#xff08;绘图区&#xff09;Axes创建方式Axes基本绘图功能Axes绘图的常用参数Ax…

04powerbi-度量值-筛选引擎CALCULATE()

1、calculate calculate 的参数分两部分&#xff0c;分别是计算器和筛选器 2、多条件calculater与表筛选 多条件有不列的多条件 相同列的多条件 3、calculatertable &#xff08;表&#xff0c;筛选条件&#xff09;表筛选 与calculate用法一样&#xff0c;可以用创建表&…

深度学习原理与Pytorch实战

深度学习原理与Pytorch实战 第2版 强化学习人工智能神经网络书籍 python动手学深度学习框架书 TransformerBERT图神经网络&#xff1a; 技术讲解 编辑推荐 1.基于PyTorch新版本&#xff0c;涵盖深度学习基础知识和前沿技术&#xff0c;由浅入深&#xff0c;通俗易懂&#xf…

LabelImg: 开源图像标注工具指南

LabelImg: 开源图像标注工具指南 1. 简介 LabelImg 是一个图形化的图像标注工具&#xff0c;使用 Python 和 Qt 开发。它是目标检测任务中最常用的标注工具之一&#xff0c;支持 PASCAL VOC 和 YOLO 格式的标注输出。该工具开源、免费&#xff0c;并且跨平台支持 Windows、Lin…

React---day6、7

6、组件之间进行数据传递 **6.1 父传子&#xff1a;**props传递属性 父组件&#xff1a; <div><ChildCpn name"蒋乙菥" age"18" height"1,88" /> </div>子组件&#xff1a; export class ChildCpn extends React.Component…