R3GAN训练自己的数据集

article/2025/8/27 8:22:13

简介

简介:这篇论文挑战了"GANs难以训练"的广泛观点,通过提出一个更稳定的损失函数和现代化的网络架构,构建了一个简洁而高效的GAN基线模型R3GAN。作者证明了通过合适的理论基础和架构设计,GANs可以稳定训练并达到优异性能。

论文题目:The GAN is dead; long live the GAN! A Modern Baseline GAN

会议:NeurIPS 2024

源码地址:https://www.github.com/brownvc/R3GAN

本文在调试代码的时候对代码做了一些修改,如果有遇到报错的问题可以直接复制我这篇博客修改后的代码:R3GAN利用配置好的Pytorch训练自己的数据集-CSDN博客这篇论文挑战了"GANs难以训练"的广泛观点,通过提出一个更稳定的损失函数和现代化的网络架构,构建了一个简洁而高效的GAN基线模型R3GAN。作者证明了通过合适的理论基础和架构设计,GANs可以稳定训练并达到优异性能。 https://blog.csdn.net/LJ1147517021/article/details/148315781?fromshare=blogdetail&sharetype=blogdetail&sharerId=148315781&sharerefer=PC&sharesource=LJ1147517021&sharefrom=from_link

摘要:论文反驳了GANs难以训练的普遍观点,提出了一个理论有保障的现代GAN基线。首先,推导出一个良好行为的正则化相对论GAN损失函数,解决了模式丢弃和不收敛问题,并数学证明了其局部收敛性。其次,该损失函数允许丢弃所有经验性技巧,用现代架构替换常见GANs中的过时骨干网络。以StyleGAN2为例,展示了简化和现代化的路线图,产生了新的极简基线R3GAN。尽管简单,该方法在FFHQ、ImageNet、CIFAR和Stacked MNIST数据集上超越了StyleGAN2,与最先进的GANs和扩散模型相比表现优异。

模型结构

生成器架构

核心设计原则:

  • 基于现代化ResNet架构,摒弃VGG-like设计
  • 每个分辨率阶段包含一个过渡层和两个残差块
  • 采用分组卷积和倒置瓶颈设计

关键特性:

  • 无归一化层:避免批量归一化等数据相关的归一化
  • Fix-up初始化:零初始化每个残差块的最后一层卷积
  • 双线性插值:用于上采样,避免棋盘效应

鉴别器架构

设计特点:

  • 与生成器完全对称的架构
  • 相同的残差块结构和过渡层设计
  • 分类器头:全局4×4深度卷积 + 线性层

损失函数

相对论配对GAN损失 (RpGAN):

L(θ,ψ) = E[f(D_ψ(G_θ(z)) - D_ψ(x))]

R1正则化:

R1(ψ) = (γ/2) * E[||∇_x D_ψ(x)||²]  (x~p_D)

R2正则化:

R2(θ,ψ) = (γ/2) * E[||∇_x D_ψ(x)||²]  (x~p_θ)

训练自己的数据集

1. 准备数据集

首先使用 dataset_tool.py 将您的图像数据转换为适合训练的格式:

# 从文件夹创建数据集
python dataset_tool.py --source=path/to/your/images --dest=path/to/output.zip# 如果需要调整分辨率和裁剪
python dataset_tool.py --source=path/to/your/images --dest=path/to/output.zip \--resolution=256x256 --transform=center-crop

数据集要求:

  • 图像必须是正方形(如256x256, 512x512)
  • 分辨率必须是2的幂次(64, 128, 256, 512, 1024等)
  • 支持RGB或灰度图像
  • 可以是文件夹或ZIP格式

2. 创建自定义训练配置

train.py 中添加您自己的预设配置。参考现有预设,在 main() 函数中添加:

if opts.preset == 'YOUR_DATASET':# 网络架构参数WidthPerStage = [768, 768, 768, 512, 256]  # 每阶段宽度BlocksPerStage = [2, 2, 2, 2, 2]           # 每阶段块数CardinalityPerStage = [96, 96, 96, 48, 24] # 每阶段基数FP16Stages = [-1, -2, -3, -4]              # FP16优化的阶段NoiseDimension = 64                         # 噪声维度# 如果是条件生成(有类别标签)if opts.cond:c.G_kwargs.ConditionEmbeddingDimension = NoiseDimensionc.D_kwargs.ConditionEmbeddingDimension = WidthPerStage[0]# 训练调度参数ema_nimg = 500 * 1000      # EMA开始的图像数decay_nimg = 2e7           # 总衰减图像数# 各种调度器c.ema_scheduler = { 'base_value': 0, 'final_value': ema_nimg, 'total_nimg': decay_nimg }c.aug_scheduler = { 'base_value': 0, 'final_value': 0.3, 'total_nimg': decay_nimg }c.lr_scheduler = { 'base_value': 2e-4, 'final_value': 5e-5, 'total_nimg': decay_nimg }c.gamma_scheduler = { 'base_value': 2, 'final_value': 0.2, 'total_nimg': decay_nimg }c.beta2_scheduler = { 'base_value': 0.9, 'final_value': 0.99, 'total_nimg': decay_nimg }

