Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support T2I Adapter #64

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ For detailed user guides and advanced guides, please refer to our [Documentation
- [Run Stable Diffusion ControlNet](https://diffengine.readthedocs.io/en/latest/run_guides/run_controlnet.html)
- [Run Stable Diffusion XL ControlNet](https://diffengine.readthedocs.io/en/latest/run_guides/run_controlnet_xl.html)
- [Run IP Adapter](https://diffengine.readthedocs.io/en/latest/run_guides/run_ip_adapter.html)
- [Run T2I Adapter](https://diffengine.readthedocs.io/en/latest/run_guides/run_t2i_adapter.html)
- [Inference](https://diffengine.readthedocs.io/en/latest/run_guides/inference.html)

</details>
Expand Down
45 changes: 45 additions & 0 deletions configs/_base_/datasets/fill50k_t2i_adapter_xl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
train_pipeline = [
dict(type='SaveImageShape'),
dict(
type='torchvision/Resize',
size=1024,
interpolation='bilinear',
keys=['img', 'condition_img']),
dict(type='RandomCrop', size=1024, keys=['img', 'condition_img']),
dict(type='RandomHorizontalFlip', p=0.5, keys=['img', 'condition_img']),
dict(type='ComputeTimeIds'),
dict(type='torchvision/ToTensor', keys=['img', 'condition_img']),
dict(type='DumpImage', max_imgs=10, dump_dir='work_dirs/dump'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
dict(
type='PackInputs',
input_keys=['img', 'condition_img', 'text', 'time_ids']),
]
train_dataloader = dict(
batch_size=2,
num_workers=4,
dataset=dict(
type='HFControlNetDataset',
dataset='fusing/fill50k',
condition_column='conditioning_image',
caption_column='text',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(
type='VisualizationHook',
prompt=['cyan circle with brown floral background'] * 4,
condition_image=[
'https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg' # noqa
] * 4,
height=1024,
width=1024),
dict(type='T2IAdapterSaveHook')
]
5 changes: 5 additions & 0 deletions configs/_base_/models/stable_diffusion_xl_t2i_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
model = dict(
type='StableDiffusionXLT2IAdapter',
model='stabilityai/stable-diffusion-xl-base-1.0',
vae_model='madebyollin/sdxl-vae-fp16-fix',
gradient_checkpointing=True)
84 changes: 84 additions & 0 deletions configs/t2i_adapter/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# T2I-Adapter

[T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.08453)

## Abstract

The incredible generative ability of large-scale text-to-image (T2I) models has demonstrated strong power of learning complex structures and meaningful semantics. However, relying solely on text prompts cannot fully take advantage of the knowledge learned by the model, especially when flexible and accurate controlling (e.g., color and structure) is needed. In this paper, we aim to \`\`dig out" the capabilities that T2I models have implicitly learned, and then explicitly use them to control the generation more granularly. Specifically, we propose to learn simple and lightweight T2I-Adapters to align internal knowledge in T2I models with external control signals, while freezing the original large T2I models. In this way, we can train various adapters according to different conditions, achieving rich control and editing effects in the color and structure of the generation results. Further, the proposed T2I-Adapters have attractive properties of practical value, such as composability and generalization ability. Extensive experiments demonstrate that our T2I-Adapter has promising generation quality and a wide range of applications.

<div align=center>
<img src="https://github.com/okotaku/diffengine/assets/24734142/d3de5325-34e6-44d8-ba9a-47955afcca47"/>
</div>

## Citation

```
@article{mou2023t2i,
title={T2i-adapter: Learning adapters to dig out more controllable ability for text-to-image diffusion models},
author={Mou, Chong and Wang, Xintao and Xie, Liangbin and Wu, Yanze and Zhang, Jian and Qi, Zhongang and Shan, Ying and Qie, Xiaohu},
journal={arXiv preprint arXiv:2302.08453},
year={2023}
}
```

## Run Training

Run Training

```
# single gpu
$ mim train diffengine ${CONFIG_FILE}
# multi gpus
$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch

# Example.
$ mim train diffengine configs/t2i_adapter/stable_diffusion_xl_t2i_adapter_fill50k.py
```

## Inference with diffusers

Once you have trained a model, specify the path to where the model is saved, and use it for inference with the `diffusers`.

```py
import torch
from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, AutoencoderKL
from diffusers.utils import load_image

checkpoint = 'work_dirs/stable_diffusion_xl_t2i_adapter_fill50k/step75000'
prompt = 'cyan circle with brown floral background'
condition_image = load_image(
'https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg'
).resize((1024, 1024))

adapter = T2IAdapter.from_pretrained(
checkpoint, subfolder='adapter', torch_dtype=torch.float16)

vae = AutoencoderKL.from_pretrained(
'madebyollin/sdxl-vae-fp16-fix',
torch_dtype=torch.float16,
)
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0', adapter=adapter, vae=vae, torch_dtype=torch.float16)
pipe.to('cuda')

image = pipe(
prompt,
image=condition_image,
num_inference_steps=50,
).images[0]
image.save('demo.png')
```

You can see more details on [`docs/source/run_guides/run_t2i_adapter.md`](../../docs/source/run_guides/run_t2i_adapter.md#inference-with-diffusers).

## Results Example

#### stable_diffusion_xl_t2i_adapter_fill50k

![input1](https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg)

![example1](https://github.com/okotaku/diffengine/assets/24734142/7ea65b62-a8c4-4888-8e11-9cdb69855d3c)

## Acknowledgement

These experiments are based on [diffusers docs](https://huggingface.co/docs/diffusers/main/en/training/t2i_adapters). Thank you for the great articles.
11 changes: 11 additions & 0 deletions configs/t2i_adapter/stable_diffusion_xl_t2i_adapter_fill50k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = [
'../_base_/models/stable_diffusion_xl_t2i_adapter.py',
'../_base_/datasets/fill50k_t2i_adapter_xl.py',
'../_base_/schedules/stable_diffusion_3e.py',
'../_base_/default_runtime.py'
]

optim_wrapper = dict(
optimizer=dict(lr=1e-5),
accumulative_counts=2,
)
3 changes: 2 additions & 1 deletion diffengine/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from .ip_adapter_save_hook import IPAdapterSaveHook
from .lora_save_hook import LoRASaveHook
from .sd_checkpoint_hook import SDCheckpointHook
from .t2i_adapter_save_hook import T2IAdapterSaveHook
from .unet_ema_hook import UnetEMAHook
from .visualization_hook import VisualizationHook

__all__ = [
'VisualizationHook', 'UnetEMAHook', 'SDCheckpointHook', 'LoRASaveHook',
'ControlNetSaveHook', 'IPAdapterSaveHook'
'ControlNetSaveHook', 'IPAdapterSaveHook', 'T2IAdapterSaveHook'
]
35 changes: 35 additions & 0 deletions diffengine/engine/hooks/t2i_adapter_save_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os.path as osp
from collections import OrderedDict

from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS


@HOOKS.register_module()
class T2IAdapterSaveHook(Hook):
"""Save T2I-Adapter weights with diffusers format and pick up weights from
checkpoint."""
priority = 'VERY_LOW'

def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
"""
Args:
runner (Runner): The runner of the training, validation or testing
process.
checkpoint (dict): Model's checkpoint.
"""
model = runner.model
if is_model_wrapper(model):
model = model.module
ckpt_path = osp.join(runner.work_dir, f'step{runner.iter}')

model.adapter.save_pretrained(osp.join(ckpt_path, 'adapter'))

# not save no grad key
new_ckpt = OrderedDict()
sd_keys = checkpoint['state_dict'].keys()
for k in sd_keys:
if k.startswith('adapter'):
new_ckpt[k] = checkpoint['state_dict'][k]
checkpoint['state_dict'] = new_ckpt
1 change: 1 addition & 0 deletions diffengine/models/editors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .stable_diffusion_controlnet import * # noqa: F401, F403
from .stable_diffusion_xl import * # noqa: F401, F403
from .stable_diffusion_xl_controlnet import * # noqa: F401, F403
from .t2i_adapter import * # noqa: F401, F403
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ class StableDiffusionControlNet(StableDiffusion):
"""Stable Diffusion ControlNet.

Args:
controlnet_model (str, optional): Path to pretrained VAE model with
better numerical stability. More details:
https://github.com/huggingface/diffusers/pull/4038.
controlnet_model (str, optional): Path to pretrained ControlNet model.
If None, use the default ControlNet model from Unet.
Defaults to None.
transformer_layers_per_block (List[int], optional):
The number of layers per block in the transformer. More details:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ class StableDiffusionXLControlNet(StableDiffusionXL):
"""Stable Diffusion XL ControlNet.

Args:
controlnet_model (str, optional): Path to pretrained VAE model with
better numerical stability. More details:
https://github.com/huggingface/diffusers/pull/4038.
controlnet_model (str, optional): Path to pretrained ControlNet model.
If None, use the default ControlNet model from Unet.
Defaults to None.
transformer_layers_per_block (List[int], optional):
The number of layers per block in the transformer. More details:
Expand Down
3 changes: 3 additions & 0 deletions diffengine/models/editors/t2i_adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .stable_diffusion_xl_t2i_adapter import StableDiffusionXLT2IAdapter

__all__ = ['StableDiffusionXLT2IAdapter']
Loading
Loading