【LUT技术专题】图像自适应3DLUT代码讲解

article/2025/6/8 4:58:09

本文是对图像自适应3DLUT技术的代码解读,原文解读请看图像自适应3DLUT文章讲解

1、原文概要

结合3D LUT和CNN,使用成对和非成对的数据集进行训练,训练后能够完成自动的图像增强,同时还可以做到极低的资源消耗。下图为整个模型的结构示意图,本篇代码讲解只讲解成对数据的情况,非成对是类似的。
在这里插入图片描述

2、代码结构

代码整体结构如下
在这里插入图片描述

image_adaptive_lut_train_paired.py是成对数据训练脚本,models.py文件中是网络结构和损失函数。

3 、核心代码模块

models.py 文件

这个文件包含了3DLUT中CNN weight predictor的实现、三次插值的实现和两个正则损失(平滑损失和单调损失)的计算。

1. Classifier类

此为CNN weight predictor的实现。

class Classifier(nn.Module):def __init__(self, in_channels=3):super(Classifier, self).__init__()self.model = nn.Sequential(nn.Upsample(size=(256,256),mode='bilinear'),nn.Conv2d(3, 16, 3, stride=2, padding=1),nn.LeakyReLU(0.2),nn.InstanceNorm2d(16, affine=True),*discriminator_block(16, 32, normalization=True),*discriminator_block(32, 64, normalization=True),*discriminator_block(64, 128, normalization=True),*discriminator_block(128, 128),#*discriminator_block(128, 128, normalization=True),nn.Dropout(p=0.5),nn.Conv2d(128, 3, 8, padding=0),)def forward(self, img_input):return self.model(img_input)

可以看到,输入首先进行resize到256分辨率,即HR->LR的过程,然后经过一系列卷积和归一化模块,最终经过一个kernel_size为8,输出通道为3的卷积,变成只有3个输出的weight,后续可以作用于LUT上。

其中的discriminator_block实现如下:

def discriminator_block(in_filters, out_filters, normalization=False):"""Returns downsampling layers of each discriminator block"""layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)]layers.append(nn.LeakyReLU(0.2))if normalization:layers.append(nn.InstanceNorm2d(out_filters, affine=True))#layers.append(nn.BatchNorm2d(out_filters))return layers

其实就是一个简单的卷积,搭配了一个激活函数,根据normalization选项的不同插入InstanceNorm。

2. TrilinearInterpolation

该类实现了3DLUT中会使用到的插值方法:

class TrilinearInterpolation(torch.autograd.Function):def forward(self, LUT, x):x = x.contiguous()output = x.new(x.size())dim = LUT.size()[-1]shift = dim ** 3binsize = 1.0001 / (dim-1)W = x.size(2)H = x.size(3)batch = x.size(0)self.x = xself.LUT = LUTself.dim = dimself.shift = shiftself.binsize = binsizeself.W = Wself.H = Hself.batch = batchif x.is_cuda:if batch == 1:trilinear.trilinear_forward_cuda(LUT,x,output,dim,shift,binsize,W,H,batch)elif batch > 1:output = output.permute(1,0,2,3).contiguous()trilinear.trilinear_forward_cuda(LUT,x.permute(1,0,2,3).contiguous(),output,dim,shift,binsize,W,H,batch)output = output.permute(1,0,2,3).contiguous()else:trilinear.trilinear_forward(LUT,x,output,dim,shift,binsize,W,H,batch)return outputdef backward(self, grad_x):grad_LUT = torch.zeros(3,self.dim,self.dim,self.dim,dtype=torch.float)if grad_x.is_cuda:grad_LUT = grad_LUT.cuda()if self.batch == 1:trilinear.trilinear_backward_cuda(self.x,grad_x,grad_LUT,self.dim,self.shift,self.binsize,self.W,self.H,self.batch)elif self.batch > 1:trilinear.trilinear_backward_cuda(self.x.permute(1,0,2,3).contiguous(),grad_x.permute(1,0,2,3).contiguous(),grad_LUT,self.dim,self.shift,self.binsize,self.W,self.H,self.batch)else:trilinear.trilinear_backward(self.x,grad_x,grad_LUT,self.dim,self.shift,self.binsize,self.W,self.H,self.batch)return grad_LUT, None

作者将其封装成了一个Function,前向和反向的gpu计算过程作者用cuda文件编写的(作者也实现了cpu的版本),具体的实现在trilinear_c/src/trilinear_kernel.cu(对应cpu的版本是trilinear_c/src/trilinear.c)文件中,TriLinearForward和TriLinearBackward是实际调用会使用到的核函数,前向核函数每一个thread实现的逻辑跟我们讲到的实际插值的过程是一致的,这里就不做代码讲解了。

3. TV_3D

该类实现的是两个正则化的损失函数。

class TV_3D(nn.Module):def __init__(self, dim=33):super(TV_3D,self).__init__()self.weight_r = torch.ones(3,dim,dim,dim-1, dtype=torch.float)self.weight_r[:,:,:,(0,dim-2)] *= 2.0self.weight_g = torch.ones(3,dim,dim-1,dim, dtype=torch.float)self.weight_g[:,:,(0,dim-2),:] *= 2.0self.weight_b = torch.ones(3,dim-1,dim,dim, dtype=torch.float)self.weight_b[:,(0,dim-2),:,:] *= 2.0self.relu = torch.nn.ReLU()def forward(self, LUT):dif_r = LUT.LUT[:,:,:,:-1] - LUT.LUT[:,:,:,1:]dif_g = LUT.LUT[:,:,:-1,:] - LUT.LUT[:,:,1:,:]dif_b = LUT.LUT[:,:-1,:,:] - LUT.LUT[:,1:,:,:]tv = torch.mean(torch.mul((dif_r ** 2),self.weight_r)) + torch.mean(torch.mul((dif_g ** 2),self.weight_g)) + torch.mean(torch.mul((dif_b ** 2),self.weight_b))mn = torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) + torch.mean(self.relu(dif_b))return tv, mn

这个没有特别需要讲解的,基本上是照着论文给出的公式将其翻译成代码,tv代表平滑性损失,mn代表单调性损失,因此这个类会同时输出两个损失,至于平滑损失中的w正则会在后续的训练中看到。

image_adaptive_lut_train_paired.py 文件

存放着跟训练相关的代码。以一个epoch的一个batch的一次iteration为例:

for epoch in range(opt.epoch, opt.n_epochs):mse_avg = 0psnr_avg = 0classifier.train()for i, batch in enumerate(dataloader):# Model inputsreal_A = Variable(batch["A_input"].type(Tensor))real_B = Variable(batch["A_exptC"].type(Tensor))# ------------------#  Train Generators# ------------------optimizer_G.zero_grad()fake_B, weights_norm = generator_train(real_A)# Pixel-wise lossmse = criterion_pixelwise(fake_B, real_B)tv0, mn0 = TV3(LUT0)tv1, mn1 = TV3(LUT1)tv2, mn2 = TV3(LUT2)#tv3, mn3 = TV3(LUT3)#tv4, mn4 = TV3(LUT4)tv_cons = tv0 + tv1 + tv2 #+ tv3 + tv4mn_cons = mn0 + mn1 + mn2 #+ mn3 + mn4loss = mse + opt.lambda_smooth * (weights_norm + tv_cons) + opt.lambda_monotonicity * mn_conspsnr_avg += 10 * math.log10(1 / mse.item())mse_avg += mse.item()loss.backward()optimizer_G.step()

real_A 和real_B分别是增强前图像和增强后的HQ,generator_train是根据LUT生成图像的过程,实现如下所示:

def generator_train(img):pred = classifier(img).squeeze()if len(pred.shape) == 1:pred = pred.unsqueeze(0)gen_A0 = LUT0(img)gen_A1 = LUT1(img)gen_A2 = LUT2(img)#gen_A3 = LUT3(img)#gen_A4 = LUT4(img)weights_norm = torch.mean(pred ** 2)combine_A = img.new(img.size())for b in range(img.size(0)):combine_A[b,:,:,:] = pred[b,0] * gen_A0[b,:,:,:] + pred[b,1] * gen_A1[b,:,:,:] + pred[b,2] * gen_A2[b,:,:,:] #+ pred[b,3] * gen_A3[b,:,:,:] + pred[b,4] * gen_A4[b,:,:,:]return combine_A, weights_norm

这里的classifier是我们刚讲到的网络结构,LUT0-2分别是预设置好的3条LUT,根据3条LUT生成3幅图像A0-A2,最后根据pred对gen图像进行加权后就可以输出了,顺带计算w的L2,即weights_norm。

之后是计算损失的过程:

        # Pixel-wise lossmse = criterion_pixelwise(fake_B, real_B)tv0, mn0 = TV3(LUT0)tv1, mn1 = TV3(LUT1)tv2, mn2 = TV3(LUT2)#tv3, mn3 = TV3(LUT3)#tv4, mn4 = TV3(LUT4)tv_cons = tv0 + tv1 + tv2 #+ tv3 + tv4mn_cons = mn0 + mn1 + mn2 #+ mn3 + mn4loss = mse + opt.lambda_smooth * (weights_norm + tv_cons) + opt.lambda_monotonicity * mn_cons

包含mse损失和正则损失,正则损失使用的是我们前面讲到的TV_3D类。

3、总结

代码实现核心的部分讲解完毕,Classifier和LUT0-2对应于CNN和LUT的结合,最终在数据集上学习到的LUT可以对应于预设的3条LUT曲线,Classifier预测的3个权重对他们进行加权得到最终的一条3DLUT作用于实际图像上。该文章是3DLUT的开山之作,相信也已经得到业界的应用。
代码中也有作者的预训练权重,读者可以自己自行实验下效果。


感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。


http://www.hkcw.cn/article/KjfdYWAcDN.shtml

相关文章

Docker 在 AI 开发中的实践:GPU 支持与深度学习环境的容器化

人工智能(AI)和机器学习(ML),特别是深度学习,正以前所未有的速度发展。然而,AI 模型的开发和部署并非易事。开发者常常面临复杂的依赖管理(如 Python 版本、TensorFlow/PyTorch 版本、CUDA、cuDNN)、异构硬件(CPU 和 GPU)支持以及环境复现困难等痛点。这些挑战严重阻…

COMSOL多边形骨料堆积混凝土水化热传热模拟

混凝土水化热温降研究对保障结构安全与耐久性至关重要,温升后温差易引发温度应力,导致裂缝。本案例介绍在COMSOL内建立多边形骨料堆积混凝土细观模型,并对水化热产生后的传热及温度变化进行仿真模拟。 骨料堆积混凝土细观模型采用CAD多边形…

vue入门环境搭建及demo运行

提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 vue简介:第一步:安装node.jsnode简介第二步:安装vue.js第三步:安装vue-cli工具第四步 :安装webpack第五步…

OpenCV CUDA模块图像处理------图像融合函数blendLinear()

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 该函数执行 线性融合(加权平均) 两个图像 img1 和 img2,使用对应的权重图 weights1 和 weights2。 融合公式…

“packageManager“: “pnpm@9.6.0“ 配置如何正确启动项目?

今天在学习开源项目的时候,在安装依赖时遇到了一个报错 yarn add pnpm9.6.0 error This projects package.json defines "packageManager": "yarnpnpm9.6.0". However the current global version of Yarn is 1.22.22.Presence of the "…

物联网数据归档之数据存储方案选择分析

在上一篇文章中《物联网数据归档方案选择分析》中凯哥分析了归档设计的两种方案,并对两种方案进行了对比。这篇文章咱们就来分析分析,归档后数据应该存储在哪里?及存储方案对比。 这里就选择常用的mysql及taos数据库来存储归档后的数据吧。 你在处理设备归档表存储方案时对…

八.MySQL复合查询

一.基本查询回顾 分组统计 group by 函数作用示例语句说明count(*)统计记录条数select deptno, count(*) from emp group by deptno;每个部门有多少人?sum(sal)某字段求和select deptno, sum(sal) from emp group by deptno;每个部门总工资avg(sal)求平均值select…

智能补丁管理:终结安全漏洞,开启自动化运维新时代

漏洞风暴:数字化时代的隐形战场 全球安全态势的范式转移 近年来,终端层漏洞已成为企业安全防线的最大突破口。根据美国国家标准与技术研究院(NIST)的监测数据,2023年新披露的高危漏洞数量同比增长62%,其中…

大模型 提示模板 设计

