From 1fb540580602e73ba6481ab5e4e86187690d2dbf Mon Sep 17 00:00:00 2001 From: Mattie Tesfaldet Date: Mon, 18 Sep 2023 15:54:07 -0400 Subject: [PATCH] DDP-related improvements to datamodule and logging (#594) * Dividing batch size by number of devices in MNISTDataModule's setup fn * .log file is now the same across devices when training in a DDP setting * Adding rank-aware pylogger --- configs/data/mnist.yaml | 2 +- configs/hydra/default.yaml | 2 +- src/data/mnist_datamodule.py | 16 +++++++++-- src/eval.py | 18 ++++++++---- src/train.py | 24 ++++++++++------ src/utils/__init__.py | 2 +- src/utils/instantiators.py | 2 +- src/utils/logging_utils.py | 4 +-- src/utils/pylogger.py | 56 +++++++++++++++++++++++++++--------- src/utils/rich_utils.py | 4 +-- src/utils/utils.py | 2 +- 11 files changed, 93 insertions(+), 39 deletions(-) diff --git a/configs/data/mnist.yaml b/configs/data/mnist.yaml index f63bc8947..51bfaff09 100644 --- a/configs/data/mnist.yaml +++ b/configs/data/mnist.yaml @@ -1,6 +1,6 @@ _target_: src.data.mnist_datamodule.MNISTDataModule data_dir: ${paths.data_dir} -batch_size: 128 +batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) train_val_test_split: [55_000, 5_000, 10_000] num_workers: 0 pin_memory: False diff --git a/configs/hydra/default.yaml b/configs/hydra/default.yaml index 5f617fbee..a61e9b3a3 100644 --- a/configs/hydra/default.yaml +++ b/configs/hydra/default.yaml @@ -16,4 +16,4 @@ job_logging: handlers: file: # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 - filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log + filename: ${hydra.runtime.output_dir}/${task_name}.log diff --git a/src/data/mnist_datamodule.py b/src/data/mnist_datamodule.py index a77053034..88879ca94 100644 --- a/src/data/mnist_datamodule.py +++ b/src/data/mnist_datamodule.py @@ -83,6 +83,8 @@ def __init__( self.data_val: Optional[Dataset] = None self.data_test: Optional[Dataset] = None + self.batch_size_per_device = batch_size + @property def num_classes(self) -> int: """Get the number of classes. @@ -112,6 +114,14 @@ def setup(self, stage: Optional[str] = None) -> None: :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. """ + # Divide batch size by the number of devices. + if self.trainer is not None: + if self.hparams.batch_size % self.trainer.world_size != 0: + raise RuntimeError( + f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." + ) + self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size + # load and split datasets only if not loaded already if not self.data_train and not self.data_val and not self.data_test: trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms) @@ -130,7 +140,7 @@ def train_dataloader(self) -> DataLoader[Any]: """ return DataLoader( dataset=self.data_train, - batch_size=self.hparams.batch_size, + batch_size=self.batch_size_per_device, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=True, @@ -143,7 +153,7 @@ def val_dataloader(self) -> DataLoader[Any]: """ return DataLoader( dataset=self.data_val, - batch_size=self.hparams.batch_size, + batch_size=self.batch_size_per_device, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=False, @@ -156,7 +166,7 @@ def test_dataloader(self) -> DataLoader[Any]: """ return DataLoader( dataset=self.data_test, - batch_size=self.hparams.batch_size, + batch_size=self.batch_size_per_device, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=False, diff --git a/src/eval.py b/src/eval.py index 5924d4d24..b70faae8b 100644 --- a/src/eval.py +++ b/src/eval.py @@ -24,12 +24,18 @@ # more info: https://github.com/ashleve/rootutils # ------------------------------------------------------------------------------------ # -from src import utils +from src.utils import ( + RankedLogger, + extras, + instantiate_loggers, + log_hyperparameters, + task_wrapper, +) -log = utils.get_pylogger(__name__) +log = RankedLogger(__name__, rank_zero_only=True) -@utils.task_wrapper +@task_wrapper def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Evaluates given checkpoint on a datamodule testset. @@ -48,7 +54,7 @@ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: model: LightningModule = hydra.utils.instantiate(cfg.model) log.info("Instantiating loggers...") - logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) + logger: List[Logger] = instantiate_loggers(cfg.get("logger")) log.info(f"Instantiating trainer <{cfg.trainer._target_}>") trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger) @@ -63,7 +69,7 @@ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: if logger: log.info("Logging hyperparameters!") - utils.log_hyperparameters(object_dict) + log_hyperparameters(object_dict) log.info("Starting testing!") trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) @@ -84,7 +90,7 @@ def main(cfg: DictConfig) -> None: """ # apply extra utilities # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) - utils.extras(cfg) + extras(cfg) evaluate(cfg) diff --git a/src/train.py b/src/train.py index 955eefaf4..4adbcf442 100644 --- a/src/train.py +++ b/src/train.py @@ -26,12 +26,20 @@ # more info: https://github.com/ashleve/rootutils # ------------------------------------------------------------------------------------ # -from src import utils +from src.utils import ( + RankedLogger, + extras, + get_metric_value, + instantiate_callbacks, + instantiate_loggers, + log_hyperparameters, + task_wrapper, +) -log = utils.get_pylogger(__name__) +log = RankedLogger(__name__, rank_zero_only=True) -@utils.task_wrapper +@task_wrapper def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Trains the model. Can additionally evaluate on a testset, using best weights obtained during training. @@ -53,10 +61,10 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: model: LightningModule = hydra.utils.instantiate(cfg.model) log.info("Instantiating callbacks...") - callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) + callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) log.info("Instantiating loggers...") - logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) + logger: List[Logger] = instantiate_loggers(cfg.get("logger")) log.info(f"Instantiating trainer <{cfg.trainer._target_}>") trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) @@ -72,7 +80,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: if logger: log.info("Logging hyperparameters!") - utils.log_hyperparameters(object_dict) + log_hyperparameters(object_dict) if cfg.get("train"): log.info("Starting training!") @@ -106,13 +114,13 @@ def main(cfg: DictConfig) -> Optional[float]: """ # apply extra utilities # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) - utils.extras(cfg) + extras(cfg) # train the model metric_dict, _ = train(cfg) # safely retrieve metric value for hydra-based hyperparameter optimization - metric_value = utils.get_metric_value( + metric_value = get_metric_value( metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") ) diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 435aeab99..5b0707ca5 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,5 +1,5 @@ from src.utils.instantiators import instantiate_callbacks, instantiate_loggers from src.utils.logging_utils import log_hyperparameters -from src.utils.pylogger import get_pylogger +from src.utils.pylogger import RankedLogger from src.utils.rich_utils import enforce_tags, print_config_tree from src.utils.utils import extras, get_metric_value, task_wrapper diff --git a/src/utils/instantiators.py b/src/utils/instantiators.py index ada7a5253..82b9278a4 100644 --- a/src/utils/instantiators.py +++ b/src/utils/instantiators.py @@ -7,7 +7,7 @@ from src.utils import pylogger -log = pylogger.get_pylogger(__name__) +log = pylogger.RankedLogger(__name__, rank_zero_only=True) def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: diff --git a/src/utils/logging_utils.py b/src/utils/logging_utils.py index 899defd84..360abcdce 100644 --- a/src/utils/logging_utils.py +++ b/src/utils/logging_utils.py @@ -1,11 +1,11 @@ from typing import Any, Dict -from lightning.pytorch.utilities import rank_zero_only +from lightning_utilities.core.rank_zero import rank_zero_only from omegaconf import OmegaConf from src.utils import pylogger -log = pylogger.get_pylogger(__name__) +log = pylogger.RankedLogger(__name__, rank_zero_only=True) @rank_zero_only diff --git a/src/utils/pylogger.py b/src/utils/pylogger.py index 616006780..c4ee8675e 100644 --- a/src/utils/pylogger.py +++ b/src/utils/pylogger.py @@ -1,21 +1,51 @@ import logging +from typing import Mapping, Optional -from lightning.pytorch.utilities import rank_zero_only +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only -def get_pylogger(name: str = __name__) -> logging.Logger: - """Initializes a multi-GPU-friendly python command line logger. +class RankedLogger(logging.LoggerAdapter): + """A multi-GPU-friendly python command line logger.""" - :param name: The name of the logger, defaults to ``__name__``. + def __init__( + self, + name: str = __name__, + rank_zero_only: bool = False, + extra: Optional[Mapping[str, object]] = None, + ) -> None: + """Initializes a multi-GPU-friendly python command line logger that logs on all processes + with their rank prefixed in the log message. - :return: A logger object. - """ - logger = logging.getLogger(name) + :param name: The name of the logger. Default is ``__name__``. + :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. + :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + """ + logger = logging.getLogger(name) + super().__init__(logger=logger, extra=extra) + self.rank_zero_only = rank_zero_only - # this ensures all logging levels get marked with the rank zero decorator - # otherwise logs would get multiplied for each GPU process in multi-GPU setup - logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") - for level in logging_levels: - setattr(logger, level, rank_zero_only(getattr(logger, level))) + def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None: + """Delegate a log call to the underlying logger, after prefixing its message with the rank + of the process it's being logged from. If `'rank'` is provided, then the log will only + occur on that rank/process. - return logger + :param level: The level to log at. Look at `logging.__init__.py` for more information. + :param msg: The message to log. + :param rank: The rank to log at. + :param args: Additional args to pass to the underlying logging function. + :param kwargs: Any additional keyword args to pass to the underlying logging function. + """ + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + current_rank = getattr(rank_zero_only, "rank", None) + if current_rank is None: + raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") + msg = rank_prefixed_message(msg, current_rank) + if self.rank_zero_only: + if current_rank == 0: + self.logger.log(level, msg, *args, **kwargs) + else: + if rank is None: + self.logger.log(level, msg, *args, **kwargs) + elif current_rank == rank: + self.logger.log(level, msg, *args, **kwargs) diff --git a/src/utils/rich_utils.py b/src/utils/rich_utils.py index 430590adf..aeec6806b 100644 --- a/src/utils/rich_utils.py +++ b/src/utils/rich_utils.py @@ -5,13 +5,13 @@ import rich.syntax import rich.tree from hydra.core.hydra_config import HydraConfig -from lightning.pytorch.utilities import rank_zero_only +from lightning_utilities.core.rank_zero import rank_zero_only from omegaconf import DictConfig, OmegaConf, open_dict from rich.prompt import Prompt from src.utils import pylogger -log = pylogger.get_pylogger(__name__) +log = pylogger.RankedLogger(__name__, rank_zero_only=True) @rank_zero_only diff --git a/src/utils/utils.py b/src/utils/utils.py index a4d4eb1e0..02b55765a 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -6,7 +6,7 @@ from src.utils import pylogger, rich_utils -log = pylogger.get_pylogger(__name__) +log = pylogger.RankedLogger(__name__, rank_zero_only=True) def extras(cfg: DictConfig) -> None: