华为深度学习面试手撕题:手写nn.Conv2d()函数

article/2025/8/5 7:16:59

题目

只允许利用numpy包,实现Pytorch二维卷积函数nn.Conv2d()

解答

此代码考察二维卷积的概念,详见:

6.2. 图像卷积 — 动手学深度学习 2.0.0 documentation

6.3. 填充和步幅 — 动手学深度学习 2.0.0 documentation

6.4. 多输入多输出通道 — 动手学深度学习 2.0.0 documentation

代码实现:

import numpy as np
import torch
import torch.nn as nndef conv2d(input, weight, bias=None, stride=1, padding=0):"""实现二维卷积操作参数:input:  输入数据, 形状为 (batch_size, in_channels, height, width)weight: 卷积核, 形状为 (out_channels, in_channels, kernel_h, kernel_w)bias:   偏置项, 形状为 (out_channels,)stride: 步长, 可以是整数或元组 (stride_h, stride_w)padding: 填充, 可以是整数或元组 (pad_h, pad_w)返回:输出特征图, 形状为 (batch_size, out_channels, out_h, out_w)"""# 解析步长和填充参数if isinstance(stride, int):stride_h = stride_w = strideelse:stride_h, stride_w = strideif isinstance(padding, int):pad_h = pad_w = paddingelse:pad_h, pad_w = padding# 获取输入尺寸batch_size, in_channels, in_h, in_w = input.shapeout_channels, _, kernel_h, kernel_w = weight.shape# 计算输出尺寸out_h = (in_h + 2 * pad_h - kernel_h) // stride_h + 1out_w = (in_w + 2 * pad_w - kernel_w) // stride_w + 1# 添加填充if pad_h > 0 or pad_w > 0:# 使用零填充padded_input = np.pad(input, ((0, 0), (0, 0), (pad_h, pad_h), (pad_w, pad_w)),mode='constant')else:padded_input = input# 初始化输出数组output = np.zeros((batch_size, out_channels, out_h, out_w))# 执行卷积操作for b in range(batch_size):for c_out in range(out_channels):for h_out in range(out_h):for w_out in range(out_w):# 计算输入窗口位置h_start = h_out * stride_hw_start = w_out * stride_wh_end = h_start + kernel_hw_end = w_start + kernel_w# 提取输入窗口window = padded_input[b, :, h_start:h_end, w_start:w_end]# 计算点积 (卷积操作)conv_val = np.sum(window * weight[c_out])# 添加偏置if bias is not None:conv_val += bias[c_out]# 存储结果output[b, c_out, h_out, w_out] = conv_valreturn outputimport torch
import torch.nn as nnif __name__ == "__main__":# 创建测试数据np.random.seed(42)# 输入数据: (batch_size=2, in_channels=3, height=5, width=5)input_data = np.random.randn(2, 3, 5, 5).astype(np.float32)# 卷积核: (out_channels=2, in_channels=3, kernel_h=3, kernel_w=3)weights = np.random.randn(2, 3, 3, 3).astype(np.float32)# 偏置: (out_channels=2)bias = np.array([0.5, -0.5], dtype=np.float32)# 转换为 PyTorch 张量input_torch = torch.tensor(input_data)weights_torch = torch.tensor(weights)bias_torch = torch.tensor(bias)# 测试1: 无填充, 步长=1print("测试1: 无填充, 步长=1")output1 = conv2d(input_data, weights, bias, stride=1, padding=0)# 创建 PyTorch 卷积层conv1_nn = nn.Conv2d(in_channels=3, out_channels=2, kernel_size=3, stride=1, padding=0, bias=True)# 设置权重和偏置with torch.no_grad():conv1_nn.weight.data = weights_torchconv1_nn.bias.data = bias_torch# 计算 PyTorch 输出output1_nn = conv1_nn(input_torch).detach().numpy()# 比较结果print("自定义实现与PyTorch输出是否一致:", np.allclose(output1, output1_nn, atol=1e-6))print(f"输出形状: {output1.shape}")print("自定义实现输出 (第一个样本的第一个通道前2x2):")print(output1[0, 0, :2, :2])print("PyTorch输出 (第一个样本的第一个通道前2x2):")print(output1_nn[0, 0, :2, :2])# 测试2: 填充=1, 步长=1print("\n测试2: 填充=1, 步长=1")output2 = conv2d(input_data, weights, bias, stride=1, padding=1)# 创建 PyTorch 卷积层conv2_nn = nn.Conv2d(in_channels=3, out_channels=2, kernel_size=3, stride=1, padding=1, bias=True)with torch.no_grad():conv2_nn.weight.data = weights_torchconv2_nn.bias.data = bias_torchoutput2_nn = conv2_nn(input_torch).detach().numpy()print("自定义实现与PyTorch输出是否一致:", np.allclose(output2, output2_nn, atol=1e-6))print(f"输出形状: {output2.shape}")print("自定义实现输出 (第一个样本的第一个通道前2x2):")print(output2[0, 0, :2, :2])print("PyTorch输出 (第一个样本的第一个通道前2x2):")print(output2_nn[0, 0, :2, :2])# 测试3: 无填充, 步长=2print("\n测试3: 无填充, 步长=2")output3 = conv2d(input_data, weights, bias, stride=2, padding=0)# 创建 PyTorch 卷积层conv3_nn = nn.Conv2d(in_channels=3, out_channels=2, kernel_size=3, stride=2, padding=0, bias=True)with torch.no_grad():conv3_nn.weight.data = weights_torchconv3_nn.bias.data = bias_torchoutput3_nn = conv3_nn(input_torch).detach().numpy()print("自定义实现与PyTorch输出是否一致:", np.allclose(output3, output3_nn, atol=1e-6))print(f"输出形状: {output3.shape}")print("自定义实现输出 (第一个样本的第一个通道):")print(output3[0, 0])print("PyTorch输出 (第一个样本的第一个通道):")print(output3_nn[0, 0])# 测试4: 无偏置print("\n测试4: 无偏置")output4 = conv2d(input_data, weights, None, stride=1, padding=0)# 创建 PyTorch 卷积层conv4_nn = nn.Conv2d(in_channels=3, out_channels=2, kernel_size=3, stride=1, padding=0, bias=False)with torch.no_grad():conv4_nn.weight.data = weights_torchoutput4_nn = conv4_nn(input_torch).detach().numpy()print("自定义实现与PyTorch输出是否一致:", np.allclose(output4, output4_nn, atol=1e-6))print("自定义实现输出 (第一个样本的第一个通道前2x2):")print(output4[0, 0, :2, :2])print("PyTorch输出 (第一个样本的第一个通道前2x2):")print(output4_nn[0, 0, :2, :2])'''
测试1: 无填充, 步长=1
自定义实现与PyTorch输出是否一致: True
输出形状: (2, 2, 3, 3)
自定义实现输出 (第一个样本的第一个通道前2x2):
[[-6.4546895  -2.49435902][-6.27663374  3.31103873]]
PyTorch输出 (第一个样本的第一个通道前2x2):
[[-6.4546895 -2.4943593][-6.276634   3.3110385]]测试2: 填充=1, 步长=1
自定义实现与PyTorch输出是否一致: True
输出形状: (2, 2, 5, 5)
自定义实现输出 (第一个样本的第一个通道前2x2):
[[ 1.17402518  1.28695214][-0.09722954 -6.4546895 ]]
PyTorch输出 (第一个样本的第一个通道前2x2):
[[ 1.1740253   1.2869523 ][-0.09722958 -6.4546895 ]]测试3: 无填充, 步长=2
自定义实现与PyTorch输出是否一致: True
输出形状: (2, 2, 2, 2)
自定义实现输出 (第一个样本的第一个通道):
[[-6.4546895   1.38441801][ 3.1934371  -1.1537782 ]]
PyTorch输出 (第一个样本的第一个通道):
[[-6.4546895  1.3844179][ 3.1934366 -1.1537789]]测试4: 无偏置
自定义实现与PyTorch输出是否一致: True
自定义实现输出 (第一个样本的第一个通道前2x2):
[[-6.9546895  -2.99435902][-6.77663374  2.81103873]]
PyTorch输出 (第一个样本的第一个通道前2x2):
[[-6.9546895 -2.9943593][-6.776634   2.811039 ]]
'''

http://www.hkcw.cn/article/wNuIwVtSJM.shtml

相关文章

MMR 最大边际相关性详解

最大边际相关性(MMR,max_marginal_relevance_search)的基本思想是同时考量查询与文档的 相关度,以及文档之间的 相似度。相关度 确保返回结果对查询高度相关,相似度 则鼓励不同语义的文档被包含进结果集。具体来说&…

美业+智能体,解锁行业转化新密码(2/6)

摘要:中国美业市场近年蓬勃发展,规模持续扩大,预计不久将突破万亿级别,但同时也面临着诸多挑战,如获客成本攀升、服务质量不稳定、难以满足消费者多元化个性化需求等。智能体技术的出现为美业带来了新的发展机遇&#…

Mybatis-Plus 学习

Mybatis-Plus 简介 官网:https://baomidou.com/ github 地址:https://github.com/baomidou/mybatis-plus 什么是 Mybatis-Plus MyBatis-Plus(简称 MP)是 MyBatis 的增强工具库,旨在简化开发流程,减少样…

Linux开发追踪(IMX6ULL篇_第一部分)

前言 参数:cortex-A7 698Mhz flash 8GB RAM 512M DDR3 2个100M网口 单核 初期: 一、安装完虚拟机之后,第一步先设置文件之间可以相互拷贝复制,以及通过CRT连接到虚拟机等 折磨死人了啊啊啊啊啊啊 1、关于SSH怎么安装…

中国观鸟数据集(CSV)

数据简介 今天我们分享的数据是观鸟数据集,该数据整理中国观鸟记录中心的鸟类报告数据,在2024年获取了该网站种鸟类的报告信息,详情信息以及鸟种信息,分别整理为各省的数据,方便大家研究使用,方便大家研究使…

【AI论文】SWE-rebench:一个用于软件工程代理的任务收集和净化评估的自动化管道

摘要:基于LLM的代理在越来越多的软件工程(SWE)任务中显示出有前景的能力。 然而,推进这一领域面临着两个关键挑战。 首先,高质量的训练数据稀缺,尤其是反映现实世界软件工程场景的数据,在这些场…

【计算机系统结构】习题2

目录 1.有一条静态多功能流水线由5段组成,加法用1、2、4、5段,乘法用1、3、5段,第3段时间为,其余各段为,且流水线的输出可直接返回输入端或暂存器,若计算,试计算吞吐量、加速比、效率 2.有一动…

多模态大语言模型arxiv论文略读(103)

Are Bigger Encoders Always Better in Vision Large Models? ➡️ 论文标题:Are Bigger Encoders Always Better in Vision Large Models? ➡️ 论文作者:Bozhou Li, Hao Liang, Zimo Meng, Wentao Zhang ➡️ 研究机构: 北京大学 ➡️ 问题背景&…

[ElasticSearch] RestAPI

🌸个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 🏵️热门专栏: 🧊 Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 🍕 Collection与…

【irregular swap】An Examination of Fairness of AI Models for Deepfake Detection

文章目录 An Examination of Fairness of AI Models for Deepfake Detection背景points贡献深伪检测深伪检测审计评估检测器主要发现评估方法审计结果训练分布和方法偏差An Examination of Fairness of AI Models for Deepfake Detection 会议/期刊:IJCAI 2021 作者: 背景…

初学大模型部署以及案例应用(windows+wsl+dify+mysql+Ollama+Xinference)

大模型部署以及案例应用(windowswsldifymysqlOllamaXinference) 1.wsl 安装①安装wsl②测试以及更新③安装Ubuntu系统查看系统以及版本安装Ubuntu系统进入Ubuntu系统 2、docker安装①下载安装包②安装③docker配置 3、安装dify①下载dify②安装③生成.en…

【Linux系统编程】Ext系列文件系统

目录 磁盘文件系统的必要性 认识磁盘结构 理解硬件 磁盘的物理结构 磁盘的存储结构 磁盘的逻辑结构 引入磁盘文件系统 引入"块"概念 引入"分区"概念 引入"分组"概念 ext*系列文件系统 inode、inode Bitmap、inode Table Block Bitm…

基于ZYNQ ARM+FPGA异构平台的声呐数据采集系统设计

0 引 言 近年来,随着海洋工程技术的发展,水下无人 航行器 (underwater unmanned vehicle, UUV)) 因其 灵活性、低风险性以及多功能性的优点,在维护国 家海洋权益以及海洋安全发挥着日益重要的作用 [1-3] 。 UUV 在完成目标搜索、…

前端基础学习html+css+js

HTML 区块 div标签,块级标签 span包装小部分文本,行内元素 表单 CSS css选择器 css属性 特性blockinlineinline-block是否换行✅ 换行❌ 不换行❌ 不换行可设置宽高✅ 支持❌ 不支持✅ 支持常见元素div容器 p段落 h标题span文本容器 a超链接img图片…

Client-Side Path Traversal 漏洞学习笔记

近年来,随着Web前端技术的飞速发展,越来越多的数据请求和处理逻辑被转移到客户端(浏览器)执行。这大大提升了用户体验,但也带来了新的安全威胁。其中,Client-Side Path Traversal(客户端路径穿越,CSPT)作为一种新兴的漏洞类型,逐渐受到安全研究者和攻击者的关注。本文…

关于神经网络中的梯度和神经网络的反向传播以及梯度与损失的关系

这篇博客用通俗的话介绍一下什么是梯度以及神经网络中的反向传播。 什么是梯度 可以把神经网络想象成一个 “猜答案的机器”。比如你让它猜一张图片是不是猫,它会先 “猜” 一个概率(比如猜是猫的概率是 30%),然后你告诉它 “猜…

保持本地Git仓库与远程仓库同步-业务场景示例

业务场景:团队协作开发电商网站 背景: 5人团队使用GitHub协作开发Node.js电商项目。每位开发者负责独立功能模块(如支付、商品展示、购物车)。核心痛点:频繁出现本地代码与远程仓库冲突,导致测试环境部署失…

【中国企业数字化转型之路】企业的资源投入与数字化转型的产出效益平衡探索(上篇)

在数字化转型的浪潮中,企业面临着前所未有的挑战与机遇。这一转型过程不仅需要大量的技术、人才、管理和时间投入,更需要在投入与产出之间找到精准的平衡点,以确保转型的效益最大化。技术投入方面,企业需斥巨资引进云计算、大数据…

AR/MR实时光照阴影开发教程

一、效果演示 1、PICO4 Ultra MR 发光的球 2、AR实时光照 二、实现原理 PICO4 Ultra MR开发时,通过空间网格能力扫描周围环境,然后将扫描到的环境网格材质替换为一个透明材质并停止扫描;基于Google ARCore XR Plugin和ARFoundation进行安卓手…

图文详解Java集合面试题

文章目录 1、集合框架2、ArrayList、LinkedList3、HashMap、红黑树4、HashMap的put流程 1、集合框架 两条大支线: ①Collection接口:最基本的集合框架,提供添加、删除、清空等基本操作,主要有三个子接口:i&#xff1a…