Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fea] Add PirateNet and update allen_cahn document #907

Merged
merged 5 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/zh/api/arch.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- AMGNet
- MLP
- ModifiedMLP
- PirateNet
- DeepONet
- DeepPhyLSTM
- LorenzEmbedding
Expand Down
71 changes: 42 additions & 29 deletions docs/zh/examples/allen_cahn.md
Original file line number Diff line number Diff line change
@@ -1,34 +1,46 @@
# Allen-Cahn

<!-- <a href="TODO" class="md-button md-button--primary" style>AI Studio快速体验</a> -->
<a href="https://aistudio.baidu.com/projectdetail/7927786" class="md-button md-button--primary" style>AI Studio快速体验</a>

=== "模型训练命令"

``` sh
python allen_cahn_default.py
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
python allen_cahn_piratenet.py
```

=== "模型评估命令"

``` sh
python allen_cahn_default.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/allen_cahn/allen_cahn_default_pretrained.pdparams
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
python allen_cahn_piratenet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_pretrained.pdparams
```

=== "模型导出命令"

``` sh
python allen_cahn_default.py mode=export
python allen_cahn_piratenet.py mode=export
```

=== "模型推理命令"

``` sh
python allen_cahn_default.py mode=infer
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
python allen_cahn_piratenet.py mode=infer
```

| 预训练模型 | 指标 |
|:--| :--|
| [allen_cahn_default_pretrained.pdparams](TODO) | TODO |
| [allen_cahn_piratenet_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_pretrained.pdparams) | L2Rel.u: 8.32403e-06 |

## 1. 背景简介

Expand Down Expand Up @@ -72,27 +84,27 @@ $$
### 3.1 模型构建

在 Allen-Cahn 问题中,每一个已知的坐标点 $(t, x)$ 都有对应的待求解的未知量 $(u)$,
在这里使用比较简单的 MLP(Multilayer Perceptron, 多层感知机) 来表示 $(t, x)$ 到 $(u)$ 的映射函数 $f: \mathbb{R}^2 \to \mathbb{R}^1$ ,即:
在这里使用 PirateNet 来表示 $(t, x)$ 到 $(u)$ 的映射函数 $f: \mathbb{R}^2 \to \mathbb{R}^1$ ,即:

$$
u = f(t, x)
$$

上式中 $f$ 即为 MLP 模型本身,用 PaddleScience 代码表示如下
上式中 $f$ 即为 PirateNet 模型本身,用 PaddleScience 代码表示如下

``` py linenums="63"
--8<--
examples/allen_cahn/allen_cahn_default.py:63:64
examples/allen_cahn/allen_cahn_piratenet.py:63:64
--8<--
```

为了在计算时,准确快速地访问具体变量的值,在这里指定网络模型的输入变量名是 `("t", "x")`,输出变量名是 `("u")`,这些命名与后续代码保持一致。

接着通过指定 MLP 的层数、神经元个数,就实例化出了一个拥有 4 层隐藏神经元,每层神经元数为 256 的神经网络模型 `model`,使用 `tanh` 作为激活函数。
接着通过指定 PirateNet 的层数、神经元个数,就实例化出了一个拥有 3 个 PiraBlock,每个 PiraBlock 的隐层神经元个数为 256 的神经网络模型 `model`, 并且使用 `tanh` 作为激活函数。

``` yaml linenums="35"
``` yaml linenums="34"
--8<--
examples/allen_cahn/conf/allen_cahn_default.yaml:35:41
examples/allen_cahn/conf/allen_cahn_piratenet.yaml:34:40
--8<--
```

Expand All @@ -102,7 +114,7 @@ Allen-Cahn 微分方程可以用如下代码表示:

``` py linenums="66"
--8<--
examples/allen_cahn/allen_cahn_default.py:66:67
examples/allen_cahn/allen_cahn_piratenet.py:66:67
--8<--
```

Expand All @@ -112,7 +124,7 @@ examples/allen_cahn/allen_cahn_default.py:66:67

``` py linenums="69"
--8<--
examples/allen_cahn/allen_cahn_default.py:69:81
examples/allen_cahn/allen_cahn_piratenet.py:69:81
--8<--
```

Expand All @@ -124,7 +136,7 @@ examples/allen_cahn/allen_cahn_default.py:69:81

``` py linenums="94"
--8<--
examples/allen_cahn/allen_cahn_default.py:94:110
examples/allen_cahn/allen_cahn_piratenet.py:94:110
--8<--
```

Expand All @@ -139,11 +151,11 @@ examples/allen_cahn/allen_cahn_default.py:94:110
#### 3.4.2 周期边界约束

此处我们采用 hard-constraint 的方式,在神经网络模型中,对输入数据使用cos、sin等周期函数进行周期化,从而让$u_{\theta}$在数学上直接满足方程的周期性质。
根据方程可得函数$u(t, x)$在$x$轴上的周期为2,因此将该周期设置到模型配置里即可。
根据方程可得函数$u(t, x)$在$x$轴上的周期为 2,因此将该周期设置到模型配置里即可。

``` yaml linenums="35"
``` yaml linenums="41"
--8<--
examples/allen_cahn/conf/allen_cahn_default.yaml:35:43
examples/allen_cahn/conf/allen_cahn_piratenet.yaml:41:42
--8<--
```

Expand All @@ -153,25 +165,25 @@ examples/allen_cahn/conf/allen_cahn_default.yaml:35:43

``` py linenums="112"
--8<--
examples/allen_cahn/allen_cahn_default.py:112:125
examples/allen_cahn/allen_cahn_piratenet.py:112:125
--8<--
```

