Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于三类边界条件 #1083

Open
ZBDat opened this issue Feb 24, 2025 · 15 comments
Open

关于三类边界条件 #1083

ZBDat opened this issue Feb 24, 2025 · 15 comments

Comments

@ZBDat
Copy link

ZBDat commented Feb 24, 2025

请提出你的问题 Please ask your question

请问边界条件类是否支持neumann边界条件?第三类边界条件呢?比如du/dx = h (u - u0)

@HydrogenSulfate
Copy link
Collaborator

HydrogenSulfate commented Feb 24, 2025

你好,PaddleScience里并没有直接实现某一种类型的边界条件,而是希望用户基于自己的公式手动编写,从而支持任意类型的边界条件。
根据你给的公式,你需要的边界条件大概是这样:

x = sympy.symbols('x')
u = sympy.Function('u')(x)

def du_dx_pred(data):
   # 注意这里的data['h']和data['u0']是你需要在dataloader中准备好的输入数据,data['u']是模型本身的输出,会自动添加到data中
   return data['h'] * (data['u'] - data['u0'])

bc_inlet = ppsci.constraint.BoundaryConstraint(
    {"du/dx": u.diff(x)},  # 等式左侧公式
    {"du/dx": du_dx_pred}, # 等式右侧计算公式
    your_geometry_object,
    {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc_inlet},
    ppsci.loss.MSELoss("sum"),
    name="inlet",
)

我们有很多的带有复杂边界条件的案例可以参考,如:https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/aneurysm/#342

@pecanjk
Copy link

pecanjk commented Mar 4, 2025

如果一开始方程的定义是用的ppsci里带的方程,并没有自己去使用sympy定义输入输出变量,比如equation = {"laplace": ppsci.equation.Laplace(dim=2)}},边界条件有导数,比如du/dx, du/dy,那怎么做呢?

是再定义一遍

x = sympy.symbols('x')
y = sympy.symbols('y')
u = sympy.Function('u')(y)
bc_inlet = ppsci.constraint.BoundaryConstraint(
    {"du/dx": u.diff(x), "du/dy": u.diff(y), }, 
    {"du/dx": 0,"du/dy": 0}, 
    your_geometry_object,
    {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc_inlet},
    ppsci.loss.MSELoss("sum")
)

还是能够下面这样?

bc_inlet = ppsci.constraint.BoundaryConstraint(
    {"du/dx": lambda out: out["u"].diff(x), "du/dy": lambda out: out["u"].diff(y) },  #会报错,没有定义x,y
    {"du/dx": 0,"du/dy": 0}, 
    your_geometry_object,
    {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc_inlet},
    ppsci.loss.MSELoss("sum")
)

@pecanjk
Copy link

pecanjk commented Mar 4, 2025

如果一开始方程的定义是用的ppsci里带的方程,并没有自己去使用sympy定义输入输出变量,比如equation = {"laplace": ppsci.equation.Laplace(dim=2)}},边界条件有导数,比如du/dx, du/dy,那怎么做呢?

是再定义一遍

x = sympy.symbols('x')
y = sympy.symbols('y')
u = sympy.Function('u')(y)
bc_inlet = ppsci.constraint.BoundaryConstraint(
    {"du/dx": u.diff(x), "du/dy": u.diff(y), }, 
    {"du/dx": 0,"du/dy": 0}, 
    your_geometry_object,
    {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc_inlet},
    ppsci.loss.MSELoss("sum")
)

还是能够下面这样?

bc_inlet = ppsci.constraint.BoundaryConstraint(
    {"du/dx": lambda out: out["u"].diff(x), "du/dy": lambda out: out["u"].diff(y) },  #会报错,没有定义x,y
    {"du/dx": 0,"du/dy": 0}, 
    your_geometry_object,
    {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc_inlet},
    ppsci.loss.MSELoss("sum")
)

测试了一下,如果这样就可以。不是很理解原理

equation = {"laplace": ppsci.equation.Laplace(dim=2)}}
u = sympy.symbols('u')
x = sympy.symbols('x')
y = sympy.symbols('y')
bc_inlet = ppsci.constraint.BoundaryConstraint(
    {"du/dx": u.diff(x), "du/dy": u.diff(y) }, 
    {"du/dx": 0,"du/dy": 0}, 
    your_geometry_object,
    {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc_inlet},
    ppsci.loss.MSELoss("sum")
)

