Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "id" support. Refactor Writers. Add Writer additional format extensibility. #78

Merged
merged 22 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b24ab22
Remove loss logging when predicting
ibro45 Aug 11, 2023
4a13721
Add "id" for each batch sample ID-ing purposes. Refactor Writers, add…
ibro45 Aug 13, 2023
2ea66bd
Remove interval arg, group_tensors fn, and on pred epoch end writing.…
ibro45 Aug 14, 2023
ffc26ae
Small fixes
ibro45 Aug 15, 2023
6c4563d
Remove multi opt and scheduler support. Replace remaininig sys.exit's.
ibro45 Aug 15, 2023
f8d689b
Update configure_optimizers docstring
ibro45 Aug 15, 2023
57b4447
Fix index ID issue in DDP writing. Replace broadcast with gather in t…
ibro45 Aug 15, 2023
605764a
Add missing if DDP check
ibro45 Aug 15, 2023
ae1a452
Update docstrings, rename and refactor parse_data
ibro45 Aug 16, 2023
cb63a9e
Add freezer to init file
ibro45 Aug 16, 2023
736d9f6
Change property to attribute
ibro45 Aug 16, 2023
fe7693a
Add support for dict metrics. Refactor system.
ibro45 Aug 16, 2023
57bcd7a
Fix typos
ibro45 Aug 16, 2023
799719e
Remove unused imports
ibro45 Aug 16, 2023
21d84f3
Update logger.py to support the temp ModuleDict fix
ibro45 Aug 18, 2023
c8eedea
Add continue to freezer and detach cpu to image logging
ibro45 Aug 20, 2023
8874399
Remove multi_pred, refactor Writer, Logger, and optional imports
ibro45 Sep 14, 2023
9ca6f7d
Bump gitpython from 3.1.32 to 3.1.35
dependabot[bot] Sep 14, 2023
01bfcd3
Bump certifi from 2023.5.7 to 2023.7.22
dependabot[bot] Sep 14, 2023
441a23d
Merge remote-tracking branch 'origin/dependabot/pip/certifi-2023.7.22…
ibro45 Sep 14, 2023
d0193b8
Merge remote-tracking branch 'origin/dependabot/pip/gitpython-3.1.35'…
ibro45 Sep 14, 2023
7ccce95
Remove add_batch_dim
ibro45 Sep 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lighter/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .freezer import LighterFreezer
from .logger import LighterLogger
from .writer.file import LighterFileWriter
from .writer.table import LighterTableWriter
15 changes: 9 additions & 6 deletions lighter/callbacks/freezer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def on_test_batch_start(
self._on_batch_start(trainer, pl_module)

def on_predict_batch_start(
self, trainer: Trainer, pl_module: LighterSystem, batch: Any, batch_idx: int, dataloader_idx: int
self, trainer: Trainer, pl_module: LighterSystem, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> None:
self._on_batch_start(trainer, pl_module)

Expand Down Expand Up @@ -122,20 +122,23 @@ def _set_model_requires_grad(self, model: Union[Module, LighterSystem], requires
# Leave the excluded-from-freezing parameters trainable.
if self.except_names and name in self.except_names:
param.requires_grad = True
elif self.except_name_starts_with and any(name.startswith(prefix) for prefix in self.except_name_starts_with):
continue
if self.except_name_starts_with and any(name.startswith(prefix) for prefix in self.except_name_starts_with):
param.requires_grad = True
continue

# Freeze/unfreeze the specified parameters, based on the `requires_grad` argument.
elif self.names and name in self.names:
if self.names and name in self.names:
param.requires_grad = requires_grad
frozen_layers.append(name)
elif self.name_starts_with and any(name.startswith(prefix) for prefix in self.name_starts_with):
continue
if self.name_starts_with and any(name.startswith(prefix) for prefix in self.name_starts_with):
param.requires_grad = requires_grad
frozen_layers.append(name)
continue

# Otherwise, leave the parameter trainable.
else:
param.requires_grad = True
param.requires_grad = True

self._frozen_state = not requires_grad
# Log only when freezing the parameters.
Expand Down
152 changes: 73 additions & 79 deletions lighter/callbacks/logger.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from typing import Any, Dict, Union

import itertools
import sys
from datetime import datetime
from pathlib import Path

import torch
from loguru import logger
from monai.utils.module import optional_import
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor

from lighter import LighterSystem
from lighter.callbacks.utils import get_lighter_mode, is_data_type_supported, parse_data, preprocess_image
from lighter.callbacks.utils import get_lighter_mode, preprocess_image
from lighter.utils.dynamic_imports import OPTIONAL_IMPORTS


Expand Down Expand Up @@ -62,8 +60,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:
stage (str): stage of the training process. Passed automatically by PyTorch Lightning.
"""
if trainer.logger is not None:
logger.error("When using LighterLogger, set Trainer(logger=None).")
sys.exit()
raise ValueError("When using LighterLogger, set Trainer(logger=None).")

if not trainer.is_global_zero:
return
Expand All @@ -76,8 +73,6 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:

# Tensorboard initialization.
if self.tensorboard:
# Tensorboard is a part of PyTorch, no need to check if it is not available.
OPTIONAL_IMPORTS["tensorboard"], _ = optional_import("torch.utils.tensorboard")
tensorboard_dir = self.log_dir / "tensorboard"
tensorboard_dir.mkdir()
self.tensorboard = OPTIONAL_IMPORTS["tensorboard"].SummaryWriter(log_dir=tensorboard_dir)
Expand All @@ -86,10 +81,6 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:

# Wandb initialization.
if self.wandb:
OPTIONAL_IMPORTS["wandb"], wandb_available = optional_import("wandb")
if not wandb_available:
logger.error("Weights & Biases not installed. To install it, run `pip install wandb`. Exiting.")
sys.exit()
wandb_dir = self.log_dir / "wandb"
wandb_dir.mkdir()
self.wandb = OPTIONAL_IMPORTS["wandb"].init(project=self.project, dir=wandb_dir, config=self.config)
Expand Down Expand Up @@ -165,6 +156,7 @@ def _log_image(self, name: str, image: torch.Tensor, global_step: int) -> None:
image (torch.Tensor): image to be logged.
global_step (int): current global step.
"""
image = image.detach().cpu()
if self.tensorboard:
self.tensorboard.add_image(name, image, global_step=global_step)
if self.wandb:
Expand All @@ -179,7 +171,6 @@ def _log_histogram(self, name: str, tensor: torch.Tensor, global_step: int) -> N
global_step (int): current global step.
"""
tensor = tensor.detach().cpu()

if self.tensorboard:
self.tensorboard.add_histogram(name, tensor, global_step=global_step)
if self.wandb:
Expand All @@ -193,49 +184,44 @@ def _on_batch_end(self, outputs: Dict, trainer: Trainer) -> None:
outputs (Dict): output dict from the model.
trainer (Trainer): Trainer, passed automatically by PyTorch Lightning.
"""
if not trainer.sanity_checking:
mode = get_lighter_mode(trainer.state.stage)
# Accumulate the loss.
if mode in ["train", "val"]:
self.loss[mode] += outputs["loss"].item()
# Logging frequency. Log only on rank 0.
if trainer.is_global_zero and self.global_step_counter[mode] % trainer.log_every_n_steps == 0:
# Get global step.
global_step = self._get_global_step(trainer)

# Log loss.
if outputs["loss"] is not None:
self._log_scalar(f"{mode}/loss/step", outputs["loss"], global_step)

# Log metrics.
if outputs["metrics"] is not None:
for name, metric in outputs["metrics"].items():
self._log_scalar(f"{mode}/metrics/{name}/step", metric, global_step)

# Log input, target, and pred.
for name in ["input", "target", "pred"]:
if self.log_types[name] is None:
continue
# Ensure data is of a valid type.
if not is_data_type_supported(outputs[name]):
raise ValueError(
f"`{name}` has to be a Tensor, List[Tensor], Tuple[Tensor], Dict[str, Tensor], "
f"Dict[str, List[Tensor]], or Dict[str, Tuple[Tensor]]. `{type(outputs[name])}` is not supported."
)
for identifier, item in parse_data(outputs[name]).items():
item_name = f"{mode}/data/{name}" if identifier is None else f"{mode}/data/{name}_{identifier}"
self._log_by_type(item_name, item, self.log_types[name], global_step)

# Log learning rate stats. Logs at step if a scheduler's interval is step-based.
if mode == "train":
lr_stats = self.lr_monitor.get_stats(trainer, "step")
for name, value in lr_stats.items():
self._log_scalar(f"{mode}/optimizer/{name}/step", value, global_step)

# Increment the step counters.
self.global_step_counter[mode] += 1
if mode in ["train", "val"]:
self.epoch_step_counter[mode] += 1
if trainer.sanity_checking:
return

mode = get_lighter_mode(trainer.state.stage)

# Accumulate the loss.
if mode in ["train", "val"]:
self.loss[mode] += outputs["loss"].item()

# Log only on rank 0 and according to the `log_every_n_steps` parameter. Otherwise, only increment the step counters.
if not trainer.is_global_zero or self.global_step_counter[mode] % trainer.log_every_n_steps != 0:
self._increment_step_counters(mode)
return

global_step = self._get_global_step(trainer)

# Loss.
if outputs["loss"] is not None:
self._log_scalar(f"{mode}/loss/step", outputs["loss"], global_step)

# Metrics.
if outputs["metrics"] is not None:
for name, metric in outputs["metrics"].items():
self._log_scalar(f"{mode}/metrics/{name}/step", metric, global_step)

# Input, target, and pred.
for name in ["input", "target", "pred"]:
if self.log_types[name] is not None:
self._log_by_type(f"{mode}/data/{name}", outputs[name], self.log_types[name], global_step)

# LR info. Logs at step if a scheduler's interval is step-based.
if mode == "train":
lr_stats = self.lr_monitor.get_stats(trainer, "step")
for name, value in lr_stats.items():
self._log_scalar(f"{mode}/optimizer/{name}/step", value, global_step)

# Increment the step counters.
self._increment_step_counters(mode)

def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None:
"""Performs logging at the end of an epoch. Logs the epoch number, the loss, and the metrics.
Expand All @@ -249,46 +235,44 @@ def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None:
mode = get_lighter_mode(trainer.state.stage)
loss, metrics = None, None

# Loss
# Get the accumulated loss over the epoch and processes.
if mode in ["train", "val"]:
# Get the accumulated loss.
loss = self.loss[mode]
# Reduce the loss and average it on each rank.
loss = trainer.strategy.reduce(loss, reduce_op="mean")
# Divide the accumulated loss by the number of steps in the epoch.
loss /= self.epoch_step_counter[mode]

# Metrics
# Get the torchmetrics.
metric_collection = pl_module.metrics[mode]
# TODO: Remove the "_" prefix when fixed https://github.com/pytorch/pytorch/issues/71203
metric_collection = pl_module.metrics["_" + mode]
if metric_collection is not None:
# Compute the epoch metrics.
metrics = metric_collection.compute()
# Reset the metrics for the next epoch.
metric_collection.reset()

# Log. Only on rank 0.
if trainer.is_global_zero:
# Get global step.
global_step = self._get_global_step(trainer)
# Log only on rank 0.
if not trainer.is_global_zero:
return

# Log epoch number.
self._log_scalar("epoch", trainer.current_epoch, global_step)
global_step = self._get_global_step(trainer)

# Log loss.
if loss is not None:
self._log_scalar(f"{mode}/loss/epoch", loss, global_step)
# Epoch number.
self._log_scalar("epoch", trainer.current_epoch, global_step)

# Log metrics.
if metrics is not None:
for name, metric in metrics.items():
self._log_scalar(f"{mode}/metrics/{name}/epoch", metric, global_step)
# Loss.
if loss is not None:
self._log_scalar(f"{mode}/loss/epoch", loss, global_step)

# Log learning rate stats. Logs at epoch if a scheduler's interval is epoch-based, or if no scheduler is used.
if mode == "train":
lr_stats = self.lr_monitor.get_stats(trainer, "epoch")
for name, value in lr_stats.items():
self._log_scalar(f"{mode}/optimizer/{name}/epoch", value, global_step)
# Metrics.
if metrics is not None:
for name, metric in metrics.items():
self._log_scalar(f"{mode}/metrics/{name}/epoch", metric, global_step)

# LR info. Logged at epoch if the scheduler's interval is epoch-based, or if no scheduler is used.
if mode == "train":
lr_stats = self.lr_monitor.get_stats(trainer, "epoch")
for name, value in lr_stats.items():
self._log_scalar(f"{mode}/optimizer/{name}/epoch", value, global_step)

def _get_global_step(self, trainer: Trainer) -> int:
"""Return the global step for the current mode. Note that when Trainer
Expand All @@ -309,6 +293,16 @@ def _get_global_step(self, trainer: Trainer) -> int:
return self.global_step_counter["train"]
return self.global_step_counter[mode]

def _increment_step_counters(self, mode: str) -> None:
"""Increment the global step and epoch step counters for the specified mode.

Args:
mode (str): mode to increment the global step counter for.
"""
self.global_step_counter[mode] += 1
if mode in ["train", "val"]:
self.epoch_step_counter[mode] += 1

def on_train_epoch_start(self, trainer: Trainer, pl_module: LighterSystem) -> None:
# Reset the loss and the epoch step counter for the next epoch.
self.loss["train"] = 0
Expand Down
Loading
Loading