Skip to content

Commit

Permalink
[Fix]Add logger initialation of topopt example (PaddlePaddle#704)
Browse files Browse the repository at this point in the history
* [Fix]Add logger initialation of topopt example

* update1 doc
  • Loading branch information
lijialin03 committed Dec 20, 2023
1 parent 2116200 commit fa63860
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
44 changes: 22 additions & 22 deletions docs/zh/examples/topopt.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ examples/topopt/functions.py:68:101
--8<--
```

``` py linenums="37"
``` py linenums="40"
--8<--
examples/topopt/topopt.py:37:46
examples/topopt/topopt.py:40:48
--8<--
```

Expand All @@ -79,9 +79,9 @@ examples/topopt/topopt.py:37:46
经过 SIMP 的 $N_{0}$ 次初始迭代步骤得到的图像 $I$ 可以看作是模糊了的最终结构。由于最终的优化解给出的图像 $I^*$ 并不包含中间过程的信息,因此 $I^*$ 可以被解释为图像 $I$ 的掩码。于是 $I \rightarrow I^*$ 这一优化过程可以看作是二分类的图像分割或者前景-背景分割过程,因此构建 Unet 模型进行预测,具体网络结构如图所示:
![Unet](https://ai-studio-static-online.cdn.bcebos.com/7a0e54df9c9d48e5841423546e851f620e73ea917f9e4258aefc47c498bba85e)

``` py linenums="87"
``` py linenums="90"
--8<--
examples/topopt/topopt.py:87:89
examples/topopt/topopt.py:90:91
--8<--
```

Expand All @@ -98,9 +98,9 @@ examples/topopt/conf/topopt.yaml:49:54
--8<--
```

``` py linenums="33"
``` py linenums="36"
--8<--
examples/topopt/topopt.py:33:36
examples/topopt/topopt.py:36:38
--8<--
```

Expand All @@ -120,9 +120,9 @@ examples/topopt/functions.py:102:133

在本案例中,我们采用监督学习方式进行训练,所以使用监督约束 `SupervisedConstraint`,代码如下:

``` py linenums="47"
``` py linenums="50"
--8<--
examples/topopt/topopt.py:47:73
examples/topopt/topopt.py:50:75
--8<--
```

Expand Down Expand Up @@ -152,19 +152,19 @@ examples/topopt/functions.py:23:67
--8<--
```

``` py linenums="77"
``` py linenums="80"
--8<--
examples/topopt/topopt.py:77:79
examples/topopt/topopt.py:80:81
--8<--
```

### 3.7 优化器构建

训练过程会调用优化器来更新模型参数,此处选择 `Adam` 优化器。

``` py linenums="90"
``` py linenums="93"
--8<--
examples/topopt/topopt.py:90:94
examples/topopt/topopt.py:93:96
--8<--
```

Expand All @@ -191,9 +191,9 @@ $$

loss 构建代码如下:

``` py linenums="260"
``` py linenums="263"
--8<--
examples/topopt/topopt.py:260:273
examples/topopt/topopt.py:263:274
--8<--
```

Expand All @@ -211,9 +211,9 @@ $$
其中 $n_{0} = w_{00} + w_{01}$ , $n_{1} = w_{10} + w_{11}$ ,$w_{tp}$ 表示实际是 $t$ 类且被预测为 $p$ 类的像素点的数量
metric 构建代码如下:

``` py linenums="274"
``` py linenums="277"
--8<--
examples/topopt/topopt.py:274:316
examples/topopt/topopt.py:277:317
--8<--
```

Expand All @@ -230,9 +230,9 @@ examples/topopt/conf/topopt.yaml:29:31

训练代码如下:

``` py linenums="74"
``` py linenums="77"
--8<--
examples/topopt/topopt.py:74:110
examples/topopt/topopt.py:77:111
--8<--
```

Expand All @@ -245,9 +245,9 @@ examples/topopt/topopt.py:74:110
#### 3.10.1 评估器构建
为应用 PaddleScience API,此处在每一次评估时构建一个评估器 SupervisedValidator 进行评估:

``` py linenums="215"
``` py linenums="218"
--8<--
examples/topopt/topopt.py:215:242
examples/topopt/topopt.py:218:245
--8<--
```

Expand All @@ -258,9 +258,9 @@ examples/topopt/topopt.py:215:242

使用 `ppsci.utils.misc.plot_curve()` 方法直接绘制 Binary Accuracy 和 IoU 的结果:

``` py linenums="182"
``` py linenums="185"
--8<--
examples/topopt/topopt.py:182:192
examples/topopt/topopt.py:185:193
--8<--
```

Expand Down
9 changes: 6 additions & 3 deletions examples/topopt/topopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@


def train(cfg: DictConfig):
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(cfg.seed)
# initialize logger
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info")

# 4 training cases parameters
LEARNING_RATE = cfg.TRAIN.learning_rate / (1 + cfg.TRAIN.epochs // 15)
Expand Down Expand Up @@ -110,10 +113,10 @@ def train(cfg: DictConfig):

# evaluate 4 models
def evaluate(cfg: DictConfig):
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(cfg.seed)

# initialize logger for evaluation
logger.init_logger("ppsci", osp.join(cfg.output_dir, "results.log"), "info")
# initialize logger
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info")

# fixed iteration stop times for evaluation
iterations_stop_times = range(5, 85, 5)
Expand Down

0 comments on commit fa63860

Please sign in to comment.