@HydrogenSulfate
Copy link
Collaborator

HydrogenSulfate commented Mar 4, 2025

@pecanjk ppsci中的XXXConstraint有两个关键参数:output_exprlabel_dict,两个参数都是一个字典,它们的key表示需要被约束的量的名字,比如"du/dx",必须是字符串类型;而value则表示这个被约束的量是如何计算的,所以value可以是一个函数,或者是一个sympy的表达式,如果是函数,在这里被调用,expr就是你传入的函数,

for name, expr in expr_dict.items():
output_dict[name] = expr(data_dict)

如果是sympy表达式,那么就会由ppsci.lambdify将其转换为函数,如下所示

def convert_expr(
container_dict: Union[
Dict[str, ppsci.constraint.Constraint],
Dict[str, ppsci.validate.Validator],
Dict[str, ppsci.visualize.Visualizer],
]
) -> None:
for container in container_dict.values():
exprs = [
expr
for expr in container.output_expr.values()
if isinstance(expr, sp.Basic)
]
if len(exprs) > 0:
funcs = ppsci.lambdify(
exprs,
self.model,
extra_parameters=extra_parameters,
# graph_filename=osp.join(self.output_dir, "symbolic_graph_visual"), # HACK: Activate it for DEBUG.
fuse_derivative=True,
)
ind = 0
for name in container.output_expr:
if isinstance(container.output_expr[name], sp.Basic):
container.output_expr[name] = funcs[ind]
# FIXME: Equation with parameter not support yet.
# if self.world_size > 1:
# container.output_expr[name] = dist_wrapper(
# container.output_expr[name]
# )
ind += 1
if self.constraint:
convert_expr(self.constraint)
if self.validator:
convert_expr(self.validator)
if self.visualizer:
convert_expr(self.visualizer)

然后再在expression.py中被同样的方式调用

你写的那份能跑通的代码其实是不正确的,因为你写的u是一个sympy变量对象,而不是一个sympy函数对象,所以u.diff(x)其实结果是0,而不是表达式,所以你的代码仅仅能跑通,而并不是正确的。

需要自定义方程,请参考我们的文档:https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/development/#241,【2.4 构建方程】章节

@pecanjk
Copy link

pecanjk commented Mar 4, 2025

方程自定义后,还是没法拿到输出变量的一阶导,除非自定义方程中直接把一阶导数也用self.add_equation("du/dx", expr1_compute_func)添加进去,但是这样还不如我就不用ppsci自带的那些方程了,直接全都在自己代码里用symbol写。

我看了nvidia modulus的案例的modulus-sym/examples/wave_equation/wave_1d.py,对于wave equation, du/dt(0)一阶导的初值边界条件会经常用,它提供了'u__t'这样的计算
https://github.com/NVIDIA/modulus-sym/blob/24813033e0bb3e68604e5e995bb364c62573ad3a/examples/wave_equation/wave_1d.py#L57C1-L65C6

@HydrogenSulfate
Copy link
Collaborator

方程自定义后,还是没法拿到输出变量的一阶导,除非自定义方程中直接把一阶导数也用self.add_equation("du/dx", expr1_compute_func)添加进去,但是这样还不如我就不用ppsci自带的那些方程了,直接全都在自己代码里用symbol写。

我看了nvidia modulus的案例的modulus-sym/examples/wave_equation/wave_1d.py,对于wave equation, du/dt(0)一阶导的初值边界条件会经常用,它提供了'u__t'这样的计算 https://github.com/NVIDIA/modulus-sym/blob/24813033e0bb3e68604e5e995bb364c62573ad3a/examples/wave_equation/wave_1d.py#L57C1-L65C6

输出变量的一阶导数是如何计算的,需要用户指定,modulus里用的是双下划线 u__t,而ppsci里用的是 u.diff(t)

@HydrogenSulfate
Copy link
Collaborator

或者你可以提供你自己的写的方程代码,我可以帮你看一下

@pecanjk
Copy link

pecanjk commented Mar 4, 2025

@HydrogenSulfate

您看一下,wave equation

##main.py
import hydra
import numpy as np
from omegaconf import DictConfig
import paddle
import ppsci

