PINN for PDE(偏微分方程)1 - 正向问题

article/2025/8/7 7:54:03

PINN for PDE(偏微分方程)1 - 正向问题

目录

  • PINN for PDE(偏微分方程)1 - 正向问题
    • 一、什么是PINN的正问题
    • 二、求解的实际例子
    • 三、基于Pytorch实现的代码 - 分解
      • 3.1 引入库函数
      • 3.2 设置超参数
      • 3.3 设计随机种子,确保复现结果的一致性
      • 3.4 对于条件等式生成对应的训练数据
      • 3.5 定义PINN网络
      • 3.6 定义求导函数
      • 3.7 定义损失函数
      • 3.8 模型训练
      • 3.9 绘制结果图像
    • 结果
      • 4.1 真实结果图
      • 4.2 PINN预测结果图
      • 4.3 两者误差图
    • 参考

一、什么是PINN的正问题

​ 在 PINN(Physics-Informed Neural Networks,物理信息神经网络)中,**正问题(Forward Problem)**指的是:

给定一个偏微分方程(PDE)形式边界条件初始条件,求解该方程在定义域内的解函数

PINN 框架下,求解正问题的基本步骤如下

  • 构建神经网络 u θ ( x , t ) u_\theta(x, t) uθ(x,t) 作为 PDE 解的近似;
  • 利用自动微分计算偏导数,将 PDE 表达为损失函数项;
  • 同时将边界条件和初始条件也转化为损失函数;
  • 通过最小化总损失来训练神经网络参数 θ \theta θ,使其近似满足物理规律。

正问题的特征是:问题的数学描述是充分确定的,即 PDE 与条件信息足够多,因此目标是通过神经网络来模拟已知物理过程。

二、求解的实际例子

偏微分方程如下:
∂ 2 u ∂ x 2 − ∂ 4 u ∂ y 4 = ( 2 − x 2 ) e − y \frac{\partial ^2u}{\partial x^2}-\frac{\partial ^4u}{\partial y^4}=\left( 2-x^2 \right) e^{-y} x22uy44u=(2x2)ey
考虑以下边界条件,
u y y ( x , 0 ) = x 2 u y y ( x , 1 ) = x 2 e u ( x , 0 ) = x 2 u ( x , 1 ) = x 2 e u ( 0 , y ) = 0 u ( 1 , y ) = e − y u_{yy}\left( x,0 \right) =x^2 \\ u_{yy}\left( x,1 \right) =\frac{x^2}{e} \\ u\left( x,0 \right) =x^2 \\ u\left( x,1 \right) =\frac{x^2}{e} \\ u\left( 0,y \right) =0 \\ u\left( 1,y \right) =e^{-y} uyy(x,0)=x2uyy(x,1)=ex2u(x,0)=x2u(x,1)=ex2u(0,y)=0u(1,y)=ey

以上偏微分方程真解为:
u ( x , y ) = x 2 e − y u(x,y)=x^2 e^{-y} u(x,y)=x2ey
x x x y y y 的区域范围均为 [ 0 , 1 ] [0,1] [0,1]

三、基于Pytorch实现的代码 - 分解

3.1 引入库函数

import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

3.2 设置超参数

# ====================== 超参数 ======================
epochs = 10000         # 训练轮数
h = 100               # 作图时的网格密度
N = 1000              # PDE残差点(训练内点)
N1 = 100              # 边界条件点
N2 = 1000             # 数据点(已知解)
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设置GPU运行

3.3 设计随机种子,确保复现结果的一致性

