PyTorch实战——基于生成对抗网络生成服饰图像

article/2025/7/22 8:58:36

PyTorch实战——基于生成对抗网络生成服饰图像

    • 0. 前言
    • 1. 模型分析与数据准备
    • 2. 判别器
    • 3. 生成器
    • 4. 模型训练
    • 5. 模型保存与加载
    • 相关链接

0. 前言

我们已经学习了生成对抗网络 (Generative Adversarial Network, GAN) 的工作原理,接下来,将学习如何将其应用于生成其他形式的内容。在本节中,介绍使用 GAN 创建灰度图像,包括外套、衬衫、凉鞋等服饰,学习在设计生成器网络时如何镜像判别器网络。在本节中,生成器和判别器网络使用全连接层,全连接层的每个神经元都与前一层和后一层的所有神经元相连接。

1. 模型分析与数据准备

在本节中,我们将训练一个生成对抗网络 (Generative Adversarial Network, GAN) 模型,生成如凉鞋、T恤、外套和包等服装的灰度图像。在使用 GAN 生成图像时,首先需要获取训练数据。然后,从零开始创建一个判别器网络。在创建生成器网络时,将镜像判别器网络的架构。最后,训练 GAN,并使用训练好的模型来生成图像。接下来,让我们通过实现一个简单的 GAN 模型来生成灰度服装图像。
准备训练数据,以创建使用批数据的迭代器。训练集包含 60,000 张图像,在图像分类模型中,我们通常将训练集进一步划分为训练集和验证集,使用验证集的损失来判断模型参数是否已收敛,从而决定是否停止训练。但 GAN 的训练方法与传统的监督学习模型不同,由于生成样本的质量在训练过程中不断提高,判别器的训练变得越来越困难。因此,判别器网络的损失不能很好地反映模型的质量。通常评估 GAN 性能的方法是通过视觉检查,评估生成图像的质量和真实性。也可以通过与训练样本的比较来评估生成样本的质量,并使用如 Inception Score 之类的评估方法来评估 GAN 的表现。但研究表明这类评估方法存在缺陷,Inception Score 在模型比较时未能提供有用的指导。在本节中,我们将定期使用视觉检查来检查生成样本的质量,并确定何时停止训练。

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as Ttransform=T.Compose([T.ToTensor(),T.Normalize([0.5],[0.5])])train_set=torchvision.datasets.FashionMNIST(root=".",train=True,download=True,transform=transform)batch_size=32
train_loader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=True)import matplotlib.pyplot as plt
from torchvision.utils import make_gridimages, labels = next(iter(train_loader))
grid = make_grid(0.5-images/2, 8, 4)
plt.imshow(grid.numpy().transpose((1, 2, 0)),cmap="gray_r")
plt.axis("off")
plt.show()

2. 判别器

判别器网络类似于二分类器,将样本分类为真实或虚假。

(1) 使用 PyTorch 创建判别器神经网络 D

