PyTorch中nn.Module详解

article/2025/7/22 6:40:20

直接print(dir(nn.Module)),得到如下内容:
在这里插入图片描述

一、模型结构与参数

  1. parameters()

    • 用途:返回模块的所有可训练参数(如权重、偏置)。
    • 示例
      for param in model.parameters():print(param.shape)
      
  2. named_parameters()

    • 用途:返回带名称的参数迭代器,便于调试和访问特定参数。
    • 示例
      for name, param in model.named_parameters():if 'weight' in name:print(name, param.shape)
      
  3. children()

    • 用途:返回直接子模块的迭代器。
    • 示例
      for child in model.children():print(type(child))
      
  4. modules()

    • 用途:递归返回所有子模块(包括自身)。
    • 示例
      for module in model.modules():if isinstance(module, nn.Conv2d):print(module.kernel_size)
      

二、模型状态与模式

  1. train()eval()

    • 用途:切换训练/推理模式(影响Dropout、BatchNorm等层)。
    • 示例
      model.train()  # 训练模式
      model.eval()   # 推理模式
      
  2. training

    • 用途:布尔属性,指示当前模式(True 为训练,False 为推理)。
    • 示例
      print(model.training)  # 输出:True/False
      

三、模型保存与加载

  1. state_dict()

    • 用途:返回包含模型所有参数的字典(OrderedDict)。
    • 示例
      torch.save(model.state_dict(), 'model.pth')
      
  2. load_state_dict()

    • 用途:从字典加载模型参数。
    • 示例
      model.load_state_dict(torch.load('model.pth'))
      

四、设备与数据类型

  1. to()

    • 用途:将模型移动到指定设备(如GPU)或转换数据类型。
    • 示例
      model.to('cuda')          # 移动到GPU
      model.to(torch.float16)   # 转换为半精度
      
  2. cpu()cuda()

    • 用途:快捷方法,分别将模型移动到CPU或GPU。
    • 示例
      model.cuda()  # 等价于 model.to('cuda')
      

五、前向传播与计算

  1. forward()

    • 用途:定义模型的前向传播逻辑(需在自定义模块中重写)。
    • 示例
      class MyModel(nn.Module):def forward(self, x):return self.layer(x)
      
  2. __call__()

    • 用途:调用模型实例时触发(内部调用 forward(),支持钩子函数)。
    • 示例
      output = model(x)  # 等价于 output = model.forward(x)
      

六、参数初始化与优化

  1. zero_grad()

    • 用途:清空所有参数的梯度(通常在每个训练步骤前调用)。
    • 示例
      optimizer.zero_grad()  # 等价于 model.zero_grad()
      
  2. requires_grad_()

    • 用途:设置参数是否需要梯度(用于冻结部分模型)。
    • 示例
      for param in model.parameters():param.requires_grad = False  # 冻结所有参数
      

七、调试与信息

  1. extra_repr()

    • 用途:自定义模块打印信息(需在子类中重写)。
    • 示例
      class MyModel(nn.Module):def extra_repr(self):return f"hidden_size={self.hidden_size}"
      
  2. dump_patches()

    • 用途:打印模型的补丁信息(用于调试版本差异)。

八、其他实用方法

  1. apply()

    • 用途:递归应用函数到所有子模块(如初始化权重)。
    • 示例
      def init_weights(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)
      model.apply(init_weights)
      
  2. register_forward_hook()

    • 用途:注册前向传播钩子(用于捕获中间输出,调试或特征提取)。

总结

日常使用中,最频繁的方法包括:

  • 模型构建parameters(), children(), modules()
  • 训练与推理train(), eval(), zero_grad(), forward()
  • 保存与加载state_dict(), load_state_dict()
  • 设备管理to(), cuda(), cpu()

其他方法根据具体需求选择使用,例如钩子函数用于高级调试,apply() 用于统一初始化。

与nn.Sequential对比:

1. 继承关系与基础属性

  • nn.Module

    • 是所有神经网络模块的基类,提供最基础的功能(如参数管理、钩子机制)。
    • 包含核心属性:_parameters, _modules, _buffers 等。
  • nn.Sequential

    • nn.Module 的子类,继承了所有基础功能。
    • 额外添加了与顺序执行相关的属性(如 __getitem__append)。

2. 核心差异对比

