Skip to content

Commit

Permalink
Merge pull request #133 from okotaku/feat/ip_adapter_dino
Browse files Browse the repository at this point in the history
[Feat] IP Adapter DINO
  • Loading branch information
okotaku authored Feb 11, 2024
2 parents 1f9006f + c44d950 commit 665b109
Show file tree
Hide file tree
Showing 14 changed files with 317 additions and 251 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torchvision
from mmengine.dataset import DefaultSampler

from diffengine.datasets import HFDataset
from diffengine.datasets.transforms import (
ComputeTimeIds,
PackInputs,
RandomCrop,
RandomHorizontalFlip,
RandomTextDrop,
SaveImageShape,
TorchVisonTransformWrapper,
TransformersImageProcessor,
)
from diffengine.engine.hooks import IPAdapterSaveHook, VisualizationHook

train_pipeline = [
dict(type=SaveImageShape),
dict(type=TransformersImageProcessor,
pretrained="facebook/dinov2-base"),
dict(type=RandomTextDrop),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Resize,
size=1024, interpolation="bilinear"),
dict(type=RandomCrop, size=1024),
dict(type=RandomHorizontalFlip, p=0.5),
dict(type=ComputeTimeIds),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.ToTensor),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Normalize, mean=[0.5], std=[0.5]),
dict(
type=PackInputs, input_keys=["img", "text", "time_ids", "clip_img"]),
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
dataset=dict(
type=HFDataset,
dataset="lambdalabs/pokemon-blip-captions",
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(
type=VisualizationHook,
prompt=["a drawing of a green pokemon with red eyes"] * 2 + [""] * 2,
example_image=[
'https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true' # noqa
] * 4,
height=1024,
width=1024),
dict(type=IPAdapterSaveHook),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torchvision
from mmengine.dataset import DefaultSampler

from diffengine.datasets import HFDataset
from diffengine.datasets.transforms import (
ComputeTimeIds,
PackInputs,
RandomCrop,
RandomHorizontalFlip,
RandomTextDrop,
SaveImageShape,
TorchVisonTransformWrapper,
TransformersImageProcessor,
)
from diffengine.engine.hooks import IPAdapterSaveHook, VisualizationHook

train_pipeline = [
dict(type=SaveImageShape),
dict(type=TransformersImageProcessor,
pretrained="facebook/dinov2-giant"),
dict(type=RandomTextDrop),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Resize,
size=1024, interpolation="bilinear"),
dict(type=RandomCrop, size=1024),
dict(type=RandomHorizontalFlip, p=0.5),
dict(type=ComputeTimeIds),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.ToTensor),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Normalize, mean=[0.5], std=[0.5]),
dict(
type=PackInputs, input_keys=["img", "text", "time_ids", "clip_img"]),
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
dataset=dict(
type=HFDataset,
dataset="lambdalabs/pokemon-blip-captions",
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(
type=VisualizationHook,
prompt=["a drawing of a green pokemon with red eyes"] * 2 + [""] * 2,
example_image=[
'https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true' # noqa
] * 4,
height=1024,
width=1024),
dict(type=IPAdapterSaveHook),
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from diffusers.models.embeddings import ImageProjection
from transformers import (
AutoTokenizer,
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
Expand Down Expand Up @@ -34,4 +35,5 @@
subfolder="sdxl_models/image_encoder"),
image_projection=dict(type=ImageProjection,
num_image_text_embeds=4),
feature_extractor=dict(type=CLIPImageProcessor),
gradient_checkpointing=True)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from diffusers.models.embeddings import IPAdapterPlusImageProjection
from transformers import (
AutoTokenizer,
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
Expand Down Expand Up @@ -39,4 +40,5 @@
heads=20,
num_queries=16,
ffn_ratio=4),
feature_extractor=dict(type=CLIPImageProcessor),
gradient_checkpointing=True)
12 changes: 12 additions & 0 deletions diffengine/configs/ip_adapter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,15 @@ You can see more details on [`docs/source/run_guides/run_ip_adapter.md`](../../d
![input1](https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true)

![example1](https://github.com/okotaku/diffengine/assets/24734142/ace81220-010b-44a5-aa8f-3acdf3f54433)

#### stable_diffusion_xl_pokemon_blip_ip_adapter_plus_dinov2

![input1](https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true)

![example1](https://github.com/okotaku/diffengine/assets/24734142/5e1e2088-d00b-4909-9c64-61a7b5ac6b44)

#### stable_diffusion_xl_pokemon_blip_ip_adapter_plus_dinov2_giant

![input1](https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true)

![example1](https://github.com/okotaku/diffengine/assets/24734142/f76c33ba-c1ac-4f6f-b256-d48de5e58bf8)
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from mmengine.config import read_base
from transformers import AutoImageProcessor, Dinov2Model

with read_base():
from .._base_.datasets.pokemon_blip_xl_ip_adapter_dinov2 import *
from .._base_.default_runtime import *
from .._base_.models.stable_diffusion_xl_ip_adapter_plus import *
from .._base_.schedules.stable_diffusion_xl_50e import *


model.image_encoder = dict(
type=Dinov2Model.from_pretrained,
pretrained_model_name_or_path="facebook/dinov2-base")
model.feature_extractor = dict(
type=AutoImageProcessor.from_pretrained,
pretrained_model_name_or_path="facebook/dinov2-base")

train_dataloader.update(batch_size=1)

optim_wrapper.update(accumulative_counts=4) # update every four times

train_cfg.update(by_epoch=True, max_epochs=100)
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from mmengine.config import read_base
from transformers import AutoImageProcessor, Dinov2Model

with read_base():
from .._base_.datasets.pokemon_blip_xl_ip_adapter_dinov2_giant import *
from .._base_.default_runtime import *
from .._base_.models.stable_diffusion_xl_ip_adapter_plus import *
from .._base_.schedules.stable_diffusion_xl_50e import *


model.image_encoder = dict(
type=Dinov2Model.from_pretrained,
pretrained_model_name_or_path="facebook/dinov2-giant")
model.feature_extractor = dict(
type=AutoImageProcessor.from_pretrained,
pretrained_model_name_or_path="facebook/dinov2-giant")

train_dataloader.update(batch_size=1)

optim_wrapper.update(accumulative_counts=4) # update every four times

train_cfg.update(by_epoch=True, max_epochs=100)
2 changes: 2 additions & 0 deletions diffengine/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SaveImageShape,
T5TextPreprocess,
TorchVisonTransformWrapper,
TransformersImageProcessor,
)
from .wrappers import RandomChoice

Expand Down Expand Up @@ -47,4 +48,5 @@
"TorchVisonTransformWrapper",
"ConcatMultipleImgs",
"ComputeaMUSEdMicroConds",
"TransformersImageProcessor",
]
33 changes: 33 additions & 0 deletions diffengine/datasets/transforms/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mmengine.dataset.base_dataset import Compose
from torchvision.transforms.functional import crop
from torchvision.transforms.transforms import InterpolationMode
from transformers import AutoImageProcessor
from transformers import CLIPImageProcessor as HFCLIPImageProcessor

from diffengine.datasets.transforms.base import BaseTransform
Expand Down Expand Up @@ -936,3 +937,35 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
micro_conds = micro_conds[0]
results["micro_conds"] = micro_conds
return results


@TRANSFORMS.register_module()
class TransformersImageProcessor(BaseTransform):
"""TransformersImageProcessor.
Args:
----
key (str): `key` to apply augmentation from results. Defaults to 'img'.
output_key (str): `output_key` after applying augmentation from
results. Defaults to 'clip_img'.
"""

def __init__(self, key: str = "img", output_key: str = "clip_img",
pretrained: str | None = None) -> None:
self.key = key
self.output_key = output_key
self.pipeline = AutoImageProcessor.from_pretrained(pretrained)

def transform(self, results: dict) -> dict | tuple[list, list] | None:
"""Transform.
Args:
----
results (dict): The result dict.
"""
assert not isinstance(results[self.key], list), (
"TransformersImageProcessor only support single image.")
# (1, 3, 224, 224) -> (3, 224, 224)
results[self.output_key] = self.pipeline(
images=results[self.key], return_tensors="pt").pixel_values[0]
return results
11 changes: 6 additions & 5 deletions diffengine/models/editors/ip_adapter/ip_adapter_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
from diffusers.utils import load_image
from PIL import Image
from torch import nn
from transformers import CLIPImageProcessor

from diffengine.models.archs import (
load_ip_adapter,
process_ip_adapter_state_dict,
set_unet_ip_adapter,
)
from diffengine.models.editors.stable_diffusion_xl import StableDiffusionXL
from diffengine.registry import MODELS
from diffengine.registry import MODELS, TRANSFORMS


@MODELS.register_module()
Expand All @@ -26,6 +25,7 @@ class IPAdapterXL(StableDiffusionXL):
----
image_encoder (dict): The image encoder config.
image_projection (dict): The image projection config.
feature_extractor (dict): The feature extractor config.
pretrained_adapter (str, optional): Path to pretrained IP-Adapter.
Defaults to None.
pretrained_adapter_subfolder (str, optional): Sub folder of pretrained
Expand Down Expand Up @@ -55,6 +55,7 @@ def __init__(self,
*args,
image_encoder: dict,
image_projection: dict,
feature_extractor: dict,
pretrained_adapter: str | None = None,
pretrained_adapter_subfolder: str = "",
pretrained_adapter_weights_name: str = "",
Expand All @@ -80,6 +81,8 @@ def __init__(self,
self.pretrained_adapter_weights_name = pretrained_adapter_weights_name
self.zeros_image_embeddings_prob = zeros_image_embeddings_prob

self.feature_extractor = TRANSFORMS.build(feature_extractor)

super().__init__(
*args,
unet_lora_config=unet_lora_config,
Expand Down Expand Up @@ -162,7 +165,7 @@ def infer(self,
tokenizer_2=self.tokenizer_two,
unet=self.unet,
image_encoder=self.image_encoder,
feature_extractor=CLIPImageProcessor(),
feature_extractor=self.feature_extractor,
torch_dtype=(torch.float16 if self.device != torch.device("cpu")
else torch.float32),
)
Expand Down Expand Up @@ -286,7 +289,6 @@ def forward(
replacement=True).to(image_embeds)
image_embeds = (image_embeds * mask.view(-1, 1)).view(num_batches, 1, 1, -1)

# TODO(takuoko): drop image # noqa
ip_tokens = self.image_projection(image_embeds)

model_pred = self.unet(
Expand Down Expand Up @@ -387,7 +389,6 @@ def forward(
image_embeds = self.image_encoder(
clip_img, output_hidden_states=True).hidden_states[-2]

# TODO(takuoko): drop image # noqa
ip_tokens = self.image_projection(image_embeds)

model_pred = self.unet(
Expand Down
33 changes: 19 additions & 14 deletions tests/configs/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from diffusers.models.embeddings import ImageProjection
from transformers import (
AutoTokenizer,
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
Expand All @@ -16,28 +17,32 @@

base_model = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
model = dict(type=IPAdapterXL,
model=base_model,
tokenizer_one=dict(type=AutoTokenizer.from_pretrained,
model=base_model,
tokenizer_one=dict(type=AutoTokenizer.from_pretrained,
subfolder="tokenizer",
use_fast=False),
tokenizer_two=dict(type=AutoTokenizer.from_pretrained,
tokenizer_two=dict(type=AutoTokenizer.from_pretrained,
subfolder="tokenizer_2",
use_fast=False),
scheduler=dict(type=DDPMScheduler.from_pretrained,
scheduler=dict(type=DDPMScheduler.from_pretrained,
subfolder="scheduler"),
text_encoder_one=dict(type=CLIPTextModel.from_pretrained,
subfolder="text_encoder"),
text_encoder_two=dict(type=CLIPTextModelWithProjection.from_pretrained,
subfolder="text_encoder_2"),
vae=dict(
text_encoder_one=dict(type=CLIPTextModel.from_pretrained,
subfolder="text_encoder"),
text_encoder_two=dict(type=CLIPTextModelWithProjection.from_pretrained,
subfolder="text_encoder_2"),
vae=dict(
type=AutoencoderKL.from_pretrained,
subfolder="vae"),
unet=dict(type=UNet2DConditionModel.from_pretrained,
subfolder="unet"),
image_encoder=dict(type=CLIPVisionModelWithProjection.from_pretrained,
unet=dict(type=UNet2DConditionModel.from_pretrained,
subfolder="unet"),
image_encoder=dict(type=CLIPVisionModelWithProjection.from_pretrained,
pretrained_model_name_or_path="hf-internal-testing/unidiffuser-diffusers-test",
subfolder="image_encoder"),
image_projection=dict(type=ImageProjection,
num_image_text_embeds=4),
image_projection=dict(type=ImageProjection,
num_image_text_embeds=4),
feature_extractor=dict(
type=CLIPImageProcessor.from_pretrained,
pretrained_model_name_or_path="hf-internal-testing/unidiffuser-diffusers-test",
subfolder="image_processor"),
data_preprocessor=dict(type=IPAdapterXLDataPreprocessor),
loss=dict(type=L2Loss))
Loading

0 comments on commit 665b109

Please sign in to comment.