在微分方程约束、初值约束构建完毕之后,以刚才的命名为关键字,封装到一个字典中,方便后续访问。

``` py linenums="126"
--8<--
examples/allen_cahn/allen_cahn_default.py:126:130
examples/allen_cahn/allen_cahn_piratenet.py:126:130
--8<--
```

### 3.5 超参数设定

接下来需要指定训练轮数和学习率,此处按实验经验,使用 200 轮训练轮数,0.001 的初始学习率。
接下来需要指定训练轮数和学习率,此处按实验经验,使用 300 轮训练轮数,0.001 的初始学习率。

``` yaml linenums="51"
``` yaml linenums="50"
--8<--
examples/allen_cahn/conf/allen_cahn_default.yaml:51:73
examples/allen_cahn/conf/allen_cahn_piratenet.yaml:50:63
--8<--
```

Expand All @@ -181,7 +193,7 @@ examples/allen_cahn/conf/allen_cahn_default.yaml:51:73

``` py linenums="132"
--8<--
examples/allen_cahn/allen_cahn_default.py:132:136
examples/allen_cahn/allen_cahn_piratenet.py:132:136
--8<--
```

Expand All @@ -191,7 +203,7 @@ examples/allen_cahn/allen_cahn_default.py:132:136

``` py linenums="138"
--8<--
examples/allen_cahn/allen_cahn_default.py:138:156
examples/allen_cahn/allen_cahn_piratenet.py:138:156
--8<--
```

Expand All @@ -201,15 +213,15 @@ examples/allen_cahn/allen_cahn_default.py:138:156

``` py linenums="158"
--8<--
examples/allen_cahn/allen_cahn_default.py:158:194
examples/allen_cahn/allen_cahn_piratenet.py:158:184
--8<--
```

## 4. 完整代码

``` py linenums="1" title="allen_cahn_default.py"
``` py linenums="1" title="allen_cahn_piratenet.py"
--8<--
examples/allen_cahn/allen_cahn_default.py
examples/allen_cahn/allen_cahn_piratenet.py
--8<--
```

Expand All @@ -218,12 +230,13 @@ examples/allen_cahn/allen_cahn_default.py
在计算域上均匀采样出 $201\times501$ 个点,其预测结果和解析解如下图所示。

<figure markdown>
![allen_cahn_default.jpg](https://paddle-org.bj.bcebos.com/paddlescience/docs/AllenCahn/allen_cahn_default.png){ loading=lazy }
![allen_cahn_piratenet.jpg](https://paddle-org.bj.bcebos.com/paddlescience/docs/AllenCahn/allen_cahn_piratenet_ac.png){ loading=lazy }
<figcaption> 左侧为 PaddleScience 预测结果,中间为解析解结果,右侧为两者的差值</figcaption>
</figure>

可以看到对于函数$u(t, x)$,模型的预测结果和解析解的结果基本一致。

## 6. 参考资料

- [PIRATENETS: PHYSICS-INFORMED DEEP LEARNING WITHRESIDUAL ADAPTIVE NETWORKS](https://arxiv.org/pdf/2402.00326.pdf)
- [Allen-Cahn equation](https://github.com/PredictiveIntelligenceLab/jaxpi/blob/main/examples/allen_cahn/README.md)
2 changes: 1 addition & 1 deletion examples/allen_cahn/allen_cahn_causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,12 @@ def inference(cfg: DictConfig):

input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
# mapping data to cfg.INFER.output_keys
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
}
u_pred = output_dict["u"].reshape([len(t_star), len(x_star)])
# mapping data to cfg.INFER.output_keys

plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)

Expand Down
23 changes: 4 additions & 19 deletions examples/allen_cahn/allen_cahn_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,9 @@ def gen_label_batch(input_batch):
solver = ppsci.solver.Solver(
model,
constraint,
cfg.output_dir,
optimizer,
epochs=cfg.TRAIN.epochs,
iters_per_epoch=cfg.TRAIN.iters_per_epoch,
save_freq=cfg.TRAIN.save_freq,
log_freq=cfg.log_freq,
eval_during_train=True,
eval_freq=cfg.TRAIN.eval_freq,
optimizer=optimizer,
equation=equation,
validator=validator,
pretrained_model_path=cfg.TRAIN.pretrained_model_path,
checkpoint_path=cfg.TRAIN.checkpoint_path,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
loss_aggregator=mtl.GradNorm(
model,
len(constraint),
Expand Down Expand Up @@ -226,11 +216,9 @@ def evaluate(cfg: DictConfig):
# initialize solver
solver = ppsci.solver.Solver(
model,
output_dir=cfg.output_dir,
log_freq=cfg.log_freq,
validator=validator,
pretrained_model_path=cfg.EVAL.pretrained_model_path,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
cfg=cfg,
)

# evaluate after finished training
Expand All @@ -250,10 +238,7 @@ def export(cfg: DictConfig):
model = ppsci.arch.MLP(**cfg.MODEL)

# initialize solver
solver = ppsci.solver.Solver(
model,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)
solver = ppsci.solver.Solver(model, cfg=cfg)
# export model
from paddle.static import InputSpec

Expand All @@ -275,12 +260,12 @@ def inference(cfg: DictConfig):

input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
# mapping data to cfg.INFER.output_keys
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
}
u_pred = output_dict["u"].reshape([len(t_star), len(x_star)])
# mapping data to cfg.INFER.output_keys

plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)

Expand Down
Loading