2. 手写数字预测 gui版

article/2025/6/18 6:01:57

2. 手写数字预测 gui版

  • 背景
  • 1.界面绘制
  • 2.处理图片
  • 3. 加载模型
  • 4. 预测
  • 5.结果
  • 6.一点小问题

在这里插入图片描述

背景

做了手写数字预测的模型,但是老是跑模型太无聊了,就配合pyqt做了一个可视化界面出来玩一下

源代码可以去这里https://github.com/Leezed525/pytorch_toy拿

1.界面绘制

在这里插入图片描述

整个页面布局逻辑很简单,搭建一下就好了

class MainWindow(QMainWindow):def __init__(self):super().__init__()self.net = self.get_net()  # 获取数字预测模型self.setWindowTitle("PyQt 数字预测")self.setGeometry(100, 100, 500, 550)  # 设置主窗口的初始位置和大小,留出空间给按钮self.setFixedSize(500, 550)self.setWindowFlags(self.windowFlags() & ~Qt.WindowType.WindowMaximizeButtonHint)central_widget = QWidget()  # 创建一个中央 QWidgetself.setCentralWidget(central_widget)  # 设置中央 QWidget 为主窗口的中心部件layout = QVBoxLayout(central_widget)  # 为中央 QWidget 创建一个垂直布局# 创建一个水平布局operation_layer = QHBoxLayout()  # 创建一个水平布局用于放置操作区域left_operation_layer = QVBoxLayout()right_operation_layer = QVBoxLayout()self.canvas = DrawingCanvas(self)  # 创建 DrawingCanvas 实例canvas_label = QLabel("请在此处绘制数字")  # 创建一个标签,提示用户在画布上绘制数字canvas_label.setAlignment(Qt.AlignmentFlag.AlignCenter)canvas_label.setStyleSheet("font-size: 20px;")  # 设置标签的样式left_operation_layer.addWidget(canvas_label)  # 将标签添加到左侧操作区域布局中left_operation_layer.addWidget(self.canvas)left_operation_layer.setStretch(0, 1)left_operation_layer.setStretch(1, 10)  # 设置画布的伸缩比例,使其占据更多空间operation_layer.addLayout(left_operation_layer)  # 将左侧操作区域布局添加到操作层布局中# 右侧操作区域self.predict_label = QLabel("预测结果: ")  # 创建一个标签,显示预测结果right_operation_layer.addWidget(self.predict_label)self.predict_digit_labels = []for i in range(10):predict_digit_label = QLabel(f"数字 {i}: 0.00%")  # 创建标签显示每个数字的预测概率self.predict_digit_labels.append(predict_digit_label)  # 将标签添加到列表中for label in self.predict_digit_labels:right_operation_layer.addWidget(label)operation_layer.addLayout(right_operation_layer)  # 将右侧操作区域布局添加到操作层布局中operation_layer.setStretch(0, 10)operation_layer.setStretch(1, 1)layout.addLayout(operation_layer)  # 将操作层布局添加到主布局中# 按钮区布局button_layout = QHBoxLayout()  # 创建一个垂直布局用于放置按钮clear_button = QPushButton("清空画布")  # 清空画布按钮clear_button.clicked.connect(self.canvas.clear_canvas)  # 连接按钮的点击信号到清空画布方法predict_button = QPushButton("预测")  # 清空画布按钮predict_button.clicked.connect(self.predict)  # 连接按钮的点击信号到预测方法button_layout.addStretch(6)button_layout.addWidget(clear_button)button_layout.addWidget(predict_button)layout.addLayout(button_layout)  # 将按钮布局添加到主布局中

其中稍微有点心智压力的区域就是画图区域,这里配合ai然后再自行修改一下就好了,逻辑就是鼠标按住然后绘制,松开后停止绘制。

canvas代码

