python打卡day42

article/2025/8/11 10:54:10

Grad-CAM与Hook函数

知识点回顾

  1. 回调函数
  2. lambda函数
  3. hook函数的模块钩子和张量钩子
  4. Grad-CAM的示例

在深度学习中,我们经常需要查看或修改模型中间层的输出或梯度,但标准的前向传播和反向传播过程通常是一个黑盒,很难直接访问中间层的信息。PyTorch 提供了一种强大的工具——hook 函数,它允许我们在不修改模型结构的情况下,获取或修改中间层的信息。常用场景如下:

  1. 调试与可视化中间层输出
  2. 特征提取:如在图像分类模型中提取高层语义特征用于下游任务
  3. 梯度分析与修改: 在训练过程中,对某些层进行梯度裁剪或缩放,以改变模型训练的动态
  4. 模型压缩:在推理阶段对特定层的输出应用掩码(如剪枝后的模型权重掩码),实现轻量化推理

1、回调函数

Hook本质是回调函数,所以我们先介绍一下回调函数。回调函数是作为参数传递给其他函数的函数,其目的是在某个特定事件发生时被调用执行。这种机制允许代码在运行时动态指定需要执行的逻辑,其中回调函数作为参数传入,所以在定义的时候一般用callback来命名

在 PyTorch 的 Hook API 中,回调参数通常命名为 hook,PyTorch 的 Hook 机制基于其动态计算图系统:

  1. 当你注册一个 Hook 时,PyTorch 会在计算图的特定节点(如模块或张量)上添加一个回调函数
  2. 当计算图执行到该节点时(前向或反向传播),自动触发对应的 Hook 函数
  3. Hook 函数可以访问或修改流经该节点的数据(如输入、输出或梯度)

2、lambda函数

在hook中常常用到lambda函数,它是一种匿名函数(没有正式名称的函数),最大特点是用完即弃,无需提前命名和定义。它的语法形式非常简约,仅需一行即可完成定义,格式:lambda 参数列表: 表达式

  • 参数列表:可以是单个参数、多个参数或无参数
  • 表达式:函数的返回值(无需 return 语句,表达式结果直接返回)

举个例子

# 定义匿名函数:计算平方
square = lambda x: x ** 2# 调用
print(square(5))  # 输出: 25

3、hook函数

PyTorch 提供了两种主要的 hook:

  • Module Hooks(模块钩子):用于监听整个模块的输入和输出
  • Tensor Hooks:用于监听张量的梯度

(1)模块钩子

允许我们在模块的输入或输出经过时进行监听。PyTorch 提供了两种模块钩子:

  1. register_forward_hook:在模块的前向传播完成后立即被调用,这个函数可以访问模块的输入和输出,但不能修改
  2. register_backward_hook:在反向传播过程中被调用的,可以用来获取或修改梯度信息

前向钩子举个例子

# 创建模型实例
model = SimpleModel()# 创建一个列表用于存储中间层的输出
conv_outputs = []# 定义前向钩子函数 - 用于在模型前向传播过程中获取中间层信息
def forward_hook(module, input, output):print(f"钩子被调用!模块类型: {type(module)}")print(f"输入形状: {input[0].shape}") #  input是一个元组,对应 (image, label)print(f"输出形状: {output.shape}")# 保存卷积层的输出用于后续分析# 使用detach()避免追踪梯度,防止内存泄漏conv_outputs.append(output.detach())# 在卷积层注册前向钩子
# register_forward_hook返回一个句柄,用于后续移除钩子
hook_handle = model.conv.register_forward_hook(forward_hook)# 创建一个随机输入张量 (批次大小=1, 通道=1, 高度=4, 宽度=4)
x = torch.randn(1, 1, 4, 4)# 执行前向传播 - 此时会自动触发钩子函数
output = model(x)# 释放钩子 - 重要!防止在后续模型使用中持续调用钩子造成意外行为或内存泄漏
hook_handle.remove()

反向钩子

