Skip to content

Commit

Permalink
[Fea] Add PirateNet and update allen_cahn document (PaddlePaddle#907)
Browse files Browse the repository at this point in the history
* Add PiraNet and update allen_cahn document

* fix example code for mlp.py

* rename pira to pirate

* update AIStudio link for allen cahn
  • Loading branch information
HydrogenSulfate authored May 23, 2024
1 parent fceb3f2 commit bfee4f5
Show file tree
Hide file tree
Showing 10 changed files with 711 additions and 60 deletions.
1 change: 1 addition & 0 deletions docs/zh/api/arch.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- AMGNet
- MLP
- ModifiedMLP
- PirateNet
- DeepONet
- DeepPhyLSTM
- LorenzEmbedding
Expand Down
71 changes: 42 additions & 29 deletions docs/zh/examples/allen_cahn.md
Original file line number Diff line number Diff line change
@@ -1,34 +1,46 @@
# Allen-Cahn

<!-- <a href="TODO" class="md-button md-button--primary" style>AI Studio快速体验</a> -->
<a href="https://aistudio.baidu.com/projectdetail/7927786" class="md-button md-button--primary" style>AI Studio快速体验</a>

=== "模型训练命令"

``` sh
python allen_cahn_default.py
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
python allen_cahn_piratenet.py
```

=== "模型评估命令"

``` sh
python allen_cahn_default.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/allen_cahn/allen_cahn_default_pretrained.pdparams
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
python allen_cahn_piratenet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_pretrained.pdparams
```

=== "模型导出命令"

``` sh
python allen_cahn_default.py mode=export
python allen_cahn_piratenet.py mode=export
```

=== "模型推理命令"

``` sh
python allen_cahn_default.py mode=infer
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
python allen_cahn_piratenet.py mode=infer
```

| 预训练模型 | 指标 |
|:--| :--|
| [allen_cahn_default_pretrained.pdparams](TODO) | TODO |
| [allen_cahn_piratenet_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_pretrained.pdparams) | L2Rel.u: 8.32403e-06 |

## 1. 背景简介

Expand Down Expand Up @@ -72,27 +84,27 @@ $$
### 3.1 模型构建

在 Allen-Cahn 问题中,每一个已知的坐标点 $(t, x)$ 都有对应的待求解的未知量 $(u)$,
在这里使用比较简单的 MLP(Multilayer Perceptron, 多层感知机) 来表示 $(t, x)$ 到 $(u)$ 的映射函数 $f: \mathbb{R}^2 \to \mathbb{R}^1$ ,即:
在这里使用 PirateNet 来表示 $(t, x)$ 到 $(u)$ 的映射函数 $f: \mathbb{R}^2 \to \mathbb{R}^1$ ,即:

$$
u = f(t, x)
$$

上式中 $f$ 即为 MLP 模型本身,用 PaddleScience 代码表示如下
上式中 $f$ 即为 PirateNet 模型本身,用 PaddleScience 代码表示如下

``` py linenums="63"
--8<--
examples/allen_cahn/allen_cahn_default.py:63:64
examples/allen_cahn/allen_cahn_piratenet.py:63:64
--8<--
```

为了在计算时,准确快速地访问具体变量的值,在这里指定网络模型的输入变量名是 `("t", "x")`,输出变量名是 `("u")`,这些命名与后续代码保持一致。

接着通过指定 MLP 的层数、神经元个数,就实例化出了一个拥有 4 层隐藏神经元,每层神经元数为 256 的神经网络模型 `model`使用 `tanh` 作为激活函数。
接着通过指定 PirateNet 的层数、神经元个数,就实例化出了一个拥有 3 个 PiraBlock,每个 PiraBlock 的隐层神经元个数为 256 的神经网络模型 `model` 并且使用 `tanh` 作为激活函数。

``` yaml linenums="35"
``` yaml linenums="34"
--8<--
examples/allen_cahn/conf/allen_cahn_default.yaml:35:41
examples/allen_cahn/conf/allen_cahn_piratenet.yaml:34:40
--8<--
```

Expand All @@ -102,7 +114,7 @@ Allen-Cahn 微分方程可以用如下代码表示:

``` py linenums="66"
--8<--
examples/allen_cahn/allen_cahn_default.py:66:67
examples/allen_cahn/allen_cahn_piratenet.py:66:67
--8<--
```

Expand All @@ -112,7 +124,7 @@ examples/allen_cahn/allen_cahn_default.py:66:67

``` py linenums="69"
--8<--
examples/allen_cahn/allen_cahn_default.py:69:81
examples/allen_cahn/allen_cahn_piratenet.py:69:81
--8<--
```

Expand All @@ -124,7 +136,7 @@ examples/allen_cahn/allen_cahn_default.py:69:81

``` py linenums="94"
--8<--
examples/allen_cahn/allen_cahn_default.py:94:110
examples/allen_cahn/allen_cahn_piratenet.py:94:110
--8<--
```

Expand All @@ -139,11 +151,11 @@ examples/allen_cahn/allen_cahn_default.py:94:110
#### 3.4.2 周期边界约束

此处我们采用 hard-constraint 的方式,在神经网络模型中,对输入数据使用cos、sin等周期函数进行周期化,从而让$u_{\theta}$在数学上直接满足方程的周期性质。
根据方程可得函数$u(t, x)$在$x$轴上的周期为2,因此将该周期设置到模型配置里即可。
根据方程可得函数$u(t, x)$在$x$轴上的周期为 2,因此将该周期设置到模型配置里即可。

``` yaml linenums="35"
``` yaml linenums="41"
--8<--
examples/allen_cahn/conf/allen_cahn_default.yaml:35:43
examples/allen_cahn/conf/allen_cahn_piratenet.yaml:41:42
--8<--
```

Expand All @@ -153,25 +165,25 @@ examples/allen_cahn/conf/allen_cahn_default.yaml:35:43

``` py linenums="112"
--8<--
examples/allen_cahn/allen_cahn_default.py:112:125
examples/allen_cahn/allen_cahn_piratenet.py:112:125
--8<--
```

在微分方程约束、初值约束构建完毕之后,以刚才的命名为关键字,封装到一个字典中,方便后续访问。

``` py linenums="126"
--8<--
examples/allen_cahn/allen_cahn_default.py:126:130
examples/allen_cahn/allen_cahn_piratenet.py:126:130
--8<--
```

### 3.5 超参数设定

接下来需要指定训练轮数和学习率,此处按实验经验,使用 200 轮训练轮数,0.001 的初始学习率。
接下来需要指定训练轮数和学习率,此处按实验经验,使用 300 轮训练轮数,0.001 的初始学习率。

``` yaml linenums="51"
``` yaml linenums="50"
--8<--
examples/allen_cahn/conf/allen_cahn_default.yaml:51:73
examples/allen_cahn/conf/allen_cahn_piratenet.yaml:50:63
--8<--
```

Expand All @@ -181,7 +193,7 @@ examples/allen_cahn/conf/allen_cahn_default.yaml:51:73

``` py linenums="132"
--8<--
examples/allen_cahn/allen_cahn_default.py:132:136
examples/allen_cahn/allen_cahn_piratenet.py:132:136
--8<--
```

Expand All @@ -191,7 +203,7 @@ examples/allen_cahn/allen_cahn_default.py:132:136

``` py linenums="138"
--8<--
examples/allen_cahn/allen_cahn_default.py:138:156
examples/allen_cahn/allen_cahn_piratenet.py:138:156
--8<--
```

Expand All @@ -201,15 +213,15 @@ examples/allen_cahn/allen_cahn_default.py:138:156

``` py linenums="158"
--8<--
examples/allen_cahn/allen_cahn_default.py:158:194
examples/allen_cahn/allen_cahn_piratenet.py:158:184
--8<--
```

## 4. 完整代码

``` py linenums="1" title="allen_cahn_default.py"
``` py linenums="1" title="allen_cahn_piratenet.py"
--8<--
examples/allen_cahn/allen_cahn_default.py
examples/allen_cahn/allen_cahn_piratenet.py
--8<--
```

Expand All @@ -218,12 +230,13 @@ examples/allen_cahn/allen_cahn_default.py
在计算域上均匀采样出 $201\times501$ 个点,其预测结果和解析解如下图所示。

<figure markdown>
![allen_cahn_default.jpg](https://paddle-org.bj.bcebos.com/paddlescience/docs/AllenCahn/allen_cahn_default.png){ loading=lazy }
![allen_cahn_piratenet.jpg](https://paddle-org.bj.bcebos.com/paddlescience/docs/AllenCahn/allen_cahn_piratenet_ac.png){ loading=lazy }
<figcaption> 左侧为 PaddleScience 预测结果,中间为解析解结果,右侧为两者的差值</figcaption>
</figure>

可以看到对于函数$u(t, x)$,模型的预测结果和解析解的结果基本一致。

## 6. 参考资料

- [PIRATENETS: PHYSICS-INFORMED DEEP LEARNING WITHRESIDUAL ADAPTIVE NETWORKS](https://arxiv.org/pdf/2402.00326.pdf)
- [Allen-Cahn equation](https://github.com/PredictiveIntelligenceLab/jaxpi/blob/main/examples/allen_cahn/README.md)
2 changes: 1 addition & 1 deletion examples/allen_cahn/allen_cahn_causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,12 @@ def inference(cfg: DictConfig):

input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
# mapping data to cfg.INFER.output_keys
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
}
u_pred = output_dict["u"].reshape([len(t_star), len(x_star)])
# mapping data to cfg.INFER.output_keys

plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)

Expand Down
23 changes: 4 additions & 19 deletions examples/allen_cahn/allen_cahn_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,9 @@ def gen_label_batch(input_batch):
solver = ppsci.solver.Solver(
model,
constraint,
cfg.output_dir,
optimizer,
epochs=cfg.TRAIN.epochs,
iters_per_epoch=cfg.TRAIN.iters_per_epoch,
save_freq=cfg.TRAIN.save_freq,
log_freq=cfg.log_freq,
eval_during_train=True,
eval_freq=cfg.TRAIN.eval_freq,
optimizer=optimizer,
equation=equation,
validator=validator,
pretrained_model_path=cfg.TRAIN.pretrained_model_path,
checkpoint_path=cfg.TRAIN.checkpoint_path,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
loss_aggregator=mtl.GradNorm(
model,
len(constraint),
Expand Down Expand Up @@ -226,11 +216,9 @@ def evaluate(cfg: DictConfig):
# initialize solver
solver = ppsci.solver.Solver(
model,
output_dir=cfg.output_dir,
log_freq=cfg.log_freq,
validator=validator,
pretrained_model_path=cfg.EVAL.pretrained_model_path,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
cfg=cfg,
)

# evaluate after finished training
Expand All @@ -250,10 +238,7 @@ def export(cfg: DictConfig):
model = ppsci.arch.MLP(**cfg.MODEL)

# initialize solver
solver = ppsci.solver.Solver(
model,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)
solver = ppsci.solver.Solver(model, cfg=cfg)
# export model
from paddle.static import InputSpec

Expand All @@ -275,12 +260,12 @@ def inference(cfg: DictConfig):

input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
# mapping data to cfg.INFER.output_keys
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
}
u_pred = output_dict["u"].reshape([len(t_star), len(x_star)])
# mapping data to cfg.INFER.output_keys

plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)

Expand Down
Loading

0 comments on commit bfee4f5

Please sign in to comment.