Skip to content

Commit

Permalink
[Enh] Refactor sum aggregator (PaddlePaddle#834)
Browse files Browse the repository at this point in the history
* add Sum loss aggregator

* simplify loss aggregation code in train.py and add check for AGDA and PCGrad when used with amp

* add check for using L-BFGS with use_amp=True

* Refine Relobralo

* Fix docstring of timedomain.py

* remove unnecessary code in train.py

* automatically download *.pdeqn file if available when download pretrained model

* wrap func generated by symbolic module with DDP

* fix Relobralo

* initialize loss with 0.0 instead of first loss
  • Loading branch information
HydrogenSulfate committed Apr 7, 2024
1 parent 54f8b8d commit 17eff79
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 45 deletions.
1 change: 1 addition & 0 deletions docs/zh/api/loss/mtl.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
- LossAggregator
- PCGrad
- Relobralo
- Sum
show_root_heading: true
heading_level: 3
4 changes: 1 addition & 3 deletions docs/zh/examples/viv.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
=== "模型评估命令"

``` sh
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdeqn
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams
python viv.py mode=eval EVAL.pretrained_model_path=./viv_pretrained
python viv.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams
```

| 预训练模型 | 指标 |
Expand Down
8 changes: 4 additions & 4 deletions ppsci/geometry/timedomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def random_points(
Args:
n (int): The total number of random points to generate.
random (string): Specifies the way to generate random points, default is "pseudo" , which means that a pseudo-random number generator is used.
random (str): Specifies the way to generate random points, default is "pseudo" , which means that a pseudo-random number generator is used.
criteria (Optional[Callable]): A method that filters on the generated random points, defualt is None.
Returns:
Expand Down Expand Up @@ -432,7 +432,7 @@ def random_boundary_points(
Args:
n (int): The total number of spatial-temporal points generated on a given geometry boundary.
random (string): Controls the way to generate random points. Default is "pseudo".
random (str): Controls the way to generate random points. Default is "pseudo".
criteria (Optional[Callable]): Used to filter the generated boundary points, only points that meet certain conditions are retained. Default is None.
Returns:
Expand Down Expand Up @@ -650,7 +650,7 @@ def random_initial_points(self, n: int, random: str = "pseudo"):
Args:
n (int): The total number of generated points.
random (string): Controls the way to generate random points. Default is "pseudo".
random (str): Controls the way to generate random points. Default is "pseudo".
Returns:
np.ndarray: A set of point coordinates randomly distributed on the spatial-temporal domain at the initial moment.
Expand Down Expand Up @@ -709,7 +709,7 @@ def sample_initial_interior(
Args:
n (int): The total number of interior points generated.
random (string): The method used to specify the initial point of generation. Default is "pseudo".
random (str): The method used to specify the initial point of generation. Default is "pseudo".
criteria (Optional[Callable]): Used to filter the generated interior points, only points that meet certain conditions are retained. Default is None.
evenly (bool): Indicates whether the initial points are generated evenly. Default is False.
compute_sdf_derivatives (bool): Indicates whether to calculate the derivative of signed distance function or not. Default is False.
Expand Down
2 changes: 2 additions & 0 deletions ppsci/loss/mtl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from ppsci.loss.mtl.base import LossAggregator
from ppsci.loss.mtl.pcgrad import PCGrad
from ppsci.loss.mtl.relobralo import Relobralo
from ppsci.loss.mtl.sum import Sum

__all__ = [
"AGDA",
"LossAggregator",
"PCGrad",
"Relobralo",
"Sum",
]


Expand Down
2 changes: 1 addition & 1 deletion ppsci/loss/mtl/agda.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, model: nn.Layer, M: int = 100, gamma: float = 0.999) -> None:
self.Lf_tilde_acc = 0.0
self.Lu_tilde_acc = 0.0

def __call__(self, losses, step: int = 0):
def __call__(self, losses, step: int = 0) -> "AGDA":
if len(losses) != 2:
raise ValueError(
f"Number of losses(tasks) for AGDA shoule be 2, but got {len(losses)}"
Expand Down
2 changes: 1 addition & 1 deletion ppsci/loss/mtl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, model: nn.Layer) -> None:
if not param.stop_gradient:
self.param_num += 1

def __call__(self, losses, step: int = 0):
def __call__(self, losses, step: int = 0) -> "LossAggregator":
self.losses = losses
self.loss_num = len(losses)
self.step = step
Expand Down
18 changes: 8 additions & 10 deletions ppsci/loss/mtl/relobralo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,28 @@ def __init__(
self.register_buffer("losses_prev", paddle.zeros([self.num_losses]))
self.register_buffer("lmbda", paddle.ones([self.num_losses]))

def _softmax(self, vec: paddle.Tensor) -> paddle.Tensor:
def _softmax(self, vec: "paddle.Tensor") -> "paddle.Tensor":
max_item = vec.max()
result = paddle.exp(vec - max_item) / paddle.exp(vec - max_item).sum()
return result

def _compute_bal(
self, losses_vec1: paddle.Tensor, losses_vec2: paddle.Tensor
) -> paddle.Tensor:
self, losses_vec1: "paddle.Tensor", losses_vec2: "paddle.Tensor"
) -> "paddle.Tensor":
return self.num_losses * (
self._softmax(losses_vec1 / (self.tau * losses_vec2 + self.eps))
)

def __call__(self, losses: List[paddle.Tensor], step: int = 0) -> "Relobralo":
self.step = step
def __call__(self, losses: List["paddle.Tensor"], step: int = 0) -> "paddle.Tensor":
assert len(losses) == self.num_losses, (
f"Length of given losses({len(losses)}) should be equal to "
f"num_losses({self.num_losses})."
)
self.step = step
losses_stacked = paddle.stack(losses) # [num_losses, ]

if self.step == 0:
self.loss = losses_stacked.sum()
loss = losses_stacked.sum()
with paddle.no_grad():
paddle.assign(losses_stacked.detach(), self.losses_init)
else:
Expand All @@ -110,12 +110,10 @@ def __call__(self, losses: List[paddle.Tensor], step: int = 0) -> "Relobralo":
)

# 3. compute reweighted total loss with lambda
self.loss = (losses_stacked * self.lmbda).sum()
loss = (losses_stacked * self.lmbda).sum()

# update losses_prev at the end of each step
with paddle.no_grad():
paddle.assign(losses_stacked.detach(), self.losses_prev)
return self

def backward(self) -> None:
self.loss.backward()
return loss
50 changes: 50 additions & 0 deletions ppsci/loss/mtl/sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Sequence

if TYPE_CHECKING:
import paddle

from ppsci.loss.mtl.base import LossAggregator


class Sum(LossAggregator):
r"""
**Default loss aggregator** which do simple summation for given losses as below.
$$
loss = \sum_i^N losses_i
$$
"""

def __init__(self) -> None:
self.step = 0

def __call__(
self, losses: Sequence["paddle.Tensor"], step: int = 0
) -> paddle.Tensor:
assert (
len(losses) > 0
), f"Number of given losses({len(losses)}) can not be empty."
self.step = step

loss = 0.0
for i in range(len(losses)):
loss += losses[i]

return loss
19 changes: 15 additions & 4 deletions ppsci/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ def __init__(

# choosing an appropriate training function for different optimizers
if misc.typename(self.optimizer) == "LBFGS":
if self.use_amp:
raise ValueError(
"Auto Mix Precision is not supported for L-BFGS optimizer."
)
self.train_epoch_func = ppsci.solver.train.train_LBFGS_epoch_func
if self.update_freq != 1:
self.update_freq = 1
Expand Down Expand Up @@ -398,8 +402,13 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:
jit.enable_to_static(to_static)
logger.info(f"Set to_static={to_static} for computational optimization.")

# use loss aggregator, use summation if None
self.loss_aggregator = loss_aggregator
# use loss aggregator, use Sum if None
if isinstance(loss_aggregator, (mtl.AGDA, mtl.PCGrad)) and self.use_amp:
raise ValueError(
"Auto Mix Precision do not support AGDA, PCGrad loss aggregator yet, "
"please set use_amp=False."
)
self.loss_aggregator = loss_aggregator or mtl.Sum()

# convert sympy to callable object if exist
extra_parameters = []
Expand Down Expand Up @@ -432,6 +441,10 @@ def convert_expr(
for name in container.output_expr:
if isinstance(container.output_expr[name], sp.Basic):
container.output_expr[name] = funcs[ind]
if self.world_size > 1:
container.output_expr[name] = dist_wrapper(
container.output_expr[name]
)
ind += 1

if self.constraint:
Expand Down Expand Up @@ -775,7 +788,6 @@ def export(
)
logger.message(f"ONNX model has been exported to: {export_path}.onnx")

@functools.lru_cache()
def autocast_context_manager(
self, enable: bool, level: Literal["O0", "O1", "O2", "OD"] = "O1"
) -> contextlib.AbstractContextManager:
Expand Down Expand Up @@ -820,7 +832,6 @@ def no_grad_context_manager(
)
return ctx_manager

@functools.lru_cache()
def no_sync_context_manager(
self,
enable: bool,
Expand Down
39 changes: 17 additions & 22 deletions ppsci/solver/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
f"Training iteration {solver.global_step + 1}"
) # Training iteration

total_loss = 0.0
total_batch_size = 0
reader_cost = 0.0
batch_cost = 0.0
Expand Down Expand Up @@ -106,31 +105,30 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_push("Loss aggregator")

total_loss = solver.loss_aggregator(
constraint_losses, solver.global_step
)
if solver.update_freq > 1:
total_loss = total_loss / solver.update_freq

for i, _constraint in enumerate(solver.constraint.values()):
total_loss += constraint_losses[i]
loss_dict[_constraint.name] += (
loss_dict[_constraint.name] = (
float(constraint_losses[i]) / solver.update_freq
)
if solver.update_freq > 1:
total_loss = total_loss / solver.update_freq
loss_dict["loss"] = float(total_loss)

if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_pop() # Loss aggregator

loss_dict["loss"] = float(total_loss)

# backward
if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_push("Loss backward")

if solver.loss_aggregator is None:
if solver.use_amp:
total_loss_scaled = solver.scaler.scale(total_loss)
total_loss_scaled.backward()
else:
total_loss.backward()
if solver.use_amp:
total_loss_scaled = solver.scaler.scale(total_loss)
total_loss_scaled.backward()
else:
solver.loss_aggregator(constraint_losses, solver.global_step).backward()
total_loss.backward()

if solver.nvtx_flag: # only for nsight analysis
core.nvprof_nvtx_pop() # Loss backward
Expand Down Expand Up @@ -233,7 +231,6 @@ def closure() -> paddle.Tensor:
Returns:
paddle.Tensor: Computed loss scalar.
"""
total_loss = 0
with solver.no_sync_context_manager(solver.world_size > 1, solver.model):
with solver.autocast_context_manager(solver.use_amp, solver.amp_level):
# forward for every constraint, including model and equation expression
Expand All @@ -248,20 +245,18 @@ def closure() -> paddle.Tensor:
label_dicts,
weight_dicts,
)

total_loss = solver.loss_aggregator(
constraint_losses, solver.global_step
)
# accumulate all losses
for i, _constraint in enumerate(solver.constraint.values()):
total_loss += constraint_losses[i]
loss_dict[_constraint.name] = float(constraint_losses[i])
loss_dict["loss"] = float(total_loss)

# backward
solver.optimizer.clear_grad()
if solver.loss_aggregator is None:
total_loss.backward()
else:
solver.loss_aggregator(
constraint_losses, solver.global_step
).backward()
total_loss.backward()

if solver.world_size > 1:
# fuse + allreduce manually before optimization if use DDP model
Expand Down
17 changes: 17 additions & 0 deletions ppsci/utils/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,25 @@ def load_pretrain(
... path="path/to/pretrain_model") # doctest: +SKIP
"""
if path.startswith("http"):
# download from path(url) and get its' physical path
eqn_path = path.replace(".pdparams", ".pdeq", 1)
path = download.get_weights_path_from_url(path)

# automatically download additional equation weights if avaiable
def is_url_accessible(url: str):
try:
import requests

response = requests.head(url, timeout=5)
return response.status_code == requests.codes.ok
except requests.RequestException:
return False
except Exception:
return False

if is_url_accessible(eqn_path):
download.get_weights_path_from_url(eqn_path)

# remove ".pdparams" in suffix of path for convenient
if path.endswith(".pdparams"):
path = path[:-9]
Expand Down

0 comments on commit 17eff79

Please sign in to comment.