# 定义一个存储梯度的列表
conv_gradients = []# 定义反向钩子函数
def backward_hook(module, grad_input, grad_output):print(f"反向钩子被调用!模块类型: {type(module)}")print(f"输入梯度数量: {len(grad_input)}")print(f"输出梯度数量: {len(grad_output)}")# 保存梯度供后续分析conv_gradients.append((grad_input, grad_output))# 在卷积层注册反向钩子
hook_handle = model.conv.register_backward_hook(backward_hook)# 创建一个随机输入并进行前向传播
x = torch.randn(1, 1, 4, 4, requires_grad=True)
output = model(x)# 定义一个简单的损失函数并进行反向传播
loss = output.sum()
loss.backward()# 释放钩子
hook_handle.remove()

(2)张量钩子

PyTorch 还提供了张量钩子,允许我们直接监听和修改张量的梯度。张量钩子有两种:

  1. register_hook:用于监听张量的梯度
  2. register_full_backward_hook:用于在完整的反向传播过程中监听张量的梯度
# 创建一个需要计算梯度的张量
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
z = y ** 3# 定义一个钩子函数,用于修改梯度
def tensor_hook(grad):print(f"原始梯度: {grad}")# 修改梯度,例如将梯度减半return grad / 2# 在y上注册钩子
hook_handle = y.register_hook(tensor_hook)# 计算梯度,梯度会从z反向传播经过y到x,此时调用钩子函数
z.backward()print(f"x的梯度: {x.grad}")# 释放钩子
hook_handle.remove()

4、Grad-CAM

一个可视化算法,通过梯度信息用热力图显示图片中哪些区域让CNN做出了某个分类决定(比如为什么认为这是“猫”),原理:

  • 梯度计算:看最后几层特征图的梯度,哪个特征图对预测“猫”的贡献大
  • 加权融合:把重要的特征图合并成一张热力图(重要区域更亮)
  • 叠加显示:把热力图盖在原图上,一眼看出猫的脸/耳朵等关键部位被高亮了
# Grad-CAM实现
class GradCAM:def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layerself.gradients = Noneself.activations = None# 注册钩子,用于获取目标层的前向传播输出和反向传播梯度self.register_hooks()def register_hooks(self):# 前向钩子函数,在目标层前向传播后被调用,保存目标层的输出(激活值)def forward_hook(module, input, output):self.activations = output.detach()# 反向钩子函数,在目标层反向传播后被调用,保存目标层的梯度def backward_hook(module, grad_input, grad_output):self.gradients = grad_output[0].detach()# 在目标层注册前向钩子和反向钩子self.target_layer.register_forward_hook(forward_hook)self.target_layer.register_backward_hook(backward_hook)def generate_cam(self, input_image, target_class=None):# 前向传播,得到模型输出model_output = self.model(input_image)if target_class is None:# 如果未指定目标类别,则取模型预测概率最大的类别作为目标类别target_class = torch.argmax(model_output, dim=1).item()# 清除模型梯度,避免之前的梯度影响self.model.zero_grad()# 反向传播,构造one-hot向量,使得目标类别对应的梯度为1,其余为0,然后进行反向传播计算梯度one_hot = torch.zeros_like(model_output)one_hot[0, target_class] = 1model_output.backward(gradient=one_hot)# 获取之前保存的目标层的梯度和激活值gradients = self.gradientsactivations = self.activations# 对梯度进行全局平均池化,得到每个通道的权重,用于衡量每个通道的重要性weights = torch.mean(gradients, dim=(2, 3), keepdim=True)# 加权激活映射,将权重与激活值相乘并求和,得到类激活映射的初步结果cam = torch.sum(weights * activations, dim=1, keepdim=True)# ReLU激活,只保留对目标类别有正贡献的区域,去除负贡献的影响cam = F.relu(cam)# 调整大小并归一化,将类激活映射调整为与输入图像相同的尺寸(32x32),并归一化到[0, 1]范围cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)cam = cam - cam.min()cam = cam / cam.max() if cam.max() > 0 else camreturn cam.cpu().squeeze().numpy(), target_classidx = 102  # 选择测试集中的第101张图片 (索引从0开始)
image, label = testset[idx]
print(f"选择的图像类别: {classes[label]}")# 转换图像以便可视化
def tensor_to_np(tensor):img = tensor.cpu().numpy().transpose(1, 2, 0)mean = np.array([0.5, 0.5, 0.5])std = np.array([0.5, 0.5, 0.5])img = std * img + meanimg = np.clip(img, 0, 1)return img# 添加批次维度并移动到设备
input_tensor = image.unsqueeze(0).to(device)# 初始化Grad-CAM(选择最后一个卷积层)
grad_cam = GradCAM(model, model.conv3)# 生成热力图
heatmap, pred_class = grad_cam.generate_cam(input_tensor)# 可视化
plt.figure(figsize=(12, 4))# 原始图像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始图像: {classes[label]}")
plt.axis('off')# 热力图
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM热力图: {classes[pred_class]}")
plt.axis('off')# 叠加的图像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("叠加热力图")
plt.axis('off')plt.tight_layout()
plt.savefig('grad_cam_result.png')
plt.show()

