Pytorch Geometric官方例程pytorch_geometric/examples/link_pred.py环境安装教程及图数据集制作

article/2025/7/27 14:34:41

最近需要训练图卷积神经网络(Graph Convolution Neural Network, GCNN),在配置GCNN环境上总结了一些经验。

我觉得对于初学者而言,图神经网络的训练会有2个难点:

①环境配置

②数据集制作

一、环境配置

我最初光想到要给GCNN配环境就觉得有些困难,感觉相比于目标检测、分类识别这些任务用规则数据,图神经网络的模型、数据都是图,所以内心觉得会比较难。

我之前更有一个误区,就是觉得不规则结构的图数据不能用CUDA进行并行加速。实际上,图,在电脑里也是以张量这种规则结构数据存在的,完全能用CUDA进行加速计算,训练GCN前配置CUDA完全OK。


以下是我配置的环境,可用CUDA成功运行link_pred.py

几个关键包的版本:

torch                                2.4.1
torch-geometric               2.3.1
torchaudio                       2.4.1
torchvision                       0.14.0
torchviz                            0.0.2

pandas                             1.0.3

numpy                              1.20.0

 CUDA: 11.8

注意要先安装好CUDA,显示了:

 

再安装GPU版本的torch,不然python检测安装的是cpu版本的torch。这时,就得卸载重新安装了

环境配置成功:

print(torch.__version__)
print(torch.cuda.is_available())

如果CUDA环境安装失败,会打印:

2.4.1+cpu
False

其实只安装torch和CUDA还好,如果你的python中有numpy和pandas可能解决版本之间的冲突会耗费不少时间,我就是在numpy和pandas版本上试了很久,最终找到现在的版本是相互兼容的。

CUDA的版本切换可以参考我的另一篇博客:

CUDA版本切换

二、数据集制作

掌握图数据集制作的关键在于掌握slices切片:

for ...data = Data(x=X, edge_index=Edge_index, edge_label_index=Edge_label_index,             edge_label=Edge_label)data_list.append(data)
data_, slices = self.collate(data_list)  # 将不同大小的图数据对齐,填充
torch.save((data_, slices), self.processed_paths[0])

和CNN不同的是,GCN没有样本维度,需要把所有样本拼成一张大图喂给GCN进行训练 

数据集生成代码:

#作者:zhouzhichao
#创建时间:2025/5/30
#内容:生成200个样本的PYG数据集import h5py
import hdf5storage
import numpy as np
import torch
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.utils import negative_samplingbase_dir = "D:\\无线通信网络认知\\论文1\\experiment\\直推式拓扑推理实验\\拓扑生成\\200样本\\"N = 30
grapg_size = N
train_n = 31
M = 3000class graph_data(InMemoryDataset):def __init__(self, root, signals=None, tp_list = None, transform=None, pre_transform=None):# self.Signals = Signals# self.Tp_list = Tp_listself.signals = signalsself.tp_list = tp_listsuper().__init__(root, transform, pre_transform)# self.data, self.slices = torch.load(self.processed_paths[0])self.data = torch.load(self.processed_paths[0])# 返回process方法所需的保存文件名。你之后保存的数据集名字和列表里的一致@propertydef processed_file_names(self):return ['gcn_data.pt']# 生成数据集所用的方法def process(self):# data_list = []# for k in range(200):# signals = self.Signals[:, :, k]# tp_list = np.array(mat_file[self.Tp_list[0, k]])signals = self.signalstp_list =self.tp_list# tp = Tp[:,:,k]X = torch.tensor(signals, dtype=torch.float)# 所有的边Edge_index = torch.tensor(tp_list, dtype=torch.long)# 所有的边1标签edge_label = np.ones((tp_list.shape[1]))# edge_label = np.zeros((tp_list.shape[1]))Edge_label = torch.tensor(edge_label, dtype=torch.float)neg_edge_index = negative_sampling(edge_index=Edge_index, num_nodes=grapg_size,num_neg_samples=Edge_index.shape[1], method='sparse')# 拼接正负样本索引# c = 0# for i in range(31):#     for i in range(31):#         if torch.equal(Edge_index[:, i], neg_edge_index[:, i]):#             c = c + 1# print("c: ",c)Edge_label_index = Edge_indexperm = torch.randperm(Edge_index.size(1))Edge_index = Edge_index[:, perm]Edge_index = Edge_index[:, :train_n]Edge_label_index = torch.cat([Edge_label_index, neg_edge_index],dim=-1,)# 拼接正负样本Edge_label = torch.cat([Edge_label,Edge_label.new_zeros(neg_edge_index.size(1))], dim=0)# Edge_label = torch.cat([#     Edge_label,#     Edge_label.new_ones(neg_edge_index.size(1))# ], dim=0)data = Data(x=X, edge_index=Edge_index, edge_label_index=Edge_label_index, edge_label=Edge_label)torch.save(data, self.processed_paths[0])# data_list.append(data)# data_, slices = self.collate(data_list)  # 将不同大小的图数据对齐,填充# torch.save((data_, slices), self.processed_paths[0])for snr in [0,20,40]:print("snr: ", snr)mat_file = h5py.File(base_dir + str(N) + '_nodes_dataset_snr-' + str(snr) + '_M_' + str(M) + '.mat', 'r')# mat_file = hdf5storage.loadmat(base_dir + str(N) + '_nodes_dataset_snr-' + str(snr) + '_M_' + str(M) + '.mat', 'r')# 获取数据集Signals = mat_file["Signals"][()]# signals = np.swapaxes(signals, 1, 0)Tp = mat_file["Tp"][()]Tp_list = mat_file["Tp_list"][()]# tp_list = tp_list - 1# 关闭文件# mat_file.close()# graph_data("gcn_data")# n = Signals.shape[2]n = 10for i in range(n):signals = Signals[:,:,i]tp_list = np.array(mat_file[Tp_list[0, i]])root = "gcn_data-"+str(i)+"_N_"+str(N)+"_snr_"+str(snr)+"_train_n_"+str(train_n)+"_M_"+str(M)graph_data(root, signals = signals, tp_list = tp_list)print("")print("...图数据生成完成...")

