Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enh] add validation for hydra config #769

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ppsci/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def build_dataloader(_dataset, cfg):
num_workers=cfg.get("num_workers", _DEFAULT_NUM_WORKERS),
use_shared_memory=cfg.get("use_shared_memory", False),
worker_init_fn=init_fn,
# TODO: Do not enable persistent_workers' below for
# TODO: Do not enable 'persistent_workers' below for
# 'IndexError: pop from empty list ...' will be raised in certain cases
# persistent_workers=cfg.get("num_workers", _DEFAULT_NUM_WORKERS) > 0,
)
Expand Down
12 changes: 10 additions & 2 deletions ppsci/solver/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def _eval_by_dataset(
solver.eval_time_info["batch_cost"].update(batch_cost)
batch_size = next(iter(input_dict.values())).shape[0]
printer.update_eval_loss(solver, loss_dict, batch_size)
if iter_id == 1 or iter_id % log_freq == 0:
if (
iter_id == 1
or iter_id % log_freq == 0
or iter_id == len(_validator.data_loader)
):
printer.log_eval_info(
solver,
batch_size,
Expand Down Expand Up @@ -247,7 +251,11 @@ def _eval_by_batch(
solver.eval_time_info["reader_cost"].update(reader_cost)
solver.eval_time_info["batch_cost"].update(batch_cost)
printer.update_eval_loss(solver, loss_dict, batch_size)
if iter_id == 1 or iter_id % log_freq == 0:
if (
iter_id == 1
or iter_id % log_freq == 0
or iter_id == len(_validator.data_loader)
):
printer.log_eval_info(
solver,
batch_size,
Expand Down
29 changes: 25 additions & 4 deletions ppsci/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class Solver:
validator (Optional[Dict[str, ppsci.validate.Validator]]): Validator dict. Defaults to None.
visualizer (Optional[Dict[str, ppsci.visualize.Visualizer]]): Visualizer dict. Defaults to None.
use_amp (bool, optional): Whether use AMP. Defaults to False.
amp_level (Literal["O1", "O2", "O0"], optional): AMP level. Defaults to "O0".
amp_level (Literal["O0", "O1", "O2", "OD"], optional): AMP level. Defaults to "O1".
pretrained_model_path (Optional[str]): Pretrained model path. Defaults to None.
checkpoint_path (Optional[str]): Checkpoint path. Defaults to None.
compute_metric_by_batch (bool, optional): Whether calculate metrics after each batch during evaluation. Defaults to False.
Expand All @@ -86,7 +86,7 @@ class Solver:
Examples:
>>> import ppsci
>>> model = ppsci.arch.MLP(("x",), ("u",), 5, 20)
>>> opt = ppsci.optimizer.AdamW(1e-3)((model,))
>>> opt = ppsci.optimizer.AdamW(1e-3)(model)
>>> geom = ppsci.geometry.Rectangle((0, 0), (1, 1))
>>> pde_constraint = ppsci.constraint.InteriorConstraint(
... {"u": lambda out: out["u"]},
Expand Down Expand Up @@ -134,7 +134,7 @@ def __init__(
validator: Optional[Dict[str, ppsci.validate.Validator]] = None,
visualizer: Optional[Dict[str, ppsci.visualize.Visualizer]] = None,
use_amp: bool = False,
amp_level: Literal["O1", "O2", "O0"] = "O0",
amp_level: Literal["O0", "O1", "O2", "OD"] = "O1",
pretrained_model_path: Optional[str] = None,
checkpoint_path: Optional[str] = None,
compute_metric_by_batch: bool = False,
Expand All @@ -152,7 +152,28 @@ def __init__(
# set optimizer
self.optimizer = optimizer
# set learning rate scheduler
self.lr_scheduler = lr_scheduler
if lr_scheduler is not None:
logger.warning(
"The argument: 'lr_scheduler' now automatically retrieves from "
"'optimizer._learning_rate' when 'optimizer' is given, so it is "
"recommended to remove it from the Solver's initialization arguments."
)
self.lr_scheduler = (
optimizer._learning_rate
if (
isinstance(optimizer, optim.Optimizer)
and isinstance(optimizer._learning_rate, optim.lr.LRScheduler)
)
else None
)
if isinstance(self.optimizer, ppsci.optimizer.OptimizerList):
self.lr_scheduler = ppsci.optimizer.lr_scheduler.SchedulerList(
tuple(
opt._learning_rate
for opt in self.optimizer
if isinstance(opt._learning_rate, optim.lr.LRScheduler)
)
)

# set training hyper-parameter
self.epochs = epochs
Expand Down
48 changes: 43 additions & 5 deletions ppsci/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.util
import inspect
import sys
from os import path as osp
from typing import Any

from hydra.experimental.callback import Callback
from omegaconf import DictConfig

from ppsci.utils import config as config_module
from ppsci.utils import logger
from ppsci.utils import misc

RUNTIME_EXIT_CODE = 1 # for other errors
VALIDATION_ERROR_EXIT_CODE = 2 # for invalid argument detected in config file


class InitCallback(Callback):
"""Callback class for:
1. Fixing random seed to 'config.seed'
2. Initialize logger while creating output directory(if not exist).
1. Parse config dict from given yaml file and check its validity, complete missing items by its' default values.
2. Fixing random seed to 'config.seed'.
3. Initialize logger while creating output directory(if not exist).

NOTE: This callback is mainly for reducing unnecessary duplicate code in each
examples code when runing with hydra.
Expand All @@ -52,10 +60,40 @@ class InitCallback(Callback):
"""

def on_job_start(self, config: DictConfig, **kwargs: Any) -> None:
# check given cfg using pre-defined pydantic schema in 'SolverConfig', error(s) will be raised
# if any checking failed at this step
if importlib.util.find_spec("pydantic") is not None:
from pydantic import ValidationError
else:
logger.error(
f"ModuleNotFoundError at {__file__}:{inspect.currentframe().f_lineno}\n"
"Please install pydantic with `pip install pydantic` when set callbacks"
" in your config yaml."
)
sys.exit(RUNTIME_EXIT_CODE)

# check given cfg using pre-defined pydantic schema in 'SolverConfig',
# error(s) will be printed and exit program if any checking failed at this step
try:
_model_pydantic = config_module.SolverConfig(**dict(config))
# complete missing items with default values pre-defined in pydantic schema in
# 'SolverConfig'
full_cfg = DictConfig(_model_pydantic.model_dump())
except ValidationError as e:
print(e)
sys.exit(VALIDATION_ERROR_EXIT_CODE)
except Exception as e:
print(e)
sys.exit(RUNTIME_EXIT_CODE)

# fix random seed for reproducibility
misc.set_random_seed(config.seed)
misc.set_random_seed(full_cfg.seed)

# create output directory
# initialze logger while creating output directory
logger.init_logger(
"ppsci", osp.join(config.output_dir, f"{config.mode}.log"), "info"
"ppsci",
osp.join(full_cfg.output_dir, f"{full_cfg.mode}.log")
if full_cfg.output_dir
else None,
full_cfg.log_level,
)
Loading