class DrawingCanvas(QWidget):"""一个自定义的 QWidget 类,用作绘图画布。用户可以在此画布上用鼠标点击并拖动来绘制线条。"""def __init__(self, parent=None):super().__init__(parent)  # 调用父类 QWidget 的构造函数self.setWindowTitle("绘图画布")  # 设置窗口标题self.setGeometry(100, 100, 280, 280)  # 设置窗口的初始位置和大小 (x, y, width, height)self.setMinimumSize(280, 280)# 创建一个 QImage 对象作为绘图缓冲区# 所有的绘图操作都在这个 QImage 上进行,然后整体绘制到屏幕,可以避免闪烁。# QImage.Format.Format_RGB32 是 PyQt6 中推荐的 RGBA 格式,支持透明度。self.image = QImage(self.size(), QImage.Format.Format_RGB32)# 将 QImage 填充为白色。self.image.fill(Qt.GlobalColor.white)self.drawing = False  # 一个布尔标志,指示当前是否正在进行鼠标拖拽绘图self.last_point = QPoint()  # 存储鼠标上次的位置,用于绘制连续的线条# 同样,颜色常量需要通过 Qt.GlobalColor 访问。self.pen_color = Qt.GlobalColor.blackself.pen_size = 20def paintEvent(self, event):"""绘制事件处理函数。每当窗口需要被重新绘制时(例如,首次显示、窗口大小改变、调用 update() 时),Qt 就会自动调用这个方法。"""painter = QPainter(self)  # 创建一个 QPainter 对象,指定在当前 QWidget (self) 上进行绘制# 将 self.image (绘图缓冲区) 的内容绘制到当前 QWidget 的整个矩形区域内。painter.drawImage(self.rect(), self.image, self.image.rect())def mousePressEvent(self, event):# 检查是否是鼠标左键被按下。if event.button() == Qt.MouseButton.LeftButton:self.drawing = True  # 设置绘图标志为 Trueself.last_point = event.pos()  # 记录当前鼠标位置作为线条的起始点def mouseMoveEvent(self, event):"""鼠标移动事件处理函数。当鼠标在窗口内移动时触发。"""# 只有当正在绘图 (self.drawing 为 True) 并且鼠标左键被按住时才执行绘图操作。# event.buttons() 返回当前按下的所有鼠标按钮的位掩码,Qt.MouseButton.LeftButton 用于检查左键是否按下。if self.drawing and event.buttons() & Qt.MouseButton.LeftButton:painter = QPainter(self.image)  # 在 QImage (绘图缓冲区) 上创建 QPainter 进行绘制# 设置画笔的颜色、粗细和样式。painter.setPen(QPen(QColor(self.pen_color), self.pen_size,Qt.PenStyle.SolidLine, Qt.PenCapStyle.RoundCap, Qt.PenJoinStyle.RoundJoin))# 绘制从上次记录的点到当前鼠标位置的直线painter.drawLine(self.last_point, event.pos())self.last_point = event.pos()  # 更新 last_point 为当前鼠标位置,为下一次绘制做准备self.update()  # 请求窗口重绘。这会间接调用 paintEvent,将 QImage 的最新内容显示到屏幕上。def mouseReleaseEvent(self, event):"""鼠标释放事件处理函数。当用户释放鼠标按钮时触发。"""# 检查是否是鼠标左键被释放。if event.button() == Qt.MouseButton.LeftButton:self.drawing = False  # 停止绘图def resizeEvent(self, event):"""窗口大小改变事件处理函数。当窗口大小改变时触发。"""# 如果新窗口的宽度或高度大于当前 QImage 的尺寸,则需要创建一个新的 QImage。if self.width() > self.image.width() or self.height() > self.image.height():new_image = QImage(self.size(), QImage.Format.Format_RGB32)# 填充新图像为白色new_image.fill(Qt.GlobalColor.white)painter = QPainter(new_image)# 将旧图像的内容绘制到新图像上,以保留已有的绘图。painter.drawImage(QPoint(0, 0), self.image)self.image = new_image  # 更新 self.image 为新的 QImageself.update()  # 请求重绘窗口def clear_canvas(self):"""清空画布内容,将整个 QImage 重新填充为白色。"""self.image.fill(Qt.GlobalColor.white)self.update()  # 请求重绘以显示空白画布def set_pen_size(self, size):"""设置画笔粗细。"""self.pen_size = size

2.处理图片