import torch.nn as nndevice="cuda" if torch.cuda.is_available() else "cpu"D=nn.Sequential(nn.Linear(784, 1024), # 第一个全连接层有 784 个输入和 1,024 个输出nn.ReLU(),nn.Dropout(0.3),nn.Linear(1024, 512),nn.ReLU(),nn.Dropout(0.3),nn.Linear(512, 256),nn.ReLU(),nn.Dropout(0.3),nn.Linear(256, 1),  # 最后一个全连接层有 256 个输入和 1 个输出nn.Sigmoid()).to(device)

输入大小为 784,因为训练集中的每张灰度图像大小为 28 × 28 像素。由于全连接层只接收一维输入,因此在将图像传递给模型之前,需要先将图像展平。输出层只有一个神经元,判别器 D 的输出是一个单一的值,使用 sigmoid 激活函数将输出压缩到 [0, 1] 范围内,以便可以将其解释为样本为真实的概率 p,而 1 - p 则表示样本是虚假的概率。

3. 生成器

(1) 创建生成器的一种常见方法是将判别器网络中使用的架构进行镜像来创建生成器:

G=nn.Sequential(nn.Linear(100, 256),  # 生成器中的第一个层与判别器中的最后一层对称nn.ReLU(),nn.Linear(256, 512),  # 生成器中的第二个层与判别器中的倒数第二层对称nn.ReLU(),nn.Linear(512, 1024), # 生成器中的第三个层与判别器中的倒数第三层对称nn.ReLU(),nn.Linear(1024, 784), # 生成器中的最后一层与判别器中的第一层对称nn.Tanh()).to(device) # 使用 Tanh() 激活函数,使得输出值在 -1 和 1 之间,与图像中的值相同

下图展示了用于生成服装灰度图像的生成器和判别器网络的架构。如图所示,判别器的输入来自训练集的一个展平后的灰度图像(包含 28 × 28 = 784 个像素),依次通过判别器网络的四个全连接层,输出的是该图像为真实图像的概率。为了生成图像,生成器使用相同的四个全连接层,但顺序相反,从潜空间获取一个包含 100 个值的随机噪声向量,并将该向量依次通过这四个全连接层。在每一层中,判别器中的每个网络层的输入输出数目颠倒后,作为生成器中每层的输出和输入数目。最终,生成器生成一个包含 784 个值的张量,这个张量可以整形为一个 28 × 28 的灰度图像。

模型架构

上图中左侧是生成器网络,右侧是判别器网络。比较这两个网络,可以看到生成器如何镜像判别器的架构。具体来说,生成器包含四个类似的全连接层,但顺序相反,生成器中的第一层镜像判别器中的最后一层,生成器中的第二层镜像判别器中的倒数第二层,依此类推。生成器的输出为一个包含 784 个值的张量,这些值在经过 Tanh() 激活函数后位于 -11 之间,这与判别器网络的输入相匹配。

(1) 判别器 D 执行的是二分类任务,因此 GAN 模型的损失函数使用二元交叉熵损失。判别器和生成器都使用 Adam 优化器,学习率为 0.0001

loss_fn=nn.BCELoss()
lr=0.0001
optimD=torch.optim.Adam(D.parameters(),lr=lr)
optimG=torch.optim.Adam(G.parameters(),lr=lr)

接下来,使用训练数据集中服装图像训练本节创建的 GAN 模型。

4. 模型训练

(1) 在本节中,依靠视觉检查来判断模型是否训练完成,为此,定义 see_output() 函数,定期可视化生成器生成的虚假图像。需要注意的是,虽然我们可以使用 PyTorch 实现 Inception Score 来评估 GAN,但由于 Inception Score 评估方法的低效性,并不推荐使用 Inception Score 来评估生成模型:

import matplotlib.pyplot as pltdef see_output():noise=torch.randn(32,100).to(device=device)fake_samples=G(noise).cpu().detach()    # 生成 32 张虚假图像plt.figure(dpi=100,figsize=(20,10))for i in range(32):ax=plt.subplot(4, 8, i + 1)img=(fake_samples[i]/2+0.5).reshape(28, 28)plt.imshow(img)    # 图像可视化plt.xticks([])plt.yticks([])plt.show()see_output()

运行代码,可以看到生成的图像如下所示,它们完全不像服装,因为生成器还未经过训练。

结果可视化

(2) 为了训练 GAN 模型,定义函数:train_D_on_real()train_D_on_fake()train_G()

real_labels=torch.ones((batch_size,1)).to(device)
fake_labels=torch.zeros((batch_size,1)).to(device)def train_D_on_real(real_samples):r=real_samples.reshape(-1,28*28).to(device)out_D=D(r)    labels=torch.ones((r.shape[0],1)).to(device)loss_D=loss_fn(out_D,labels)    optimD.zero_grad()loss_D.backward()optimD.step()    return loss_Ddef train_D_on_fake():        noise=torch.randn(batch_size,100).to(device=device)generated_data=G(noise)preds=D(generated_data)loss_D=loss_fn(preds,fake_labels)optimD.zero_grad()loss_D.backward()optimD.step()return loss_Ddef train_G(): noise=torch.randn(batch_size,100).to(device=device)generated_data=G(noise)preds=D(generated_data)loss_G=loss_fn(preds,real_labels)optimG.zero_grad()loss_G.backward()optimG.step()return loss_G

(3) 接下来,训练模型,遍历训练数据集中的所有批数据。对于每个批数据,首先使用真实样本训练判别器。之后,生成器生成一批虚假样本,用这些虚假样本再次训练判别器。最后,使用生成器再次生成一批虚假样本,用它们来训练生成器。训练模型 50epoch,生成结果如下所示:

for i in range(50):gloss=0dloss=0for n, (real_samples,_) in enumerate(train_loader):loss_D=train_D_on_real(real_samples) # 使用真实样本训练判别器dloss+=loss_Dloss_D=train_D_on_fake()    # 使用虚假样本训练判别器dloss+=loss_Dloss_G=train_G()            # 训练生成器gloss+=loss_Ggloss=gloss/ndloss=dloss/n# 每隔 10 个 epoch 可视化生成图像if i % 10 == 9:print(f"at epoch {i+1}, dloss: {dloss}, gloss {gloss}")see_output()

生成结果

每训练 10epoch,可视化生成的服装,如上图所示。经过 10epoch 的训练后,模型已经能够生成明显可以作为真实服装的图像,能够明显的辨别出图像的外形,随着训练的进行,生成的图像质量越来越好。

5. 模型保存与加载

(1) 丢弃判别器,并保存训练好的生成器,以便生成样本:

import os
scripted = torch.jit.script(G) 
os.makedirs("files", exist_ok=True)
scripted.save('files/fashion_gen.pt')

(2) 将生成器保存在本地文件夹中后,要使用生成器,只需加载模型:

new_G=torch.jit.load('files/fashion_gen.pt',map_location=device)
new_G.eval()

(3) 生成器加载完成后,将其用于生成服装图像:

noise=torch.randn(batch_size,100).to(device=device)
fake_samples=new_G(noise).cpu().detach()
for i in range(32):ax = plt.subplot(4, 8, i + 1)plt.imshow((fake_samples[i]/2+0.5).reshape(28, 28))plt.xticks([])plt.yticks([])
plt.subplots_adjust(hspace=-0.6)
plt.show()

生成的服装如下图所示,可以看到,生成的服装与训练集中的服装非常接近。

生成结果

相关链接

PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch实战(1)——神经网络与模型训练过程详解
PyTorch实战(2)——PyTorch基础
PyTorch实战(3)——使用PyTorch构建神经网络
PyTorch实战(4)——卷积神经网络详解
PyTorch实战(5)——分类任务详解
PyTorch实战(6)——生成模型(Generative Model)详解
PyTorch实战(7)——生成对抗网络实践详解
PyTorch实战——生成对抗网络数值数据生成


http://www.hkcw.cn/article/coyFPFptuF.shtml

相关文章

C++四种类型转换方式

const_cast,去掉(指针或引用)常量属性的一个类型转换,但需要保持转换前后类型一致static_cast,提供编译器认为安全的类型转换(最常使用)reinterpret_cast,类似于c语言风格的强制类型转换,不保证安全;dynamic_cast,主要用于继承结构中&#xf…

得物C++开发面试题及参考答案

HTTP/HTTPS 协议的区别及 HTTPS 加密过程 HTTP(超文本传输协议)是一种用于传输超文本的协议,它是明文传输的,这意味着数据在传输过程中容易被截取和篡改,存在较大的安全隐患。而 HTTPS(超文本传输安全协议…

头歌之动手学人工智能-Pytorch 之优化

目录 第1关:如何使用optimizer 任务描述 编程要求 测试说明 真正的科学家应当是个幻想家;谁不是幻想家,谁就只能把自己称为实践家。 —— 巴尔扎克开始你的任务吧,祝你成功! 第2关:optim.SGD 任务描述…

RV1126-OPENCV Mat理解和AT函数

一.Mat概念 Mat 是整个图像存储的核心也是所有图像处理的最基础的类,Mat 主要存储图像的矩阵类型,包括向量、矩阵、灰度或者彩色图像等等。Mat由两部分组成:矩阵头,矩阵数据。矩阵头是存储图像的长度、宽度、色彩信息等头部信息&a…

DeepSeek R1-0528模型:五大升级亮点,引领AI推理新高度

在AI技术迅猛发展的浪潮中,模型的迭代升级不断推动着行业的进步。DeepSeek R1-0528模型的推出,犹如一颗重磅炸弹,在AI领域激起千层浪。它究竟有何神奇之处?下面为你揭秘其五大全新升级亮点。 深度思考能力显著提升 DeepSeek R1-05…

司机缺氧离世有5个上学孩子 家庭重担引关注

近日,46岁的河南卡车司机常志荣在青藏高原离世,卡友团队一同将其骨灰接回老家。6月1日,常志荣已经在老家安葬。他去世后,留下了重组家庭的6个孩子,其中5个孩子还在上学。车友任先生透露,常大哥出发青藏线运输前,同行曾建议他至少携带两罐氧气,但他为省下30元费用,最终…

迪士尼情侣和一家三口打架 拍照争执引发冲突

6月1日,浦东公安分局接到报警称迪士尼乐园内发生打架事件。初步调查显示,闫某某(男,22岁)与女友在拍照时,因刘某某(男,36岁)夫妻的女儿进入拍摄画面,双方发生口角并引发肢体冲突,造成闫某某和刘某某互有皮外伤,小女孩未受伤。目前,调查处理工作正在进行中。当天,…

温度计“液泡”是什么,温度计为什么能测温?

温度计“液泡”是什么,温度计为什么能测温? 液体膨胀式温度计介绍 最近,有位小朋友说,他看的一本科普书上说:把温度计插在水里,水分子就会对液泡产生撞击,液泡里面的分子就会跟着动起来&#x…

C++学习过程分享

空指针:int *p NULL; 空指针:指针变量指向内存中编号为0的空间;用途:初始化指针变量注意:空指针指向的内存不允许访问注意:内存编号为0-255为系统占用空间,不允许用户访问 野指针:…

【IC】RTL功耗高精度预测

介绍 美国能源部(DOE)的一份综合报告“半导体供应链深度潜水评估”(2022年2月)呼吁将能源效率提高1000倍,以维持未来的需求,因为世界能源产量有限。能源效率是当今设计师的首要任务。能源效率的整体方法必…

美国要求澳大利亚上调军费 提升至GDP的3.5%

当地时间6月1日,美国国防部长赫格塞思在新加坡会见了澳大利亚副总理兼国防部长理查德马尔斯。双方讨论了美澳联盟的关键优先事项。赫格塞思在会谈中建议澳大利亚尽快将国防开支提高到国内生产总值的3.5%。尽管澳大利亚不是北约成员国,美国总统特朗普此前曾多次要求北约成员国…

高反缺氧去世卡友已回老家下葬 留下6个孩子引关注

近日,46岁的河南卡车司机常志荣在青藏高原不幸离世。他的卡友团队一同将其骨灰接回老家,并于6月1日在老家安葬。常志荣去世后,留下了重组家庭的六个孩子,其中五个孩子还在上学。车友任先生透露,常志荣出发前曾被建议至少携带两罐氧气,但他为了节省30元费用,最终只购买了…

亚历山大本赛季两战步行者场均39分 雷霆双胜步行者

NBA总决赛的对阵双方是雷霆和步行者。本赛季两队交手两次,雷霆分别以120-114和132-111取胜。这两场比赛中霍姆格伦都没有出战。亚历山大在这两场比赛中的表现非常出色,场均能够贡献39分、7篮板、8助攻、1抢断和1盖帽,场均罚球次数达到11.5次,真实命中率为71.1%。在面对步行…

六地将有大到暴雨 端午出行需谨慎

中央气象台今日6时继续发布暴雨蓝色预警和强对流天气蓝色预警。福建、广东、广西等六个地区将出现大到暴雨,多地还将遭遇8级以上的雷暴大风或冰雹天气。正值端午假期最后一天,出行前请务必关注天气情况。责任编辑:zhangxiaohua

017搜索之深度优先DFS——算法备赛

深度优先搜索 如果说广度优先搜索是逐层扩散,那深度优先搜索就是一条道走到黑。 深度优先遍历是用递归实现的,预定一条顺序规则(如上下左右顺序) ,一直往第一个方向搜索直到走到尽头或不满足要求后返回上一个叉路口按…

举办中国户外运动展哪个城市较理想

杭州:强劲经济引擎,中国户外运动展的理想之选! 为什么是杭州?—— 硬核实力,无可争议! 经济活力领跑: 浙江人均GDP超2.5万美元,人均收入与消费全国TOP 1!2024年省外人口…

JMeter接口自动化脚本框架

登录后的CRUD自动化脚本 内容: 用户自定义变量 ${}引用 HTTP请求默认值:复用内容 HTTP信息头:请求类型、token、cookie setUp、tearDown线程组:前后置操作 响应断言:文本、代码 Json提取器:提取响应…

缺氧离世卡车司机已下葬卡友发声 家庭重担引关注

近日,46岁的河南卡车司机常志荣在青藏高原离世,卡友团队将其骨灰接回老家。6月1日,他在老家安葬。常志荣去世后,留下了重组家庭的6个孩子,其中5个还在上学。据车友任先生透露,常志荣出发前同行曾建议他至少携带两罐氧气,但他为省下30元费用,最终只购买了一罐氧气。常志…

【C++】多态

目录 1. 多态的概念 2. 多态的定义和实现 2.1 构成多态的条件 2.2 虚函数 2.3 虚函数的重写(覆盖) 2.4 小试牛刀 3. 重载/重写/隐藏的对比 4. 纯虚函数和抽象类 5.多态的原理 5.1 虚表 5.2 虚表指针 5.3 对比虚函数、虚表、虚表指针 1. 多态的…

肖战演活了藏海 台湾观众好评如潮

正在热播的电视剧《藏海传》在台湾引起了广泛关注,不仅获得了岛内观众的一致好评,也得到了媒体的争相报道。这部剧以其精良的制作、紧凑的情节以及所展现的中华文化深深吸引了台湾观众。5月31日,“肖战演藏海在台湾刷屏”这一话题冲上了微博热搜榜。近年来,大陆电视剧在台湾…