小白的进阶之路系列之七----人工智能从初步到精通pytorch自动微分优化以及载入和保存模型

article/2025/9/6 4:21:20

本文将介绍Pytorch的以下内容

自动微分函数

优化

模型保存和载入

好了,我们首先介绍一下关于微分的内容。

在训练神经网络时,最常用的算法是反向传播算法。在该算法中,根据损失函数相对于给定参数的梯度来调整参数(模型权重)。

为了计算这些梯度,PyTorch有一个内置的微分引擎,名为torch.autograd。它支持任何计算图的梯度自动计算。

考虑最简单的单层神经网络,输入x,参数w和b,以及一些损失函数。它可以在PyTorch中以以下方式定义:

import torchx = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

张量、函数与计算图

这段代码定义了以下计算图:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在这个网络中,w和b是我们需要优化的参数。因此,我们需要能够计算损失函数相对于这些变量的梯度。为了做到这一点,我们设置了这些张量的requires_grad属性。

我们应用于张量来构造计算图的函数实际上是函数类的对象。该对象知道如何在正向方向上计算函数,以及如何在反向传播步骤中计算其导数。对反向传播函数的引用存储在张量的grad_fn属性中。您可以在文档中找到Function的更多信息。

print(f"Gradient function for z = {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")

输出为:

Gradient function for z = <AddBackward0 object at 0x0000022EDB445C30>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x0000022EDB445D20>

计算梯度

为了优化神经网络中参数的权重,我们需要计算损失函数对参数的导数,即我们需要∂loss/∂w和∂loss/∂B。为了计算这些导数,我们调用loss.backward(),然后从w.g grad和b.g grad中检索值:

loss.backward()
print(w.grad)
print(b.grad)

输出为:

tensor([[0.0549, 0.1796, 0.0399],[0.0549, 0.1796, 0.0399],[0.0549, 0.1796, 0.0399],[0.0549, 0.1796, 0.0399],[0.0549, 0.1796, 0.0399]])
tensor([0.0549, 0.1796, 0.0399])

禁用梯度跟踪

默认情况下,所有requires_grad=True的张量都在跟踪它们的计算历史并支持梯度计算。然而,在某些情况下,我们不需要这样做,例如,当我们训练了模型,只想将其应用于一些输入数据时,即我们只想通过网络进行前向计算。我们可以通过使用torch.no_grad()块包围我们的计算代码来停止跟踪计算:

z = torch.matmul(x, w)+b
print(z.requires_grad)with torch.no_grad():z = torch.matmul(x, w)+b
print(z.requires_grad)

输出为:

True
False

实现相同结果的另一种方法是在张量上使用detach()方法:

z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)

输出为:

False

你可能想要禁用渐变跟踪的原因如下:

  • 将神经网络中的一些参数标记为冻结参数。

  • 当你只做正向传递时,为了加快计算速度,因为在不跟踪梯度的张量上的计算会更有效率。

更多关于计算图的知识

从概念上讲,autograd在由Function对象组成的有向无环图(DAG)中保存数据(张量)和所有执行的操作(以及产生的新张量)的记录。在DAG中,叶是输入张量,根是输出张量。通过从根到叶的跟踪图,您可以使用链式法则自动计算梯度。

在向前传递中,autograd同时做两件事:

  • 运行请求的操作来计算结果张量

  • 在DAG中维持操作的梯度函数。

当在DAG根上调用.backward()时,向后传递开始。autograd:

  • 计算每个。grad_fn的梯度,

  • 在各自张量的.grad属性中累积它们

  • 利用链式法则,一直传播到叶张量。

[!TIP]

PyTorch中的dag是动态的,需要注意的重要一点是图形是从头开始重新创建的;在每次.backward()调用之后,autograd开始填充一个新图。这正是允许您在模型中使用控制流语句的原因;如果需要,您可以在每次迭代中更改形状、大小和操作

张量梯度和雅可比积

在很多情况下,我们有一个标量损失函数,我们需要计算关于一些参数的梯度。然而,在某些情况下,输出函数是一个任意张量。在这种情况下,PyTorch允许你计算所谓的雅可比积,而不是实际的梯度。

inp = torch.eye(4, 5, requires_grad=True)
out = (inp+1).pow(2).t()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"First call\n{inp.grad}")
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")

输出为:

First call
tensor([[4., 2., 2., 2., 2.],[2., 4., 2., 2., 2.],[2., 2., 4., 2., 2.],[2., 2., 2., 4., 2.]])Second call
tensor([[8., 4., 4., 4., 4.],[4., 8., 4., 4., 4.],[4., 4., 8., 4., 4.],[4., 4., 4., 8., 4.]])Call after zeroing gradients
tensor([[4., 2., 2., 2., 

http://www.hkcw.cn/article/cvhXRIZZWL.shtml

相关文章

王树森推荐系统公开课 特征交叉01:Factorized Machine (FM) 因式分解机

对于FM的评价&#xff0c;引用视频底下的评论&#xff1a; FM算法在很久之前使用广泛&#xff0c;现在已逐渐淘汰。 线性模型只是加权和&#xff0c;没有考虑多个特征之间的交叉&#xff0c;在推荐系统中&#xff0c;特征交叉的作用是相当重要的。 如果 d d d 太大就不合适…

IAR无法跳转定义,IARstm8跳转显示路径出错,系统库文件文件名后有[RO]

当我们打开程序后&#xff0c;按下键盘F12无跳转或者显示路径出错 原因就是库文件是只读类型&#xff0c;在IAR里面无法跳转&#xff0c;可以看到后缀显示【RO】 解决办法就是&#xff0c;把IAR软件关闭&#xff0c;把标准库文件的只读给取消掉 重新打开IAR工程 然后修改头文件…

从零开始的云计算生活——第十一天,知识延续,程序管理。

一故事背景 今日整体内容是第十天的剩余部分再加上程序管理的开头部分&#xff0c;详细可以回到第十天看新增加内容&#xff0c;现在开始讲解新内容。 二Linux程序与进程 1程序,进程,线程的概念 程序&#xff1a;‌是一段静态的代码&#xff0c;它是应用软件执行的蓝本。程序…

STM32 单片机启动过程全解析:从上电到主函数的旅程

一、为什么要理解启动过程&#xff1f; STM32 的启动过程就像一台精密仪器的开机自检&#xff0c;它确保所有系统部件按既定方式初始化&#xff0c;才能顺利运行我们的应用代码。对初学者而言&#xff0c;理解启动过程能帮助解决常见“程序跑飞”“不进 main”“下载后无反应”…

2022 RoboCom 世界机器人开发者大赛(睿抗 caip) -高职组(国赛)解题报告 | 科学家

前言 题解 2022 RoboCom 世界机器人开发者大赛(睿抗 caip) -高职组&#xff08;国赛&#xff09;。 最后一题还考验能力&#xff0c;需要找到合适的剪枝。 RC-v1 智能管家 分值: 20分 签到题&#xff0c;map的简单实用 #include <bits/stdc.h>using namespace std;int…

typora插件下载链接和导入说明

1.引言 先看插件效果&#xff0c;本插件自带了历史文件tab切换、引用图片管理、思维导图、文档大纲、图排优化、文件模板、夜间模式等很多功能&#xff0c;插件的下载链接在本文最后。 2.安装插件 typora-0.9.98 之前的版本不推荐使用 插件解压为plugin文件夹&#xff0c;并移…

深化生态协同,宁盾身份域管完成与拓波软件兼容互认证

在信创产业蓬勃发展的浪潮下&#xff0c;行业生态的兼容适配决定了信创产品是否好用。近日&#xff0c;宁盾身份域管与拓波软件 TurboEX 邮件系统完成兼容互认证。测试结果显示宁盾身份域管&#xff08;信创版&#xff09;与 TurboEX 邮件服务器软件相互良好兼容&#xff0c;运…

Socket 编程 TCP

目录 1. TCP socket API 详解 1.1 socket 1.2 bind 1.3 listen 1.4 accept 1.5 read&&write 1.6 connect 1.7 recv 1.8 send 1.9 popen 1.10 fgets 2. EchoServer 3. 多线程远程命令执行 4. 引入线程池版本翻译 5. 验证TCP - windows作为client访问Linu…

SmolVLM2: The Smollest Video Model Ever(七)

编写测试代码与评价指标 现在的数据集里面只涉及tool的分类和手术phase的分类&#xff0c;所以编写的评价指标还是那些通用的&#xff0c;但是&#xff1a; predicted_labels:[The current surgical phase is CalotTriangleDissection, Grasper, Hook tool exists., The curre…

Cancer Cell丨肺癌早期干预新突破,TIM-3靶点或成关键

2025年5月8日&#xff0c;Cancer Cell 在线发表了一篇来自美国MD安德森癌症中心的研究文章Spatial and multiomics analysis of human and mouse lung adenocarcinoma precursors reveals TIM-3 as a putative target for precancer interception。作者整合了空间蛋白组、转录组…

全志V853挂载sd卡

参考文章:https://blog.csdn.net/weixin_59351001/article/details/127102440 1、插上sd卡 fdisk -l2、挂载SD卡到开发板 mount /dev/mmcblk1p1 /mnt/sdcard挂载失败(如下报错),需要格式化SD卡再进行挂载

性能测试-jmeter实战1

课程&#xff1a;B站大学 记录软件测试-性能测试学习历程、掌握前端性能测试、后端性能测试、服务端性能测试的你才是一个专业的软件测试工程师 性能测试-jmeter实战1 为什么需要性能测试呢&#xff1f;性能测试的作用&#xff1f;性能测试体系性能测试基础性能测试工具性能监控…

PABD 2025:大数据与智慧城市管理的融合之道

会议简介 2025年公共管理与大数据国际会议&#xff08;ICPMBD 2025&#xff09;确实在海口举办。本次会议将围绕公共管理与大数据的深度融合、数据分析在公共管理中的应用、大数据驱动的政策制定与优化等议题展开深入研讨。参会者将有机会聆听前沿学术报告&#xff0c;分享研究…

DL00924-基于深度学习YOLOv11的工程车辆目标检测含数据集

文末有代码完整出处 &#x1f697; 基于深度学习YOLOv11的工程车辆目标检测——引领智能识别新潮流&#xff01; &#x1f680; 随着人工智能技术的飞速发展&#xff0c; 目标检测 已经在各个领域取得了显著突破&#xff0c;尤其是在 工程车辆识别 这一关键技术上。今天&#…

Java 对接 Office 365 邮箱全攻略:OAuth2 认证 + JDK8 兼容 + Spring Boot 集成(2025 版)

&#x1f6a8; 重要通知&#xff1a;微软强制 OAuth2&#xff0c;传统认证已失效&#xff01; 2023 年 10 月起&#xff0c;Office 365 全面禁用用户名 密码认证&#xff0c;Java 开发者必须通过OAuth 2.0实现邮件发送。本文针对 CSDN 技术栈&#xff0c;提供从 Azure AD 配置…

秒杀/高并发解决方案+落地实现

前面我们防止超卖 是通过到数据库查询和到数据库抢购,来完成的, 代码如下:如果在短时间内,大量抢购冲击 DB, 造成洪峰, 容易压垮数据库解决方案:使用 Redis 完成预减库存,如果没有库存了,直接返回,减小对 DB 的压力。图示:Redis 的预减,已经存在了原子性,就是一条一条…

Baklib企业知识激活解决方案

Baklib知识中台构建路径 Baklib通过模块化架构设计与智能数据治理双轮驱动&#xff0c;为企业构建知识中台提供标准化实施路径。首先基于自然语言处理&#xff08;NLP&#xff09;技术实现非结构化文档的语义解析&#xff0c;打通CRM、ERP等业务系统间的数据孤岛&#xff1b;随…

【Gemini 深度研究】人形机器人:最新开发方案与未来展望 (2024-2025)

Gemini根据深度研究报告自动生成的html网页录屏 人形机器人&#xff1a;最新开发方案与未来展望 (2024-2025) I. 执行摘要 2024年至2025年&#xff0c;人形机器人正处于从科研探索向实际应用转型的关键时期&#xff0c;其作为通用型机器人的潜力日益显现。这一转变主要得益于具…

【动态规划:斐波那契数列模型】第 N 个泰波那契数

1、第 N 个泰波那契数&#xff08;easy&#xff09; 1137. 第 N 个泰波那契数 泰波那契序列 Tn 定义如下&#xff1a; ​ T0 0, T1 1, T2 1, 且在 n > 0 的条件下 Tn3 Tn Tn1 Tn2。给你整数 n&#xff0c;请返回第 n 个泰波那契数 Tn 的值。 示例 1&#xff1a; …

秋招Day11 - JVM - JVM调优

性能监控的命令行工具&#xff1f; 操作系统层面&#xff1a; 我用过top来查看cpu和内存的使用情况使用过vmstat查看过虚拟内存的统计信息使用过iostat查看过系统的io情况使用过netstat查看过系统的网络信息 JDK自带的命令层面&#xff0c;我使用过&#xff1a; jmap -heap…