当布局完成后就只需要处理将图片变成输入的过程就好了,先给代码,在讲解

    def get_image(self):"""获取当前画布上的图像数据。返回一个 QImage 对象,包含当前画布的绘图内容。"""image = self.canvas.image# 将图像缩放到 28x28 像素并转换为灰度图scaled_image = image.scaled(28, 28,Qt.AspectRatioMode.IgnoreAspectRatio,  # 不保持宽高比Qt.TransformationMode.SmoothTransformation  # 平滑缩放)# 转换为 8 位灰度图grayscale_image = scaled_image.convertToFormat(QImage.Format.Format_Grayscale8)# 使用 qimage2ndarray.byte_view() 获取 NumPy 数组arr_3d = qimage2ndarray.byte_view(grayscale_image)arr = arr_3d.squeeze()# 将 NumPy 数组转换为 PyTorch 张量tensor_image = torch.from_numpy(arr).float()# --- 关键修正:添加颜色反转和标准化 ---# 1. 将像素值从 [0, 255] 归一化到 [0.0, 1.0]tensor_image = tensor_image / 255.0# 2. 颜色反转:如果你的模型是基于白色数字黑色背景训练的 而画布是黑色数字白色背景,则需要反转颜色tensor_image = 1.0 - tensor_image# 3. 标准化:应用训练时使用的均值和标准差# MNIST 均值和标准差mean = 0.1307std = 0.3081tensor_image = (tensor_image - mean) / std# 添加批次维度和通道维度,使形状变为 (1, 1, 28, 28)tensor_image = tensor_image.unsqueeze(0).unsqueeze(0).cuda()# --- 可视化 PyTorch 张量 ---# 为了可视化,我们先将其恢复到 [0,1] 范围,否则标准化后的值可能很难看# 逆标准化 (用于可视化,不影响模型输入)# visual_tensor = tensor_image * std + mean# # 确保在 [0,1] 范围内# visual_tensor = torch.clamp(visual_tensor, 0.0, 1.0)# plt.figure(figsize=(2, 2))# plt.imshow(visual_tensor.cpu().squeeze().numpy(), cmap='gray')# plt.title("input")# plt.axis('off')# plt.show()return tensor_image

其中有几个注意点
1.
目前的画布是白色的,画笔是黑色,但是mnist数据集的底是黑色的,画笔是白色的,因此需要使用

tensor_image = 1.0 - tensor_image

来将颜色取反,不然跟训练数据不一样模型无法良好运行。
2.
QT中的image是Qimage,转换成numpy代码有点麻烦,我这里图省事直接用了qimage2ndarray库,因此只需一行代码

arr_3d = qimage2ndarray.byte_view(grayscale_image)

就完成了这个操作。
3.
在输入到模型之前,要进行数据预处理,如上面的代码中

        # 3. 标准化:应用训练时使用的均值和标准差# MNIST 均值和标准差mean = 0.1307std = 0.3081tensor_image = (tensor_image - mean) / std

来优化模型效果。

3. 加载模型

这里的预训练权重就直接用了上一篇文章中训练出来的权重,还给她放到cuda上了,不过这么小的模型其实放不放其实都无所谓,没有太大的影响。

    def get_net(self):"""获取数字预测模型。返回一个 DigitCNN 模型实例。"""# 创建并返回一个 DigitCNN 模型实例net = DigitCNN()net.eval()net.cuda()net.load_state_dict(torch.load('./digit_CNN.pth'))return net

4. 预测

这里就没什么好说的了,就是简单地预测然后将结果同步到gui上了。

    def predict(self):"""预测当前画布上绘制的数字。这里可以调用模型进行预测,并更新预测结果标签。"""input = self.get_image()  # 获取当前画布上的图像数据# 使用模型进行预测with torch.no_grad():output = self.net(input)# 获取预测结果self.update_predict_result(output)def update_predict_result(self, output):_, predict = output.max(1)  # 获取预测的数字类别predict = predict.cpu().numpy()[0]# 更新预测结果标签self.predict_label.setText(f"预测结果: {predict}")# 更新每个数字的预测概率probabilities = torch.softmax(output, dim=1).cpu().numpy()[0]for i, label in enumerate(self.predict_digit_labels):label.setText(f"数字 {i}: {probabilities[i] * 100:.2f}%")

5.结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

6.一点小问题

现在模型是可以用了,但是因为Mnist数据集本身的局限性,已经网络也比较小,泛化性能比较差(但是没差到不能用的地步),所以预测结果又是后会比较奇怪,例如:

.在这里插入图片描述
这是mnist数据集中的数据,可以看出这里的0大部分都是上面闭合,导致模型预测奇怪位置的闭合的0会失准。

还有其中的4大部分都是开口的,并没有闭合4上面的开口,导致写一个很标准的4反倒有时候会预测出错,还有其他的一些问题我就不赘述了。