训练代码:

#作者:zhouzhichao
#创建时间:25年5月29日
#内容:统计图中有关系节点和无关系节点的GCN特征欧式距离import sys
import torch
import random
import numpy as np
import pandas as pd
from torch_geometric.nn import GCNConv
from sklearn.metrics import roc_auc_score
sys.path.append('D:\无线通信网络认知\论文1\experiment\直推式拓扑推理实验\GCN推理')
from gcn_dataset import graph_data
print(torch.__version__)
print(torch.cuda.is_available())mode = "gcn"class Net(torch.nn.Module):def __init__(self):super().__init__()self.conv1 = GCNConv(Input_L, 1000)self.conv2 = GCNConv(1000, 20)def encode(self, x, edge_index):x1 = self.conv1(x, edge_index)x1_1 = x1.relu()x2 = self.conv2(x1_1, edge_index)x2_2 = x2.relu()return x2_2def decode(self, z, edge_label_index):# 节点和边都是矩阵,不同的计算方法致使:节点->节点,节点->边# nodes_relation = (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)# distances  = torch.norm(z[edge_label_index[0]] - z[edge_label_index[1]], dim=-1)distance_squared = torch.sum((z[edge_label_index[0]] - z[edge_label_index[1]]) ** 2, dim=-1)# print("distance_squared: ",distance_squared)return distance_squareddef decode_all(self, z):prob_adj = z @ z.t()  # 得到所有边概率矩阵return (prob_adj > 0).nonzero(as_tuple=False).t()  # 返回概率大于0的边,以edge_index的形式@torch.no_grad()def test(self,input_data):model.eval()z = model.encode(input_data.x, input_data.edge_index)out = model.decode(z, input_data.edge_label_index).view(-1)out = 1 - outN = 30
train_n = 31
M = 3000
# snr = -20
# for train_n in range(1,51):
# for M in range(3000, 499, -100):
for snr in [0,20,40]:print("snr: ", snr)for I in range(10):root = "gcn_data-"+str(I)+"_N_"+str(N)+"_snr_"+str(snr)+"_train_n_"+str(train_n)+"_M_"+str(M)gcn_data = graph_data(root)Input_L = gcn_data.x.shape[1]model = Net()# model = Net().to(device)optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()z = model.encode(gcn_data.x, gcn_data.edge_index)# out = model.decode(z, train_data.edge_label_index).view(-1).sigmoid()out = model.decode(z, gcn_data.edge_label_index).view(-1)out = 1 - outloss = criterion(out, gcn_data.edge_label)loss.backward()optimizer.step()return lossmin_loss = 99999count = 0#早停for epoch in range(10000):loss = train()if loss<min_loss:min_loss = losscount = 0count = count + 1if count>100:breakprint("epoch:  ",epoch,"   loss: ",round(loss.item(),2), "   min_loss: ",round(min_loss.item(),2))z = model.encode(gcn_data.x, gcn_data.edge_index)out = model.decode(z, gcn_data.edge_label_index).view(-1)list_0 = []list_1 = []for i in range(len(gcn_data.edge_label)):true_label = gcn_data.edge_label[i].item()euclidean_distance_value = out[i].item()if true_label==1:list_1.append(euclidean_distance_value)if true_label==0:list_0.append(euclidean_distance_value)minlength = min(len(list_1), len(list_0))list_1 = random.sample(list_1, minlength)list_0 = random.sample(list_0, minlength)value = list_1 + list_0large_class = list(np.full(len(value), snr))small_class = list(np.full(len(list_1), 1)) + list(np.full(len(list_0), 0))data = {'large_class': large_class,'small_class': small_class,'value': value}# 创建一个 DataFramedf = pd.DataFrame(data)## # 保存到 Excel 文件file_path = 'D:\无线通信网络认知\论文1\大修意见\图聚类、阈值相似性图实验补充\\' + mode + '_similarity_' + str(snr) + 'db_'+str(I)+'.xlsx'df.to_excel(file_path, index=False)