3. 开始训练

# 无条件生成(如人脸、风景等)
python train.py \--outdir=./training-runs \--data=./datasets/your_dataset.zip \--gpus=4 \--batch=256 \--mirror=1 \--aug=1 \--preset=YOUR_DATASET \--tick=1 \--snap=200# 条件生成(有类别标签)
python train.py \--outdir=./training-runs \--data=./datasets/your_dataset.zip \--gpus=4 \--batch=256 \--mirror=1 \--aug=1 \--cond=1 \--preset=YOUR_DATASET \--tick=1 \--snap=200

4. 参数说明

  • --gpus: GPU数量
  • --batch: 总批次大小
  • --mirror: 是否启用水平翻转增强
  • --aug: 是否启用数据增强
  • --cond: 是否训练条件模型(需要标签)
  • --tick: 多少kimg输出一次进度
  • --snap: 多少tick保存一次模型

5. 生成图像

训练完成后,使用保存的模型生成图像:

# 生成8张图像
python gen_images.py \--seeds=0-7 \--outdir=generated_images \--network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl# 条件生成(指定类别)
python gen_images.py \--seeds=0-7 \--outdir=generated_images \--class=5 \--network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl

6. 评估指标

python calc_metrics.py \--metrics=fid50k_full,kid50k_full \--data=./datasets/your_dataset.zip \--network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl

7.报错指南

1.UnboundLocalError: local variable 'NoiseDimension' referenced before assignment

解决办法:在 train.py 中,NoiseDimension 只在特定的预设配置块中定义(如 CIFAR10、FFHQ-64 等)。如果您使用的 --preset 参数不匹配任何现有预设,这个变量就不会被定义,导致使用时出错。可以使用作者定义好的预先设置。

--preset=CIFAR10
--preset=FFHQ-64  
--preset=FFHQ-256
--preset=ImageNet-32
--preset=ImageNet-64

2.RuntimeError: Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "R3GAN\torch_utils\custom_ops.py".

解决办法:这个错误是因为R3GAN使用了自定义的CUDA操作符,需要C++编译器来编译。在Windows系统上缺少MSVC/GCC/CLANG编译器。

修改 torch_utils/custom_ops.py:找到 get_plugin 函数(大约第84行),在函数开头添加:

def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):# 禁用所有自定义插件return Nonedef bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'):# 强制使用 'ref' 实现impl = 'ref'


http://www.hkcw.cn/article/wXCqVArCQY.shtml

相关文章

HackMyVM-Dejavu

信息搜集 主机发现 ┌──(root㉿kali)-[~] └─# arp-scan -l Interface: eth0, type: EN10MB, MAC: 00:0c:29:39:60:4c, IPv4: 192.168.43.126 Starting arp-scan 1.10.0 with 256 hosts (https://github.com/royhills/arp-scan) 192.168.43.1 c6:45:66:05:91:88 …

vue-seamless-scroll 结束从头开始,加延时后滚动

今天遇到一个大屏需求: 1️⃣初始进入页面停留5秒,然后开始滚动 2️⃣最后一条数据出现在最后一行时候暂停5秒,然后返回1️⃣ 依次循环,发现vue-seamless-scroll的方法 ScrollEnd是监测最后一条数据消失在第一行才回调&#xff…

【实证分析】上市公司全要素生产率+5种测算方式(1999-2024年)

上市公司的全要素生产率(TFP)衡量企业在资本、劳动及中间投入之外,通过技术进步、管理效率和规模效应等因素提升产出的能力。与单纯的劳动生产率或资本生产率不同,TFP综合反映了企业创新能力、资源配置效率和组织优化水平&#xf…

在 Ubuntu 上安装 NVM (Node Version Manager) 的步骤

NVM (Node Version Manager) 是一个用于管理多个 Node.js 版本的工具,它允许您在同一台设备上安装、切换和管理不同版本的 Node.js。以下是在 Ubuntu 上安装 NVM 的详细步骤: 安装前准备 可先在windows上安装ubuntu 参考链接:https://blog.…

4. Observer / Event(观察者模式) C++

4. Observer / Event(观察者模式) C++ 1. 动机(场景) 适用于观察者对象(可以有多个)在观察某个对象(目标对象)的状态,如果该对象的状态发生改变,观察者对象都将收到通知。 举个例子,当我们要做一个文件分割器(就是将一个大文件分割成指定大小的小文件),这时还需…

多模态融合新方向:光学+AI如何智能分拣,提升塑料回收率?

【导读】 面对触目惊心的全球塑料污染(每分钟百万瓶、年耗五万亿袋)以及较低的塑料回收率,本研究聚焦提升回收效率的核心环节——自动分拣技术。尽管AMP Robotics等公司利用结合现代机器学习(如R-CNN、YOLO系列)的光学…

GlobalExceptionHandler 自定义异常类 + 处理validation的异常

在 Spring Boot 项目中,​自定义异常通常用于处理特定的业务逻辑错误,并结合全局异常处理器(ControllerAdvice)统一返回结构化的错误信息。 一.全局异常处理器: 1. 自定义异常类​ 定义一个继承自 RuntimeExceptio…

零基础设计模式——结构型模式 - 代理模式

第三部分:结构型模式 - 代理模式 (Proxy Pattern) 在学习了享元模式如何通过共享对象来优化资源使用后,我们来探讨结构型模式的最后一个模式——代理模式。代理模式为另一个对象提供一个替身或占位符以控制对这个对象的访问。 核心思想:为其…

从 0 到 1 的显示革命:九天画芯张锦解码铁电液晶技术进化史

一、显示技术困局:传统液晶的天花板在哪里? 在消费电子与工业显示高速发展的今天,传统液晶技术正遭遇物理极限挑战。受 “边缘场效应” 制约,液晶分子因粘附像素格电极边框,仅中心区域可自由旋转,边缘分子的…

MySql(六)

插入数据 对mysql的表中的数据进行插入数据操作 语法格式: insert into 表名 (字段名1,字段名2..) values (字段值1,字段值2...) 这个有点类似键值对的关系。 一对一 1)首先准备一张表 /* Navicat Pre…

leetcode:372. 超级次方(python3解法,数学相关算法题)

难度:中等 你的任务是计算 ab 对 1337 取模,a 是一个正整数,b 是一个非常大的正整数且会以数组形式给出。 示例 1: 输入:a 2, b [3] 输出:8示例 2: 输入:a 2, b [1,0] 输出&…

C++ —(详述c++特性)

一 namespeace(命名空间) namespace是一个自定义的空间,这个空间相当于一个总文件夹,总文件可以有好多个,里面的小文件夹或者其他文件,也可以有其他各种各样的文件, 定义:命名空间是…

20250529-C#知识:属性

C#知识:属性 在开发过程中,在需要public读取并且不允许从外界修改的情况下经常会用到属性。本文简单介绍一下属性。 1、主要内容及代码示例 属性类似成员变量属性包括get和set语句块属性能单独为get和set设置访问权限属性能为get和set操作添加处理逻辑g…

知识课堂|sCMOS相机可编程快门模式解析

sCMOS相机凭借高灵敏度、高动态、低读出噪声特性,成为生命科学成像领域的核心设备。在光片荧光显微镜LSFM成像应用中,传统卷帘快门的时序限制可能引发运动伪影或光片照明不均匀问题。可编程快门模式通过精确控制传感器曝光时序,实现与激光扫描…

Apache Kafka 实现原理深度解析:生产、存储与消费全流程

Apache Kafka 实现原理深度解析:生产、存储与消费全流程 引言 Apache Kafka 作为分布式流处理平台的核心,其高吞吐、低延迟、持久化存储的设计使其成为现代数据管道的事实标准。本文将从消息生产、持久化存储、消息消费三个阶段拆解 Kafka 的核心实现原…

已解决:.NetCore控制台程序(WebAPI)假死,程序挂起接口不通

本问题已得到解决,请看以下小结: 关于《.NetCore控制台程序(WebAPI)假死,程序暂停接口不通》的解决方案 记录备注报错时间2025年报错版本VS2022 WINDOWS10报错复现鼠标点一下控制台,会卡死报错描述——报错截图——报错原因 控制台启用了“快…

π0-通用VLA模型-2024.11.13-开源

π 0 π0 π0在2025.2.4开源,目前在github有3.4k的星标,说他是通用策略表现在两点上: 做的任务是多元的而且都比较复杂,比如叠衣服,从洗衣机里拿出衣服等等既可以控制单臂,又可以双臂,还可以控…

基于本地化大模型的智能编程助手全栈实践:从模型部署到IDE深度集成学习心得

近年来,随着ChatGPT、Copilot等AI编程工具的爆发式增长,开发者生产力获得了前所未有的提升。然而,云服务的延迟、隐私顾虑及API调用成本促使我探索一种更自主可控的方案:基于开源大模型构建本地化智能编程助手。本文将分享我构建本…

机器视觉2,硬件选型

机器视觉1,学习了硬件的基本知识和选型,现在另外的教材巩固知识 选相机 工业相机选型的保姆级教程_哔哩哔哩_bilibili 1.先看精度多少mm,被检测物体长宽多少mm》分辨率, 选出合理范围内的相机 2.靶面尺寸,得出分…

预处理,咕咕咕

1.预定义符号 _FILE_ //编译的源文件 _LINE_ //文件行号 _DATA_ //文件编译日期 _TIME_ //文件编译时间 _STDC_ //如果文件编译遵循ANSI C,其值为一,否则未定义 printf("%d",_FILE_,_LINE_);2.#define定义常量 #define name stuff #define MAX 100…