Skip to content

Commit

Permalink
rename pira to pirate
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed May 22, 2024
1 parent ed4bbfa commit ba04ee2
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 40 deletions.
2 changes: 1 addition & 1 deletion docs/zh/api/arch.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
- AMGNet
- MLP
- ModifiedMLP
- PiraNet
- PirateNet
- DeepONet
- DeepPhyLSTM
- LorenzEmbedding
Expand Down
46 changes: 23 additions & 23 deletions docs/zh/examples/allen_cahn.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
python allen_cahn_piranet.py
python allen_cahn_piratenet.py
```

=== "模型评估命令"
Expand All @@ -19,13 +19,13 @@
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
python allen_cahn_piranet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piranet_pretrained.pdparams
python allen_cahn_piratenet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_pretrained.pdparams
```

=== "模型导出命令"

``` sh
python allen_cahn_piranet.py mode=export
python allen_cahn_piratenet.py mode=export
```

=== "模型推理命令"
Expand All @@ -35,12 +35,12 @@
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --output ./dataset/antiderivative_unaligned_train.npz
python allen_cahn_piranet.py mode=infer
python allen_cahn_piratenet.py mode=infer
```

| 预训练模型 | 指标 |
|:--| :--|
| [allen_cahn_piranet_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piranet_pretrained.pdparams) | L2Rel.u: 8.32403e-06 |
| [allen_cahn_piratenet_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_pretrained.pdparams) | L2Rel.u: 8.32403e-06 |

## 1. 背景简介

Expand Down Expand Up @@ -84,27 +84,27 @@ $$
### 3.1 模型构建

在 Allen-Cahn 问题中,每一个已知的坐标点 $(t, x)$ 都有对应的待求解的未知量 $(u)$,
,在这里使用 PiraNet 来表示 $(t, x)$ 到 $(u)$ 的映射函数 $f: \mathbb{R}^2 \to \mathbb{R}^1$ ,即:
,在这里使用 PirateNet 来表示 $(t, x)$ 到 $(u)$ 的映射函数 $f: \mathbb{R}^2 \to \mathbb{R}^1$ ,即:

$$
u = f(t, x)
$$

上式中 $f$ 即为 PiraNet 模型本身,用 PaddleScience 代码表示如下
上式中 $f$ 即为 PirateNet 模型本身,用 PaddleScience 代码表示如下

``` py linenums="63"
--8<--
examples/allen_cahn/allen_cahn_piranet.py:63:64
examples/allen_cahn/allen_cahn_piratenet.py:63:64
--8<--
```

为了在计算时,准确快速地访问具体变量的值,在这里指定网络模型的输入变量名是 `("t", "x")`,输出变量名是 `("u")`,这些命名与后续代码保持一致。

接着通过指定 PiraNet 的层数、神经元个数,就实例化出了一个拥有 3 个 PiraBlock,每个 PiraBlock 的隐层神经元个数为 256 的神经网络模型 `model`, 并且使用 `tanh` 作为激活函数。
接着通过指定 PirateNet 的层数、神经元个数,就实例化出了一个拥有 3 个 PiraBlock,每个 PiraBlock 的隐层神经元个数为 256 的神经网络模型 `model`, 并且使用 `tanh` 作为激活函数。

``` yaml linenums="34"
--8<--
examples/allen_cahn/conf/allen_cahn_piranet.yaml:34:40
examples/allen_cahn/conf/allen_cahn_piratenet.yaml:34:40
--8<--
```

Expand All @@ -114,7 +114,7 @@ Allen-Cahn 微分方程可以用如下代码表示:

``` py linenums="66"
--8<--
examples/allen_cahn/allen_cahn_piranet.py:66:67
examples/allen_cahn/allen_cahn_piratenet.py:66:67
--8<--
```

Expand All @@ -124,7 +124,7 @@ examples/allen_cahn/allen_cahn_piranet.py:66:67

``` py linenums="69"
--8<--
examples/allen_cahn/allen_cahn_piranet.py:69:81
examples/allen_cahn/allen_cahn_piratenet.py:69:81
--8<--
```

Expand All @@ -136,7 +136,7 @@ examples/allen_cahn/allen_cahn_piranet.py:69:81

``` py linenums="94"
--8<--
examples/allen_cahn/allen_cahn_piranet.py:94:110
examples/allen_cahn/allen_cahn_piratenet.py:94:110
--8<--
```

Expand All @@ -155,7 +155,7 @@ examples/allen_cahn/allen_cahn_piranet.py:94:110

``` yaml linenums="41"
--8<--
examples/allen_cahn/conf/allen_cahn_piranet.yaml:41:42
examples/allen_cahn/conf/allen_cahn_piratenet.yaml:41:42
--8<--
```

Expand All @@ -165,15 +165,15 @@ examples/allen_cahn/conf/allen_cahn_piranet.yaml:41:42

``` py linenums="112"
--8<--
examples/allen_cahn/allen_cahn_piranet.py:112:125
examples/allen_cahn/allen_cahn_piratenet.py:112:125
--8<--
```

在微分方程约束、初值约束构建完毕之后,以刚才的命名为关键字,封装到一个字典中,方便后续访问。

``` py linenums="126"
--8<--
examples/allen_cahn/allen_cahn_piranet.py:126:130
examples/allen_cahn/allen_cahn_piratenet.py:126:130
--8<--
```

Expand All @@ -183,7 +183,7 @@ examples/allen_cahn/allen_cahn_piranet.py:126:130

``` yaml linenums="50"
--8<--
examples/allen_cahn/conf/allen_cahn_piranet.yaml:50:63
examples/allen_cahn/conf/allen_cahn_piratenet.yaml:50:63
--8<--
```

Expand All @@ -193,7 +193,7 @@ examples/allen_cahn/conf/allen_cahn_piranet.yaml:50:63

``` py linenums="132"
--8<--
examples/allen_cahn/allen_cahn_piranet.py:132:136
examples/allen_cahn/allen_cahn_piratenet.py:132:136
--8<--
```

Expand All @@ -203,7 +203,7 @@ examples/allen_cahn/allen_cahn_piranet.py:132:136

``` py linenums="138"
--8<--
examples/allen_cahn/allen_cahn_piranet.py:138:156
examples/allen_cahn/allen_cahn_piratenet.py:138:156
--8<--
```

Expand All @@ -213,15 +213,15 @@ examples/allen_cahn/allen_cahn_piranet.py:138:156

``` py linenums="158"
--8<--
examples/allen_cahn/allen_cahn_piranet.py:158:172
examples/allen_cahn/allen_cahn_piratenet.py:158:184
--8<--
```

## 4. 完整代码

``` py linenums="1" title="allen_cahn_piranet.py"
``` py linenums="1" title="allen_cahn_piratenet.py"
--8<--
examples/allen_cahn/allen_cahn_piranet.py
examples/allen_cahn/allen_cahn_piratenet.py
--8<--
```