http://www.hkcw.cn/article/KLSwWTFhYB.shtml

相关文章

AI预测3D新模型百十个定位预测+胆码预测+去和尾2025年5月30日第93弹

从今天开始&#xff0c;咱们还是暂时基于旧的模型进行预测&#xff0c;好了&#xff0c;废话不多说&#xff0c;按照老办法&#xff0c;重点8-9码定位&#xff0c;配合三胆下1或下2&#xff0c;杀1-2个和尾&#xff0c;再杀4-5个和值&#xff0c;可以做到100-300注左右。 (1)定…

架构加速-深度学习教程

由于RK、jetson nano和电脑的GPU不相同&#xff0c;对应的pytorch也不同&#xff0c;因此不能直接将电脑训练好的模型丢到板端运行&#xff0c;因为训练的模型框架不同。就像你torch1.13和torch2.0都不一定支持&#xff0c;更何况不同平台上的torch。因此需要进行onnx模型转化&…

顶会新热门:机器学习可解释性

&#x1f9c0;机器学习模型的可解释性一直是研究的热点和挑战之一&#xff0c;同样也是近两年各大顶会的投稿热门。 &#x1f9c0;这是因为模型的决策过程不仅需要高准确性&#xff0c;还需要能被我们理解&#xff0c;不然我们很难将它迁移到其它的问题中&#xff0c;也很难进…

MicroPython+L298N+ESP32控制电机转速

要使用MicroPython控制L298N电机驱动板来控制电机的转速&#xff0c;你可以通过PWM&#xff08;脉冲宽度调制&#xff09;信号来调节电机速度。L298N是一个双H桥驱动器&#xff0c;可以同时控制两个电机的正反转和速度。 硬件准备&#xff1a; 1. L298N 电机控制板 2. ESP32…

Chainlink:连接 Web2 与 Web3 的去中心化桥梁

区块链技术通过智能合约实现了去中心化的自动执行&#xff0c;但智能合约无法直接访问链下数据&#xff0c;限制了其在现实世界的应用。Chainlink 作为去中心化预言机网络&#xff0c;以信任最小化的方式解决了这一问题&#xff0c;成为连接传统互联网&#xff08;Web2&#xf…

杨传辉:构建 Data × AI 能力,打造 AI 时代的一体化数据底座|OceanBase 开发者大会实录

5 月 17 日&#xff0c;OceanBase 在广州举办第三届开发者大会。主论坛环节&#xff0c;OceanBase CTO 杨传辉系统阐述了 Data AI 战略&#xff0c;并正式推出三大产品&#xff1a;PowerRAG、共享存储 及OceanBase桌面版。 杨传辉指出&#xff0c;数据与AI模型的一体化融合&a…

AU6825集成音频DSP的2x32W数字型ClaSSD音频功率放大器(替代TAS5825)