from wave_equation import Wave
c=1
# compute ground truth function
def u_solution_func(data):
    """compute ground truth for u as label data"""
    t, x = data["t"], data["x"]
    return np.sin(np.pi * x) * np.cos(c * np.pi * t) + np.sin(2 * np.pi * x) * np.cos(2 * c * np.pi * t)

def train(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.misc.set_random_seed(cfg.seed)

    # set equation
    equation = {"wave": Wave(dim=1, time=True, c=c)}

    # set geometry domain
    t_domain = ppsci.geometry.TimeDomain(**cfg.TIME_COORD) # 时间域
    geom_domain = ppsci.geometry.Interval(**cfg.GEOM_COORD)
    domain = ppsci.geometry.TimeXGeometry(t_domain, geom_domain)
    
    # set model
    model = ppsci.arch.MLP(**cfg.MODEL)

    train_dataloader_cfg = {
        "dataset": "IterableNamedArrayDataset",
        "iters_per_epoch": cfg.TRAIN.iters_per_epoch,
    }
    # set constraint
    pde_constraint = ppsci.constraint.InteriorConstraint(
        equation["wave"].equations,
        {"wave": 0},
        domain,
        {**train_dataloader_cfg, "batch_size": cfg.NPOINT_TOTAL},
        ppsci.loss.CausalMSELoss(20,"mean",2),
        evenly=True,
        name="EQ",
    )
    bc = ppsci.constraint.BoundaryConstraint(
        {"u": lambda out: out["u"]},
        {"u": 0},
        domain,
        {**train_dataloader_cfg, "batch_size": cfg.NPOINT_BC},
        ppsci.loss.MSELoss("mean"),
        evenly=True,
        name="BC",
    )
    ic = ppsci.constraint.InitialConstraint(
        {"u0": lambda out: out["u"], "u_t0": u.diff(t)}, ##这里需要一阶导的计算
        {"u0": lambda _in: np.sin(np.pi * _in["x"]) + np.sin(2* np.pi * _in["x"]), "u_t0": 0},
        domain,
        {**train_dataloader_cfg, "batch_size": cfg.NPOINT_IC},
        ppsci.loss.MSELoss("mean"),
        evenly=True,
        name="IC",
    )

    # wrap constraints together
    constraint = {
        pde_constraint.name: pde_constraint,
        bc.name: bc,
        ic.name: ic,
    }

    # set optimizer
    optimizer = ppsci.optimizer.Adam(learning_rate=cfg.TRAIN.learning_rate)(model)


    # set validator
    mse_metric = ppsci.validate.GeometryValidator(
        {"u": lambda out: out["u"]},
        {"u": u_solution_func},
        domain,
        {
            "dataset": "IterableNamedArrayDataset",
            "total_size": cfg.NPOINT_TOTAL,
        },
        ppsci.loss.MSELoss(),
        evenly=True,
        metric={"MSE": ppsci.metric.MSE()},
        with_initial=False,
        name="MSE_Metric",
    )
    validator = {mse_metric.name: mse_metric}

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        constraint,
        output_dir=cfg.output_dir,
        optimizer=optimizer,
        validator=validator,
        cfg=cfg
    )
    # train model
    solver.train()

@hydra.main(version_base=None, config_path="./conf", config_name="wave_1d.yaml")
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    # elif cfg.mode == "export":
    #     export(cfg)
    # elif cfg.mode == "infer":
    #     inference(cfg)
    else:
        raise ValueError(
            f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
        )

if __name__ == "__main__":
    main()
##自定义的wave_equation
from __future__ import annotations
from typing import Optional
from typing import Tuple
from ppsci.equation import PDE

class Wave(PDE):
    def __init__(self, dim: int=1, time: bool=True, c = 1.0, detach_keys: Optional[Tuple[str, ...]] = None):
        super().__init__()
        self.detach_keys = detach_keys
        self.dim = dim
        self.time = time

        t, x, y, z = self.create_symbols("t x y z")

        if dim == 1:
            invars = (x,)
        elif dim == 2:
            invars = (x, y)
        elif dim == 3:
            invars = (x, y, z)
        else:
            raise ValueError("Dimension should be 1, 2 or 3.")
        
        if time:
            invars = (t,) + invars
        
        u = self.create_function("u", invars)

        if time:
            u_tt = u.diff(t,2)
            coord_var = invars[1:]
        else:
            u_tt = 0
            coord_var = invars
        
        poisson = 0
        for invar in coord_var:
            poisson += u.diff(invar, 2)
        
        self.add_equation("wave", u_tt - c**2 * poisson)
        
        self._apply_detach()