def setup_seed(seed):torch.manual_seed(seed)  # 设置 PyTorch 中 CPU 上的随机数种子,使得所有如 torch.rand()、torch.randn() 等函数在 CPU 上的随机数生成具有可重复性。torch.cuda.manual_seed_all(seed) # 设置所有 GPU 设备上的随机数种子。GPU 上也有自己的随机数生成器,用于模型参数初始化、Dropout 等操作。torch.backends.cudnn.deterministic = True # 设置 cuDNN 后端为确定性模式。'''cuDNN 是 NVIDIA 为深度学习优化的加速库,但它为了加速有时使用了非确定性算法(比如卷积时自动选择最快的实现方式,某些可能会导致浮点计算顺序不同)。这个设置会强制它只使用确定性算法(牺牲一些速度),确保每次前向/反向传播都一致。'''# 设置随机数种子
setup_seed(888888)

3.4 对于条件等式生成对应的训练数据

'''
# 下面这些,都是在构建数据集,对每个微分方程(包括边界点和中间点)都构建对应的数据集,包含(x,y,对应的条件值(微分方程的真值部分))
'''
# Domain and Sampling,内点采样
def interior(n=N):# 生成 PDE(偏微分方程)区域内的训练点# 随机生成 n 个 x 坐标(范围在 [0, 1))x = torch.rand(n, 1).to(device)# 随机生成 n 个 y 坐标(范围在 [0, 1))y = torch.rand(n, 1).to(device)# 计算对应点的“条件值”(可能是解析解、真值或用于损失函数的目标值)# 这里定义为 cond = (2 - x²) * exp(-y),# 是里面的点偏导等于的一个值cond = (2 - x ** 2) * torch.exp(-y)# 返回的 x 和 y 启用自动求导功能,以便后续可用于计算梯度(如 PDE 中的导数)return x.requires_grad_(True), y.requires_grad_(True), cond# 下边界条件的第一个,对y的二阶导等于 x 平方
def down_yy(n=N1):# 下边界上的 u_yy(x, 0) = x² 条件# 随机生成 n 个 x 坐标,范围在 [0, 1)x = torch.rand(n, 1).to(device)# y 坐标全为 0,表示这是在 y=0 的边界上y = torch.zeros_like(x).to(device)# 条件值:u_yy(x, 0) = x²,即函数在边界上的二阶偏导值(对 y 的二阶导数)等于 x²cond = x ** 2# 返回启用自动求导的 x、y,以及边界条件值 condreturn x.requires_grad_(True), y.requires_grad_(True), cond# 这个是边界条件的第二个,对y的二阶导等于x平方除以e
def up_yy(n=N1):# 边界 u_yy(x,1)=x^2/ex = torch.rand(n, 1).to(device)y = torch.ones_like(x).to(device)cond = x ** 2 / torch.ereturn x.requires_grad_(True), y.requires_grad_(True), cond# 这个是边界条件的第三个,对u(x,0)等于x平方
def down(n=N1):# 边界 u(x,0)=x^2x = torch.rand(n, 1).to(device)y = torch.zeros_like(x).to(device)cond = x ** 2return x.requires_grad_(True), y.requires_grad_(True), conddef up(n=N1):# 边界 u(x,1)=x^2/ex = torch.rand(n, 1).to(device)y = torch.ones_like(x).to(device)cond = x ** 2 / torch.ereturn x.requires_grad_(True), y.requires_grad_(True), conddef left(n=N1):# 边界 u(0,y)=0y = torch.rand(n, 1).to(device)x = torch.zeros_like(y).to(device)cond = torch.zeros_like(x)return x.requires_grad_(True), y.requires_grad_(True), conddef right(n=N1):# 边界 u(1,y)=e^(-y)y = torch.rand(n, 1).to(device)x = torch.ones_like(y).to(device)cond = torch.exp(-y)return x.requires_grad_(True), y.requires_grad_(True), cond'''
# 真实数据模拟
'''
# 真实解的数据点(监督学习),也就是构建真实数据,(x,y,value),因为u=x^2 * exp(-y) 是解析解,所以是利用这个来模拟真实数据
def data_interior(n=N2):# 内点x = torch.rand(n, 1).to(device)y = torch.rand(n, 1).to(device)cond = (x ** 2) * torch.exp(-y)return x.requires_grad_(True), y.requires_grad_(True), cond'''
# 因此综合来说,解决一个PINN的正向问题,需要对应的真实数据,(输入(x,y),输出(value)),边界条件的数据,(x_边界,y_边界,value_边界条件)
# 训练的时候,输入网络输入信息(比如位置或者时间信息等等),输出为值,此时计算其数据loss,
如果是边界的位置上,需要计算其边界loss(因为正常来说,我们能拿到的数据都是中间的那些真实数据,我们都需要手动去构建边界的数据去使其满足边界条件)。
'''

3.5 定义PINN网络

# Neural Network,简单的一个的神经网络。
class MLP(torch.nn.Module):def __init__(self):super(MLP, self).__init__()self.net = torch.nn.Sequential(torch.nn.Linear(2, 32),torch.nn.Tanh(),torch.nn.Linear(32, 32),torch.nn.Tanh(),torch.nn.Linear(32, 32),torch.nn.Tanh(),torch.nn.Linear(32, 32),torch.nn.Tanh(),torch.nn.Linear(32, 1))def forward(self, x):return self.net(x)
# MSEloss,其实就是平方损失,L2距离
# Loss
loss = torch.nn.MSELoss()

3.6 定义求导函数


def gradients(u, x, order=1):"""计算函数 u 对变量 x 的高阶导数。参数:u (torch.Tensor): 待求导的函数输出。x (torch.Tensor): 自变量。order (int): 导数的阶数,默认为 1。返回:torch.Tensor: u 对 x 的导数,阶数为 order。""""""grad函数参数解释:参数名	解释u (outputs)	待求导的结果(标量或向量张量),即你想知道它对某些变量的导数。x (inputs)	自变量,通常是你需要对其求导的张量(需要 requires_grad=True)。grad_outputs=torch.ones_like(u)	通常用于处理非标量输出(比如 u 是向量)。注意:PyTorch 默认只能对标量求导,如果 u 是向量,grad_outputs 代表“如何把 u 合成一个标量”(通过对每个分量乘以 1,然后求和,相当于 $\sum u_i$)。所以这个值填写的是u里面每个数值的权重比例create_graph=True	创建一个可用于高阶导数的计算图(即反向传播的图也支持再次求导)。必须设置为 True 才能求二阶导。only_inputs=True	只计算 inputs 的梯度。一般设为 True。返回的是元组,里面是一个一个tensor,代表了每个input的梯度,如果input是[x,y],那返回的是(tensor([3.]), tensor([2.])),这里只有一个,因此,返回第一个的(也就是对x求导)就行"""if order == 1:return torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u),create_graph=True,only_inputs=True, )[0]# 嵌套求导else:return gradients(gradients(u, x), x, order=order - 1)

3.7 定义损失函数

# 以下7个损失是PDE损失,对每个构造的数据,进行计算loss,包含了6个边界损失和一个数据损失。
def l_interior(u):# 损失函数L1x, y, cond = interior()uxy = u(torch.cat([x, y], dim=1))return loss(gradients(uxy, x, 2) - gradients(uxy, y, 4), cond)def l_down_yy(u):# 损失函数L2x, y, cond = down_yy()uxy = u(torch.cat([x, y], dim=1))return loss(gradients(uxy, y, 2), cond)def l_up_yy(u):# 损失函数L3x, y, cond = up_yy()uxy = u(torch.cat([x, y], dim=1))return loss(gradients(uxy, y, 2), cond)def l_down(u):# 损失函数L4x, y, cond = down()uxy = u(torch.cat([x, y], dim=1))return loss(uxy, cond)def l_up(u):# 损失函数L5x, y, cond = up()uxy = u(torch.cat([x, y], dim=1))return loss(uxy, cond)def l_left(u):# 损失函数L6x, y, cond = left()uxy = u(torch.cat([x, y], dim=1))return loss(uxy, cond)def l_right(u):# 损失函数L7x, y, cond = right()uxy = u(torch.cat([x, y], dim=1))return loss(uxy, cond)# 构造数据损失
def l_data(u):# 损失函数L8x, y, cond = data_interior()uxy = u(torch.cat([x, y], dim=1))return loss(uxy, cond)

3.8 模型训练

# Trainingu = MLP().to(device) # 定义网络
opt = torch.optim.Adam(params=u.parameters()) # 定义优化器for i in range(epochs):opt.zero_grad() # 优化器清除梯度l = l_interior(u) \+ l_up_yy(u) \+ l_down_yy(u) \+ l_up(u) \+ l_down(u) \+ l_left(u) \+ l_right(u) \+ l_data(u)l.backward() # 损失反向传播opt.step() # 优化器,参数更新if i % 100 == 0: # 每一百次,输出现在的进度print(i)

3.9 绘制结果图像

# Inference
'''
# 推理,对空间内随便取点,然后利用解析解,解出真实值,然后利用网络得到数值解,最后计算每个之间从距离,得到每个位置的误差,最后绘制出了三个图,真实值图,预测值图,误差值图
'''xc = torch.linspace(0, 1, h).to(device)
xm, ym = torch.meshgrid(xc, xc, indexing='ij')
xx = xm.reshape(-1, 1)
yy = ym.reshape(-1, 1)
xy = torch.cat([xx, yy], dim=1)
u_pred = u(xy)
u_real = xx * xx * torch.exp(-yy)u_error = torch.abs(u_pred-u_real)
u_pred_fig = u_pred.reshape(h,h)
u_real_fig = u_real.reshape(h,h)
u_error_fig = u_error.reshape(h,h)
print("Max abs error is: ", float(torch.max(torch.abs(u_pred - xx * xx * torch.exp(-yy)))))
# 仅有PDE损失    Max abs error:  0.004852950572967529
# 带有数据点损失  Max abs error:  0.0018916130065917969# 作PINN数值解图
fig = plt.figure()
ax = Axes3D(fig)
fig.add_axes(ax)
ax.plot_surface(xm.cpu().detach().numpy(), ym.cpu().detach().numpy(), u_pred_fig.cpu().detach().numpy())
ax.text2D(0.5, 0.9, "PINN", transform=ax.transAxes)
plt.show()
fig.savefig("PINN solve.png")# 作真解图
fig = plt.figure()
ax = Axes3D(fig)
fig.add_axes(ax)
ax.plot_surface(xm.cpu().detach().numpy(), ym.cpu().detach().numpy(), u_real_fig.cpu().detach().numpy())
ax.text2D(0.5, 0.9, "real solve", transform=ax.transAxes)
plt.show()
fig.savefig("real solve.png")# 误差图
fig = plt.figure()
ax = Axes3D(fig)
fig.add_axes(ax)
ax.plot_surface(xm.detach().cpu().numpy(), ym.cpu().detach().numpy(), u_error_fig.cpu().detach().numpy())
ax.text2D(0.5, 0.9, "abs error", transform=ax.transAxes)
plt.show()
fig.savefig("abs error.png")

结果

4.1 真实结果图

在这里插入图片描述

4.2 PINN预测结果图

在这里插入图片描述

4.3 两者误差图

在这里插入图片描述

参考

该篇主要内容基础:PINN解偏微分方程实例1-CSDN博客

github资料:PINNs-for-PDE/PINN_exp1_cs at main · YanxinTong/PINNs-for-PDE


http://www.hkcw.cn/article/yrHfGHvJDj.shtml

相关文章

Adobe LiveCycle ES、LiveCycle DS 与 BlazeDS 关系解析与比较

Adobe LiveCycle 系列产品是企业级解决方案的重要组成部分,但在命名和功能上常常造成混淆。 产品定义 Adobe LiveCycle ES (Enterprise Suite) LiveCycle ES是一个基于SOA的平台,部署在J2EE应用服务器上。它提供开发、部署、配置和执行服务的功能。基…

Redis最佳实践——性能优化技巧之监控与告警详解

Redis 在电商应用的性能优化技巧之监控与告警全面详解 一、监控体系构建 1. 核心监控指标矩阵 指标类别关键指标计算方式/说明健康阈值(参考值)内存相关used_memoryINFO Memory 获取不超过 maxmemory 的 80%mem_fragmentation_ratio内存碎片率 used_m…

使用 DeepSeek API 搭建智能体《无间》- 卓伊凡的完整指南 -优雅草卓伊凡

使用 DeepSeek API 搭建智能体《无间》- 卓伊凡的完整指南 -优雅草卓伊凡 作者:卓伊凡 前言:为什么选择 DeepSeek API,而非私有化部署? 在开始搭建智能体之前,我想先说明 为什么推荐使用 DeepSeek API,而…

lidar和imu的标定(三)平面约束的方法

看了一篇:基于平面特征的地面机器人雷达-惯性里程计外参标定方法; 它和GRIL-Calib不同之处,就是采用了平面优化和栅格优化。 栅格优化就不介绍了,感觉工程上不。 平面优化则很容易懂,就是标定出来了激光雷达到IMU之…

CppCon 2014 学习: C++ on Mars

主要介绍了如何在火星探测器的飞行软件中使用 C。: 介绍了火星探测器(如 Sojourner, Spirit, Opportunity, Curiosity, Perseverance)。强调其复杂性和自主性。 延迟的现实:地球与火星之间的通信时延 单程信号延迟为 4 到 22 分…

【MFC】初识MFC

目录 01 模态和非模态对话框 02 静态文本 static text 01 模态和非模态对话框 首先我们需要知道模态对话框和非模态对话框的区别: 模态对话框是一种阻塞时对话框,它会阻止用户与应用程序的其他部分进行交互,直到用户与该对话框进行交互并关…

C#数字图像处理(二)

文章目录 1.灰度直方图1.1 灰度直方图定义1.2 灰度直方图编程实例 2.线性点运算2.1线性点运算定义2.2 线性点运算编程实例 3.全等级直方图灰度拉伸3.1 灰度拉伸定义3.2 灰度拉伸编程实例 4.直方图均衡化4.1 直方图均衡化定义4.2 直方图均衡化编程实例 5.直方图匹配5.1 直方图匹…

SOC-ESP32S3部分:24-WiFi配网

飞书文档https://x509p6c8to.feishu.cn/wiki/OD4pwTE8Jift2IkYKdNcSckOnfd 对于WiFi类设备,最重要的功能之一就是联网,WiFi需要联网,就需要知道我们家里路由的账号和密码,像手机类型的高端设备没什么问题,我们可以直接…

使用langchain实现五种分块策略:语义分块、父文档分块、递归分块、特殊格式、固定长度分块

文章目录 分块策略详解1. 固定长度拆分(简单粗暴)2. 递归字符拆分(智能切割)3. 特殊格式拆分(定向打击)Markdown分块 4. 语义分割(更智能切割)基于Embedding的语义分块基于模型的端到…

(七)【Linux进程的创建、终止和等待】

1 进程创建 1.1 在谈fork函数 #include <unistd.h> // 需要的头文件// 返回值&#xff1a;子进程中返回0&#xff0c;父进程返回子进程id&#xff0c;出错返回-1调用fork函数后&#xff0c;内核做了下面的工作&#xff1a; 创建了一个子进程的PCB结构体、并拷贝一份相…

EMO2:基于末端执行器引导的音频驱动虚拟形象视频生成

今天带来EMO2&#xff08;全称End-Effector Guided Audio-Driven Avatar Video Generation&#xff09;是阿里巴巴智能计算研究院研发的创新型音频驱动视频生成技术。该技术通过结合音频输入和静态人像照片&#xff0c;生成高度逼真且富有表现力的动态视频内容&#xff0c;值得…

Baklib知识中台加速企业服务智能化实践

知识中台架构体系构建 Baklib 通过构建多层级架构体系实现知识中台的底层支撑&#xff0c;其核心包含数据采集层、知识加工层、服务输出层及智能应用层。在数据采集端&#xff0c;系统支持对接CRM、ERP等业务系统&#xff0c;结合NLP技术实现非结构化数据的自动抽取&#xff1…

GpuGeek 618大促引爆AI开发新体验

随着生成式AI技术迅猛发展&#xff0c;高效可靠的算力资源已成为企业和开发者突破创新瓶颈的战略支点。根据赛迪顾问最新发布的《2025中国AI Infra平台市场发展研究报告》显示&#xff0c;2025年中国生成式人工智能企业应用市场规模将达到629.0亿元&#xff0c;作为AI企业级应用…

Linux线程同步实战:多线程程序的同步与调度

个人主页&#xff1a;chian-ocean 文章专栏-Linux Linux线程同步实战&#xff1a;多线程程序的同步与调度 个人主页&#xff1a;chian-ocean文章专栏-Linux 前言&#xff1a;为什么要实现线程同步线程饥饿&#xff08;Thread Starvation&#xff09;示例&#xff1a;抢票问题 …

任务22:创建示例Django项目,展示ECharts图形示例

任务描述 知识点&#xff1a; DjangoECharts 重 点&#xff1a; DjangoECharts 内 容&#xff1a; 创建Django项目掌握ECharts绘制图形通过官网ECharts示例&#xff0c;完成Django项目&#xff0c;并通过配置项进行修改图形 任务指导 1、创建web_test的Django项目 2…

深度学习入门Day1--Python基础

一、基础语法 1.变量 python是“动态类型语言”的编程语言。用户无需明确指出x的类型是int。 x10 #初始化 print(x) #输出x x100 #赋值 print(x) print(type(x))#输出x的类型<class int>2.算术计算 >>>4*5 >20 >>>3**3#**表示乘方&#xff08;3…

九坤:熵最小化加速LLM收敛

&#x1f4d6;标题&#xff1a;One-shot Entropy Minimization &#x1f310;来源&#xff1a;arXiv, 2505.20282 &#x1f31f;摘要 我们训练了 13,440 个大型语言模型&#xff0c;发现熵最小化只需要一个未标记的数据和 10 步优化&#xff0c;以实现比使用数千个数据获得的…

微服务面试(分布式事务、注册中心、远程调用、服务保护)

1.分布式事务 分布式事务&#xff0c;就是指不是在单个服务或单个数据库架构下&#xff0c;产生的事务&#xff0c;例如&#xff1a; 跨数据源的分布式事务跨服务的分布式事务综合情况 我们之前解决分布式事务问题是直接使用Seata框架的AT模式&#xff0c;但是解决分布式事务…

儿童节快乐,聊聊数字的规律和同余原理

某年的6月1日是星期日。那么&#xff0c;同一年的6月30日是星期几&#xff1f; 星期是7天一个循环。所以说&#xff0c;这一天是星期几&#xff0c;7天之后同样也是星期几。而6月30日是在6月1日的29天之后&#xff1a;29 7 4 ... 1用29除以7&#xff0c;可以得出余数为1。而…

视觉分析明火检测助力山东化工厂火情防控

视觉分析技术赋能化工厂火情防控&#xff1a;从山东事故看明火与烟雾检测的应用价值 一、背景&#xff1a;山东化工事故中的火情防控痛点 近期&#xff0c;山东高密友道化学有限公司、淄博润兴化工科技有限公司等企业接连发生爆炸事故&#xff0c;暴露出传统火情防控手段的局…