From 7b33f2741d69f1ca10311453a4acc3f460fca349 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 23 Jan 2024 05:21:28 +0000 Subject: [PATCH 1/4] add validation for hydra config --- ppsci/data/__init__.py | 2 +- ppsci/solver/eval.py | 12 +- ppsci/solver/solver.py | 21 ++- ppsci/utils/callbacks.py | 46 +++++- ppsci/utils/config.py | 303 +++++++++++++++++++++++++++++++++++++++ pyproject.toml | 1 + requirements.txt | 1 + 7 files changed, 374 insertions(+), 12 deletions(-) diff --git a/ppsci/data/__init__.py b/ppsci/data/__init__.py index 3d8aab7ef..d8f34e4ef 100644 --- a/ppsci/data/__init__.py +++ b/ppsci/data/__init__.py @@ -160,7 +160,7 @@ def build_dataloader(_dataset, cfg): num_workers=cfg.get("num_workers", _DEFAULT_NUM_WORKERS), use_shared_memory=cfg.get("use_shared_memory", False), worker_init_fn=init_fn, - # TODO: Do not enable persistent_workers' below for + # TODO: Do not enable 'persistent_workers' below for # 'IndexError: pop from empty list ...' will be raised in certain cases # persistent_workers=cfg.get("num_workers", _DEFAULT_NUM_WORKERS) > 0, ) diff --git a/ppsci/solver/eval.py b/ppsci/solver/eval.py index c28feb491..38c7380be 100644 --- a/ppsci/solver/eval.py +++ b/ppsci/solver/eval.py @@ -128,7 +128,11 @@ def _eval_by_dataset( solver.eval_time_info["batch_cost"].update(batch_cost) batch_size = next(iter(input_dict.values())).shape[0] printer.update_eval_loss(solver, loss_dict, batch_size) - if iter_id == 1 or iter_id % log_freq == 0: + if ( + iter_id == 1 + or iter_id % log_freq == 0 + or iter_id == len(_validator.data_loader) + ): printer.log_eval_info( solver, batch_size, @@ -247,7 +251,11 @@ def _eval_by_batch( solver.eval_time_info["reader_cost"].update(reader_cost) solver.eval_time_info["batch_cost"].update(batch_cost) printer.update_eval_loss(solver, loss_dict, batch_size) - if iter_id == 1 or iter_id % log_freq == 0: + if ( + iter_id == 1 + or iter_id % log_freq == 0 + or iter_id == len(_validator.data_loader) + ): printer.log_eval_info( solver, batch_size, diff --git a/ppsci/solver/solver.py b/ppsci/solver/solver.py index e4fd6e087..1c4b7fa1e 100644 --- a/ppsci/solver/solver.py +++ b/ppsci/solver/solver.py @@ -74,7 +74,7 @@ class Solver: validator (Optional[Dict[str, ppsci.validate.Validator]]): Validator dict. Defaults to None. visualizer (Optional[Dict[str, ppsci.visualize.Visualizer]]): Visualizer dict. Defaults to None. use_amp (bool, optional): Whether use AMP. Defaults to False. - amp_level (Literal["O1", "O2", "O0"], optional): AMP level. Defaults to "O0". + amp_level (Literal["O0", "O1", "O2", "OD"], optional): AMP level. Defaults to "O1". pretrained_model_path (Optional[str]): Pretrained model path. Defaults to None. checkpoint_path (Optional[str]): Checkpoint path. Defaults to None. compute_metric_by_batch (bool, optional): Whether calculate metrics after each batch during evaluation. Defaults to False. @@ -86,7 +86,7 @@ class Solver: Examples: >>> import ppsci >>> model = ppsci.arch.MLP(("x",), ("u",), 5, 20) - >>> opt = ppsci.optimizer.AdamW(1e-3)((model,)) + >>> opt = ppsci.optimizer.AdamW(1e-3)(model) >>> geom = ppsci.geometry.Rectangle((0, 0), (1, 1)) >>> pde_constraint = ppsci.constraint.InteriorConstraint( ... {"u": lambda out: out["u"]}, @@ -134,7 +134,7 @@ def __init__( validator: Optional[Dict[str, ppsci.validate.Validator]] = None, visualizer: Optional[Dict[str, ppsci.visualize.Visualizer]] = None, use_amp: bool = False, - amp_level: Literal["O1", "O2", "O0"] = "O0", + amp_level: Literal["O0", "O1", "O2", "OD"] = "O1", pretrained_model_path: Optional[str] = None, checkpoint_path: Optional[str] = None, compute_metric_by_batch: bool = False, @@ -152,7 +152,20 @@ def __init__( # set optimizer self.optimizer = optimizer # set learning rate scheduler - self.lr_scheduler = lr_scheduler + if lr_scheduler is not None: + logger.warning( + "The argument: 'lr_scheduler' now automatically retrieves from " + "'optimizer._learning_rate' when 'optimizer' is given, so it is " + "recommended to remove it from the Solver's initialization arguments." + ) + self.lr_scheduler = ( + optimizer._learning_rate + if ( + isinstance(optimizer, optim.Optimizer) + and isinstance(optimizer._learning_rate, optim.lr.LRScheduler) + ) + else None + ) # set training hyper-parameter self.epochs = epochs diff --git a/ppsci/utils/callbacks.py b/ppsci/utils/callbacks.py index 5a9a1af9e..94ed1a0ee 100644 --- a/ppsci/utils/callbacks.py +++ b/ppsci/utils/callbacks.py @@ -12,20 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.util +import inspect +import sys from os import path as osp from typing import Any from hydra.experimental.callback import Callback from omegaconf import DictConfig +from ppsci.utils import config as config_module from ppsci.utils import logger from ppsci.utils import misc +RUNTIME_EXIT_CODE = 1 # for other errors +VALIDATION_ERROR_EXIT_CODE = 2 # for invalid argument detected in config file + class InitCallback(Callback): """Callback class for: - 1. Fixing random seed to 'config.seed' - 2. Initialize logger while creating output directory(if not exist). + 1. Parse config dict from given yaml file and check its validity, complete missing items by its' default values. + 2. Fixing random seed to 'config.seed'. + 3. Initialize logger while creating output directory(if not exist). NOTE: This callback is mainly for reducing unnecessary duplicate code in each examples code when runing with hydra. @@ -52,10 +60,38 @@ class InitCallback(Callback): """ def on_job_start(self, config: DictConfig, **kwargs: Any) -> None: + # check given cfg using pre-defined pydantic schema in 'SolverConfig', error(s) will be raised + # if any checking failed at this step + if importlib.util.find_spec("pydantic") is not None: + from pydantic import ValidationError + else: + logger.error( + f"ModuleNotFoundError at {__file__}:{inspect.currentframe().f_lineno}\n" + "Please install pydantic with `pip install pydantic` when set callbacks" + " in your config yaml." + ) + sys.exit(RUNTIME_EXIT_CODE) + + # check given cfg using pre-defined pydantic schema in 'SolverConfig', + # error(s) will be printed and exit program if any checking failed at this step + try: + _model_pydantic = config_module.SolverConfig(**dict(config)) + # complete missing items with default values pre-defined in pydantic schema in + # 'SolverConfig' + full_cfg = DictConfig(_model_pydantic.model_dump()) + except ValidationError as e: + print(e) + sys.exit(VALIDATION_ERROR_EXIT_CODE) + except Exception as e: + print(e) + sys.exit(RUNTIME_EXIT_CODE) + # fix random seed for reproducibility - misc.set_random_seed(config.seed) + misc.set_random_seed(full_cfg.seed) - # create output directory + # initialze logger while creating output directory logger.init_logger( - "ppsci", osp.join(config.output_dir, f"{config.mode}.log"), "info" + "ppsci", + osp.join(full_cfg.output_dir, f"{full_cfg.mode}.log"), + full_cfg.log_level, ) diff --git a/ppsci/utils/config.py b/ppsci/utils/config.py index f6b3faccf..91cf2d52e 100644 --- a/ppsci/utils/config.py +++ b/ppsci/utils/config.py @@ -16,15 +16,318 @@ import argparse import copy +import importlib.util import os +from typing import Mapping +from typing import Optional import yaml from paddle import static +from typing_extensions import Literal from ppsci.utils import logger +from ppsci.utils import misc __all__ = ["get_config", "replace_shape_with_inputspec_", "AttrDict"] +if importlib.util.find_spec("pydantic") is not None: + from pydantic import BaseModel + from pydantic import field_validator + from pydantic_core.core_schema import FieldValidationInfo + + __all__.append("SolverConfig") + + class TrainConfig(BaseModel): + """ + Schema of training config for pydantic validation. + """ + + epochs: int = 0 + iters_per_epoch: int = 20 + update_freq: int = 1 + save_freq: int = 0 + eval_during_train: bool = False + start_eval_epoch: int = 1 + eval_freq: int = 1 + checkpoint_path: Optional[str] = None + pretrained_model_path: Optional[str] = None + + @field_validator("epochs") + def epochs_check(cls, v): + if not isinstance(v, int): + raise ValueError( + f"'epochs' should be a int or None, but got {misc.typename(v)}" + ) + elif v <= 0: + raise ValueError( + "'epochs' should be a positive integer when is type of int, " + f"but got {v}" + ) + return v + + @field_validator("iters_per_epoch") + def iters_per_epoch_check(cls, v): + if not isinstance(v, int): + raise ValueError( + "'iters_per_epoch' should be a int or None" + f", but got {misc.typename(v)}" + ) + elif v <= 0: + raise ValueError( + "'iters_per_epoch' should be a positive integer when is type of int" + f", but got {v}" + ) + return v + + @field_validator("update_freq") + def update_freq_check(cls, v): + if v is not None: + if not isinstance(v, int): + raise ValueError( + "'update_freq' should be a int or None" + f", but got {misc.typename(v)}" + ) + elif v <= 0: + raise ValueError( + "'update_freq' should be a positive integer when is type of int" + f", but got {v}" + ) + return v + + @field_validator("save_freq") + def save_freq_check(cls, v): + if v is not None: + if not isinstance(v, int): + raise ValueError( + "'save_freq' should be a int or None" + f", but got {misc.typename(v)}" + ) + elif v < 0: + raise ValueError( + "'save_freq' should be a non-negtive integer when is type of int" + f", but got {v}" + ) + return v + + @field_validator("eval_during_train") + def eval_during_train_check(cls, v): + if not isinstance(v, bool): + raise ValueError( + "'eval_during_train' should be a bool" + f", but got {misc.typename(v)}" + ) + return v + + @field_validator("start_eval_epoch") + def start_eval_epoch_check(cls, v, info: FieldValidationInfo): + if info.data["eval_during_train"]: + if not isinstance(v, int): + raise ValueError( + f"'start_eval_epoch' should be a int, but got {misc.typename(v)}" + ) + if v <= 0: + raise ValueError( + f"'start_eval_epoch' should be a positive integer when " + f"'eval_during_train' is True, but got {v}" + ) + return v + + @field_validator("eval_freq") + def eval_freq_check(cls, v, info: FieldValidationInfo): + if info.data["eval_during_train"]: + if not isinstance(v, int): + raise ValueError( + f"'eval_freq' should be a int, but got {misc.typename(v)}" + ) + if v <= 0: + raise ValueError( + f"'eval_freq' should be a positive integer when " + f"'eval_during_train' is True, but got {v}" + ) + return v + + @field_validator("pretrained_model_path") + def pretrained_model_path_check(cls, v): + if v is not None and not isinstance(v, str): + raise ValueError( + "'pretrained_model_path' should be a str or None, " + f"but got {misc.typename(v)}" + ) + return v + + @field_validator("checkpoint_path") + def checkpoint_path_check(cls, v): + if v is not None and not isinstance(v, str): + raise ValueError( + "'checkpoint_path' should be a str or None, " + f"but got {misc.typename(v)}" + ) + return v + + class EvalConfig(BaseModel): + """ + Schema of evaluation config for pydantic validation. + """ + + pretrained_model_path: Optional[str] = None + eval_with_no_grad: bool = False + compute_metric_by_batch: bool = False + + @field_validator("pretrained_model_path") + def pretrained_model_path_check(cls, v): + if v is not None and not isinstance(v, str): + raise ValueError( + "'pretrained_model_path' should be a str or None, " + f"but got {misc.typename(v)}" + ) + return v + + @field_validator("eval_with_no_grad") + def eval_with_no_grad_check(cls, v): + if not isinstance(v, bool): + raise ValueError( + f"'eval_with_no_grad' should be a bool, but got {misc.typename(v)}" + ) + return v + + @field_validator("compute_metric_by_batch") + def compute_metric_by_batch_check(cls, v): + if not isinstance(v, bool): + raise ValueError( + f"'compute_metric_by_batch' should be a bool, but got {misc.typename(v)}" + ) + return v + + class SolverConfig(BaseModel): + """ + Schema of global config for pydantic validation. + """ + + # Training related config + TRAIN: Optional[TrainConfig] = None + + # Evaluation related config + EVAL: Optional[EvalConfig] = None + + # Global settings config + mode: Literal["train", "eval"] = "train" + output_dir: Optional[str] = None + log_freq: int = 20 + seed: int = 42 + use_vdl: bool = False + use_wandb: bool = False + wandb_config: Optional[Mapping] = None + device: Literal["cpu", "gpu", "xpu"] = "gpu" + use_amp: bool = False + amp_level: Literal["O0", "O1", "O2", "0D"] = "O0" + to_static: bool = False + log_level: Literal["debug", "info", "warning", "error"] = "info" + + @field_validator("mode") + def mode_check(cls, v): + if v not in ["train", "eval"]: + raise ValueError( + f"'mode' should be one of ['train', 'eval'], but got {v}" + ) + return v + + @field_validator("output_dir") + def output_dir_check(cls, v): + if v is not None and not isinstance(v, str): + raise ValueError( + "'output_dir' should be a string or None" + f"but got {misc.typename(v)}" + ) + return v + + @field_validator("log_freq") + def log_freq_check(cls, v): + if not isinstance(v, int): + raise ValueError( + "'log_freq' should be a int" f", but got {misc.typename(v)}" + ) + elif v <= 0: + raise ValueError( + "'log_freq' should be a non-negtive integer when is type of int" + f", but got {v}" + ) + return v + + @field_validator("seed") + def seed_check(cls, v): + if not isinstance(v, int): + raise ValueError(f"'seed' should be a int, but got {misc.typename(v)}") + if v < 0: + raise ValueError(f"'seed' should be a non-negtive integer, but got {v}") + return v + + @field_validator("use_vdl") + def use_vdl_check(cls, v): + if not isinstance(v, bool): + raise ValueError( + f"'use_vdl' should be a bool, but got {misc.typename(v)}" + ) + return v + + @field_validator("use_wandb") + def use_wandb_check(cls, v, info: FieldValidationInfo): + if not isinstance(v, bool): + raise ValueError( + f"'use_wandb' should be a bool, but got {misc.typename(v)}" + ) + if not isinstance(info.data["wandb_config"], dict): + raise ValueError( + f"'wandb_config' should be a dict when 'use_wandb' is True, " + f"but got {misc.typename(info.data['wandb_config'])}" + ) + return v + + @field_validator("device") + def device_check(cls, v): + if not isinstance(v, str): + raise ValueError( + f"'device' should be a str, but got {misc.typename(v)}" + ) + if v not in ["cpu", "gpu"]: + raise ValueError( + f"'device' should be one of ['cpu', 'gpu'], but got {v}" + ) + return v + + @field_validator("use_amp") + def use_amp_check(cls, v): + if not isinstance(v, bool): + raise ValueError( + f"'use_amp' should be a bool, but got {misc.typename(v)}" + ) + return v + + @field_validator("amp_level") + def amp_level_check(cls, v): + v = v.upper() + if v not in ["O0", "O1", "O2", "OD"]: + raise ValueError( + f"'amp_level' should be one of ['O0', 'O1', 'O2', 'OD'], but got {v}" + ) + return v + + @field_validator("to_static") + def to_static_check(cls, v): + if not isinstance(v, bool): + raise ValueError( + f"'to_static' should be a bool, but got {misc.typename(v)}" + ) + return v + + @field_validator("log_level") + def log_level_check(cls, v): + if v not in ["debug", "info", "warning", "error"]: + raise ValueError( + "'log_level' should be one of ['debug', 'info', 'warning', 'error']" + f", but got {v}" + ) + return v + class AttrDict(dict): def __getattr__(self, key): diff --git a/pyproject.toml b/pyproject.toml index d1fa0066e..bc87deffe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "matplotlib", "meshio==5.3.4", "numpy>=1.20.0,<=1.23.1", + "pydantic", "pyevtk", "pyvista==0.37.0", "pyyaml", diff --git a/requirements.txt b/requirements.txt index 8c26b4809..6238d42ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ imageio matplotlib meshio==5.3.4 numpy>=1.20.0,<=1.23.1 +pydantic pyevtk pyvista==0.37.0 pyyaml From 5f4c7165f4ff7e0e51825d36c05129f03771a2b1 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 23 Jan 2024 06:47:58 +0000 Subject: [PATCH 2/4] update unitest for pydantic --- ppsci/utils/callbacks.py | 4 +- test/experimental/test_gaussian_integrate.py | 14 ++++ test/experimental/test_trapezoid_integrate.py | 14 ++++ test/utils/test_config.py | 69 +++++++++++++++++++ test/utils/test_writer.py | 1 - 5 files changed, 100 insertions(+), 2 deletions(-) create mode 100644 test/utils/test_config.py diff --git a/ppsci/utils/callbacks.py b/ppsci/utils/callbacks.py index 94ed1a0ee..dbc73dd35 100644 --- a/ppsci/utils/callbacks.py +++ b/ppsci/utils/callbacks.py @@ -92,6 +92,8 @@ def on_job_start(self, config: DictConfig, **kwargs: Any) -> None: # initialze logger while creating output directory logger.init_logger( "ppsci", - osp.join(full_cfg.output_dir, f"{full_cfg.mode}.log"), + osp.join(full_cfg.output_dir, f"{full_cfg.mode}.log") + if full_cfg.output_dir + else None, full_cfg.log_level, ) diff --git a/test/experimental/test_gaussian_integrate.py b/test/experimental/test_gaussian_integrate.py index 1943274f5..ff6430fd4 100644 --- a/test/experimental/test_gaussian_integrate.py +++ b/test/experimental/test_gaussian_integrate.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Callable from typing import List diff --git a/test/experimental/test_trapezoid_integrate.py b/test/experimental/test_trapezoid_integrate.py index 442e42f98..34ded2ec8 100644 --- a/test/experimental/test_trapezoid_integrate.py +++ b/test/experimental/test_trapezoid_integrate.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Callable import numpy as np diff --git a/test/utils/test_config.py b/test/utils/test_config.py new file mode 100644 index 000000000..5f650685c --- /dev/null +++ b/test/utils/test_config.py @@ -0,0 +1,69 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hydra +import paddle +import pytest +from omegaconf import DictConfig + +paddle.seed(1024) + + +@pytest.mark.parametrize( + "epochs,mode,seed", + [ + (-1, "train", 1024), + (20, "wrong_mode", 1024), + (10, "eval", -1), + ], +) +def test_invalid_epochs( + epochs, + mode, + seed, +): + @hydra.main(version_base=None, config_path="./", config_name="test_config.yaml") + def main(cfg: DictConfig): + pass + + # sys.exit will be called when validation error in pydantic, so there we use + # SystemExit instead of other type of errors. + with pytest.raises(SystemExit): + cfg_dict = dict( + { + "TRAIN": { + "epochs": epochs, + }, + "mode": mode, + "seed": seed, + "hydra": { + "callbacks": { + "init_callback": { + "_target_": "ppsci.utils.callbacks.InitCallback" + } + } + }, + } + ) + # print(cfg_dict) + import yaml + + with open("test_config.yaml", "w") as f: + yaml.dump(dict(cfg_dict), f) + + main() + + +if __name__ == "__main__": + pytest.main() diff --git a/test/utils/test_writer.py b/test/utils/test_writer.py index 7bd56a7c3..6e960bee2 100644 --- a/test/utils/test_writer.py +++ b/test/utils/test_writer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import numpy as np import pytest From 4adf21f6e1a9e27b6ee195d0379de08ec5671c94 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 23 Jan 2024 09:45:28 +0000 Subject: [PATCH 3/4] fix for OptimizerList --- ppsci/solver/solver.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ppsci/solver/solver.py b/ppsci/solver/solver.py index 1c4b7fa1e..7e57ad862 100644 --- a/ppsci/solver/solver.py +++ b/ppsci/solver/solver.py @@ -166,6 +166,14 @@ def __init__( ) else None ) + if isinstance(self.optimizer, ppsci.optimizer.OptimizerList): + self.lr_scheduler = ppsci.optimizer.lr_scheduler.SchedulerList( + tuple( + opt._learning_rate + for opt in self.optimizer + if isinstance(opt._learning_rate, optim.lr.LRScheduler) + ) + ) # set training hyper-parameter self.epochs = epochs From 83192b01b67c4eff176f3db77d0cc6a4e2a461af Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 23 Jan 2024 11:21:13 +0000 Subject: [PATCH 4/4] fix --- ppsci/optimizer/lr_scheduler.py | 7 ++----- ppsci/optimizer/optimizer.py | 3 +++ ppsci/utils/config.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ppsci/optimizer/lr_scheduler.py b/ppsci/optimizer/lr_scheduler.py index 697aac48c..1f1dd7576 100644 --- a/ppsci/optimizer/lr_scheduler.py +++ b/ppsci/optimizer/lr_scheduler.py @@ -739,7 +739,7 @@ class SchedulerList: """SchedulerList which wrap more than one scheduler. Args: scheduler_list (Tuple[lr.LRScheduler, ...]): Schedulers listed in a tuple. - by_epoch (bool, optional): Learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False. + Examples: >>> import ppsci >>> sch1 = ppsci.optimizer.lr_scheduler.Linear(10, 2, 0.001)() @@ -747,12 +747,9 @@ class SchedulerList: >>> sch = ppsci.optimizer.lr_scheduler.SchedulerList((sch1, sch2)) """ - def __init__( - self, scheduler_list: Tuple[lr.LRScheduler, ...], by_epoch: bool = False - ): + def __init__(self, scheduler_list: Tuple[lr.LRScheduler, ...]): super().__init__() self._sch_list = scheduler_list - self.by_epoch = by_epoch def step(self): for sch in self._sch_list: diff --git a/ppsci/optimizer/optimizer.py b/ppsci/optimizer/optimizer.py index bb2268b31..674396a90 100644 --- a/ppsci/optimizer/optimizer.py +++ b/ppsci/optimizer/optimizer.py @@ -525,3 +525,6 @@ def __getitem__(self, idx): def __setitem__(self, idx, opt): raise NotImplementedError("Can not modify any item in OptimizerList.") + + def __iter__(self): + yield from iter(self._opt_list) diff --git a/ppsci/utils/config.py b/ppsci/utils/config.py index 91cf2d52e..de857db05 100644 --- a/ppsci/utils/config.py +++ b/ppsci/utils/config.py @@ -219,7 +219,7 @@ class SolverConfig(BaseModel): wandb_config: Optional[Mapping] = None device: Literal["cpu", "gpu", "xpu"] = "gpu" use_amp: bool = False - amp_level: Literal["O0", "O1", "O2", "0D"] = "O0" + amp_level: Literal["O0", "O1", "O2", "OD"] = "O1" to_static: bool = False log_level: Literal["debug", "info", "warning", "error"] = "info" @@ -244,7 +244,7 @@ def output_dir_check(cls, v): def log_freq_check(cls, v): if not isinstance(v, int): raise ValueError( - "'log_freq' should be a int" f", but got {misc.typename(v)}" + f"'log_freq' should be a int, but got {misc.typename(v)}" ) elif v <= 0: raise ValueError( @@ -277,7 +277,7 @@ def use_wandb_check(cls, v, info: FieldValidationInfo): ) if not isinstance(info.data["wandb_config"], dict): raise ValueError( - f"'wandb_config' should be a dict when 'use_wandb' is True, " + "'wandb_config' should be a dict when 'use_wandb' is True, " f"but got {misc.typename(info.data['wandb_config'])}" ) return v