总之如果想要模型想要获得更好的表现,一是可以增强一下模型的能力,第二个我觉得更重的是把数据好好清洗一下,有些数据真的太差了


http://www.hkcw.cn/article/rDqgVEbSps.shtml

相关文章

用JS实现植物大战僵尸(前端作业)

1. 先搭架子 整体效果&#xff1a; 点击开始后进入主场景 左侧是植物卡片 右上角是游戏的开始和暂停键 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevic…

巴黎球迷打出TIFO悼念恩里克女儿 感人至深的纪念

北京时间6月1日,巴黎圣日耳曼在欧冠决赛中以5-0战胜国际米兰,夺得本赛季欧冠冠军。赛后,安联球场展示了一个感人至深的TIFO,主角是巴黎圣日耳曼主教练恩里克和他的已故女儿Xana。十年前,恩里克带领巴塞罗那夺得欧冠冠军时,曾与女儿Xana一起将巴萨的旗帜插进球场。然而,X…

六一儿童节 实践我先行活动举行

5月30日,在“六一”国际儿童节来临之际,“实践我先行——2025年在宋庆龄奶奶生活过的地方过六一”活动在北京宋庆龄故居举行,逾百名中外少年儿童和教师代表参加。活动现场,北京市西城区金融街惠泽幼儿园的小朋友们表演了群鼓节目《华夏少年》。中国宋庆龄基金会党组书记、副…

阿什拉夫弑旧主 破门后拒绝庆祝 情深义重

在欧冠决赛中,巴黎圣日耳曼迎战国际米兰。上半场,阿什拉夫攻破了老东家的大门,帮助巴黎取得领先。这位现年26岁的摩洛哥后卫曾在2020年至2021年效力于国际米兰,并为蓝黑军团出场45次。比赛进行到第12分钟时,阿什拉夫推射空门得手,将比分改写为1-0。进球后,他举起双手,拒…

端午安康(Python)

端午节总算是回家了&#xff0c;感觉时间过得真快&#xff0c;马上就毕业了&#xff0c;用Python弄了一个端午节元素的界面&#xff0c;虽然有点不像&#xff0c;祝大家端午安康。端午节粽子&#xff08;python&#xff09;_python画粽子-CSDN博客https://blog.csdn.net/weixin…

10.安卓逆向2-frida hook技术-frida基本使用-frida指令(用于hook)

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 内容参考于&#xff1a;图灵Python学院 工具下载&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1bb8NhJc9eTuLzQr39lF55Q?pwdzy89 提取码&#xff1…

# CppCon 2014 学习: Quick game development with C++11/C++14

这是一个关于游戏开发与现代 C&#xff08;尤其是 C11/C14&#xff09;结合的技术分享或讲座的概要&#xff0c;结构清晰、内容分为几个部分&#xff1a; About This Talk — 内容结构 1. 导言部分&#xff08;Introductory part&#xff09; 介绍为什么选择游戏开发作为主题…

vscode不满足先决条件问题的解决——vscode的老版本安装与禁止更新(附安装包)

目录 起因 vscode更新设置的关闭 安装包 结语 起因 由于主包用的系统是centos的&#xff0c;且版本有点老了&#xff0c;再加上vscode现在不支持老版本的&#xff0c;这对主包来说更是雪上加霜啊 但是主包看了网上很多教程&#xff0c;眼花缭乱&#xff0c;好多配置要改&…

如何手搓扫雷(待扩展)

文章目录 一、扫雷游戏分析与设计1.1 扫雷游戏的功能说明1.2 游戏的分析和设计1.2.1 数据结构的分析1.2.2 文件结构设计 二、扫雷游戏的代码实现三、扫雷游戏的扩展总结 一、扫雷游戏分析与设计 扫雷游戏网页版 1.1 扫雷游戏的功能说明 使用控制台&#xff08;黑框框的程序&a…

Python打卡训练营学习记录Day41

DAY 41 简单CNN 知识回顾 数据增强卷积神经网络定义的写法batch归一化&#xff1a;调整一个批次的分布&#xff0c;常用与图像数据特征图&#xff1a;只有卷积操作输出的才叫特征图调度器&#xff1a;直接修改基础学习率 卷积操作常见流程如下&#xff1a; 1. 输入 → 卷积层 →…

