Skip to content

Commit

Permalink
Support subloss logging through loss dicts (#111)
Browse files Browse the repository at this point in the history
* Support subloss logging through loss dicts

* Replace implicit subloss addition with user's "total" loss key

* Update lighter/system.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Update lighter/system.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Fix codestyle

---------

Co-authored-by: Suraj Pai <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored May 24, 2024
1 parent 538e286 commit 2e5614b
Showing 1 changed file with 37 additions and 15 deletions.
52 changes: 37 additions & 15 deletions lighter/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,41 +205,63 @@ def _base_step(self, batch: Dict, batch_idx: int, mode: str) -> Union[Dict[str,
target = apply_fns(target, self.postprocessing["logging"]["target"])
pred = apply_fns(pred, self.postprocessing["logging"]["pred"])

# Ensure that a dict of losses has a 'total' key.
if isinstance(loss, dict) and "total" not in loss:
raise ValueError(
"The loss dictionary must include a 'total' key that combines all sublosses. "
"Example: {'total': combined_loss, 'subloss1': loss1, ...}"
)

# Logging
self._log_stats(loss, metrics, mode, batch_idx)

return {"loss": loss, "metrics": metrics, "input": input, "target": target, "pred": pred, "id": id}

def _log_stats(self, loss: torch.Tensor, metrics: MetricCollection, mode: str, batch_idx: int) -> None:
# Return the loss as required by Lightning as well as other data that can be used in hooks or callbacks.
return {
"loss": loss["total"] if isinstance(loss, dict) else loss,
"metrics": metrics,
"input": input,
"target": target,
"pred": pred,
"id": id,
}

def _log_stats(
self, loss: Union[torch.Tensor, Dict[str, torch.Tensor]], metrics: MetricCollection, mode: str, batch_idx: int
) -> None:
"""
Logs the loss, metrics, and optimizer statistics.
Args:
loss (torch.Tensor): Calculated loss.
loss (Union[torch.Tensor, Dict[str, torch.Tensor]]): Calculated loss or a dict of sublosses.
metrics (MetricCollection): Calculated metrics.
mode (str): Mode of operation (train/val/test/predict).
batch_idx (int): Index of current batch.
"""
if self.trainer.logger is None:
return

# Arguments for self.log()
log_kwargs = {"logger": True, "batch_size": self.batch_size}
on_step_log_kwargs = {"on_epoch": False, "on_step": True, "sync_dist": False}
on_epoch_log_kwargs = {"on_epoch": True, "on_step": False, "sync_dist": True}
on_step_log = partial(self.log, logger=True, batch_size=self.batch_size, on_step=True, on_epoch=False, sync_dist=False)
on_epoch_log = partial(self.log, logger=True, batch_size=self.batch_size, on_step=False, on_epoch=True, sync_dist=True)

# Loss
if loss is not None:
self.log(f"{mode}/loss/step", loss, **log_kwargs, **on_step_log_kwargs)
self.log(f"{mode}/loss/epoch", loss, **log_kwargs, **on_epoch_log_kwargs)
if not isinstance(loss, dict):
on_step_log(f"{mode}/loss/step", loss)
on_epoch_log(f"{mode}/loss/epoch", loss)
else:
for name, subloss in loss.items():
on_step_log(f"{mode}/loss/{name}/step", subloss)
on_epoch_log(f"{mode}/loss/{name}/epoch", subloss)
# Metrics
if metrics is not None:
for k, v in metrics.items():
self.log(f"{mode}/metrics/{k}/step", v, **log_kwargs, **on_step_log_kwargs)
self.log(f"{mode}/metrics/{k}/epoch", v, **log_kwargs, **on_epoch_log_kwargs)
for name, metric in metrics.items():
if not isinstance(metric, Metric):
raise TypeError(f"Expected type for metric is 'Metric', got '{type(metric).__name__}' instead.")
on_step_log(f"{mode}/metrics/{name}/step", metric)
on_epoch_log(f"{mode}/metrics/{name}/epoch", metric)
# Optimizer's lr, momentum, beta. Logged in train mode and once per epoch.
if mode == "train" and batch_idx == 0:
for k, v in get_optimizer_stats(self.optimizer).items():
self.log(f"{mode}/{k}", v, **log_kwargs, **on_epoch_log_kwargs)
for name, optimizer_stat in get_optimizer_stats(self.optimizer).items():
on_epoch_log(f"{mode}/{name}", optimizer_stat)

def _base_dataloader(self, mode: str) -> DataLoader:
"""Instantiate the dataloader for a mode (train/val/test/predict).
Expand Down

0 comments on commit 2e5614b

Please sign in to comment.