大模型 提示模板 设计 论文介绍:LangGPT - 从编程语言视角重构大语言模型结构化可复用提示设计框架 核心问题: 现有提示工程缺乏结构化设计模板,依赖经验优化,学习成本高且复用性低,难以支持提示的迭代更新。 创新思路: 受编程语言的结构化和可复用性启发,提出LangGP…

不连网也能跑大模型?

一、这是个什么 App? 你有没有想过,不用连网,你的手机也能像 ChatGPT 那样生成文字、识别图片、甚至回答复杂问题?Google 最近悄悄发布了一个实验性 Android 应用——AI Edge Gallery,就是为此而生的。 这个应用不在…

基于开源AI大模型与AI智能名片的S2B2C商城小程序源码优化:企业成本管理与获客留存的新范式

摘要:本文以企业成本管理的两大核心——外部成本与内部成本为切入点,结合开源AI大模型、AI智能名片及S2B2C商城小程序源码技术,构建了企业数字化转型的“技术-成本-运营”三维模型。研究结果表明,通过AI智能名片实现获客留存效率提…

【AFW+GRU(CNN+RNN)】Deepfakes Detection with Automatic Face Weighting

文章目录 Deepfakes Detection with Automatic Face Weighting背景pointsDeepfake检测挑战数据集方法人脸检测面部特征提取自动人脸加权门控循环单元训练流程提升网络测试时间增强实验结果Deepfakes Detection with Automatic Face Weighting 会议/期刊:CVPRW 2020 作者: …

【Zephyr 系列 6】使用 Zephyr + BLE 打造蓝牙广播与连接系统(STEVAL-IDB011V1 实战)

🧠关键词:Zephyr、BLE、广播、连接、GATT、低功耗蓝牙、STEVAL-IDB011V1 📌适合人群:希望基于 Zephyr 实现 BLE 通信的嵌入式工程师、蓝牙产品开发人员 🧭 前言:为什么选择 Zephyr 开发 BLE? 在传统 BLE 开发中,我们大多依赖于厂商 SDK(如 Nordic SDK、BlueNRG SD…

【前端后端环境】

学习视频【带小白做毕设02】从0开始手把手带你做Vue框架的快速搭建以及项目工程的讲解 C:\Users\Again>java -version java version "21.0.1" 2023-10-17 LTS Java(TM) SE Runtime Environment (build 21.0.112-LTS-29) Java HotSpot(TM) 64-Bit Server VM (buil…

机器学习:决策树和剪枝

本文目录: 一、决策树基本知识(一)概念(二)决策树建立过程 二、决策树生成(一)ID3决策树:基于信息增益构建的决策树。(二)C4.5决策树(三&#xff…

【Vmware】虚拟机安装、镜像安装、Nat网络模式、本地VM8、ssh链接保姆篇(图文教程)

文章目录 Vmware下载Vmware安装镜像安装虚拟机安装网络模式Nat模式设置ssh链接 更多相关内容可查看 Vmware下载 官网下载地址:https://vmoc.waltzsy.com/?bd_vid8868926919570357435#goods Vmware安装 以管理员身份运行 弹框如下,点击下一步 我同意&…

移动端测试岗位高频面试题及解析

文章目录 一、基础概念二、自动化测试三、性能测试四、专项测试五、安全与稳定性六、高级场景七、实战难题八、其他面题 一、基础概念 移动端测试与Web测试的核心区别? 解析:网络波动(弱网测试)、设备碎片化(机型适配&…

什么是“草台班子”?

“草台班子”是一个常用的汉语俗语,其含义在不同语境下略有差异,核心特点是强调组织或团队的非专业性、临时性和不规范性。以下从原意、引申义、常见用法三方面展开说明: 一、原意:传统戏曲中的流动演出团体 起源: 最…

无人机避障——感知部分(Ubuntu 20.04 复现Vins Fusion跑数据集)胎教级教程

硬件环境:NVIDIA Jeston Orin nx 系统:Ubuntu 20.04 任务:跑通 EuRoC MAV Dataset 数据集 展示结果: 编译Vins Fusion 创建工作空间vins_ws # 创建目录结构 mkdir -p ~/vins_ws/srccd ~/vins_ws/src# 初始化工作空间&#xf…