功能类别nn.Modulenn.Sequential
模块构建需要手动实现 forward 方法自动按顺序执行子模块,无需定义 forward
子模块访问通过属性名(如 self.conv1通过索引或命名(如 model[0]
动态修改需手动管理子模块支持 appendextendinsert 等操作
适用场景复杂网络结构(如ResNet、U-Net)简单顺序结构(如LeNet卷积部分)

3. 具体方法对比

3.1 公共方法(两者都有)
# 模型参数与结构
['parameters', 'named_parameters', 'children', 'modules', 'named_children', 'named_modules']# 模型状态
['train', 'eval', 'training', 'zero_grad', 'requires_grad_']# 设备与数据类型
['to', 'cpu', 'cuda', 'float', 'double', 'half', 'bfloat16']# 保存与加载
['state_dict', 'load_state_dict']# 钩子机制
['register_forward_hook', 'register_backward_hook']
3.2 nn.Sequential 特有的方法
# 列表操作(动态修改模块顺序)
['__getitem__', '__setitem__', '__delitem__', '__len__', 'append', 'extend', 'insert', 'pop']# 索引相关
['_get_item_by_idx']
3.3 nn.Module 特有的方法
# 自定义实现
['forward', 'extra_repr']# 高级管理
['add_module', 'register_module', 'register_parameter', 'register_buffer']

4. 示例对比

4.1 创建模型
# nn.Module(需自定义 forward)
class CustomModel(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv2d(3, 64, 3)self.relu = nn.ReLU()def forward(self, x):return self.relu(self.conv(x))# nn.Sequential(自动按顺序执行)
seq_model = nn.Sequential(nn.Conv2d(3, 64, 3),nn.ReLU()
)
4.2 访问子模块
# nn.Module
custom_model.conv  # 通过属性名访问# nn.Sequential
seq_model[0]       # 通过索引访问
seq_model.append(nn.MaxPool2d(2))  # 动态添加模块

5. 总结

特性nn.Modulenn.Sequential
灵活性高(自定义任意逻辑)低(仅支持顺序执行)
代码复杂度较高(需手动实现 forward低(自动处理前向传播)
动态修改不支持直接操作(需手动管理)支持 appendinsert 等操作
适用场景复杂网络、分支结构、自定义操作简单堆叠模块(如CNN的卷积部分)

建议:

  • 对于简单的顺序网络,优先使用 nn.Sequential 以减少代码量。
  • 对于包含复杂逻辑(如残差连接、多输入输出)的网络,使用 nn.Module 自定义实现。

http://www.hkcw.cn/article/yJGCKFPnam.shtml

相关文章

若依项目天气模块

在若依项目里添加了一个天气模块,记录一下过程。 一、功能结构与组件布局 天气模块以卡片形式(el-card)展示,包含以下核心功能: 实时天气:显示当前城市、温度、天气状况(如晴、多云&#xff…

APM32芯得 EP.06 | APM32F407移植uC/OS-III实时操作系统经验分享

《APM32芯得》系列内容为用户使用APM32系列产品的经验总结,均转载自21ic论坛极海半导体专区,全文未作任何修改,未经原文作者授权禁止转载。 最近我开始学习 uC/OS-III 实时操作系统,并着手将其移植到APM32F407 开发板上。在这个过…

图解gpt之注意力机制原理与应用

大家有没有注意到,当序列变长时,比如翻译一篇长文章,或者处理一个长句子,RNN这种编码器就有点力不从心了。它把整个序列信息压缩到一个固定大小的向量里,信息丢失严重,而且很难记住前面的细节,特…

更新密码--二阶注入攻击的原理

1.原理知识: 二阶SQL注入攻击(Second-Order SQL Injection)原理详解 一、基本概念 二阶注入是一种"存储型"SQL注入,攻击流程分为两个阶段: ​​首次输入​​:攻击者将恶意SQL片段存入数据库​…

RFID技术助力托盘运输线革新

RFID技术助力托盘运输线革新 湖北某工厂托盘运输线使用上存在的问题: 1、托盘在运输线上受信息录入时间等问题影响,导致效率低下; 2、原先托盘上粘贴的条码容易污损,并且时常需要更新更换,导致信息录入、出入库等步…

EasyRTC嵌入式音视频通信SDK助力1v1实时音视频通话全场景应用

一、方案概述​ 在数字化通信需求日益增长的今天,EasyRTC作为一款全平台互通的实时视频通话方案,实现了设备与平台间的跨端连接。它支持微信小程序、APP、PC客户端等多端协同,开发者通过该方案可快速搭建1v1实时音视频通信系统,适…

java.io.IOException: ZIP entry size is too large or invalid

java.io.IOException: ZIP entry size is too large or invalid 解决方案&#xff1a;pom.xml添加<nonFilteredFileExtension>xlsx</nonFilteredFileExtension>

vue3 项目配置多语言支持,如何从服务端拿多语言配置

在 Vue3 项目中实现多语言支持并从服务端获取配置&#xff0c;可以使用 Vue I18n 库。在初始化阶段可以发送请求获取多语言配置或者通过本地文件加载json文件的方式&#xff0c;都可以实现。我这里是tauri项目&#xff0c;所以使用的是invoke从tauri端拿到配置文件&#xff0c;…

龙舟竞渡与芯片制造的共通逻辑:华芯邦的文化破局之道

端午节承载着中华民族数千年的精神密码&#xff0c;龙舟最初是古人沟通天地、祈求风调雨顺的仪式载体。战国时期&#xff0c;屈原投江的悲壮故事为端午注入了家国情怀&#xff0c;龙舟竞渡从此兼具纪念英雄与祈福避疫的双重意义。这种文化内核&#xff0c;与深圳市华芯邦“以科…

OS9.【Linux】基本权限(下)

目录 1.默认权限 掩码 修改权限掩码 目录的权限说明 r权限 w权限 x权限 结论 家目录权限 2.共享目录 粘滞位t 承接OS8.【Linux】基本权限(上)文章 1.默认权限 创建用户时拥有者所属组都是该用户,而且对其他人没有任何权限 掩码 新建文件new.txt1和目录folder后…

【容器docker】启动容器kibana报错:“message“:“Error: Cannot find module ‘./logs‘

说明&#xff1a; 1、服务器数据盘挂了&#xff0c;然后将以前的数据用rsync拷贝过去&#xff0c;启动容器kibana服务&#xff0c;报错信息如下图所示&#xff1a; 2、可能是拷贝docker文件夹&#xff0c;有些文件没有拷贝过去&#xff0c;导致无论是给文件夹授权用户kibana或者…

【25-cv-05917】HSP律所代理Le Petit Prince 小王子商标维权案

Le Petit Prince 小王子 案件号&#xff1a;25-cv-05917 立案时间&#xff1a;2025年5月28日 原告&#xff1a;SOCIETE POUR LOEUVRE ET LA MEMOIRE DANTOINE DE SAINT EXUPERY - SUCCESSION DE SAINT EXUPERY-DAGAY 代理律所&#xff1a;HSP 原告介绍 《小王子》&#x…

信创国产化

一、硬件国产化 1. 飞腾E2000Q 二、操作系统国产化 1. 麒麟系统 1.1 麒麟嵌入式支持飞腾E2000Q 1.1.1 启动安装盘制作 1. 下载rufus工具,安装,下载麒麟系统ISO镜像文件。 2. 使用rufus制作启动盘,U盘插入(注先备份数据,会格式化盘符),配置参数如图。 3. 点击…

一、Sqoop历史发展及原理

作者&#xff1a;IvanCodes 日期&#xff1a;2025年5月30日 专栏&#xff1a;Sqoop教程 在大数据时代&#xff0c;数据往往分散存储在各种不同类型的系统中。其中&#xff0c;传统的关系型数据库 (RDBMS) 如 MySQL, Oracle, PostgreSQL 等&#xff0c;仍然承载着大量的关键业务…

2.从0开始搭建vue项目(node.js,vue3,Ts,ES6)

从“0到跑起来一个 Vue 项目”&#xff0c;重点是各个工具之间的关联关系、职责边界和技术演化脉络。 从你写代码 → 到代码能跑起来 → 再到代码可以部署上线&#xff0c;每一步都有不同的工具参与。 &#x1f63a;&#x1f63a;1. 安装 Node.js —— 万事的根基 Node.js 是…

包管理工具

npx工具 npx是什么捏&#xff1f; npx是npm5.2之后自带的一个命令 npx的作用非常之多&#xff0c;但是比较常见的是它用来调用项目中的某个模块的指令 现在假设一个场景&#xff1a; 你在项目里安装了webpack&#xff0c;也在全局中安装了webpack&#xff0c;但是这俩版本…

信号发生器幅值和偏置设置

Vrms是有效幅度 Vpp是幅度峰峰值 Vp是幅度最大值 幅度 2Vpp, 偏置 0V: 信号范围&#xff1a; -1V (谷底) 到 1V (峰顶) -> 中心点在 0V。 幅度 2Vpp, 偏置 1V: 信号范围&#xff1a; (-1V 1V) 0V (谷底) 到 (1V 1V) 2V (峰顶) -> 中心点在 1V。 形状和 Vpp (2…

深入浅出:Spring IOCDI

什么是IOC IOC IOC(Inversion of Control)&#xff0c;是一种设计思想&#xff0c;在之前的SpringMVC里就在类上添加RestController和Controller注解就是使用了IOC&#xff0c;这两个注解就是在Spring中创建一个对象&#xff0c;并将注解下的类交给Spring管理&#xff0c;Spr…

Java并发

一、进程和线程 进程&#xff1a; 程序的一次执行过程&#xff0c;是系统运行程序的基本单位&#xff0c;因此进程是动态。系统运行一个程序即是一个进程从创建&#xff0c;运行到消亡的过程。 在Java中&#xff0c;当我们启动main函数时其实就是启动了一个JVM进程&#xff…

通过回调函数注册定时器触发事件

1、说明 使用回调函数&#xff0c;注册定时器触发事件的模式&#xff0c;提高定时器中断的可操作性&#xff0c;那如何实现呢&#xff1f; 2、.h文件 下面是定时器句柄的声明 3、.c文件 3.1、静态定时器句柄头 3.2、定时器回调函数处理 下面的函数是放在1ms的中断中的&#…