@浙大疏锦行


http://www.hkcw.cn/article/TigQNeLEIr.shtml

相关文章

[总结]前端性能指标分析、性能监控与分析、Lighthouse性能评分分析

前端性能分析大全 前端性能优化 LightHouse性能评分 性能指标监控分析 浏览器加载资源的全过程性能指标分析 性能指标 在实现性能监控前,先了解Web Vitals涉及的常见的性能指标 Web Vitals 是由 Google 推出的网页用户体验衡量指标体系,旨在帮助开发者量…

Linux 驱动之设备树

Linux 驱动之设备树 参考视频地址 【北京迅为】嵌入式学习之Linux驱动(第七期_设备树_全新升级)_基于RK3568_哔哩哔哩_bilibili 本章总领 1.设备树基本知识 什么是设备树? ​ Linux之父Linus Torvalds在2011年3月17日的ARM Linux邮件列表…

Unity Mono与IL2CPP比较

Unity提供了两种主要的脚本后端(Scripting Backend)选项:Mono和IL2CPP。它们在性能、平台支持和功能特性上有显著差异。 Edit>Project Settings>Player>Other Settings Mono后端 特点: 基于开源的Mono项目(.NET运行时实现) 使用即时编译(JIT…

配置Ollama环境变量,实现远程访问

在安装 Ollama 时配置环境变量 OLLAMA_HOST0.0.0.0:11434的主要目的是允许 Ollama 服务被局域网或远程设备访问,而不仅仅是本地主机(localhost)。 以下是详细原因: 1. Ollama默认行为的限制 默认情况下,Ollama 的 API…

仓颉鸿蒙开发:制作底部标签栏

今天制作标签栏,标签栏里面的有4个区域:首页、社区、消息、我的,以及对应的图标。点击的区域显示为高亮,未点击的区域显示为灰色 简单的将视图上面区域做一下 一、制作顶部公共视图部分 internal import ohos.base.* internal …

AWS之数据分析

目录 数据分析产品对比 1. Amazon Athena 3. AWS Lake Formation 4. AWS Glue 5. Amazon OpenSearch Service 6. Amazon Kinesis Data Analytics 7. Amazon Redshift 8.Amazon Redshift Spectrum 搜索服务对比 核心功能与定位对比 适用场景 关键差异总结 注意事项 …

Linux进程间通信----简易进程池实现

进程池的模拟实现 1.进程池的原理: 是什么 进程池是一种多进程编程模式,核心思想是先创建好一定数量的子进程用作当作资源,这些进程可以帮助完成任务并且重复利用,避免频繁的进程的创建和销毁的开销。 下面我们举例子来帮助理…

【Oracle】安装单实例

个人主页:Guiat 归属专栏:Oracle 文章目录 1. 安装前的准备工作1.1 硬件和系统要求1.2 检查系统环境1.3 下载Oracle软件 2. 系统配置2.1 创建Oracle用户和组2.2 配置内核参数2.3 配置用户资源限制2.4 安装必要的软件包 3. 目录结构和环境变量3.1 创建Ora…

Pyecharts 库的概念与函数

基本概念 Pyecharts 是一个基于 ECharts 的 Python 数据可视化库,具有以下特点: 基于 ECharts:底层使用百度开源的 ECharts 图表库 多种图表类型:支持折线图、柱状图、饼图、散点图、地图等多种图表 交互式:生成的图…

【深入详解】C语言内存函数:memcpy、memmove的使用和模拟实现,memset、memcmp函数的使用

目录 一、memcpy、memmove使用和模拟实现 (一)memcpy的使用和模拟实现 1、代码演示: (1)memcpy拷贝整型 (2)memcpy拷贝浮点型 2、模拟实现 (二)memmove的使用和模…

设计模式——责任链设计模式(行为型)

摘要 责任链设计模式是一种行为型设计模式,旨在将请求的发送者与接收者解耦,通过多个处理器对象按链式结构依次处理请求,直到某个处理器处理为止。它包含抽象处理者、具体处理者和客户端等核心角色。该模式适用于多个对象可能处理请求的场景…

软件的兼容性如何思考与分析?

软件功能的兼容性是指软件在实现功能的时候,能够与其他软件、硬件、系统环境以及数据格式等相互协作、互不冲突,并且能够正确处理不同来源或不同版本的数据、接口和功能模块的能力。它确保软件在多种环境下能够正常运行,同时与其他系统和用户…

C++ —— STL容器——string类

1. 前言 本篇博客将会介绍 string 中的一些常用的函数,在使用 string 中的函数时,需要加上头文件 string。 2. string 中的常见成员函数 2.1 初始化函数 string 类中的常用的初始化函数有以下几种: 1. string() …

DFS每日刷题

目录 P1605 迷宫 P1451 求细胞数量 P1219 [USACO1.5] 八皇后 Checker Challenge P1605 迷宫 #include <iostream> using namespace std; int n, m, t; int a[20][20]; int startx, starty, endx, endy; bool vis[20][20]; int res; int dx[] {0, 1, 0, -1}; int dy[]…

USART 串口通信全解析:原理、结构与代码实战

文章目录 USARTUSART简介USART框图USART基本结构数据帧起始位侦测数据采样波特率发生器串口发送数据 主要代码串口接收数据与发送数据主要代码 USART USART简介 一、USART 的全称与基本定义 英文全称 USART&#xff1a;Universal Synchronous Asynchronous Receiver Transmi…

C# winform 教程(一)

一、安装方法 官网下载社区免费版&#xff0c;在线下载安装 VS2022官网下载地址 下载后双击启动&#xff0c;选择需要模块&#xff08;net桌面开发&#xff0c;通用window平台开发&#xff0c;或者其他自己想使用的模块&#xff0c;后期可以修改&#xff09;&#xff0c;选择…

ZLG ZCANPro,ECU刷新,bug分享

文章目录 摘要 📋问题的起因bug分享 ✨思考&反思 🤔摘要 📋 ZCANPro想必大家都不陌生,买ZLG的CAN卡,必须要用的上位机软件。在汽车行业中,有ECU软件升级的需求,通常都通过UDS协议实现程序的更新,满足UDS升级的上位机要么自己开发,要么用CANoe或者VFlash,最近…

Matlab作图之 subplot

1. subplot(m, n, p) 将当前图形划分为m*n的网格&#xff0c;在 p 指定的位置创建坐标轴 matlab 按照行号对子图的位置进行编号 第一个子图是第一行第一列&#xff0c;第二个子图是第二行第二列......... 如果指定 p 位置存在坐标轴&#xff0c; 此命令会将已存在的坐标轴设…

【STM32F1标准库】理论——外部中断

目录 一、中断介绍 二、外部引脚EXTI申请的中断 三、外部中断的适用场景 四、其他注意事项 一、中断介绍 STM32可以触发中断的外设有外部引脚(EXTI)、定时器、ADC、DMA、串口、I2C、SPI等 中断同一由NVIC管理 n表示一个外设可能同时占用多个中断通道 优先级的值越小优先…

SAP学习笔记 - 开发18 - 前端Fiori开发 应用描述符(manifest.json)的用途

上一章讲了 Component配置&#xff08;组件化&#xff09;。 本章继续讲Fiori的知识。 目录 1&#xff0c;应用描述符(Descriptor for Applications) 1&#xff09;&#xff0c; manifest.json 2&#xff09;&#xff0c;index.html 3&#xff09;&#xff0c;Component.…