【python深度学习】Day 39 图像数据与显存

article/2025/9/4 9:36:46
知识点
  1. 图像数据的格式:灰度和彩色数据
  2. 模型的定义
  3. 显存占用的4种地方
    1. 模型参数+梯度参数
    2. 优化器参数
    3. 数据批量所占显存
    4. 神经元输出中间状态
  4. batchisize和训练的关系

作业:今日代码较少,理解内容即可

一、图像数据的介绍

结构化数据(如表格)的形状通常是 (样本数, 特征数)

图像数据的形状更复杂,需要保留空间信息(高度、宽度、通道),不能直接用一维向量表示。其中,颜色信息往往是最开始输入数据的通道的含义,因为每个颜色可以用红绿蓝三原色表示,因此一般输入数据的通道数是 3。

1.灰度图像

2.彩色图像

在 PyTorch 中,图像数据的形状的格式是 (通道数, 高度, 宽度) ,即 Channel First 格式。这与常见的 (高度, 宽度, 通道数)格式不同,Channel Last 格式,如 NumPy 数组。---注意顺序关系

注意点:

1. 如果用matplotlib库来画图,需要转换下顺序,我们后续介绍

2. 模型输入还需要加入 批次维度(Batch Size),形状变为 (批次大小, 通道数, 高度, 宽度)。例如,批量输入 10 张 MNIST 图像时,形状为 (10, 1, 28, 28)。

二、图像相关神经网络定义

1.黑白图像

# 先归一化,再标准化
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])
import matplotlib.pyplot as plt# 2. 加载MNIST数据集,如果没有会自动下载
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)
# 定义两层MLP神经网络
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()  # 将28x28的图像展平为784维向量self.layer1 = nn.Linear(784, 128)  # 第一层:784个输入,128个神经元self.relu = nn.ReLU()  # 激活函数self.layer2 = nn.Linear(128, 10)  # 第二层:128个输入,10个输出(对应10个数字类别)def forward(self, x):x = self.flatten(x)  # 展平图像x = self.layer1(x)   # 第一层线性变换x = self.relu(x)     # 应用ReLU激活函数x = self.layer2(x)   # 第二层线性变换,输出logitsreturn x# 初始化模型
model = MLP()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)  # 将模型移至GPU(如果可用)from torchsummary import summary  # 导入torchsummary库
print("\n模型结构信息:")
summary(model, input_size=(1, 28, 28))  # 输入尺寸为MNIST图像尺寸

 

图像MLP与之前结构化MLP的差异对比

(1)输入图像,需要展平操作

        MLP 的输入层要求输入是一维向量,但 MNIST 图像是二维结构(28×28 像素),形状为 [1, 28, 28](通道 × 高 × 宽)。

        通过nn.Flatten()展平操作,将二维图像 “拉成” 一维向量(784=28×28 个元素),使其符合全连接层的输入格式。其中不定义这个flatten方法,直接在前向传播的过程中用 x = x.view(-1, 28 * 28) 将图像展平为一维向量也可以实现

(2)输入数据的尺寸包含了通道数input_size=(1, 28, 28)

(3)参数的计算

         第一层 layer1(全连接层)

            权重参数:输入维度 × 输出维度 = 784 × 128 = 100,352

            偏置参数:输出维度 = 128

            合计:100,352 + 128 = 100,480

        第二层 layer2(全连接层)

           权重参数:输入维度 × 输出维度 = 128 × 10 = 1,280

            偏置参数:输出维度 = 10

            合计:1,280 + 10 = 1,290

- 总参数:100,480(layer1) + 1,290(layer2) = 101,770

2.彩色图像

class MLP(nn.Module):def __init__(self, input_size=3072, hidden_size=128, num_classes=10):super(MLP, self).__init__()# 展平层:将3×32×32的彩色图像转为一维向量# 输入尺寸计算:3通道 × 32高 × 32宽 = 3072self.flatten = nn.Flatten()# 全连接层self.fc1 = nn.Linear(input_size, hidden_size)  # 第一层self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, num_classes)  # 输出层def forward(self, x):x = self.flatten(x)  # 展平:[batch, 3, 32, 32] → [batch, 3072]x = self.fc1(x)      # 线性变换:[batch, 3072] → [batch, 128]x = self.relu(x)     # 激活函数x = self.fc2(x)      # 输出层:[batch, 128] → [batch, 10]return x# 初始化模型
model = MLP()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)  # 将模型移至GPU(如果可用)from torchsummary import summary  # 导入torchsummary库
print("\n模型结构信息:")
summary(model, input_size=(3, 32, 32))  # CIFAR-10 彩色图像(3×32×32)