defaults:
  - ppsci_default
  - TRAIN: train_default
  - TRAIN/ema: ema_default
  - TRAIN/swa: swa_default
  - EVAL: eval_default
  - INFER: infer_default
  - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
  - _self_

hydra:
  run:
    # dynamic output directory according to running time and override name
    dir: outputs_wave_1d/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
  job:
    name: ${mode} # name of logfile
    chdir: false # keep current working directory unchanged
  callbacks:
    init_callback:
      _target_: ppsci.utils.callbacks.InitCallback
  sweep:
    # output directory for multirun
    dir: ${hydra.run.dir}
    subdir: ./

# general settings
mode: train # running mode: train/eval
seed: 201
output_dir: ${hydra:run.dir}
log_freq: 20

# set geometry
GEOM_COORD:
  l: 0
  r: 1

TIME_COORD:
  t0: 0.0
  t1: 1.0
  time_step: 0.05

NPOINT_TOTAL: 20000
NPOINT_BC: 100
NPOINT_IC: 200


# model settings
MODEL:
  input_keys: ["t", "x"]
  output_keys: ["u"]
  num_layers: 3
  hidden_size: 60
  activation: "tanh"


# training settings
TRAIN:
  epochs: 4000
  iters_per_epoch: 1
  eval_during_train: true
  eval_freq: 100
  learning_rate: 1.0e-3
  pretrained_model_path: null

# evaluation settings
EVAL:
  pretrained_model_path: null

# inference settings
INFER:
  pretrained_model_path: null

@HydrogenSulfate
Copy link
Collaborator

HydrogenSulfate commented Mar 4, 2025

看到你的方程仅定义在Wave中,所以你的sympy函数、变量也只有在Wave这个PDE中生效,所以可以在main.py中再次定义函数u、变量x、t,然后使用;或者用ppsci.autodiff下的两个接口也可以

## main.py
### 第一种方式
t, x = sympy.symbols('t x')
u = sympy.Function('u')(t, x)
ic = ppsci.constraint.InitialConstraint(
    {"u0": lambda out: out["u"], "u_t0": u.diff(t)},
    {"u0": lambda _in: np.sin(np.pi * _in["x"]) + np.sin(2* np.pi * _in["x"]), "u_t0": 0},
    domain,
    {**train_dataloader_cfg, "batch_size": cfg.NPOINT_IC},
    ppsci.loss.MSELoss("mean"),
    evenly=True,
    name="IC",
)

### 第二种方式
from ppsci.autodiff import jacobian, hessian

ic = ppsci.constraint.InitialConstraint(
    {"u0": lambda out: out["u"], "u_t0": lambda out: jacobian(out["u"], out["t"])},
    {"u0": lambda _in: np.sin(np.pi * _in["x"]) + np.sin(2* np.pi * _in["x"]), "u_t0": 0},
    domain,
    {**train_dataloader_cfg, "batch_size": cfg.NPOINT_IC},
    ppsci.loss.MSELoss("mean"),
    evenly=True,
    name="IC",
)

@pecanjk
Copy link

pecanjk commented Mar 4, 2025

感谢,能否加入官方案例中去,目前还没有wave equation的支持

@HydrogenSulfate
Copy link
Collaborator

HydrogenSulfate commented Mar 4, 2025

感谢,能否加入官方案例中去,目前还没有wave equation的支持

加入官方案例的话需要开发者提交PR,包含代码、文档、训练权重,可以参考:https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/development/#3。如果你有时间和兴趣的话,可以把你的这个案例完善一下,然后提交PR到这个仓库里,我会帮忙review

@ZBDat
Copy link
Author

ZBDat commented Mar 10, 2025

非常感谢讲解 @HydrogenSulfate ,现在我有点了解应该怎么写边界条件这部分了。如果我的约束条件中的u也是网络输出,而我希望约束du/dx,是不是能像下面这么写:

