diff --git a/README.md b/README.md
index 76b7d8c..74d2456 100644
--- a/README.md
+++ b/README.md
@@ -189,48 +189,49 @@ For detailed user guides and advanced guides, please refer to our [Documentation
|
|
|
|
@@ -255,27 +256,27 @@ For detailed user guides and advanced guides, please refer to our [Documentation
|
|
|
|
diff --git a/diffengine/configs/_base_/datasets/pickascore_xl.py b/diffengine/configs/_base_/datasets/pickascore_xl.py
new file mode 100644
index 0000000..50a85b2
--- /dev/null
+++ b/diffengine/configs/_base_/datasets/pickascore_xl.py
@@ -0,0 +1,61 @@
+import torchvision
+from mmengine.dataset import DefaultSampler
+
+from diffengine.datasets import HFDPODataset
+from diffengine.datasets.transforms import (
+ ComputeTimeIds,
+ ConcatMultipleImgs,
+ PackInputs,
+ RandomCrop,
+ RandomHorizontalFlip,
+ SaveImageShape,
+ TorchVisonTransformWrapper,
+)
+from diffengine.engine.hooks import SDCheckpointHook, VisualizationHook
+
+train_pipeline = [
+ dict(type=SaveImageShape),
+ dict(type=TorchVisonTransformWrapper,
+ transform=torchvision.transforms.Resize,
+ size=1024, interpolation="bilinear"),
+ dict(type=RandomCrop, size=1024),
+ dict(type=RandomHorizontalFlip, p=0.5),
+ dict(type=ComputeTimeIds),
+ dict(type=TorchVisonTransformWrapper,
+ transform=torchvision.transforms.ToTensor),
+ dict(type=TorchVisonTransformWrapper,
+ transform=torchvision.transforms.Normalize, mean=[0.5], std=[0.5]),
+ dict(type=ConcatMultipleImgs),
+ dict(type=PackInputs, input_keys=["img", "text", "time_ids"]),
+]
+train_dataloader = dict(
+ batch_size=2,
+ num_workers=2,
+ dataset=dict(
+ type=HFDPODataset,
+ dataset="kashif/pickascore",
+ split="validation",
+ image_columns=["jpg_0", "jpg_1"],
+ caption_column="caption",
+ pipeline=train_pipeline),
+ sampler=dict(type=DefaultSampler, shuffle=True),
+)
+
+val_dataloader = None
+val_evaluator = None
+test_dataloader = val_dataloader
+test_evaluator = val_evaluator
+
+custom_hooks = [
+ dict(
+ type=VisualizationHook,
+ prompt=[
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography", # noqa
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", # noqa
+ ],
+ height=1024,
+ width=1024),
+ dict(type=SDCheckpointHook),
+]
diff --git a/diffengine/configs/_base_/models/stable_diffusion_xl_dpo.py b/diffengine/configs/_base_/models/stable_diffusion_xl_dpo.py
new file mode 100644
index 0000000..53fd329
--- /dev/null
+++ b/diffengine/configs/_base_/models/stable_diffusion_xl_dpo.py
@@ -0,0 +1,26 @@
+from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
+from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection
+
+from diffengine.models.editors import StableDiffusionXLDPO
+
+base_model = "stabilityai/stable-diffusion-xl-base-1.0"
+model = dict(type=StableDiffusionXLDPO,
+ model=base_model,
+ tokenizer_one=dict(type=AutoTokenizer.from_pretrained,
+ subfolder="tokenizer",
+ use_fast=False),
+ tokenizer_two=dict(type=AutoTokenizer.from_pretrained,
+ subfolder="tokenizer_2",
+ use_fast=False),
+ scheduler=dict(type=DDPMScheduler.from_pretrained,
+ subfolder="scheduler"),
+ text_encoder_one=dict(type=CLIPTextModel.from_pretrained,
+ subfolder="text_encoder"),
+ text_encoder_two=dict(type=CLIPTextModelWithProjection.from_pretrained,
+ subfolder="text_encoder_2"),
+ vae=dict(
+ type=AutoencoderKL.from_pretrained,
+ pretrained_model_name_or_path="madebyollin/sdxl-vae-fp16-fix"),
+ unet=dict(type=UNet2DConditionModel.from_pretrained,
+ subfolder="unet"),
+ gradient_checkpointing=True)
diff --git a/diffengine/configs/diffusion_dpo/README.md b/diffengine/configs/diffusion_dpo/README.md
new file mode 100644
index 0000000..62b6030
--- /dev/null
+++ b/diffengine/configs/diffusion_dpo/README.md
@@ -0,0 +1,76 @@
+# Diffusion DPO
+
+[Diffusion Model Alignment Using Direct Preference Optimization](https://arxiv.org/abs/2311.12908)
+
+## Abstract
+
+Large language models (LLMs) are fine-tuned using human comparison data with Reinforcement Learning from Human Feedback (RLHF) methods to make them better aligned with users' preferences. In contrast to LLMs, human preference learning has not been widely explored in text-to-image diffusion models; the best existing approach is to fine-tune a pretrained model using carefully curated high quality images and captions to improve visual appeal and text alignment. We propose Diffusion-DPO, a method to align diffusion models to human preferences by directly optimizing on human comparison data. Diffusion-DPO is adapted from the recently developed Direct Preference Optimization (DPO), a simpler alternative to RLHF which directly optimizes a policy that best satisfies human preferences under a classification objective. We re-formulate DPO to account for a diffusion model notion of likelihood, utilizing the evidence lower bound to derive a differentiable objective. Using the Pick-a-Pic dataset of 851K crowdsourced pairwise preferences, we fine-tune the base model of the state-of-the-art Stable Diffusion XL (SDXL)-1.0 model with Diffusion-DPO. Our fine-tuned base model significantly outperforms both base SDXL-1.0 and the larger SDXL-1.0 model consisting of an additional refinement model in human evaluation, improving visual appeal and prompt alignment. We also develop a variant that uses AI feedback and has comparable performance to training on human preferences, opening the door for scaling of diffusion model alignment methods.
+
+
+
+
+
+## Citation
+
+```
+```
+
+## Run Training
+
+Run Training
+
+```
+# single gpu
+$ diffengine train ${CONFIG_FILE}
+# multi gpus
+$ NPROC_PER_NODE=${GPU_NUM} diffengine train ${CONFIG_FILE}
+
+# Example.
+$ diffengine train stable_diffusion_xl_dpo_pickascore
+```
+
+## Inference with diffusers
+
+Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module.
+
+Before inferencing, we should convert weights for diffusers format,
+
+```bash
+$ diffengine convert ${CONFIG_FILE} ${INPUT_FILENAME} ${OUTPUT_DIR} --save-keys ${SAVE_KEYS}
+# Example
+$ diffengine convert stable_diffusion_xl_dpo_pickascore work_dirs/stable_diffusion_xl_dpo_pickascore/epoch_50.pth work_dirs/stable_diffusion_xl_dpo_pickascore --save-keys unet
+```
+
+Then we can run inference.
+
+```py
+import torch
+from diffusers import DiffusionPipeline, UNet2DConditionModel, AutoencoderKL
+
+prompt = 'Astronaut in a jungle, cold color palette, muted colors, detailed, 8k'
+checkpoint = 'work_dirs/stable_diffusion_xl_dpo_pickascore'
+
+unet = UNet2DConditionModel.from_pretrained(
+ checkpoint, subfolder='unet', torch_dtype=torch.float16)
+vae = AutoencoderKL.from_pretrained(
+ 'madebyollin/sdxl-vae-fp16-fix',
+ torch_dtype=torch.float16,
+)
+pipe = DiffusionPipeline.from_pretrained(
+ 'stabilityai/stable-diffusion-xl-base-1.0', unet=unet, vae=vae, torch_dtype=torch.float16)
+pipe.to('cuda')
+
+image = pipe(
+ prompt,
+ num_inference_steps=50,
+ width=1024,
+ height=1024,
+).images[0]
+image.save('demo.png')
+```
+
+## Results Example
+
+#### stable_diffusion_xl_dpo_pickascore
+
+![example1](https://github.com/okotaku/diffengine/assets/24734142/efa32784-3151-4cb2-9af2-368a8c4b527b)
diff --git a/diffengine/configs/diffusion_dpo/stable_diffusion_xl_dpo_pickascore.py b/diffengine/configs/diffusion_dpo/stable_diffusion_xl_dpo_pickascore.py
new file mode 100644
index 0000000..950bff4
--- /dev/null
+++ b/diffengine/configs/diffusion_dpo/stable_diffusion_xl_dpo_pickascore.py
@@ -0,0 +1,12 @@
+from mmengine.config import read_base
+
+with read_base():
+ from .._base_.datasets.pickascore_xl import *
+ from .._base_.default_runtime import *
+ from .._base_.models.stable_diffusion_xl_dpo import *
+ from .._base_.schedules.stable_diffusion_xl_50e import *
+
+
+train_dataloader.update(batch_size=1)
+
+optim_wrapper.update(accumulative_counts=4) # update every four times
diff --git a/diffengine/datasets/__init__.py b/diffengine/datasets/__init__.py
index 483fff4..38fbc86 100644
--- a/diffengine/datasets/__init__.py
+++ b/diffengine/datasets/__init__.py
@@ -1,5 +1,6 @@
from .hf_controlnet_datasets import HFControlNetDataset
from .hf_datasets import HFDataset, HFDatasetPreComputeEmbs
+from .hf_dpo_datasets import HFDPODataset
from .hf_dreambooth_datasets import HFDreamBoothDataset
from .hf_esd_datasets import HFESDDatasetPreComputeEmbs
from .samplers import * # noqa: F403
@@ -11,4 +12,5 @@
"HFControlNetDataset",
"HFDatasetPreComputeEmbs",
"HFESDDatasetPreComputeEmbs",
+ "HFDPODataset",
]
diff --git a/diffengine/datasets/hf_dpo_datasets.py b/diffengine/datasets/hf_dpo_datasets.py
new file mode 100644
index 0000000..84641ad
--- /dev/null
+++ b/diffengine/datasets/hf_dpo_datasets.py
@@ -0,0 +1,112 @@
+# flake8: noqa: TRY004,S311
+import io
+import os
+import random
+from collections.abc import Sequence
+from pathlib import Path
+
+import numpy as np
+from datasets import load_dataset
+from mmengine.dataset.base_dataset import Compose
+from PIL import Image
+from torch.utils.data import Dataset
+
+from diffengine.registry import DATASETS
+
+Image.MAX_IMAGE_PIXELS = 1000000000
+
+
+@DATASETS.register_module()
+class HFDPODataset(Dataset):
+ """DPO Dataset for huggingface datasets.
+
+ Args:
+ ----
+ dataset (str): Dataset name or path to dataset.
+ image_columns (list[str]): Image column names. Defaults to ['image'].
+ caption_column (str): Caption column name. Defaults to 'text'.
+ label_column (str): Label column name of whether image_columns[0] is
+ better than image_columns[1]. Defaults to 'label_0'.
+ csv (str): Caption csv file name when loading local folder.
+ Defaults to 'metadata.csv'.
+ pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
+ split (str): Dataset split. Defaults to 'train'.
+ cache_dir (str, optional): The directory where the downloaded datasets
+ will be stored.Defaults to None.
+ """
+
+ def __init__(self,
+ dataset: str,
+ image_columns: list[str] | None = None,
+ caption_column: str = "text",
+ label_column: str = "label_0",
+ csv: str = "metadata.csv",
+ pipeline: Sequence = (),
+ split: str = "train",
+ cache_dir: str | None = None) -> None:
+ if image_columns is None:
+ image_columns = ["image", "image2"]
+ self.dataset_name = dataset
+ if Path(dataset).exists():
+ # load local folder
+ data_file = os.path.join(dataset, csv)
+ self.dataset = load_dataset(
+ "csv", data_files=data_file, cache_dir=cache_dir)[split]
+ else:
+ # load huggingface online
+ self.dataset = load_dataset(dataset, cache_dir=cache_dir)[split]
+ self.pipeline = Compose(pipeline)
+
+ self.image_columns = image_columns
+ self.caption_column = caption_column
+ self.label_column = label_column
+
+ def __len__(self) -> int:
+ """Get the length of dataset.
+
+ Returns
+ -------
+ int: The length of filtered dataset.
+ """
+ return len(self.dataset)
+
+ def __getitem__(self, idx: int) -> dict:
+ """Get item.
+
+ Get the idx-th image and data information of dataset after
+ ``self.pipeline`.
+
+ Args:
+ ----
+ idx (int): The index of self.data_list.
+
+ Returns:
+ -------
+ dict: The idx-th image and data information of dataset after
+ ``self.pipeline``.
+ """
+ data_info = self.dataset[idx]
+ images = []
+ for image_column in self.image_columns:
+ image = data_info[image_column]
+ if isinstance(image, str):
+ image = Image.open(os.path.join(self.dataset_name, image))
+ elif not isinstance(image, Image.Image):
+ image = Image.open(io.BytesIO(image))
+ image = image.convert("RGB")
+ images.append(image)
+ label = data_info[self.label_column]
+ if not label:
+ images = images[::-1]
+ caption = data_info[self.caption_column]
+ if isinstance(caption, str):
+ pass
+ elif isinstance(caption, list | np.ndarray):
+ # take a random caption if there are multiple
+ caption = random.choice(caption)
+ else:
+ msg = (f"Caption column `{self.caption_column}` should "
+ "contain either strings or lists of strings.")
+ raise ValueError(msg)
+ result = {"img": images, "text": caption}
+ return self.pipeline(result)
diff --git a/diffengine/datasets/transforms/__init__.py b/diffengine/datasets/transforms/__init__.py
index d47d2ab..b9ff756 100644
--- a/diffengine/datasets/transforms/__init__.py
+++ b/diffengine/datasets/transforms/__init__.py
@@ -9,6 +9,7 @@
CLIPImageProcessor,
ComputePixArtImgInfo,
ComputeTimeIds,
+ ConcatMultipleImgs,
GetMaskedImage,
MaskToTensor,
MultiAspectRatioResizeCenterCrop,
@@ -43,4 +44,5 @@
"AddConstantCaption",
"DumpMaskedImage",
"TorchVisonTransformWrapper",
+ "ConcatMultipleImgs",
]
diff --git a/diffengine/datasets/transforms/processing.py b/diffengine/datasets/transforms/processing.py
index 535f6ac..cedfff4 100644
--- a/diffengine/datasets/transforms/processing.py
+++ b/diffengine/datasets/transforms/processing.py
@@ -78,7 +78,10 @@ def __init__(self,
def __call__(self, results: dict) -> dict:
"""Call transform."""
for k in self.keys:
- results[k] = self.t(results[k])
+ if not isinstance(results[k], list):
+ results[k] = self.t(results[k])
+ else:
+ results[k] = [self.t(img) for img in results[k]]
return results
def __repr__(self) -> str:
@@ -132,10 +135,15 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
-------
dict: 'ori_img_shape' key is added as original image shape.
"""
- results["ori_img_shape"] = [
- results["img"].height,
- results["img"].width,
- ]
+ if not isinstance(results["img"], list):
+ imgs = [results["img"]]
+ else:
+ imgs = results["img"]
+
+ ori_img_shape = [[img.height, img.width] for img in imgs]
+ if not isinstance(results["img"], list):
+ ori_img_shape = ori_img_shape[0]
+ results["ori_img_shape"] = ori_img_shape
return results
@@ -187,16 +195,46 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
dict: 'crop_top_left' and `crop_bottom_right` key is added as crop
point.
"""
- if self.force_same_size:
- assert all(
- results["img"].size == results[k].size for k in self.keys), (
- "Size mismatch. {k: results[k].size for k in self.keys}"
- )
- y1, x1, h, w = self.pipeline.get_params(results["img"], self.size)
+ components = dict()
+ for k in self.keys:
+ if not isinstance(results["img"], list):
+ components[k] = [results[k]]
+ else:
+ components[k] = results[k]
+
+ crop_top_left = []
+ crop_bottom_right = []
+ before_crop_size = []
+ for i in range(len(components["img"])):
+ if self.force_same_size:
+ assert all(
+ components["img"][i].size == components[k][i].size
+ for k in self.keys), (
+ "Size mismatch."
+ )
+ before_crop_size.append([components["img"][i].height,
+ components["img"][i].width])
+
+ y1, x1, h, w = self.pipeline.get_params(components["img"][i],
+ self.size)
+ for k in self.keys:
+ components[k][i] = crop(components[k][i], y1, x1, h, w)
+ crop_top_left.append([y1, x1])
+ crop_bottom_right.append([y1 + h, x1 + w])
+
+ if not isinstance(results["img"], list):
+ for k in self.keys:
+ components[k] = components[k][0]
+ crop_top_left = crop_top_left[0]
+ crop_bottom_right = crop_bottom_right[0]
+ before_crop_size = before_crop_size[0]
+
for k in self.keys:
- results[k] = crop(results[k], y1, x1, h, w)
- results["crop_top_left"] = [y1, x1]
- results["crop_bottom_right"] = [y1 + h, x1 + w]
+ results[k] = components[k]
+
+ results["crop_top_left"] = crop_top_left
+ results["crop_bottom_right"] = crop_bottom_right
+ results["before_crop_size"] = before_crop_size
return results
@@ -242,15 +280,50 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
-------
dict: 'crop_top_left' key is added as crop points.
"""
- assert all(results["img"].size == results[k].size for k in self.keys)
- y1 = max(0, int(round((results["img"].height - self.size[0]) / 2.0)))
- x1 = max(0, int(round((results["img"].width - self.size[1]) / 2.0)))
- y2 = max(0, int(round((results["img"].height + self.size[0]) / 2.0)))
- x2 = max(0, int(round((results["img"].width + self.size[1]) / 2.0)))
+ components = dict()
+ for k in self.keys:
+ if not isinstance(results["img"], list):
+ components[k] = [results[k]]
+ else:
+ components[k] = results[k]
+
+ crop_top_left: list = []
+ crop_bottom_right: list = []
+ before_crop_size: list = []
+ for i in range(len(components["img"])):
+ assert all(
+ components["img"][i].size == components[k][i].size
+ for k in self.keys), (
+ "Size mismatch."
+ )
+ before_crop_size.append([components["img"][i].height,
+ components["img"][i].width])
+
+ y1 = max(0, int(round(
+ (components["img"][i].height - self.size[0]) / 2.0)))
+ x1 = max(0, int(round(
+ (components["img"][i].width - self.size[1]) / 2.0)))
+ y2 = max(0, int(round(
+ (components["img"][i].height + self.size[0]) / 2.0)))
+ x2 = max(0, int(round(
+ (components["img"][i].width + self.size[1]) / 2.0)))
+ for k in self.keys:
+ components[k][i] = self.pipeline(components[k][i])
+ crop_top_left.append([y1, x1])
+ crop_bottom_right.append([y2, x2])
+
+ if not isinstance(results["img"], list):
+ for k in self.keys:
+ components[k] = components[k][0]
+ crop_top_left = crop_top_left[0]
+ crop_bottom_right = crop_bottom_right[0]
+ before_crop_size = before_crop_size[0]
+
for k in self.keys:
- results[k] = self.pipeline(results[k])
- results["crop_top_left"] = [y1, x1]
- results["crop_bottom_right"] = [y2, x2]
+ results[k] = components[k]
+ results["crop_top_left"] = crop_top_left
+ results["crop_bottom_right"] = crop_bottom_right
+ results["before_crop_size"] = before_crop_size
return results
@@ -298,6 +371,8 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
----
results (dict): The result dict.
"""
+ assert not isinstance(results["img"], list), (
+ "MultiAspectRatioResizeCenterCrop only support single image.")
aspect_ratio = results["img"].height / results["img"].width
bucked_id = np.argmin(np.abs(aspect_ratio - self.aspect_ratios))
return self.pipelines[bucked_id](results)
@@ -328,7 +403,7 @@ def __init__(self, *args, p: float = 0.5,
self.pipeline = torchvision.transforms.RandomHorizontalFlip(
*args, p=1.0, **kwargs)
- def transform(self, results: dict) -> dict | tuple[list, list] | None:
+ def transform(self, results: dict) -> dict | tuple[list, list] | None: # noqa: C901,PLR0912
"""Transform.
Args:
@@ -339,15 +414,42 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
-------
dict: 'crop_top_left' key is fixed.
"""
- if random.random() < self.p:
- assert all(results["img"].size == results[k].size
- for k in self.keys)
+ components = dict()
+ additional_keys = [
+ "crop_top_left", "crop_bottom_right", "before_crop_size",
+ ] if "crop_top_left" in results else []
+ for k in self.keys + additional_keys:
+ if not isinstance(results["img"], list):
+ components[k] = [results[k]]
+ else:
+ components[k] = results[k]
+
+ crop_top_left = []
+ for i in range(len(components["img"])):
+ if random.random() < self.p:
+ assert all(components["img"][i].size == components[k][i].size
+ for k in self.keys)
+ for k in self.keys:
+ components[k][i] = self.pipeline(components[k][i])
+ if "crop_top_left" in results:
+ y1 = components["crop_top_left"][i][0]
+ x1 = (
+ components["before_crop_size"][i][1] - components[
+ "crop_bottom_right"][i][1])
+ crop_top_left.append([y1, x1])
+ elif "crop_top_left" in results:
+ crop_top_left.append(components["crop_top_left"][i])
+
+ if not isinstance(results["img"], list):
for k in self.keys:
- results[k] = self.pipeline(results[k])
+ components[k] = components[k][0]
if "crop_top_left" in results:
- y1 = results["crop_top_left"][0]
- x1 = results["img"].width - results["crop_bottom_right"][1]
- results["crop_top_left"] = [y1, x1]
+ crop_top_left = crop_top_left[0]
+
+ for k in self.keys:
+ results[k] = components[k]
+ if "crop_top_left" in results:
+ results["crop_top_left"] = crop_top_left
return results
@@ -368,9 +470,23 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
"""
assert "ori_img_shape" in results
assert "crop_top_left" in results
- target_size = [results["img"].height, results["img"].width]
- time_ids = results["ori_img_shape"] + results[
- "crop_top_left"] + target_size
+
+ time_ids = []
+ if not isinstance(results["img"], list):
+ img = [results["img"]]
+ ori_img_shape = [results["ori_img_shape"]]
+ crop_top_left = [results["crop_top_left"]]
+ else:
+ img = results["img"]
+ ori_img_shape = results["ori_img_shape"]
+ crop_top_left = results["crop_top_left"]
+
+ for i in range(len(img)):
+ target_size = [img[i].height, img[i].width]
+ time_ids.append(ori_img_shape[i] + crop_top_left[i] + target_size)
+
+ if not isinstance(results["img"], list):
+ time_ids = time_ids[0]
results["time_ids"] = time_ids
return results
@@ -394,8 +510,24 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
dict: 'time_ids' key is added as original image shape.
"""
assert "ori_img_shape" in results
- results["resolution"] = [float(s) for s in results["ori_img_shape"]]
- results["aspect_ratio"] = results["img"].height / results["img"].width
+ if not isinstance(results["img"], list):
+ img = [results["img"]]
+ ori_img_shape = [results["ori_img_shape"]]
+ else:
+ img = results["img"]
+ ori_img_shape = results["ori_img_shape"]
+
+ resolution: list = []
+ aspect_ratio: list = []
+ for i in range(len(img)):
+ resolution.append([float(s) for s in ori_img_shape[i]])
+ aspect_ratio.append(img[i].height / img[i].width)
+
+ if not isinstance(results["img"], list):
+ resolution = resolution[0]
+ aspect_ratio = aspect_ratio[0]
+ results["resolution"] = resolution
+ results["aspect_ratio"] = aspect_ratio
return results
@@ -427,6 +559,8 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
----
results (dict): The result dict.
"""
+ assert not isinstance(results[self.key], list), (
+ "CLIPImageProcessor only support single image.")
# (1, 3, 224, 224) -> (3, 224, 224)
results[self.output_key] = self.pipeline(
images=results[self.key], return_tensors="pt").pixel_values[0]
@@ -668,6 +802,8 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
----
results (dict): The result dict.
"""
+ assert not isinstance(results[self.key], list), (
+ "MaskToTensor only support single image.")
# (1, 3, 224, 224) -> (3, 224, 224)
results[self.key] = torch.Tensor(results[self.key]).permute(2, 0, 1)
return results
@@ -693,6 +829,8 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
----
results (dict): The result dict.
"""
+ assert not isinstance(results["img"], list), (
+ "GetMaskedImage only support single image.")
mask_threahold = 0.5
results[self.key] = results["img"] * (results["mask"] < mask_threahold)
return results
@@ -730,3 +868,31 @@ def transform(self,
for k in self.keys:
results[k] = results[k] + " " + self.constant_caption
return results
+
+
+@TRANSFORMS.register_module()
+class ConcatMultipleImgs(BaseTransform):
+ """ConcatMultipleImgs.
+
+ Args:
+ ----
+ keys (List[str], optional): `keys` to apply augmentation from results.
+ Defaults to None.
+ """
+
+ def __init__(self, keys: list[str] | None = None) -> None:
+ if keys is None:
+ keys = ["img"]
+ self.keys = keys
+
+ def transform(self,
+ results: dict) -> dict | tuple[list, list] | None:
+ """Transform.
+
+ Args:
+ ----
+ results (dict): The result dict.
+ """
+ for k in self.keys:
+ results[k] = torch.cat(results[k], dim=0)
+ return results
diff --git a/diffengine/models/editors/__init__.py b/diffengine/models/editors/__init__.py
index 80a09a0..1462681 100644
--- a/diffengine/models/editors/__init__.py
+++ b/diffengine/models/editors/__init__.py
@@ -12,6 +12,7 @@
from .stable_diffusion_inpaint import * # noqa: F403
from .stable_diffusion_xl import * # noqa: F403
from .stable_diffusion_xl_controlnet import * # noqa: F403
+from .stable_diffusion_xl_dpo import * # noqa: F403
from .stable_diffusion_xl_inpaint import * # noqa: F403
from .t2i_adapter import * # noqa: F403
from .wuerstchen import * # noqa: F403
diff --git a/diffengine/models/editors/stable_diffusion_xl_dpo/__init__.py b/diffengine/models/editors/stable_diffusion_xl_dpo/__init__.py
new file mode 100644
index 0000000..c1ae4e2
--- /dev/null
+++ b/diffengine/models/editors/stable_diffusion_xl_dpo/__init__.py
@@ -0,0 +1,4 @@
+from .sdxl_dpo_data_preprocessor import SDXLDPODataPreprocessor
+from .stable_diffusion_xl_dpo import StableDiffusionXLDPO
+
+__all__ = ["StableDiffusionXLDPO", "SDXLDPODataPreprocessor"]
diff --git a/diffengine/models/editors/stable_diffusion_xl_dpo/sdxl_dpo_data_preprocessor.py b/diffengine/models/editors/stable_diffusion_xl_dpo/sdxl_dpo_data_preprocessor.py
new file mode 100644
index 0000000..b6ab821
--- /dev/null
+++ b/diffengine/models/editors/stable_diffusion_xl_dpo/sdxl_dpo_data_preprocessor.py
@@ -0,0 +1,45 @@
+import torch
+from mmengine.model.base_model.data_preprocessor import BaseDataPreprocessor
+
+from diffengine.registry import MODELS
+
+
+@MODELS.register_module()
+class SDXLDPODataPreprocessor(BaseDataPreprocessor):
+ """SDXLDataPreprocessor."""
+
+ def forward(
+ self,
+ data: dict,
+ training: bool = False # noqa
+ ) -> dict | list:
+ """Preprocesses the data into the model input format.
+
+ After the data pre-processing of :meth:`cast_data`, ``forward``
+ will stack the input tensor list to a batch tensor at the first
+ dimension.
+
+ Args:
+ ----
+ data (dict): Data returned by dataloader
+ training (bool): Whether to enable training time augmentation.
+
+ Returns:
+ -------
+ dict or list: Data in the same format as the model input.
+ """
+ assert "result_class_image" not in data["inputs"], (
+ "result_class_image is not supported for SDXLDPO")
+
+ data["inputs"]["img"] = torch.cat(
+ torch.stack(data["inputs"]["img"]).chunk(2, dim=1))
+ data["inputs"]["time_ids"] = torch.cat(
+ torch.stack(data["inputs"]["time_ids"]).chunk(2, dim=1))
+ # pre-compute text embeddings
+ if "prompt_embeds" in data["inputs"]:
+ data["inputs"]["prompt_embeds"] = torch.stack(
+ data["inputs"]["prompt_embeds"])
+ if "pooled_prompt_embeds" in data["inputs"]:
+ data["inputs"]["pooled_prompt_embeds"] = torch.stack(
+ data["inputs"]["pooled_prompt_embeds"])
+ return super().forward(data)
diff --git a/diffengine/models/editors/stable_diffusion_xl_dpo/stable_diffusion_xl_dpo.py b/diffengine/models/editors/stable_diffusion_xl_dpo/stable_diffusion_xl_dpo.py
new file mode 100644
index 0000000..9810c76
--- /dev/null
+++ b/diffengine/models/editors/stable_diffusion_xl_dpo/stable_diffusion_xl_dpo.py
@@ -0,0 +1,191 @@
+from copy import deepcopy
+from typing import Optional
+
+import torch
+import torch.nn.functional as F # noqa: N812
+from torch import nn
+
+from diffengine.models.editors.stable_diffusion_xl import StableDiffusionXL
+from diffengine.registry import MODELS
+
+
+@MODELS.register_module()
+class StableDiffusionXLDPO(StableDiffusionXL):
+ """Stable Diffusion XL DPO.
+
+ Args:
+ ----
+ beta_dpo (int): DPO KL Divergence penalty. Defaults to 5000.
+ loss (dict, optional): The loss config. Defaults to None.
+ data_preprocessor (dict, optional): The pre-process config of
+ :class:`SDXLDPODataPreprocessor`.
+ """
+
+ def __init__(self,
+ *args,
+ beta_dpo: int = 5000,
+ loss: dict | None = None,
+ data_preprocessor: dict | nn.Module | None = None,
+ **kwargs) -> None:
+ if loss is None:
+ loss = {"type": "L2Loss", "loss_weight": 1.0,
+ "reduction": "none"}
+ if data_preprocessor is None:
+ data_preprocessor = {"type": "SDXLDPODataPreprocessor"}
+
+ super().__init__(
+ *args,
+ loss=loss,
+ data_preprocessor=data_preprocessor,
+ **kwargs) # type: ignore[misc]
+
+ self.beta_dpo = beta_dpo
+
+ def prepare_model(self) -> None:
+ """Prepare model for training.
+
+ Disable gradient for some models.
+ """
+ self.orig_unet = deepcopy(
+ self.unet).requires_grad_(requires_grad=False)
+
+ super().prepare_model()
+
+ def loss( # type: ignore[override]
+ self,
+ model_pred: torch.Tensor,
+ ref_pred: torch.Tensor,
+ noise: torch.Tensor,
+ latents: torch.Tensor,
+ timesteps: torch.Tensor,
+ weight: torch.Tensor | None = None) -> dict[str, torch.Tensor]:
+ """Calculate loss."""
+ if self.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ self.scheduler.register_to_config(
+ prediction_type=self.prediction_type)
+
+ if self.scheduler.config.prediction_type == "epsilon":
+ gt = noise
+ elif self.scheduler.config.prediction_type == "v_prediction":
+ gt = self.scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ msg = f"Unknown prediction type {self.scheduler.config.prediction_type}"
+ raise ValueError(msg)
+
+ loss_dict = {}
+ # calculate loss in FP32
+ if self.loss_module.use_snr:
+ model_loss = self.loss_module(
+ model_pred.float(),
+ gt.float(),
+ timesteps,
+ self.scheduler.alphas_cumprod,
+ self.scheduler.config.prediction_type,
+ weight=weight)
+ ref_loss = self.loss_module(
+ ref_pred.float(),
+ gt.float(),
+ timesteps,
+ self.scheduler.alphas_cumprod,
+ self.scheduler.config.prediction_type,
+ weight=weight)
+ else:
+ model_loss = self.loss_module(
+ model_pred.float(), gt.float(), weight=weight)
+ ref_loss = self.loss_module(
+ ref_pred.float(), gt.float(), weight=weight)
+ model_loss = model_loss.mean(
+ dim=list(range(1, len(model_loss.shape))))
+ ref_loss = ref_loss.mean(
+ dim=list(range(1, len(ref_loss.shape))))
+ model_losses_w, model_losses_l = model_loss.chunk(2)
+ model_diff = model_losses_w - model_losses_l
+
+ ref_losses_w, ref_losses_l = ref_loss.chunk(2)
+ ref_diff = ref_losses_w - ref_losses_l
+ scale_term = -0.5 * self.beta_dpo
+ inside_term = scale_term * (model_diff - ref_diff)
+ loss = -1 * F.logsigmoid(inside_term.mean())
+
+ loss_dict["loss"] = loss
+ return loss_dict
+
+ def forward(
+ self,
+ inputs: dict,
+ data_samples: Optional[list] = None, # noqa
+ mode: str = "loss") -> dict:
+ """Forward function.
+
+ Args:
+ ----
+ inputs (dict): The input dict.
+ data_samples (Optional[list], optional): The data samples.
+ Defaults to None.
+ mode (str, optional): The mode. Defaults to "loss".
+
+ Returns:
+ -------
+ dict: The loss dict.
+ """
+ assert mode == "loss"
+ assert "result_class_image" not in inputs, (
+ "result_class_image is not supported for SDXLDPO")
+ # num_batches is divided by 2 because we have two images per sample
+ num_batches = len(inputs["img"]) // 2
+
+ latents = self.vae.encode(inputs["img"]).latent_dist.sample()
+ latents = latents * self.vae.config.scaling_factor
+
+ noise = self.noise_generator(latents[:num_batches])
+ # repeat noise for each sample set
+ noise = noise.repeat(2, 1, 1, 1)
+
+ timesteps = self.timesteps_generator(self.scheduler, num_batches,
+ self.device)
+ # repeat timesteps for each sample set
+ timesteps = timesteps.repeat(2)
+
+ noisy_latents = self._preprocess_model_input(latents, noise, timesteps)
+
+ if not self.pre_compute_text_embeddings:
+ inputs["text_one"] = self.tokenizer_one(
+ inputs["text"],
+ max_length=self.tokenizer_one.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt").input_ids.to(self.device)
+ inputs["text_two"] = self.tokenizer_two(
+ inputs["text"],
+ max_length=self.tokenizer_two.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt").input_ids.to(self.device)
+ prompt_embeds, pooled_prompt_embeds = self.encode_prompt(
+ inputs["text_one"], inputs["text_two"])
+ else:
+ prompt_embeds = inputs["prompt_embeds"]
+ pooled_prompt_embeds = inputs["pooled_prompt_embeds"]
+ # repeat text embeds for each sample set
+ prompt_embeds = prompt_embeds.repeat(2, 1, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)
+ unet_added_conditions = {
+ "time_ids": inputs["time_ids"],
+ "text_embeds": pooled_prompt_embeds,
+ }
+
+ model_pred = self.unet(
+ noisy_latents,
+ timesteps,
+ prompt_embeds,
+ added_cond_kwargs=unet_added_conditions).sample
+ with torch.no_grad():
+ ref_pred = self.orig_unet(
+ noisy_latents,
+ timesteps,
+ prompt_embeds,
+ added_cond_kwargs=unet_added_conditions,
+ ).sample
+
+ return self.loss(model_pred, ref_pred, noise, latents, timesteps)
diff --git a/diffengine/models/losses/debias_estimation_loss.py b/diffengine/models/losses/debias_estimation_loss.py
index ee8bb23..91ab1be 100644
--- a/diffengine/models/losses/debias_estimation_loss.py
+++ b/diffengine/models/losses/debias_estimation_loss.py
@@ -16,6 +16,8 @@ class DeBiasEstimationLoss(BaseLoss):
----
loss_weight (float): Weight of this loss item.
Defaults to ``1.``.
+ reduction: (str): The reduction method for the loss.
+ Defaults to 'mean'.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'l2'.
@@ -23,10 +25,15 @@ class DeBiasEstimationLoss(BaseLoss):
def __init__(self,
loss_weight: float = 1.0,
+ reduction: str = "mean",
loss_name: str = "debias_estimation") -> None:
super().__init__()
+ assert reduction in ["mean", "none"], (
+ f"reduction should be 'mean' or 'none', got {reduction}"
+ )
self.loss_weight = loss_weight
+ self.reduction = reduction
self._loss_name = loss_name
@property
@@ -69,5 +76,6 @@ def forward(self,
dim=list(range(1, len(loss.shape)))) * mse_loss_weights
if weight is not None:
loss = loss * weight
- loss = loss.mean()
+ if self.reduction == "mean":
+ loss = loss.mean()
return loss * self.loss_weight
diff --git a/diffengine/models/losses/hubar_loss.py b/diffengine/models/losses/hubar_loss.py
index 4b7d2ae..85f7a4f 100644
--- a/diffengine/models/losses/hubar_loss.py
+++ b/diffengine/models/losses/hubar_loss.py
@@ -16,6 +16,8 @@ class HuberLoss(BaseLoss):
Default: 1.0
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
+ reduction: (str): The reduction method for the loss.
+ Defaults to 'mean'.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'l2'.
@@ -24,11 +26,16 @@ class HuberLoss(BaseLoss):
def __init__(self,
delta: float = 1.0,
loss_weight: float = 1.0,
+ reduction: str = "mean",
loss_name: str = "l2") -> None:
super().__init__()
+ assert reduction in ["mean", "none"], (
+ f"reduction should be 'mean' or 'none', got {reduction}"
+ )
self.delta = delta
self.loss_weight = loss_weight
+ self.reduction = reduction
self._loss_name = loss_name
def forward(self,
@@ -51,6 +58,9 @@ def forward(self,
if weight is not None:
loss = F.huber_loss(
pred, gt, reduction="none", delta=self.delta) * weight
- return loss.mean() * self.loss_weight
+ if self.reduction == "mean":
+ loss = loss.mean()
+ return loss * self.loss_weight
- return F.huber_loss(pred, gt, delta=self.delta) * self.loss_weight
+ return F.huber_loss(pred, gt, delta=self.delta,
+ reduction=self.reduction) * self.loss_weight
diff --git a/diffengine/models/losses/l2_loss.py b/diffengine/models/losses/l2_loss.py
index c3773e1..250db5c 100644
--- a/diffengine/models/losses/l2_loss.py
+++ b/diffengine/models/losses/l2_loss.py
@@ -13,6 +13,8 @@ class L2Loss(BaseLoss):
----
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
+ reduction: (str): The reduction method for the loss.
+ Defaults to 'mean'.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'l2'.
@@ -20,10 +22,15 @@ class L2Loss(BaseLoss):
def __init__(self,
loss_weight: float = 1.0,
+ reduction: str = "mean",
loss_name: str = "l2") -> None:
super().__init__()
+ assert reduction in ["mean", "none"], (
+ f"reduction should be 'mean' or 'none', got {reduction}"
+ )
self.loss_weight = loss_weight
+ self.reduction = reduction
self._loss_name = loss_name
def forward(self,
@@ -45,6 +52,8 @@ def forward(self,
"""
if weight is not None:
loss = F.mse_loss(pred, gt, reduction="none") * weight
- return loss.mean() * self.loss_weight
+ if self.reduction == "mean":
+ loss = loss.mean()
+ return loss * self.loss_weight
- return F.mse_loss(pred, gt) * self.loss_weight
+ return F.mse_loss(pred, gt, reduction=self.reduction) * self.loss_weight
diff --git a/diffengine/models/losses/snr_l2_loss.py b/diffengine/models/losses/snr_l2_loss.py
index bf11d46..ec07720 100644
--- a/diffengine/models/losses/snr_l2_loss.py
+++ b/diffengine/models/losses/snr_l2_loss.py
@@ -19,6 +19,8 @@ class SNRL2Loss(BaseLoss):
snr_gamma (float): SNR weighting gamma to be used if re balancing the
loss. "More details here: https://arxiv.org/abs/2303.09556."
Defaults to ``5.``.
+ reduction: (str): The reduction method for the loss.
+ Defaults to 'mean'.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'l2'.
@@ -27,11 +29,16 @@ class SNRL2Loss(BaseLoss):
def __init__(self,
loss_weight: float = 1.0,
snr_gamma: float = 5.0,
+ reduction: str = "mean",
loss_name: str = "snrl2") -> None:
super().__init__()
+ assert reduction in ["mean", "none"], (
+ f"reduction should be 'mean' or 'none', got {reduction}"
+ )
self.loss_weight = loss_weight
self.snr_gamma = snr_gamma
+ self.reduction = reduction
self._loss_name = loss_name
@property
@@ -76,5 +83,6 @@ def forward(self,
dim=list(range(1, len(loss.shape)))) * mse_loss_weights
if weight is not None:
loss = loss * weight
- loss = loss.mean()
+ if self.reduction == "mean":
+ loss = loss.mean()
return loss * self.loss_weight
diff --git a/tests/test_datasets/test_transforms/test_processing.py b/tests/test_datasets/test_transforms/test_processing.py
index a279944..e889cb7 100644
--- a/tests/test_datasets/test_transforms/test_processing.py
+++ b/tests/test_datasets/test_transforms/test_processing.py
@@ -113,6 +113,17 @@ def test_transform(self):
data = trans(data)
self.assertListEqual(data["ori_img_shape"], ori_img_shape)
+ def test_transform_list(self):
+ img_path = osp.join(osp.dirname(__file__), "../../testdata/color.jpg")
+ data = {"img": [Image.open(img_path),
+ Image.open(img_path).resize((64, 64))]}
+ ori_img_shape = [[img.height, img.width] for img in data["img"]]
+
+ # test transform
+ trans = TRANSFORMS.build(dict(type="SaveImageShape"))
+ data = trans(data)
+ self.assertListEqual(data["ori_img_shape"], ori_img_shape)
+
class TestComputeTimeIds(TestCase):
@@ -130,6 +141,20 @@ def test_transform(self):
self.assertListEqual(data["time_ids"],
[32, 32, 0, 0, img.height, img.width])
+ def test_transform_list(self):
+ img_path = osp.join(osp.dirname(__file__), "../../testdata/color.jpg")
+ img = Image.open(img_path)
+ data = {"img": [img, img],
+ "ori_img_shape": [[32, 32], [48, 48]],
+ "crop_top_left": [[0, 0], [10, 10]]}
+
+ # test transform
+ trans = TRANSFORMS.build(dict(type="ComputeTimeIds"))
+ data = trans(data)
+ self.assertListEqual(data["time_ids"],
+ [[32, 32, 0, 0, img.height, img.width],
+ [48, 48, 10, 10, img.height, img.width]])
+
class TestRandomCrop(TestCase):
crop_size = 32
@@ -206,6 +231,93 @@ def test_transform_multiple_keys(self):
assert lower == upper + self.crop_size
assert right == left + self.crop_size
+ def test_transform_list(self):
+ img_path = osp.join(osp.dirname(__file__), "../../testdata/color.jpg")
+ data = {"img": [Image.open(img_path),
+ Image.open(img_path).resize((64, 64))]}
+
+ # test transform
+ trans = TRANSFORMS.build(dict(type="RandomCrop", size=self.crop_size))
+ data = trans(data)
+ assert "crop_top_left" in data
+ assert len(data["crop_top_left"]) == 2
+ for i in range(len(data["img"])):
+ assert (
+ data["img"][i].height == data["img"][i].width == self.crop_size
+ )
+ upper, left = data["crop_top_left"][i]
+ lower, right = data["crop_bottom_right"][i]
+ assert lower == upper + self.crop_size
+ assert right == left + self.crop_size
+ np.equal(
+ np.array(data["img"][i]),
+ np.array(
+ Image.open(img_path).crop((left, upper, right, lower))))
+
+ def test_transform_multiple_keys_list(self):
+ img_path = osp.join(osp.dirname(__file__), "../../testdata/color.jpg")
+ data = {
+ "img": [
+ Image.open(img_path), Image.open(img_path).resize((64, 64))],
+ "condition_img": [
+ Image.open(img_path), Image.open(img_path).resize((64, 64))],
+ }
+
+ # test transform
+ trans = TRANSFORMS.build(
+ dict(
+ type="RandomCrop",
+ size=self.crop_size,
+ keys=["img", "condition_img"]))
+ data = trans(data)
+ assert "crop_top_left" in data
+ assert len(data["crop_top_left"]) == 2
+ for i in range(len(data["img"])):
+ assert (
+ data["img"][i].height == data["img"][i].width == self.crop_size
+ )
+ upper, left = data["crop_top_left"][i]
+ lower, right = data["crop_bottom_right"][i]
+ assert lower == upper + self.crop_size
+ assert right == left + self.crop_size
+ np.equal(
+ np.array(data["img"][i]),
+ np.array(
+ Image.open(img_path).crop((left, upper, right, lower))))
+ np.equal(np.array(data["img"][i]),
+ np.array(data["condition_img"][i]))
+
+ # size mismatch
+ data = {
+ "img": [Image.open(img_path),
+ Image.open(img_path).resize((64, 64))],
+ "condition_img": [
+ Image.open(img_path).resize((298, 398)),
+ Image.open(img_path).resize((64, 64))],
+ }
+ with pytest.raises(
+ AssertionError, match="Size mismatch"):
+ data = trans(data)
+
+ # test transform force_same_size=False
+ trans = TRANSFORMS.build(
+ dict(
+ type="RandomCrop",
+ size=self.crop_size,
+ force_same_size=False,
+ keys=["img", "condition_img"]))
+ data = trans(data)
+ assert "crop_top_left" in data
+ assert len(data["crop_top_left"]) == 2
+ for i in range(len(data["img"])):
+ assert (
+ data["img"][i].height == data["img"][i].width == self.crop_size
+ )
+ upper, left = data["crop_top_left"][i]
+ lower, right = data["crop_bottom_right"][i]
+ assert lower == upper + self.crop_size
+ assert right == left + self.crop_size
+
class TestCenterCrop(TestCase):
crop_size = 32
@@ -257,6 +369,62 @@ def test_transform_multiple_keys(self):
np.array(Image.open(img_path).crop((left, upper, right, lower))))
np.equal(np.array(data["img"]), np.array(data["condition_img"]))
+ def test_transform_list(self):
+ img_path = osp.join(osp.dirname(__file__), "../../testdata/color.jpg")
+ data = {"img": [
+ Image.open(img_path), Image.open(img_path).resize((64, 64))]}
+
+ # test transform
+ trans = TRANSFORMS.build(dict(type="CenterCrop", size=self.crop_size))
+ data = trans(data)
+ assert "crop_top_left" in data
+ assert len(data["crop_top_left"]) == 2
+ for i in range(len(data["img"])):
+ assert (
+ data["img"][i].height == data["img"][i].width == self.crop_size
+ )
+ upper, left = data["crop_top_left"][i]
+ lower, right = data["crop_bottom_right"][i]
+ assert lower == upper + self.crop_size
+ assert right == left + self.crop_size
+ np.equal(
+ np.array(data["img"][i]),
+ np.array(
+ Image.open(img_path).crop((left, upper, right, lower))))
+
+ def test_transform_multiple_keys_list(self):
+ img_path = osp.join(osp.dirname(__file__), "../../testdata/color.jpg")
+ data = {
+ "img": [
+ Image.open(img_path), Image.open(img_path).resize((64, 64))],
+ "condition_img": [
+ Image.open(img_path), Image.open(img_path).resize((64, 64))],
+ }
+
+ # test transform
+ trans = TRANSFORMS.build(
+ dict(
+ type="CenterCrop",
+ size=self.crop_size,
+ keys=["img", "condition_img"]))
+ data = trans(data)
+ assert "crop_top_left" in data
+ assert len(data["crop_top_left"]) == 2
+ for i in range(len(data["img"])):
+ assert (
+ data["img"][i].height == data["img"][i].width == self.crop_size
+ )
+ upper, left = data["crop_top_left"][i]
+ lower, right = data["crop_bottom_right"][i]
+ assert lower == upper + self.crop_size
+ assert right == left + self.crop_size
+ np.equal(
+ np.array(data["img"][i]),
+ np.array(
+ Image.open(img_path).crop((left, upper, right, lower))))
+ np.equal(np.array(data["img"][i]),
+ np.array(data["condition_img"][i]))
+
class TestRandomHorizontalFlip(TestCase):
@@ -268,7 +436,8 @@ def test_transform(self):
data = {
"img": Image.open(img_path),
"crop_top_left": [0, 0],
- "crop_bottom_right": [10, 10],
+ "crop_bottom_right": [200, 200],
+ "before_crop_size": [224, 224],
}
# test transform
@@ -277,7 +446,7 @@ def test_transform(self):
assert "crop_top_left" in data
assert len(data["crop_top_left"]) == 2
self.assertListEqual(data["crop_top_left"],
- [0, data["img"].width - 10])
+ [0, data["before_crop_size"][1] - 200])
np.equal(
np.array(data["img"]),
@@ -287,7 +456,8 @@ def test_transform(self):
data = {
"img": Image.open(img_path),
"crop_top_left": [0, 0],
- "crop_bottom_right": [10, 10],
+ "crop_bottom_right": [200, 200],
+ "before_crop_size": [224, 224],
}
trans = TRANSFORMS.build(dict(type="RandomHorizontalFlip", p=0.))
data = trans(data)
@@ -302,7 +472,8 @@ def test_transform_multiple_keys(self):
"img": Image.open(img_path),
"condition_img": Image.open(img_path),
"crop_top_left": [0, 0],
- "crop_bottom_right": [10, 10],
+ "crop_bottom_right": [200, 200],
+ "before_crop_size": [224, 224],
}
# test transform
@@ -315,13 +486,92 @@ def test_transform_multiple_keys(self):
assert "crop_top_left" in data
assert len(data["crop_top_left"]) == 2
self.assertListEqual(data["crop_top_left"],
- [0, data["img"].width - 10])
+ [0, data["before_crop_size"][1] - 200])
np.equal(
np.array(data["img"]),
np.array(Image.open(img_path).transpose(Image.FLIP_LEFT_RIGHT)))
np.equal(np.array(data["img"]), np.array(data["condition_img"]))
+ def test_transform_list(self):
+ img_path = osp.join(osp.dirname(__file__), "../../testdata/color.jpg")
+ data = {
+ "img": [
+ Image.open(img_path), Image.open(img_path).resize((64, 64))],
+ "crop_top_left": [[0, 0], [10, 10]],
+ "crop_bottom_right": [[200, 200], [220, 220]],
+ "before_crop_size": [[224, 224], [256, 256]],
+ }
+
+ # test transform
+ trans = TRANSFORMS.build(dict(type="RandomHorizontalFlip", p=1.))
+ transformed_data = trans(data)
+ assert "crop_top_left" in data
+ assert len(data["crop_top_left"]) == 2
+ for i in range(len(data["img"])):
+ self.assertListEqual(
+ data["crop_top_left"][i],
+ [data["crop_top_left"][i][0],
+ data["before_crop_size"][i][1] - data[
+ "crop_bottom_right"][i][1]])
+
+ np.equal(
+ np.array(transformed_data["img"][i]),
+ np.array(
+ data["img"][i].transpose(Image.FLIP_LEFT_RIGHT)))
+
+ # test transform p=0.0
+ data = {
+ "img": [
+ Image.open(img_path), Image.open(img_path).resize((64, 64))],
+ "crop_top_left": [[0, 0], [10, 10]],
+ "crop_bottom_right": [[200, 200], [220, 220]],
+ "before_crop_size": [[224, 224], [256, 256]],
+ }
+ trans = TRANSFORMS.build(dict(type="RandomHorizontalFlip", p=0.))
+ transformed_data = trans(data)
+ assert "crop_top_left" in data
+ for i in range(len(data["img"])):
+ self.assertListEqual(data["crop_top_left"][i],
+ data["crop_top_left"][i])
+ np.equal(np.array(transformed_data["img"][i]),
+ np.array(data["img"][i]))
+
+ def test_transform_multiple_keys_list(self):
+ img_path = osp.join(osp.dirname(__file__), "../../testdata/color.jpg")
+ data = {
+ "img": [
+ Image.open(img_path), Image.open(img_path).resize((64, 64))],
+ "condition_img": [
+ Image.open(img_path), Image.open(img_path).resize((64, 64))],
+ "crop_top_left": [[0, 0], [10, 10]],
+ "crop_bottom_right": [[200, 200], [220, 220]],
+ "before_crop_size": [[224, 224], [256, 256]],
+ }
+
+ # test transform
+ trans = TRANSFORMS.build(
+ dict(
+ type="RandomHorizontalFlip",
+ p=1.,
+ keys=["img", "condition_img"]))
+ transformed_data = trans(data)
+ assert "crop_top_left" in data
+ assert len(data["crop_top_left"]) == 2
+ for i in range(len(data["img"])):
+ self.assertListEqual(
+ data["crop_top_left"][i],
+ [data["crop_top_left"][i][0],
+ data["before_crop_size"][i][1] - data[
+ "crop_bottom_right"][i][1]])
+
+ np.equal(
+ np.array(data["img"][i]),
+ np.array(
+ transformed_data["img"][i].transpose(Image.FLIP_LEFT_RIGHT)))
+ np.equal(np.array(data["img"][i]),
+ np.array(data["condition_img"][i]))
+
class TestMultiAspectRatioResizeCenterCrop(TestCase):
sizes = [(32, 32), (16, 48)] # noqa
@@ -397,6 +647,33 @@ def test_transform_multiple_keys(self):
(left, upper, right, lower))))
np.equal(np.array(data["img"]), np.array(data["condition_img"]))
+ def test_transform_list(self):
+ img_path = osp.join(osp.dirname(__file__), "../../testdata/color.jpg")
+ data = {"img": [Image.open(img_path).resize((32, 36)),
+ Image.open(img_path).resize((55, 16))]}
+
+ # test transform
+ trans = TRANSFORMS.build(
+ dict(type="MultiAspectRatioResizeCenterCrop", sizes=self.sizes))
+ with pytest.raises(
+ AssertionError, match="MultiAspectRatioResizeCenterCrop only"):
+ _ = trans(data)
+
+ def test_transform_multiple_keys_list(self):
+ img_path = osp.join(osp.dirname(__file__), "../../testdata/color.jpg")
+ data = {
+ "img": [Image.open(img_path).resize((32, 36)),
+ Image.open(img_path).resize((55, 16))],
+ "condition_img": [Image.open(img_path).resize((32, 36)),
+ Image.open(img_path).resize((55, 16))]}
+
+ # test transform
+ trans = TRANSFORMS.build(
+ dict(type="MultiAspectRatioResizeCenterCrop", sizes=self.sizes))
+ with pytest.raises(
+ AssertionError, match="MultiAspectRatioResizeCenterCrop only"):
+ _ = trans(data)
+
class TestCLIPImageProcessor(TestCase):
@@ -416,6 +693,18 @@ def test_transform(self):
assert type(data["clip_img"]) == torch.Tensor
assert data["clip_img"].size() == (3, 224, 224)
+ def test_transform_list(self):
+ img_path = osp.join(osp.dirname(__file__), "../../testdata/color.jpg")
+ data = {
+ "img": [Image.open(img_path), Image.open(img_path)],
+ }
+
+ # test transform
+ trans = TRANSFORMS.build(dict(type="CLIPImageProcessor"))
+ with pytest.raises(
+ AssertionError, match="CLIPImageProcessor only support"):
+ _ = trans(data)
+
class TestRandomTextDrop(TestCase):
@@ -458,6 +747,22 @@ def test_transform(self):
[float(d) for d in data["ori_img_shape"]])
assert data["aspect_ratio"] == img.height / img.width
+ def test_transform_list(self):
+ img_path = osp.join(osp.dirname(__file__), "../../testdata/color.jpg")
+ img = Image.open(img_path)
+ data = {
+ "img": [img, img],
+ "ori_img_shape": [[32, 32], [48, 48]],
+ "crop_top_left": [[0, 0], [10, 10]]}
+
+ # test transform
+ trans = TRANSFORMS.build(dict(type="ComputePixArtImgInfo"))
+ data = trans(data)
+ for i in range(len(data["img"])):
+ self.assertListEqual(data["resolution"][i],
+ [float(d) for d in data["ori_img_shape"][i]])
+ assert data["aspect_ratio"][i] == img.height / img.width
+
class TestT5TextPreprocess(TestCase):
@@ -494,6 +799,15 @@ def test_transform(self):
data = trans(data)
assert data["mask"].shape == (1, 32, 32)
+ def test_transform_list(self):
+ data = {"mask": [np.zeros((32, 32, 1))] * 2}
+
+ # test transform
+ trans = TRANSFORMS.build(dict(type="MaskToTensor"))
+ with pytest.raises(
+ AssertionError, match="MaskToTensor only support"):
+ _ = trans(data)
+
class TestGetMaskedImage(TestCase):
@@ -516,6 +830,20 @@ def test_transform(self):
assert torch.allclose(data["masked_image"][10:, 10:], img[10:, 10:])
assert data["masked_image"][:10, :10].sum() == 0
+ def test_transform_list(self):
+ img_path = osp.join(osp.dirname(__file__), "../../testdata/color.jpg")
+ img = torch.Tensor(np.array(Image.open(img_path)))
+ mask = np.zeros((img.shape[0], img.shape[1], 1))
+ mask[:10, :10] = 1
+ mask = torch.Tensor(mask)
+ data = {"img": [img, img], "mask": [mask, mask]}
+
+ # test transform
+ trans = TRANSFORMS.build(dict(type="GetMaskedImage"))
+ with pytest.raises(
+ AssertionError, match="GetMaskedImage only support"):
+ _ = trans(data)
+
class TestAddConstantCaption(TestCase):
@@ -532,3 +860,17 @@ def test_transform(self):
constant_caption="in szn style"))
data = trans(data)
assert data["text"] == "a dog. in szn style"
+
+
+class TestConcatMultipleImgs(TestCase):
+
+ def test_register(self):
+ assert "ConcatMultipleImgs" in TRANSFORMS
+
+ def test_transform_list(self):
+ data = {"img": [torch.zeros((3, 32, 32))] * 2}
+
+ # test transform
+ trans = TRANSFORMS.build(dict(type="ConcatMultipleImgs"))
+ data = trans(data)
+ assert data["img"].shape == (6, 32, 32) # type: ignore[attr-defined]
diff --git a/tests/test_models/test_editors/test_stable_diffusion_xl_dpo/test_stable_diffusion_xl_dpo.py b/tests/test_models/test_editors/test_stable_diffusion_xl_dpo/test_stable_diffusion_xl_dpo.py
new file mode 100644
index 0000000..0e6d696
--- /dev/null
+++ b/tests/test_models/test_editors/test_stable_diffusion_xl_dpo/test_stable_diffusion_xl_dpo.py
@@ -0,0 +1,340 @@
+from unittest import TestCase
+
+import pytest
+import torch
+from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
+from mmengine.optim import OptimWrapper
+from torch.optim import SGD
+from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection
+
+from diffengine.models.editors import SDXLDPODataPreprocessor, StableDiffusionXLDPO
+from diffengine.models.losses import DeBiasEstimationLoss, L2Loss, SNRL2Loss
+from diffengine.registry import MODELS
+
+
+class TestStableDiffusionXL(TestCase):
+
+ def _get_config(self) -> dict:
+ base_model = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+ return dict(type=StableDiffusionXLDPO,
+ model=base_model,
+ tokenizer_one=dict(type=AutoTokenizer.from_pretrained,
+ pretrained_model_name_or_path=base_model,
+ subfolder="tokenizer",
+ use_fast=False),
+ tokenizer_two=dict(type=AutoTokenizer.from_pretrained,
+ pretrained_model_name_or_path=base_model,
+ subfolder="tokenizer_2",
+ use_fast=False),
+ scheduler=dict(type=DDPMScheduler.from_pretrained,
+ pretrained_model_name_or_path=base_model,
+ subfolder="scheduler"),
+ text_encoder_one=dict(type=CLIPTextModel.from_pretrained,
+ pretrained_model_name_or_path=base_model,
+ subfolder="text_encoder"),
+ text_encoder_two=dict(type=CLIPTextModelWithProjection.from_pretrained,
+ pretrained_model_name_or_path=base_model,
+ subfolder="text_encoder_2"),
+ vae=dict(
+ type=AutoencoderKL.from_pretrained,
+ pretrained_model_name_or_path=base_model,
+ subfolder="vae"),
+ unet=dict(type=UNet2DConditionModel.from_pretrained,
+ pretrained_model_name_or_path=base_model,
+ subfolder="unet"),
+ data_preprocessor=dict(type=SDXLDPODataPreprocessor),
+ loss=dict(type=L2Loss, reduction="none"))
+
+ def test_init(self):
+ cfg = self._get_config()
+ cfg.update(text_encoder_lora_config=dict(type="dummy"))
+ with pytest.raises(
+ AssertionError, match="If you want to use LoRA"):
+ _ = MODELS.build(cfg)
+
+ cfg = self._get_config()
+ cfg.update(
+ unet_lora_config=dict(type="dummy"),
+ finetune_text_encoder=True,
+ )
+ with pytest.raises(
+ AssertionError, match="If you want to finetune text"):
+ _ = MODELS.build(cfg)
+
+ def test_infer(self):
+ cfg = self._get_config()
+ StableDiffuser = MODELS.build(cfg)
+
+ # test infer
+ result = StableDiffuser.infer(
+ ["an insect robot preparing a delicious meal"],
+ height=64,
+ width=64)
+ assert len(result) == 1
+ assert result[0].shape == (64, 64, 3)
+
+ # test device
+ assert StableDiffuser.device.type == "cpu"
+
+ # test infer with negative_prompt
+ result = StableDiffuser.infer(
+ ["an insect robot preparing a delicious meal"],
+ negative_prompt="noise",
+ height=64,
+ width=64)
+ assert len(result) == 1
+ assert result[0].shape == (64, 64, 3)
+
+ result = StableDiffuser.infer(
+ ["an insect robot preparing a delicious meal"],
+ output_type="latent",
+ height=64,
+ width=64)
+ assert len(result) == 1
+ assert type(result[0]) == torch.Tensor
+ assert result[0].shape == (4, 32, 32)
+
+ def test_infer_v_prediction(self):
+ cfg = self._get_config()
+ cfg.update(prediction_type="v_prediction")
+ StableDiffuser = MODELS.build(cfg)
+ assert StableDiffuser.prediction_type == "v_prediction"
+
+ # test infer
+ result = StableDiffuser.infer(
+ ["an insect robot preparing a delicious meal"],
+ height=64,
+ width=64)
+ assert len(result) == 1
+ assert result[0].shape == (64, 64, 3)
+
+ def test_infer_with_lora(self):
+ cfg = self._get_config()
+ cfg.update(
+ unet_lora_config=dict(
+ type="LoRA", r=4,
+ target_modules=["to_q", "to_v", "to_k", "to_out.0"]),
+ text_encoder_lora_config = dict(
+ type="LoRA", r=4,
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]),
+ )
+ StableDiffuser = MODELS.build(cfg)
+
+ # test infer
+ result = StableDiffuser.infer(
+ ["an insect robot preparing a delicious meal"],
+ height=64,
+ width=64)
+ assert len(result) == 1
+ assert result[0].shape == (64, 64, 3)
+
+ def test_infer_with_pre_compute_embs(self):
+ cfg = self._get_config()
+ cfg.update(pre_compute_text_embeddings=True)
+ StableDiffuser = MODELS.build(cfg)
+
+ assert not hasattr(StableDiffuser, "tokenizer_one")
+ assert not hasattr(StableDiffuser, "text_encoder_one")
+ assert not hasattr(StableDiffuser, "tokenizer_two")
+ assert not hasattr(StableDiffuser, "text_encoder_two")
+
+ # test infer
+ result = StableDiffuser.infer(
+ ["an insect robot preparing a delicious meal"],
+ height=64,
+ width=64)
+ assert len(result) == 1
+ assert result[0].shape == (64, 64, 3)
+
+ # test device
+ assert StableDiffuser.device.type == "cpu"
+
+ def test_train_step(self):
+ # test load with loss module
+ cfg = self._get_config()
+ StableDiffuser = MODELS.build(cfg)
+
+ # test train step
+ data = dict(
+ inputs=dict(
+ img=[torch.zeros((6, 64, 64))],
+ text=["a dog"],
+ time_ids=[torch.zeros((2, 6))]))
+ optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
+ optim_wrapper = OptimWrapper(optimizer)
+ log_vars = StableDiffuser.train_step(data, optim_wrapper)
+ assert log_vars
+ assert isinstance(log_vars["loss"], torch.Tensor)
+
+ def test_train_step_with_lora(self):
+ # test load with loss module
+ cfg = self._get_config()
+ cfg.update(
+ unet_lora_config=dict(
+ type="LoRA", r=4,
+ target_modules=["to_q", "to_v", "to_k", "to_out.0"]),
+ text_encoder_lora_config = dict(
+ type="LoRA", r=4,
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]),
+ )
+ StableDiffuser = MODELS.build(cfg)
+
+ # test train step
+ data = dict(
+ inputs=dict(
+ img=[torch.zeros((6, 64, 64))],
+ text=["a dog"],
+ time_ids=[torch.zeros((2, 6))]))
+ optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
+ optim_wrapper = OptimWrapper(optimizer)
+ log_vars = StableDiffuser.train_step(data, optim_wrapper)
+ assert log_vars
+ assert isinstance(log_vars["loss"], torch.Tensor)
+
+ def test_train_step_input_perturbation(self):
+ # test load with loss module
+ cfg = self._get_config()
+ cfg.update(input_perturbation_gamma=0.1)
+ StableDiffuser = MODELS.build(cfg)
+
+ # test train step
+ data = dict(
+ inputs=dict(
+ img=[torch.zeros((6, 64, 64))],
+ text=["a dog"],
+ time_ids=[torch.zeros((2, 6))]))
+ optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
+ optim_wrapper = OptimWrapper(optimizer)
+ log_vars = StableDiffuser.train_step(data, optim_wrapper)
+ assert log_vars
+ assert isinstance(log_vars["loss"], torch.Tensor)
+
+ def test_train_step_with_gradient_checkpointing(self):
+ # test load with loss module
+ cfg = self._get_config()
+ cfg.update(gradient_checkpointing=True)
+ StableDiffuser = MODELS.build(cfg)
+
+ # test train step
+ data = dict(
+ inputs=dict(
+ img=[torch.zeros((6, 64, 64))],
+ text=["a dog"],
+ time_ids=[torch.zeros((2, 6))]))
+ optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
+ optim_wrapper = OptimWrapper(optimizer)
+ log_vars = StableDiffuser.train_step(data, optim_wrapper)
+ assert log_vars
+ assert isinstance(log_vars["loss"], torch.Tensor)
+
+ def test_train_step_with_pre_compute_embs(self):
+ # test load with loss module
+ cfg = self._get_config()
+ cfg.update(pre_compute_text_embeddings=True)
+ StableDiffuser = MODELS.build(cfg)
+
+ assert not hasattr(StableDiffuser, "tokenizer_one")
+ assert not hasattr(StableDiffuser, "text_encoder_one")
+ assert not hasattr(StableDiffuser, "tokenizer_two")
+ assert not hasattr(StableDiffuser, "text_encoder_two")
+
+ # test train step
+ data = dict(
+ inputs=dict(
+ img=[torch.zeros((6, 64, 64))],
+ prompt_embeds=[torch.zeros((77, 64))],
+ pooled_prompt_embeds=[torch.zeros(32)],
+ time_ids=[torch.zeros((2, 6))]))
+ optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
+ optim_wrapper = OptimWrapper(optimizer)
+ log_vars = StableDiffuser.train_step(data, optim_wrapper)
+ assert log_vars
+ assert isinstance(log_vars["loss"], torch.Tensor)
+
+ def test_train_step_dreambooth(self):
+ # test load with loss module
+ cfg = self._get_config()
+ StableDiffuser = MODELS.build(cfg)
+
+ # test train step
+ data = dict(
+ inputs=dict(
+ img=[torch.zeros((6, 64, 64))],
+ text=["a sks dog"],
+ time_ids=[torch.zeros((2, 6))]))
+ data["inputs"]["result_class_image"] = dict(
+ img=[torch.zeros((6, 64, 64))],
+ text=["a dog"],
+ time_ids=[torch.zeros((2, 6))]) # type: ignore[assignment]
+ optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
+ optim_wrapper = OptimWrapper(optimizer)
+ with pytest.raises(
+ AssertionError, match="result_class_image is not supported"):
+ _ = StableDiffuser.train_step(data, optim_wrapper)
+
+ def test_train_step_v_prediction(self):
+ # test load with loss module
+ cfg = self._get_config()
+ cfg.update(prediction_type="v_prediction")
+ StableDiffuser = MODELS.build(cfg)
+
+ # test train step
+ data = dict(
+ inputs=dict(
+ img=[torch.zeros((6, 64, 64))],
+ text=["a dog"],
+ time_ids=[torch.zeros((2, 6))]))
+ optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
+ optim_wrapper = OptimWrapper(optimizer)
+ log_vars = StableDiffuser.train_step(data, optim_wrapper)
+ assert log_vars
+ assert isinstance(log_vars["loss"], torch.Tensor)
+
+ def test_train_step_snr_loss(self):
+ # test load with loss module
+ cfg = self._get_config()
+ cfg.update(loss=dict(type=SNRL2Loss, reduction="none"))
+ StableDiffuser = MODELS.build(cfg)
+
+ # test train step
+ data = dict(
+ inputs=dict(
+ img=[torch.zeros((6, 64, 64))],
+ text=["a dog"],
+ time_ids=[torch.zeros((2, 6))]))
+ optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
+ optim_wrapper = OptimWrapper(optimizer)
+ log_vars = StableDiffuser.train_step(data, optim_wrapper)
+ assert log_vars
+ assert isinstance(log_vars["loss"], torch.Tensor)
+
+ def test_train_step_debias_estimation_loss(self):
+ # test load with loss module
+ cfg = self._get_config()
+ cfg.update(loss=dict(type=DeBiasEstimationLoss, reduction="none"))
+ StableDiffuser = MODELS.build(cfg)
+
+ # test train step
+ data = dict(
+ inputs=dict(
+ img=[torch.zeros((6, 64, 64))],
+ text=["a dog"],
+ time_ids=[torch.zeros((2, 6))]))
+ optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
+ optim_wrapper = OptimWrapper(optimizer)
+ log_vars = StableDiffuser.train_step(data, optim_wrapper)
+ assert log_vars
+ assert isinstance(log_vars["loss"], torch.Tensor)
+
+ def test_val_and_test_step(self):
+ cfg = self._get_config()
+ cfg.update(prediction_type="v_prediction")
+ StableDiffuser = MODELS.build(cfg)
+
+ # test val_step
+ with pytest.raises(NotImplementedError, match="val_step is not"):
+ StableDiffuser.val_step(torch.zeros((1, )))
+
+ # test test_step
+ with pytest.raises(NotImplementedError, match="test_step is not"):
+ StableDiffuser.test_step(torch.zeros((1, )))
diff --git a/tests/test_models/test_losses.py b/tests/test_models/test_losses.py
index a3bbd30..1aefaa2 100644
--- a/tests/test_models/test_losses.py
+++ b/tests/test_models/test_losses.py
@@ -1,3 +1,4 @@
+import pytest
import torch
from diffusers import DDPMScheduler
@@ -5,6 +6,10 @@
def test_l2_loss():
+ with pytest.raises(
+ AssertionError, match="reduction should be 'mean' or 'none'"):
+ _ = L2Loss(reduction="dummy")
+
# test asymmetric_loss
pred = torch.Tensor([[5, -5, 0], [5, -5, 0]])
gt = torch.Tensor([[1, 0, 1], [0, 1, 0]])
@@ -14,8 +19,15 @@ def test_l2_loss():
assert torch.allclose(loss(pred, gt), torch.tensor(17.1667))
assert torch.allclose(loss(pred, gt, weight=weight), torch.tensor(8.0167))
+ loss = L2Loss(reduction="none")
+ assert loss(pred, gt).shape == (2, 3)
+
def test_snr_l2_loss():
+ with pytest.raises(
+ AssertionError, match="reduction should be 'mean' or 'none'"):
+ _ = SNRL2Loss(reduction="dummy")
+
# test asymmetric_loss
pred = torch.Tensor([[5, -5, 0], [5, -5, 0]])
gt = torch.Tensor([[1, 0, 1], [0, 1, 0]])
@@ -53,8 +65,16 @@ def test_snr_l2_loss():
rtol=1e-04,
atol=1e-04)
+ loss = SNRL2Loss(reduction="none")
+ assert loss(pred, gt, timesteps.long(), scheduler.alphas_cumprod,
+ scheduler.config.prediction_type).shape == (2,)
+
def test_debias_estimation_loss():
+ with pytest.raises(
+ AssertionError, match="reduction should be 'mean' or 'none'"):
+ _ = DeBiasEstimationLoss(reduction="dummy")
+
# test asymmetric_loss
pred = torch.Tensor([[5, -5, 0], [5, -5, 0]])
gt = torch.Tensor([[1, 0, 1], [0, 1, 0]])
@@ -92,8 +112,16 @@ def test_debias_estimation_loss():
rtol=1e-04,
atol=1e-04)
+ loss = DeBiasEstimationLoss(reduction="none")
+ assert loss(pred, gt, timesteps.long(), scheduler.alphas_cumprod,
+ scheduler.config.prediction_type).shape == (2,)
+
def test_huber_loss():
+ with pytest.raises(
+ AssertionError, match="reduction should be 'mean' or 'none'"):
+ _ = HuberLoss(reduction="dummy")
+
# test asymmetric_loss
pred = torch.Tensor([[5, -5, 0], [5, -5, 0]])
gt = torch.Tensor([[1, 0, 1], [0, 1, 0]])
@@ -106,3 +134,6 @@ def test_huber_loss():
assert torch.allclose(loss(pred, gt, weight=weight), torch.tensor(1.5833),
rtol=1e-04,
atol=1e-04)
+
+ loss = HuberLoss(reduction="none")
+ assert loss(pred, gt).shape == (2, 3)