3.batch_size——批次维度

PyTorch 模型会自动处理 batch 维度(即第一维),无论 batch_size 是多少,模型的计算逻辑都不变。

class MLP(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten() # nn.Flatten()会将每个样本的图像展平为 784 维向量,但保留 batch 维度。self.layer1 = nn.Linear(784, 128)self.relu = nn.ReLU()self.layer2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x)  # 输入:[batch_size, 1, 28, 28] → [batch_size, 784]x = self.layer1(x)   # [batch_size, 784] → [batch_size, 128]x = self.relu(x)x = self.layer2(x)   # [batch_size, 128] → [batch_size, 10]return x

 三、占用显存的主要部分

1. 模型参数与梯度:模型的权重(Parameters)和对应的梯度(Gradients)会占用显存,尤其是深度神经网络(如 Transformer、ResNet 等),一个 1 亿参数的模型(如 BERT-base),单精度(float32)参数占用约 400MB(1e8×4Byte),加上梯度则翻倍至 800MB(每个权重参数都有其对应的梯度)

2. 优化器:部分优化器(如 Adam)会为每个参数存储动量(Momentum)和平方梯度(Square Gradient),进一步增加显存占用(通常为参数大小的 2-3 倍)

3. 其他开销:数据批次、前向或反向传播过程中的中间变量


http://www.hkcw.cn/article/QhEZhZAenp.shtml

相关文章

mongodb源码分析session接受客户端find命令过程

mongo/transport/service_state_machine.cpp已经分析startSession创建ASIOSession过程,并且验证connection是否超过限制。 现在继续研究ASIOSession和connection是怎么接受客户端命令的? mongo/transport/service_state_machine.cpp核心方法有&#xf…

酒店管理破局:AI 引领智能化转型

一、酒店行业现状:规模扩张加速与效率瓶颈并存 根据中国五矿证券《中国酒店行业格局分析》报告,国内酒店行业呈现头部集中化趋势。截至2024年第三季度,锦江酒店以13,186家门店、125.8万间客房的规模稳居行业第一[1]。华住集团则以178.68亿元…

大模型深度学习之双塔模型

前言 双塔模型(Two-Tower Model)是一种在推荐系统、信息检索和自然语言处理等领域广泛应用的深度学习架构。其核心思想是通过两个独立的神经网络(用户塔和物品塔)分别处理用户和物品的特征,并在共享的语义空间中通过相…

【Java Web】速通CSS

参考笔记:JavaWeb 速通CSS_java css-CSDN博客 目录 一、CSS入门 1. 基本介绍 2. 作用 二、CSS的3种引入方式 1. 行内式 1.1 示例代码 1.2 存在问题 2. 写在head标签的style子标签中 2.1 示例代码 2.2 存在问题 3.以外部文件的形式引入(开发中推荐使用)⭐⭐⭐ 3.1 说明 3…

PostgreSQL安装

我们使用开源的对象关系型数据库--PostgreSQL,它具有高性能、可扩展和支持复杂查询的特性,非常适合现在学习使用。 一.安装PostgreSQL 我用的windows,就在windows上安装。 1.首先访问 PostgreSQL 官方网站https://www.postgresql.org/dow…

C++:栈帧、命名空间、引用

一、前置知识 1.1、栈区(Stack) 1.1.1、内存分配与回收机制 分配方式​​:由编译器自动管理,通过调整栈指针(ESP/RSP)实现。 函数调用时,栈指针下移(栈从高地址向低地址增长&…

【HarmonyOS 5】鸿蒙应用px,vp,fp概念详解

【HarmonyOS 5】鸿蒙应用px,vp,fp概念详解 一、前言 目前的鸿蒙开发者,大多数是从前端或者传统移动端开发方向,转到鸿蒙应用开发方向。 前端开发同学对于开发范式很熟悉,但是对于工作流程和开发方式是会有不适感&am…

[Rust_1] 环境配置 | vs golang | 程序运行 | 包管理

目录 Rust 环境安装 GoLang和Rust 关于Go 关于Rust Rust vs. Go,优缺点 GoLang的优点 GoLang的缺点 Rust的优点 Rust的缺点 数据告诉我们什么? Rust和Go的主要区别 (1) 性能 (2) 并发性 (3) 内存安全性 (4) 开发速度 (5) 开发者体验 Ru…

Codeforces Round 1024 (Div. 2)

Problem - A - Codeforces 思维题&#xff1a; 如果n不能整除p&#xff0c;就会多出一部分&#xff0c;这个部分可以作为调和者&#xff0c;使整个数组符合要求。 如果n能整除p&#xff0c;没有调和空间&#xff0c;只有看n/p*qm 来看代码&#xff1a; #include <bits/s…

【东枫科技】KrakenSDR 天线阵列设置

标准测向需要五根相同的全向天线。您可以折衷使用更少的天线&#xff0c;但为了获得最佳性能&#xff0c;我们建议使用全部五根天线。这些天线通常是磁铁安装的鞭状天线&#xff0c;或偶极子天线。我们建议始终使用均匀圆形阵列 (UCA) 天线&#xff0c;因为它可以确定来自各个方…

包含Javascript的HTML静态页面调取本机摄像头

在实际业务开发中&#xff0c;需要在带有摄像头的工作机上拍摄施工现场工作过程的图片&#xff0c;然后上传到服务器备存。 这便需要编写可以运行在浏览器上的代码&#xff0c;并在代码中实现Javascript调取摄像头、截取帧保存为图片的功能。 为了使用户更快掌握JS调取摄像头…

2023年6月第三套第二篇

找和脑子有关系的rather than 不是的意思&#xff0c;不用看 instead表示递进的解释 even when即使不重要&#xff0c;看前方主句 d选项是even when和前方主句的杂糅&#xff0c;往往是错的 instead of 而不是 这道题&#xff0c;有的人觉得避免模仿这时候你会笑&#xff0c;所…

Redis的大Key问题如何解决?

大家好&#xff0c;我是锋哥。今天分享关于【Redis的大Key问题如何解决&#xff1f;】面试题。希望对大家有帮助&#xff1b; Redis的大Key问题如何解决&#xff1f; 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 Redis中的“大Key”问题是指某个键的值占用了过多…

magic-api配置Git插件教程

一、配置gitee.com 1&#xff0c;生成rsa密钥&#xff0c;在你的电脑右键使用管理员身份运行&#xff08;命令提示符&#xff09;&#xff0c;执行下面命令 ssh-keygen -t rsa -b 2048 -m PEM一直按回车键&#xff0c;不需要输入内容 找到 你电脑中的~/.ssh/id_rsa.pub 文件…

Virtuoso中对GDS文件进行工艺库转换的方法

如果要对相同工艺节点下进行性能评估&#xff0c;可以尝试将一个厂商的GDS文件转换到另一个厂商&#xff0c;不过要注意的是不同厂商&#xff08;比如SMIC和TSMC&#xff09;之间的DRC规则&#xff0c;尽量采用两个DRC中的约束较为紧张的厂商进行设计&#xff0c;以免转换到另外…

【二】9.关于设备树简介

1.什么是设备树&#xff1a; &#xff08;DTS&#xff09;采用树形结构描述扳级设备&#xff0c;也就是开发板上的设备信息&#xff0c;每个设备都是一个节点。 一个SOC可以做出很多不同的板子&#xff0c;这些不同的板子肯定是有共同的信息&#xff0c;将这些共同的信息提取出…

VSCode远程开发-本地SSH隧道保存即时修改

工作环境是一个网站团队几人同时在改&#xff0c;为了减少冲突&#xff0c;我们选择在自己公司服务器上先部署一版线上通用&#xff0c;再连接到不同的本地&#xff0c;这样我们团队可以在线上即时看到他人修改的结果&#xff0c;不用频繁拉取提交推送代码 在线上服务器建一个…

Embedded IDE下载及调试

安装cortex_debug插件 我这边用jlink烧录&#xff0c;其他的根据你自己的来 jlink路径在左下角齿轮设置里 设置位置&#xff1a; 芯片名称配置的都是自动生成的&#xff0c;在eide.json的这里改为你jflash芯片包的设置 调试里也会自动生成一个cortex_debug的调试选项 点旁边的…

lua注意事项

感觉是lua的一大坑啊&#xff0c;它还不如函数内部就局部变量呢 注意函数等内部&#xff0c;全部给加上local得了

【第4章 图像与视频】4.4 离屏 canvas

文章目录 前言为什么要使用 offscreenCanvas为什么要使用 OffscreenCanvas如何使用 OffscreenCanvas第一种使用方式第二种使用方式 计算时长超过多长时间适合用Web Worker 前言 在 Canvas 开发中&#xff0c;我们经常需要处理复杂的图形和动画&#xff0c;这些操作可能会影响页…