1.特性 ● 输出配置 - 立体声 2.0: 2 x 32W (8Ω,24V,THD N 10%) - 立体声 2.0: 2 x 26W (8Ω,21V,THD N 1%) ● 供电电压范围 - PVDD:4.5V -26.4V - DVDD: 1.8V 或者 3.3V ● 静态功耗 - 37mA at PVDD12V ● 音频性能指标 - THDN ≤ 0.02% at 1W,1kHz - SNR ≥ 107dB (A-wei…

关于ADS分辨率问题

笔记本上使用ADS&#xff08;Advanced Design System &#xff09;默认的界面挺大的&#xff0c;图标和字体都大&#xff0c;界面清新&#xff0c;给人一种呆呆易上手的感觉。 整个屏幕的截图 直到我打开了这个OPTIM的选项卡&#xff0c;它太长了&#xff0c;由于缩放太大&am…

海外DeepLink方案复杂?用openinstall一站式链接世界

App出海难免水土不服&#xff0c;商业模型、用户画像、增长方向没有一样是省心的&#xff0c;国内标配的DeepLink&#xff08;深度链接&#xff09;方案如果照搬出海同样无法达到最佳体验。 要知道国内外移动端生态是截然不同的&#xff0c;除了主流的URL Scheme和iOS Univers…

Ollama(1)知识点配置篇

ollama已经成功安装成功后&#xff0c;通常大家会对模型的下载位置和访问权限进行配置 1.模型下载位置修改 都是修改系统环境变量。 &#xff08;1&#xff09;默认下载位置 macOS: ~/.ollama/modelsLinux: /usr/share/ollama/.ollama/modelsWindows: C:\Users\你的电脑用户…

C# SolidWorks二次开发-实战1,找文件名不同实体相同的零件。

今天这篇文章话题来源于群里的聊天&#xff0c;在讨论有些插件功能的开发原理。 如标题&#xff0c;今天讲的是如何查找零件文件名不一样&#xff0c;但实际可能是同一个东西的办法。 - 题外话 熟悉Solidworks的人都知道&#xff0c;Solidworks有一个比较零件或者特征不同点的…

ES5时代的残党(被ES6淘汰的JS写法)

近年来&#xff0c;JavaScript语言经历了翻天覆地的变化。ES6(ECMAScript 2015)的发布标志着JavaScript进入了现代化时代&#xff0c;带来了大量新特性和更优雅的写法。但时至今日&#xff0c;许多开发者仍然固守着ES5时代的老旧模式&#xff0c;这不仅使代码显得过时&#xff…

【Python】4.字典和文件

文章目录 一、字典1、字典是什么&#xff1f;2、创建字典3、查找 key4、新增/修改元素5、删除元素6、遍历字典元素7、取出所有 key 和 value8、合法的 key 类型小结 二、文件1、文件是什么&#xff1f;2、文件路径3、文件操作1&#xff09;打开文件2&#xff09;关闭文件3&…

物流项目第十一期(智能调度之分配快递员)

本项目专栏&#xff1a; 物流项目_Auc23的博客-CSDN博客 整体核心业务流程 关键流程说明&#xff1a; 用户下单后&#xff0c;会产生取件任务&#xff0c;该任务也是由调度中心进行调度的订单转运单后&#xff0c;会发送消息到调度中心&#xff0c;在调度中心中对相同节点的运…

React 项目中封装 Excel 导入导出组件:技术分享与实践

文章目录 前言一、为什么需要封装 Excel 组件&#xff1f;二、技术选型三、核心实现1. 安装依赖2. 封装Excel导出3. 封装导入组件 &#xff08;UploadExcel&#xff09; 总结 前言 在 React 项目中&#xff0c;处理 Excel 文件的导入和导出是常见的业务需求。无论是导出报表数…

用calibredrv提取版图中指定类型cell,保留位置信息并输出新的gds

我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起来吧? 拾陆楼知识星球入口 现在有一个gds,其中的bump位置信息是我们需要的,如何从现有的gds中提取我们需要的部分呢? 需要用到工具calibredrv,如果数量少,可以用图形界面操作,方法如下: 01 打开gds calibredrv -m inp…

iOS 使用CocoaPods 添加Alamofire 提示错误的问题

Sandbox: rsync(59817) deny(1) file-write-create /Users/aaa/Library/Developer/Xcode/DerivedData/myApp-bpwnzikesjzmbadkbokxllvexrrl/Build/Products/Debug-iphoneos/myApp.app/Frameworks/Alamofire.framework/Alamofire.bundle把这个改成 no 2 设置配置文件

Python基本运算符

White graces&#xff1a;个人主页 &#x1f439;今日诗词:相恨不如潮有信&#xff0c;相思始觉海非深&#x1f439; ⛳️点赞 ☀️收藏⭐️关注&#x1f4ac;卑微小博主&#x1f64f; ⛳️点赞 ☀️收藏⭐️关注&#x1f4ac;卑微小博主&#x1f64f; 目录 &#x1f9ee; Pyt…

nginx: [emerg] bind() to 0.0.0.0:80 failed (10013: 80端口被占用

Nginx启动报错&#xff1a;nginx: [emerg] bind() to 0.0.0.0:80 failed (10013: An attempt was made to access a socket in a way forbidden by its access permissions) 这个报错代表80端口被占用 先查看占用80的端口 netstat -aon | findstr :80 把它杀掉&#xff0c;强…

vscode命令行debug

vscode命令行debug 一般命令行debug会在远程连服务器的时候用上&#xff0c;命令行debug的本质是在执行时暴露一个监听端口&#xff0c;通过进入这个端口&#xff0c;像本地调试一样进行。 这里提供两种方式&#xff1a; 直接在命令行中添加debugpy&#xff0c;适用于python…