Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Diffusion DPO #124

Merged
merged 1 commit into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 40 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,48 +189,49 @@ For detailed user guides and advanced guides, please refer to our [Documentation
<tr valign="top">
<td>
<ul>
<li><a href="configs/stable_diffusion/README.md">Stable Diffusion (2022)</a></li>
<li><a href="configs/stable_diffusion_controlnet/README.md">ControlNet (ICCV'2023)</a></li>
<li><a href="configs/stable_diffusion_dreambooth/README.md">DreamBooth (CVPR'2023)</a></li>
<li><a href="configs/stable_diffusion_lora/README.md">LoRA (ICLR'2022)</a></li>
<li><a href="configs/distill_sd_dreambooth/README.md">Distill SD DreamBooth (2023)</a></li>
<li><a href="configs/stable_diffusion_inpaint/README.md">Inpaint</a></li>
<li><a href="diffengine/configs/stable_diffusion/README.md">Stable Diffusion (2022)</a></li>
<li><a href="diffengine/configs/stable_diffusion_controlnet/README.md">ControlNet (ICCV'2023)</a></li>
<li><a href="diffengine/configs/stable_diffusion_dreambooth/README.md">DreamBooth (CVPR'2023)</a></li>
<li><a href="diffengine/configs/stable_diffusion_lora/README.md">LoRA (ICLR'2022)</a></li>
<li><a href="diffengine/configs/distill_sd_dreambooth/README.md">Distill SD DreamBooth (2023)</a></li>
<li><a href="diffengine/configs/stable_diffusion_inpaint/README.md">Inpaint</a></li>
</ul>
</td>
<td>
<ul>
<li><a href="configs/stable_diffusion_xl/README.md">Stable Diffusion XL (2023)</a></li>
<li><a href="configs/stable_diffusion_xl_controlnet/README.md">ControlNet (ICCV'2023)</a></li>
<li><a href="configs/stable_diffusion_xl_dreambooth/README.md">DreamBooth (CVPR'2023)</a></li>
<li><a href="configs/stable_diffusion_xl_lora/README.md">LoRA (ICLR'2022)</a></li>
<li><a href="configs/stable_diffusion_xl_controlnet_small/README.md">ControlNet Small (2023)</a></li>
<li><a href="configs/t2i_adapter/README.md">T2I-Adapter (2023)</a></li>
<li><a href="configs/ip_adapter/README.md">IP-Adapter (2023)</a></li>
<li><a href="configs/esd/README.md">Erasing Concepts from Diffusion Models (2023)</a></li>
<li><a href="configs/ssd_1b/README.md">SSD-1B (2023)</a></li>
<li><a href="configs/instruct_pix2pix/README.md">InstructPix2Pix (2022)</a></li>
<li><a href="configs/loha/README.md">LoHa (ICLR'2022)</a></li>
<li><a href="configs/lokr/README.md">LoKr (2022)</a></li>
<li><a href="configs/oft/README.md">OFT (NeurIPS'2023)</a></li>
<li><a href="configs/stable_diffusion_xl_controlnetxs/README.md">ControlNet-XS (2023)</a></li>
<li><a href="configs/stable_diffusion_xl_inpaint/README.md">Inpaint</a></li>
<li><a href="diffengine/configs/stable_diffusion_xl/README.md">Stable Diffusion XL (2023)</a></li>
<li><a href="diffengine/configs/stable_diffusion_xl_controlnet/README.md">ControlNet (ICCV'2023)</a></li>
<li><a href="diffengine/configs/stable_diffusion_xl_dreambooth/README.md">DreamBooth (CVPR'2023)</a></li>
<li><a href="diffengine/configs/stable_diffusion_xl_lora/README.md">LoRA (ICLR'2022)</a></li>
<li><a href="diffengine/configs/stable_diffusion_xl_controlnet_small/README.md">ControlNet Small (2023)</a></li>
<li><a href="diffengine/configs/t2i_adapter/README.md">T2I-Adapter (2023)</a></li>
<li><a href="diffengine/configs/ip_adapter/README.md">IP-Adapter (2023)</a></li>
<li><a href="diffengine/configs/esd/README.md">Erasing Concepts from Diffusion Models (2023)</a></li>
<li><a href="diffengine/configs/ssd_1b/README.md">SSD-1B (2023)</a></li>
<li><a href="diffengine/configs/instruct_pix2pix/README.md">InstructPix2Pix (2022)</a></li>
<li><a href="diffengine/configs/loha/README.md">LoHa (ICLR'2022)</a></li>
<li><a href="diffengine/configs/lokr/README.md">LoKr (2022)</a></li>
<li><a href="diffengine/configs/oft/README.md">OFT (NeurIPS'2023)</a></li>
<li><a href="diffengine/configs/stable_diffusion_xl_controlnetxs/README.md">ControlNet-XS (2023)</a></li>
<li><a href="diffengine/configs/stable_diffusion_xl_inpaint/README.md">Inpaint</a></li>
</ul>
</td>
<td>
<ul>
<li><a href="configs/deepfloyd_if/README.md">DeepFloyd IF (2023)</a></li>
<li><a href="configs/deepfloyd_if_dreambooth/README.md">DreamBooth (CVPR'2023)</a></li>
<li><a href="diffengine/configs/deepfloyd_if/README.md">DeepFloyd IF (2023)</a></li>
<li><a href="diffengine/configs/deepfloyd_if_dreambooth/README.md">DreamBooth (CVPR'2023)</a></li>
</ul>
</td>
<td>
<ul>
<li><a href="configs/min_snr_loss/README.md">Min-SNR Loss (ICCV'2023)</a></li>
<li><a href="configs/debias_estimation_loss/README.md">DeBias Estimation Loss (2023)</a></li>
<li><a href="configs/offset_noise/README.md">Offset Noise (2023)</a></li>
<li><a href="configs/pyramid_noise/README.md">Pyramid Noise (2023)</a></li>
<li><a href="configs/input_perturbation/README.md">Input Perturbation (2023)</a></li>
<li><a href="configs/timesteps_bias/README.md">Time Steps Bias (2023)</a></li>
<li><a href="configs/v_prediction/README.md">V Prediction (ICLR'2022)</a></li>
<li><a href="diffengine/configs/min_snr_loss/README.md">Min-SNR Loss (ICCV'2023)</a></li>
<li><a href="diffengine/configs/debias_estimation_loss/README.md">DeBias Estimation Loss (2023)</a></li>
<li><a href="diffengine/configs/offset_noise/README.md">Offset Noise (2023)</a></li>
<li><a href="diffengine/configs/pyramid_noise/README.md">Pyramid Noise (2023)</a></li>
<li><a href="diffengine/configs/input_perturbation/README.md">Input Perturbation (2023)</a></li>
<li><a href="diffengine/configs/timesteps_bias/README.md">Time Steps Bias (2023)</a></li>
<li><a href="diffengine/configs/v_prediction/README.md">V Prediction (ICLR'2022)</a></li>
<li><a href="diffengine/configs/diffusion_dpo/README.md">Diffusion DPO (2023)</a></li>
</ul>
</td>
</tr>
Expand All @@ -255,27 +256,27 @@ For detailed user guides and advanced guides, please refer to our [Documentation
<tr valign="top">
<td>
<ul>
<li><a href="configs/wuerstchen/README.md">Wuerstchen (2023)</a></li>
<li><a href="configs/wuerstchen_lora/README.md">LoRA (ICLR'2022)</a></li>
<li><a href="diffengine/configs/wuerstchen/README.md">Wuerstchen (2023)</a></li>
<li><a href="diffengine/configs/wuerstchen_lora/README.md">LoRA (ICLR'2022)</a></li>
</ul>
</td>
<td>
<ul>
<li><a href="configs/lcm/README.md">Latent Consistency Models (2023)</a></li>
<li><a href="configs/lcm_lora/README.md">LoRA (ICLR'2022)</a></li>
<li><a href="diffengine/configs/lcm/README.md">Latent Consistency Models (2023)</a></li>
<li><a href="diffengine/configs/lcm_lora/README.md">LoRA (ICLR'2022)</a></li>
</ul>
</td>
<td>
<ul>
<li><a href="configs/pixart_alpha/README.md">PixArt-α (2023)</a></li>
<li><a href="configs/pixart_alpha_lora/README.md">LoRA (ICLR'2022)</a></li>
<li><a href="configs/pixart_alpha_dreambooth/README.md">DreamBooth (CVPR'2023)</a></li>
<li><a href="diffengine/configs/pixart_alpha/README.md">PixArt-α (2023)</a></li>
<li><a href="diffengine/configs/pixart_alpha_lora/README.md">LoRA (ICLR'2022)</a></li>
<li><a href="diffengine/configs/pixart_alpha_dreambooth/README.md">DreamBooth (CVPR'2023)</a></li>
</ul>
</td>
<td>
<ul>
<li><a href="configs/kandinsky_v22/README.md">Kandinsky 2.2 (2023)</a></li>
<li><a href="configs/kandinsky_v3/README.md">Kandinsky 3 (2023)</a></li>
<li><a href="diffengine/configs/kandinsky_v22/README.md">Kandinsky 2.2 (2023)</a></li>
<li><a href="diffengine/configs/kandinsky_v3/README.md">Kandinsky 3 (2023)</a></li>
</ul>
</td>
</tr>
Expand Down
61 changes: 61 additions & 0 deletions diffengine/configs/_base_/datasets/pickascore_xl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torchvision
from mmengine.dataset import DefaultSampler

from diffengine.datasets import HFDPODataset
from diffengine.datasets.transforms import (
ComputeTimeIds,
ConcatMultipleImgs,
PackInputs,
RandomCrop,
RandomHorizontalFlip,
SaveImageShape,
TorchVisonTransformWrapper,
)
from diffengine.engine.hooks import SDCheckpointHook, VisualizationHook

train_pipeline = [
dict(type=SaveImageShape),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Resize,
size=1024, interpolation="bilinear"),
dict(type=RandomCrop, size=1024),
dict(type=RandomHorizontalFlip, p=0.5),
dict(type=ComputeTimeIds),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.ToTensor),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Normalize, mean=[0.5], std=[0.5]),
dict(type=ConcatMultipleImgs),
dict(type=PackInputs, input_keys=["img", "text", "time_ids"]),
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
dataset=dict(
type=HFDPODataset,
dataset="kashif/pickascore",
split="validation",
image_columns=["jpg_0", "jpg_1"],
caption_column="caption",
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(
type=VisualizationHook,
prompt=[
"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography", # noqa
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", # noqa
],
height=1024,
width=1024),
dict(type=SDCheckpointHook),
]
26 changes: 26 additions & 0 deletions diffengine/configs/_base_/models/stable_diffusion_xl_dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection

from diffengine.models.editors import StableDiffusionXLDPO

base_model = "stabilityai/stable-diffusion-xl-base-1.0"
model = dict(type=StableDiffusionXLDPO,
model=base_model,
tokenizer_one=dict(type=AutoTokenizer.from_pretrained,
subfolder="tokenizer",
use_fast=False),
tokenizer_two=dict(type=AutoTokenizer.from_pretrained,
subfolder="tokenizer_2",
use_fast=False),
scheduler=dict(type=DDPMScheduler.from_pretrained,
subfolder="scheduler"),
text_encoder_one=dict(type=CLIPTextModel.from_pretrained,
subfolder="text_encoder"),
text_encoder_two=dict(type=CLIPTextModelWithProjection.from_pretrained,
subfolder="text_encoder_2"),
vae=dict(
type=AutoencoderKL.from_pretrained,
pretrained_model_name_or_path="madebyollin/sdxl-vae-fp16-fix"),
unet=dict(type=UNet2DConditionModel.from_pretrained,
subfolder="unet"),
gradient_checkpointing=True)
76 changes: 76 additions & 0 deletions diffengine/configs/diffusion_dpo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Diffusion DPO

[Diffusion Model Alignment Using Direct Preference Optimization](https://arxiv.org/abs/2311.12908)

## Abstract

Large language models (LLMs) are fine-tuned using human comparison data with Reinforcement Learning from Human Feedback (RLHF) methods to make them better aligned with users' preferences. In contrast to LLMs, human preference learning has not been widely explored in text-to-image diffusion models; the best existing approach is to fine-tune a pretrained model using carefully curated high quality images and captions to improve visual appeal and text alignment. We propose Diffusion-DPO, a method to align diffusion models to human preferences by directly optimizing on human comparison data. Diffusion-DPO is adapted from the recently developed Direct Preference Optimization (DPO), a simpler alternative to RLHF which directly optimizes a policy that best satisfies human preferences under a classification objective. We re-formulate DPO to account for a diffusion model notion of likelihood, utilizing the evidence lower bound to derive a differentiable objective. Using the Pick-a-Pic dataset of 851K crowdsourced pairwise preferences, we fine-tune the base model of the state-of-the-art Stable Diffusion XL (SDXL)-1.0 model with Diffusion-DPO. Our fine-tuned base model significantly outperforms both base SDXL-1.0 and the larger SDXL-1.0 model consisting of an additional refinement model in human evaluation, improving visual appeal and prompt alignment. We also develop a variant that uses AI feedback and has comparable performance to training on human preferences, opening the door for scaling of diffusion model alignment methods.

<div align=center>
<img src="https://github.com/okotaku/diffengine/assets/24734142/9e7ab78e-4c61-4490-aac2-fa063554b2b8"/>
</div>

## Citation

```
```

## Run Training

Run Training

```
# single gpu
$ diffengine train ${CONFIG_FILE}
# multi gpus
$ NPROC_PER_NODE=${GPU_NUM} diffengine train ${CONFIG_FILE}

# Example.
$ diffengine train stable_diffusion_xl_dpo_pickascore
```

## Inference with diffusers

Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module.

Before inferencing, we should convert weights for diffusers format,

```bash
$ diffengine convert ${CONFIG_FILE} ${INPUT_FILENAME} ${OUTPUT_DIR} --save-keys ${SAVE_KEYS}
# Example
$ diffengine convert stable_diffusion_xl_dpo_pickascore work_dirs/stable_diffusion_xl_dpo_pickascore/epoch_50.pth work_dirs/stable_diffusion_xl_dpo_pickascore --save-keys unet
```

Then we can run inference.

```py
import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel, AutoencoderKL

prompt = 'Astronaut in a jungle, cold color palette, muted colors, detailed, 8k'
checkpoint = 'work_dirs/stable_diffusion_xl_dpo_pickascore'

unet = UNet2DConditionModel.from_pretrained(
checkpoint, subfolder='unet', torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained(
'madebyollin/sdxl-vae-fp16-fix',
torch_dtype=torch.float16,
)
pipe = DiffusionPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0', unet=unet, vae=vae, torch_dtype=torch.float16)
pipe.to('cuda')

image = pipe(
prompt,
num_inference_steps=50,
width=1024,
height=1024,
).images[0]
image.save('demo.png')
```

## Results Example

#### stable_diffusion_xl_dpo_pickascore

![example1](https://github.com/okotaku/diffengine/assets/24734142/efa32784-3151-4cb2-9af2-368a8c4b527b)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from mmengine.config import read_base

with read_base():
from .._base_.datasets.pickascore_xl import *
from .._base_.default_runtime import *
from .._base_.models.stable_diffusion_xl_dpo import *
from .._base_.schedules.stable_diffusion_xl_50e import *


train_dataloader.update(batch_size=1)

optim_wrapper.update(accumulative_counts=4) # update every four times
2 changes: 2 additions & 0 deletions diffengine/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .hf_controlnet_datasets import HFControlNetDataset
from .hf_datasets import HFDataset, HFDatasetPreComputeEmbs
from .hf_dpo_datasets import HFDPODataset
from .hf_dreambooth_datasets import HFDreamBoothDataset
from .hf_esd_datasets import HFESDDatasetPreComputeEmbs
from .samplers import * # noqa: F403
Expand All @@ -11,4 +12,5 @@
"HFControlNetDataset",
"HFDatasetPreComputeEmbs",
"HFESDDatasetPreComputeEmbs",
"HFDPODataset",
]
Loading