Expand All @@ -230,7 +230,7 @@ examples/allen_cahn/allen_cahn_piranet.py
在计算域上均匀采样出 $201\times501$ 个点,其预测结果和解析解如下图所示。

<figure markdown>
![allen_cahn_piranet.jpg](https://paddle-org.bj.bcebos.com/paddlescience/docs/AllenCahn/allen_cahn_piranet_ac.png){ loading=lazy }
![allen_cahn_piratenet.jpg](https://paddle-org.bj.bcebos.com/paddlescience/docs/AllenCahn/allen_cahn_piratenet_ac.png){ loading=lazy }
<figcaption> 左侧为 PaddleScience 预测结果,中间为解析解结果,右侧为两者的差值</figcaption>
</figure>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def plot(

def train(cfg: DictConfig):
# set model
model = ppsci.arch.PiraNet(**cfg.MODEL)
model = ppsci.arch.PirateNet(**cfg.MODEL)

# set equation
equation = {"AllenCahn": ppsci.equation.AllenCahn(0.01**2)}
Expand Down Expand Up @@ -186,7 +186,7 @@ def gen_label_batch(input_batch):

def evaluate(cfg: DictConfig):
# set model
model = ppsci.arch.PiraNet(**cfg.MODEL)
model = ppsci.arch.PirateNet(**cfg.MODEL)

data = sio.loadmat(cfg.DATA_PATH)
u_ref = data["usol"].astype(dtype) # (nt, nx)
Expand Down Expand Up @@ -235,7 +235,7 @@ def evaluate(cfg: DictConfig):

def export(cfg: DictConfig):
# set model
model = ppsci.arch.PiraNet(**cfg.MODEL)
model = ppsci.arch.PirateNet(**cfg.MODEL)

# initialize solver
solver = ppsci.solver.Solver(model, cfg=cfg)
Expand Down Expand Up @@ -271,7 +271,7 @@ def inference(cfg: DictConfig):


@hydra.main(
version_base=None, config_path="./conf", config_name="allen_cahn_piranet.yaml"
version_base=None, config_path="./conf", config_name="allen_cahn_piratenet.yaml"
)
def main(cfg: DictConfig):
if cfg.mode == "train":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ defaults:
hydra:
run:
# dynamic output directory according to running time and override name
dir: outputs_allen_cahn_piranet/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
dir: outputs_allen_cahn_piratenet/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
job:
name: ${mode} # name of logfile
chdir: false # keep current working directory unchanged
Expand Down Expand Up @@ -79,7 +79,7 @@ EVAL:

# inference settings
INFER:
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piranet_pretrained.pdparams
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_pretrained.pdparams
export_path: ./inference/allen_cahn
pdmodel_path: ${INFER.export_path}.pdmodel
pdiparams_path: ${INFER.export_path}.pdiparams
Expand Down
4 changes: 2 additions & 2 deletions ppsci/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ppsci.arch.amgnet import AMGNet # isort:skip
from ppsci.arch.mlp import MLP # isort:skip
from ppsci.arch.mlp import ModifiedMLP # isort:skip
from ppsci.arch.mlp import PiraNet # isort:skip
from ppsci.arch.mlp import PirateNet # isort:skip
from ppsci.arch.deeponet import DeepONet # isort:skip
from ppsci.arch.embedding_koopman import LorenzEmbedding # isort:skip
from ppsci.arch.embedding_koopman import RosslerEmbedding # isort:skip
Expand Down Expand Up @@ -52,7 +52,7 @@
"AMGNet",
"MLP",
"ModifiedMLP",
"PiraNet",
"PirateNet",
"DeepONet",
"DeepPhyLSTM",
"LorenzEmbedding",
Expand Down
16 changes: 8 additions & 8 deletions ppsci/arch/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,8 @@ def forward(self, x):
return y


class PiraNetBlock(nn.Layer):
r"""Basic block of PiraNet.
class PirateNetBlock(nn.Layer):
r"""Basic block of PirateNet.
$$
\begin{align*}
Expand Down Expand Up @@ -548,8 +548,8 @@ def forward(self, x, u, v):
return out


class PiraNet(base.Arch):
r"""PiraNet.
class PirateNet(base.Arch):
r"""PirateNet.
[PIRATENETS: PHYSICS-INFORMED DEEP LEARNING WITHRESIDUAL ADAPTIVE NETWORKS](https://arxiv.org/pdf/2402.00326.pdf)
Expand All @@ -564,15 +564,15 @@ class PiraNet(base.Arch):
\mathbf{g}^{(l)} &= \sigma\left(\mathbf{W}_2^{(l)} \mathbf{z}_1^{(l)}+\mathbf{b}_2^{(l)}\right) \\
\mathbf{z}_2^{(l)} &= \mathbf{g}^{(l)} \odot \mathbf{U}+\left(1-\mathbf{g}^{(l)}\right) \odot \mathbf{V} \\
\mathbf{h}^{(l)} &= \sigma\left(\mathbf{W}_3^{(l)} \mathbf{z}_2^{(l)}+\mathbf{b}_3^{(l)}\right) \\
\mathbf{x}^{(l+1)} &= \text{PiraBlock}^{(l)}\left(\mathbf{x}^{(l)}\right), l=1...L-1\\
\mathbf{x}^{(l+1)} &= \text{PirateBlock}^{(l)}\left(\mathbf{x}^{(l)}\right), l=1...L-1\\
\mathbf{u}_\theta &= \mathbf{W}^{(L+1)} \mathbf{x}^{(L)}
\end{align*}
$$
Args:
input_keys (Tuple[str, ...]): Name of input keys, such as ("x", "y", "z").
output_keys (Tuple[str, ...]): Name of output keys, such as ("u", "v", "w").
num_blocks (int): Number of PiraBlocks.
num_blocks (int): Number of PirateBlocks.
hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size.
An integer for all layers, or list of integer specify each layer's size.
activation (str, optional): Name of activation function. Defaults to "tanh".
Expand All @@ -590,7 +590,7 @@ class PiraNet(base.Arch):
Examples:
>>> import paddle
>>> import ppsci
>>> model = ppsci.arch.PiraNet(
>>> model = ppsci.arch.PirateNet(
... input_keys=("x", "y"),
... output_keys=("u", "v"),
... num_blocks=3,
Expand Down Expand Up @@ -676,7 +676,7 @@ def __init__(

for i, _size in enumerate(hidden_size):
self.blocks.append(
PiraNetBlock(
PirateNetBlock(
cur_size,
activation=activation,
random_weight=random_weight,
Expand Down

0 comments on commit ba04ee2

Please sign in to comment.