我们来学mysql -- mysql8.4主从

mysql8.4主从 8.4安装主从原理主my.cnf启动创建复制用户 从my.cnf启动锁库&迁移数据连接主&开启复制检查复制 8.4安装 参考保姆级安装教程传送门 主从原理 从库准备 使用 CHANGE MASTER TO 配置主库信息并写入 master.info 文件。执行 START SLAVE 启动从库&#xff…

kafka学习笔记(三、消费者Consumer使用教程——消费性能多线程提升思考)

1.简介 KafkaConsumer是非线程安全的&#xff0c;它定义了一个acquire()方法来检测当前是否只有一个线程在操作&#xff0c;如不是则会抛出ConcurrentModifcationException异常。 acquire()可以看做是一个轻量级锁&#xff0c;它仅通过线程操作计数标记的方式来检测线程是否发…

记忆胶囊应用源码纯开源

下载地址&#xff1a;https://pan.quark.cn/s/729681531125 &#x1f4f1; 应用功能特点 核心功能&#xff1a; 创建记忆胶囊 - 用户可以创建包含文本内容的时间胶囊时间设定 - 设置胶囊的开启时间情感标签 - 为记忆添加情感标记&#xff08;开心、难过、兴奋等&#xff09;…

破题城市更新 老旧街区如何新生?南京这样干→

暮春4月,经过十年更新改造的南京小西湖街区游人纷纷,老南京风貌从更新过的街巷中透出,市井烟火气里交织着现代时尚感。但是,略微向深处走走,年久失修的房屋,私搭乱建的建筑,让小西湖少了一分西湖的美,多了几分棚户的乱。王卉在小西湖出生长大,箍桶巷33号是父亲留给她的…

郑钦文今日战萨姆索诺娃 法网1/8决赛焦点

法网6月1日赛程已公布,郑钦文与萨姆索诺娃的比赛将在苏珊-朗格伦球场第二场进行,比赛时间不早于19点。当天是法网第八比赛日,将展开单打第四轮的较量。在苏珊-朗格伦球场的第一场比赛是保罗对阵波佩林的男单第四轮。从交手记录来看,萨姆索诺娃以3-2领先郑钦文。不过,在双方…

俄罗斯布良斯克州一桥梁坍塌 已致数十人伤亡

总台记者获悉,当地时间5月31日,位于俄罗斯布良斯克州的一座桥梁发生坍塌,导致当时行经桥下、由莫斯科开往该州城市克利莫沃的列车脱轨。据俄罗斯BAZA网站报道,事件造成4人死亡,至少44人受伤。据悉,死亡人员分别是火车司机、副司机和两名乘客。有媒体报道称,不明身份者在…

neo4j 5.19.0安装、apoc csv导入导出 及相关问题处理

前言 突然有需求需要用apoc 导入 低版本的图谱数据&#xff0c;网上资料又比较少&#xff0c;所以就看官网资料并处理了apoc 导入的一些问题。 相关地址 apoc 官方安装网址 apoc 官方导出csv 教程地址 apoc 官方 导入 csv 地址 docker 安装 执行如下命令启动镜像 doc…

【Linux】进程地址空间揭秘(初步认识)

10.进程地址空间&#xff08;初步认识&#xff09; 文章目录 10.进程地址空间&#xff08;初步认识&#xff09;一、进程地址空间的实验现象解析二、进程地址空间三、虚拟内存管理补充&#xff1a;数据的写时拷贝&#xff08;浅谈&#xff09;补充&#xff1a;页表&#xff08;…

SEO长尾关键词优化进阶指南

内容概要 在流量竞争日趋激烈的数字营销环境中&#xff0c;长尾关键词作为精准获客的核心入口&#xff0c;已成为SEO进阶优化的战略重点。本指南将系统梳理从用户意图识别到可持续流量增长的完整技术路径&#xff0c;围绕“需求挖掘-资源构建-竞争突围”三大核心模块展开。通过…

[网页五子棋][对战模块]实现游戏房间页面,服务器开发(创建落子请求/响应对象)

实现游戏房间页面 创建 css/game_room.css #screen 用于显示当前的状态&#xff0c;例如“等待玩家连接中…”&#xff0c;“轮到你落子”&#xff0c;“轮到对方落子”等 #screen { width: 450px; height: 50px; margin-top: 10px; color: #8f4e19; font-size: 28px; …