Skip to content

Commit

Permalink
rework a whole whack of image logger
Browse files Browse the repository at this point in the history
  • Loading branch information
neggles committed Jan 27, 2024
1 parent 6ec95f5 commit 993b584
Showing 1 changed file with 76 additions and 113 deletions.
189 changes: 76 additions & 113 deletions src/neurosis/trainer/callbacks/image_logger.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
import logging
import pickle
from enum import Enum
from math import ceil, sqrt
from os import PathLike
from pathlib import Path
from typing import Optional, Union
from warnings import warn

import numpy as np
import torch
import wandb
from diffusers.image_processor import VaeImageProcessor
from lightning.pytorch import Callback, LightningModule, Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.utilities import rank_zero_only
from matplotlib import pyplot as plt
from matplotlib.image import AxesImage
from PIL import Image
from torch import Tensor
from torch.amp.autocast_mode import autocast
from torchvision.utils import make_grid

from neurosis.utils import isheatmap, ndimage_to_u8
from neurosis.utils.image.convert import pt_to_pil
from neurosis.utils.image.grid import CaptionGrid
from neurosis.utils.text import np_text_decode
Expand All @@ -39,7 +34,6 @@ class ImageLogger(Callback):
def __init__(
self,
every_n_train_steps: int = 100,
# every_n_epochs: int = 1, # doesn't work without wasting memory
max_images: int = 4,
clamp: bool = True,
rescale: bool = True,
Expand Down Expand Up @@ -73,7 +67,7 @@ def __init__(

self.__last_logged_step: int = -1
self.__trainer: Trainer = None
self.processor = VaeImageProcessor(do_resize=False, vae_scale_factor=8, do_convert_rgb=True)
self._first_run = True

def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
self.__trainer = trainer
Expand Down Expand Up @@ -134,69 +128,70 @@ def check_step_idx(self, global_step: int, batch_idx: int) -> bool:
@rank_zero_only
def log_local(
self,
save_dir: PathLike,
split: str,
log_dict: dict[str, Tensor],
step: int,
epoch: int,
batch_idx: int,
images: dict[str, np.ndarray | Tensor] = {},
batch: dict[str, np.ndarray | Tensor | str] = {},
step: int = ...,
epoch: int = ...,
batch_idx: int = ...,
pl_module: Optional[LightningModule] = None,
):
if save_dir is None:
return
save_dir = Path(save_dir).joinpath("images", split)
save_dir = self.local_dir.joinpath("images", split)
save_dir.mkdir(exist_ok=True, parents=True)

wandb_dict = {"trainer/global_step": step}
if log_strings := log_dict.get("strings", None):
if "caption" in log_strings and "samples" in log_dict:
samples = log_dict.pop("samples")
if isinstance(samples, Tensor):
samples = pt_to_pil(samples)

log_strings["samples"] = [wandb.Image(x, mode="RGB") for x in samples]
for k in images:
imgpath = save_dir / f"{k}_gs-{step:06}_e-{epoch:06}_b-{batch_idx:06}.png"
if isinstance(images[k], Image.Image):
img = images[k]
img.save(imgpath)
wandb_dict.update({f"{split}/{k}": wandb.Image(img)})

captions = log_strings["caption"]
if isinstance(captions, list):
captions = [captions]
elif isinstance(images[k], (Tensor, list)):
images[k] = [pt_to_pil(x) for x in images[k]]

grid = CaptionGrid()(
samples, captions, title=f"E{epoch:06} S{step:06} B{batch_idx:06} samples"
if k == "samples" and "caption" in batch:
captions = batch["caption"]
capgrid: CaptionGrid = CaptionGrid()
img: Image.Image = capgrid(
images[k], captions, title=f"GS{step:06} E{epoch:06} B{batch_idx:06} samples"
)
log_dict["samples"] = grid

table = wandb.Table()
for k in log_strings:
table.add_column(k, log_strings[k])
wandb_dict[f"{split}/sample_table"] = table

for k, val in log_dict.items():
imgpath = save_dir / f"{k}_gs-{step:06}_e-{epoch:06}_b-{batch_idx:06}.png"
img = None
if isinstance(val, Image.Image):
img = val
elif isheatmap(val):
fig, ax = plt.subplots()
ax.set_axis_off()
img: AxesImage = ax.matshow(val.cpu().numpy(), cmap="hot", interpolation="lanczos")
fig.colorbar(img, ax=ax)
fig.savefig(imgpath)
plt.close(fig)
img = Image.open(imgpath)

elif isinstance(val, Tensor):
if val.ndim == 3:
val = val.unsqueeze(0)
try:
img = ndimage_to_u8(make_grid(val, nrow=ceil(sqrt(val.shape[0]))).cpu().numpy())
img = Image.fromarray(img)
img.save(imgpath)
except Exception as e:
logger.exception(e)
continue

if img is not None:
wandb_dict.update({f"{split}/{k}": wandb.Image(img)})
img.save(imgpath)
wandb_dict.update({f"{split}/{k}": wandb.Image(img, caption="Sample Grid")})

table_dict = {}
if "samples" in images:
table_dict["samples"] = [wandb.Image(x, mode="RGB") for x in images.pop("samples", [])]

for k in batch:
if isinstance(batch[k], list):
if isinstance(batch[k][0], (str, np.bytes_)):
table_dict[k] = np_text_decode(batch[k], aslist=True)
elif isinstance(batch[k][0], Tensor):
if batch[k][0].ndim == 3 and batch[k][0].shape[0] == 3:
val = pt_to_pil(batch[k])
batch[k] = torch.stack(batch[k], dim=0)

if isinstance(batch[k], list) and isinstance(batch[k][0], Tensor):
if batch[k][0].ndim == 3 and batch[k][0].shape[0] == 3:
batch[k] = torch.stack(batch[k], dim=0)

# yeah this is hacky since i explicitly made it Not Be A Tensor Before but whatever
batch[k] = torch.stack(batch[k], dim=0)

if isinstance(batch[k], Tensor):
val = batch[k].detach().cpu()
if val.ndim == 4 and val.shape[1] == 3:
val = pt_to_pil(val)
elif val.ndim == 3:
val = [pt_to_pil(val)]
elif val.ndim == 2:
val = [x.cpu().item() for x in val]
table_dict[k] = val

if len(table_dict) > 0:
wandb_dict[f"{split}/batch"] = wandb.Table(data=table_dict)

if pl_module is not None:
for pl_logger in [x for x in pl_module.loggers if isinstance(x, WandbLogger)]:
Expand All @@ -221,10 +216,10 @@ def maybe_log_images(

# now make sure the module has a log_images method that we can call
if not hasattr(pl_module, "log_images"):
warn(f"{pl_module.__class__.__name__} has no log_images method")
logger.warning(f"{pl_module.__class__.__name__} has no log_images method")
return
if not callable(pl_module.log_images):
warn(f"{pl_module.__class__.__name__}'s log_images method is not callable! ")
logger.warning(f"{pl_module.__class__.__name__}'s log_images method is not callable! ")
return

# confirmed we're logging, save the step number
Expand All @@ -239,86 +234,54 @@ def maybe_log_images(
)
# call the actual log_images method
with torch.inference_mode(), autocast(**autocast_kwargs):
log_dict: list[Tensor] = pl_module.log_images(
images: list[Tensor] = pl_module.log_images(
batch, num_img=self.max_images, split=split, **self.log_func_kwargs
)

# if the model returned None, warn and return early
if log_dict is None:
if images is None:
warn(f"{pl_module.__class__.__name__} returned None from log_images")
return

for k in log_dict:
if isinstance(log_dict[k], Tensor):
log_dict[k] = log_dict[k].detach().float().cpu()
if not isheatmap(log_dict[k]):
log_dict[k] = log_dict[k][: min(log_dict[k].shape[0], self.max_images)]
for k in images:
images[k] = images[k][: self.max_images]
if isinstance(images[k], Tensor):
images[k] = images[k].detach().float().cpu()
if self.clamp:
log_dict[k] = log_dict[k].clamp(min=-1.0, max=1.0)
images[k] = images[k].clamp(min=-1.0, max=1.0)
if self.rescale:
log_dict[k] = (log_dict[k] + 1.0) / 2.0

log_strings = {}
for k in self.extra_log_keys:
try:
if k in batch:
log_strings[k] = batch[k][: self.max_images]
if isinstance(log_strings[k][0], np.bytes_):
log_strings[k] = np_text_decode(log_strings[k], aslist=True)
images[k] = (images[k] + 1.0) / 2.0

log_strings[k] = [np_text_decode(x) for x in log_strings[k]]
except Exception as e:
logger.exception(e)
continue

if len(log_strings) > 0:
log_dict["strings"] = log_strings
for k in batch:
batch[k] = batch[k][: self.max_images]
if isinstance(batch[k][0], (str, np.bytes_)):
batch[k] = [np_text_decode(x) for x in batch[k]]

# log the images
self.log_local(
self.local_dir,
split,
log_dict,
images,
batch,
trainer.global_step,
pl_module.current_epoch,
batch_idx,
pl_module=pl_module,
)

@rank_zero_only
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs,
batch,
batch_idx,
):
def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx):
if self.enabled:
self.maybe_log_images(trainer, pl_module, batch, batch_idx, split="train")

@rank_zero_only
def on_train_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch,
batch_idx,
):
if self.enabled and self.log_before_start and trainer.global_step == 0:
def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx):
if self.enabled and trainer.global_step == 0 and self.log_before_start:
logger.info(f"{self.__class__.__name__} running log before training...")
self.maybe_log_images(trainer, pl_module, batch, batch_idx, split="train", force=True)

@rank_zero_only
def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs,
batch,
batch_idx,
*args,
**kwargs,
self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, *args, **kwargs
):
if self.enabled and trainer.global_step > 0:
self.maybe_log_images(trainer, pl_module, batch, batch_idx, split="val")

0 comments on commit 993b584

Please sign in to comment.