bc = ppsci.constraint.BoundaryConstraint(
    {"du/dx": lambda out: out['u'].diff(x)},
    {"du/dx": lambda out: h * (out['u'] - u0)},
    ....
)

其中h和u0是定义好的常量。

@HydrogenSulfate
Copy link
Collaborator

HydrogenSulfate commented Mar 10, 2025

你理解的是正确的,但是需要注意一下sympy、Tensor两种不同类型下的微分写法,out是Dict[str, Tensor]类型,所以你下面这行代码需要改一下

{"du/dx": lambda out: out['u'].diff(x)},

改为:
lambda out: jacobian(out['u'], out['x'])(虽然'x'是输入,但是也会放到out里面,这里的out名字改为data其实更合)适
或者是: {"du/dx": u.diff(x)},其中u是sympy.Function类型,x是sympy.Symbol类型

Tensor和sympy两种写法虽然可以同时共存使用,但这是两套体系,是不能相互混用对方的函数

然后下面这行没问题

{"du/dx": lambda out: h * (out['u'] - u0)},

@ZBDat
Copy link
Author

ZBDat commented Mar 10, 2025

@HydrogenSulfate
不好意思又麻烦了。请帮忙看下我这段代码有什么问题:

def main(cfg: DictConfig):
    model = ppsci.arch.MLP(**cfg.MODEL)
    geom = {"cuboid": ppsci.geometry.Cuboid(
        (-14.0, -2.0, -1.0),
        (14.0, 2.0, 1.0))
    }
    train_dataloader_cfg = {
        "dataset": "IterableNamedArrayDataset",
        "iters_per_epoch": cfg.TRAIN.iters_per_epoch,
    }

    equation = {"heat_pde": ppsci.equation.HeatTransfer(alpha=45.0, dim=3)}

    pde_constraint = ppsci.constraint.InteriorConstraint(
        equation["heat_pde"].equations,
        {"HeatTransfer": 0},
        geom["cuboid"],
        {**train_dataloader_cfg, "batch_size": 10},
        ppsci.loss.MSELoss("mean"),
        evenly=True,
        name="EQ",
    )
    bc_top = ppsci.constraint.BoundaryConstraint(
        {"du/dx": lambda out: jacobian(out['u'], out['z'])},
        {"du/dx": lambda out: 13.0 * (out["u"] - 25.0)},
        geom["cuboid"],
        {**train_dataloader_cfg, "batch_size": 10},
        ppsci.loss.MSELoss("mean"),
        weight_dict={"u": cfg.TRAIN.weight.bc_top},
        criteria=lambda x, y, z: np.isclose(y, 1),
        name="BC_top",
    )

    constraint = {
        pde_constraint.name: pde_constraint,
        bc_top.name: bc_top,
    }


if __name__ == '__main__':
    main()

像上面那样写会报错:

  File "E:\PaddleScience\test.py", line 37, in main
    bc_top = ppsci.constraint.BoundaryConstraint(
  File "E:\PaddleScience\ppsci\constraint\boundary_constraint.py", line 99, in __init__
    input = geom.sample_boundary(
  File "E:\PaddleScience\ppsci\geometry\geometry.py", line 320, in sample_boundary
    raise ValueError(
ValueError: Sample boundary points failed, please check correctness of geometry and given criteria.

其中equation = {"heat_pde": ppsci.equation.HeatTransfer(alpha=45.0, dim=3)}这个是我自己的方程


class HeatTransfer(base.PDE):
    def __init__(self, alpha, dim: int = 3):
        super().__init__()
        space_vars = self.create_symbols("x y z")[:dim]
        t = self.create_symbols("t")
        invars = space_vars + (t,)
        u = self.create_function("u", invars)

        grads = 0
        for space_var in space_vars:
            grads += u.diff(space_var, 2)

        u_t = u.diff(t, 1)
        heat_tr = u_t - alpha * grads

        self.add_equation("HeatTransfer", heat_tr)

@HydrogenSulfate
Copy link
Collaborator

HydrogenSulfate commented Mar 10, 2025

@ZBDat 在一个三维的立方体geom表面随机撒点,筛选出足够多的y=1其实是比较困难的,所以建议你的bc可以再重新构建一个y=1附近的三维的薄片(仍然可以使用Cuboid,把对角线y的上下界设置的非常接近1即可),然后把这个三维薄片作为bc的几何即可

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants