diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 83693485d0e2..ba214f8e49f0 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -282,6 +282,10 @@
title: ControlNet
- local: api/pipelines/controlnet_sdxl
title: ControlNet with Stable Diffusion XL
+ - local: api/pipelines/controlnetxs
+ title: ControlNet-XS
+ - local: api/pipelines/controlnetxs_sdxl
+ title: ControlNet-XS with Stable Diffusion XL
- local: api/pipelines/dance_diffusion
title: Dance Diffusion
- local: api/pipelines/ddim
diff --git a/examples/research_projects/controlnetxs/README.md b/docs/source/en/api/pipelines/controlnetxs.md
similarity index 61%
rename from examples/research_projects/controlnetxs/README.md
rename to docs/source/en/api/pipelines/controlnetxs.md
index 72ed91c01db2..2d4ae7b8ce46 100644
--- a/examples/research_projects/controlnetxs/README.md
+++ b/docs/source/en/api/pipelines/controlnetxs.md
@@ -1,3 +1,15 @@
+
+
# ControlNet-XS
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
@@ -12,5 +24,16 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## StableDiffusionControlNetXSPipeline
+[[autodoc]] StableDiffusionControlNetXSPipeline
+ - all
+ - __call__
-> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
\ No newline at end of file
+## StableDiffusionPipelineOutput
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/examples/research_projects/controlnetxs/README_sdxl.md b/docs/source/en/api/pipelines/controlnetxs_sdxl.md
similarity index 56%
rename from examples/research_projects/controlnetxs/README_sdxl.md
rename to docs/source/en/api/pipelines/controlnetxs_sdxl.md
index d401c1e76698..31075c0ef96a 100644
--- a/examples/research_projects/controlnetxs/README_sdxl.md
+++ b/docs/source/en/api/pipelines/controlnetxs_sdxl.md
@@ -1,3 +1,15 @@
+
+
# ControlNet-XS with Stable Diffusion XL
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
@@ -12,4 +24,22 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
-> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
\ No newline at end of file
+
+
+🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
+
+
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## StableDiffusionXLControlNetXSPipeline
+[[autodoc]] StableDiffusionXLControlNetXSPipeline
+ - all
+ - __call__
+
+## StableDiffusionPipelineOutput
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/examples/research_projects/controlnetxs/controlnetxs.py b/examples/research_projects/controlnetxs/controlnetxs.py
deleted file mode 100644
index 14ad1d8a3af9..000000000000
--- a/examples/research_projects/controlnetxs/controlnetxs.py
+++ /dev/null
@@ -1,1014 +0,0 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import math
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-import torch
-import torch.utils.checkpoint
-from torch import nn
-from torch.nn import functional as F
-from torch.nn.modules.normalization import GroupNorm
-
-from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProcessor
-from diffusers.models.autoencoders import AutoencoderKL
-from diffusers.models.lora import LoRACompatibleConv
-from diffusers.models.modeling_utils import ModelMixin
-from diffusers.models.unets.unet_2d_blocks import (
- CrossAttnDownBlock2D,
- CrossAttnUpBlock2D,
- DownBlock2D,
- Downsample2D,
- ResnetBlock2D,
- Transformer2DModel,
- UpBlock2D,
- Upsample2D,
-)
-from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
-from diffusers.utils import BaseOutput, logging
-
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-
-@dataclass
-class ControlNetXSOutput(BaseOutput):
- """
- The output of [`ControlNetXSModel`].
-
- Args:
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
- The output of the `ControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model
- output, but is already the final output.
- """
-
- sample: torch.FloatTensor = None
-
-
-# copied from diffusers.models.controlnet.ControlNetConditioningEmbedding
-class ControlNetConditioningEmbedding(nn.Module):
- """
- Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
- [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
- training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
- convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
- (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
- model) to encode image-space conditions ... into feature maps ..."
- """
-
- def __init__(
- self,
- conditioning_embedding_channels: int,
- conditioning_channels: int = 3,
- block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
- ):
- super().__init__()
-
- self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
-
- self.blocks = nn.ModuleList([])
-
- for i in range(len(block_out_channels) - 1):
- channel_in = block_out_channels[i]
- channel_out = block_out_channels[i + 1]
- self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
- self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
-
- self.conv_out = zero_module(
- nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
- )
-
- def forward(self, conditioning):
- embedding = self.conv_in(conditioning)
- embedding = F.silu(embedding)
-
- for block in self.blocks:
- embedding = block(embedding)
- embedding = F.silu(embedding)
-
- embedding = self.conv_out(embedding)
-
- return embedding
-
-
-class ControlNetXSModel(ModelMixin, ConfigMixin):
- r"""
- A ControlNet-XS model
-
- This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic
- methods implemented for all models (such as downloading or saving).
-
- Most of parameters for this model are passed into the [`UNet2DConditionModel`] it creates. Check the documentation
- of [`UNet2DConditionModel`] for them.
-
- Parameters:
- conditioning_channels (`int`, defaults to 3):
- Number of channels of conditioning input (e.g. an image)
- controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
- The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
- conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
- The tuple of output channel for each block in the `controlnet_cond_embedding` layer.
- time_embedding_input_dim (`int`, defaults to 320):
- Dimension of input into time embedding. Needs to be same as in the base model.
- time_embedding_dim (`int`, defaults to 1280):
- Dimension of output from time embedding. Needs to be same as in the base model.
- learn_embedding (`bool`, defaults to `False`):
- Whether to use time embedding of the control model. If yes, the time embedding is a linear interpolation of
- the time embeddings of the control and base model with interpolation parameter `time_embedding_mix**3`.
- time_embedding_mix (`float`, defaults to 1.0):
- Linear interpolation parameter used if `learn_embedding` is `True`. A value of 1.0 means only the
- control model's time embedding will be used. A value of 0.0 means only the base model's time embedding will be used.
- base_model_channel_sizes (`Dict[str, List[Tuple[int]]]`):
- Channel sizes of each subblock of base model. Use `gather_subblock_sizes` on your base model to compute it.
- """
-
- @classmethod
- def init_original(cls, base_model: UNet2DConditionModel, is_sdxl=True):
- """
- Create a ControlNetXS model with the same parameters as in the original paper (https://github.com/vislearn/ControlNet-XS).
-
- Parameters:
- base_model (`UNet2DConditionModel`):
- Base UNet model. Needs to be either StableDiffusion or StableDiffusion-XL.
- is_sdxl (`bool`, defaults to `True`):
- Whether passed `base_model` is a StableDiffusion-XL model.
- """
-
- def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_attn_heads: int):
- """
- Currently, diffusers can only set the dimension of attention heads (see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why).
- The original ControlNet-XS model, however, define the number of attention heads.
- That's why compute the dimensions needed to get the correct number of attention heads.
- """
- block_out_channels = [int(size_ratio * c) for c in base_model.config.block_out_channels]
- dim_attn_heads = [math.ceil(c / num_attn_heads) for c in block_out_channels]
- return dim_attn_heads
-
- if is_sdxl:
- return ControlNetXSModel.from_unet(
- base_model,
- time_embedding_mix=0.95,
- learn_embedding=True,
- size_ratio=0.1,
- conditioning_embedding_out_channels=(16, 32, 96, 256),
- num_attention_heads=get_dim_attn_heads(base_model, 0.1, 64),
- )
- else:
- return ControlNetXSModel.from_unet(
- base_model,
- time_embedding_mix=1.0,
- learn_embedding=True,
- size_ratio=0.0125,
- conditioning_embedding_out_channels=(16, 32, 96, 256),
- num_attention_heads=get_dim_attn_heads(base_model, 0.0125, 8),
- )
-
- @classmethod
- def _gather_subblock_sizes(cls, unet: UNet2DConditionModel, base_or_control: str):
- """To create correctly sized connections between base and control model, we need to know
- the input and output channels of each subblock.
-
- Parameters:
- unet (`UNet2DConditionModel`):
- Unet of which the subblock channels sizes are to be gathered.
- base_or_control (`str`):
- Needs to be either "base" or "control". If "base", decoder is also considered.
- """
- if base_or_control not in ["base", "control"]:
- raise ValueError("`base_or_control` needs to be either `base` or `control`")
-
- channel_sizes = {"down": [], "mid": [], "up": []}
-
- # input convolution
- channel_sizes["down"].append((unet.conv_in.in_channels, unet.conv_in.out_channels))
-
- # encoder blocks
- for module in unet.down_blocks:
- if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
- for r in module.resnets:
- channel_sizes["down"].append((r.in_channels, r.out_channels))
- if module.downsamplers:
- channel_sizes["down"].append(
- (module.downsamplers[0].channels, module.downsamplers[0].out_channels)
- )
- else:
- raise ValueError(f"Encountered unknown module of type {type(module)} while creating ControlNet-XS.")
-
- # middle block
- channel_sizes["mid"].append((unet.mid_block.resnets[0].in_channels, unet.mid_block.resnets[0].out_channels))
-
- # decoder blocks
- if base_or_control == "base":
- for module in unet.up_blocks:
- if isinstance(module, (CrossAttnUpBlock2D, UpBlock2D)):
- for r in module.resnets:
- channel_sizes["up"].append((r.in_channels, r.out_channels))
- else:
- raise ValueError(
- f"Encountered unknown module of type {type(module)} while creating ControlNet-XS."
- )
-
- return channel_sizes
-
- @register_to_config
- def __init__(
- self,
- conditioning_channels: int = 3,
- conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
- controlnet_conditioning_channel_order: str = "rgb",
- time_embedding_input_dim: int = 320,
- time_embedding_dim: int = 1280,
- time_embedding_mix: float = 1.0,
- learn_embedding: bool = False,
- base_model_channel_sizes: Dict[str, List[Tuple[int]]] = {
- "down": [
- (4, 320),
- (320, 320),
- (320, 320),
- (320, 320),
- (320, 640),
- (640, 640),
- (640, 640),
- (640, 1280),
- (1280, 1280),
- ],
- "mid": [(1280, 1280)],
- "up": [
- (2560, 1280),
- (2560, 1280),
- (1920, 1280),
- (1920, 640),
- (1280, 640),
- (960, 640),
- (960, 320),
- (640, 320),
- (640, 320),
- ],
- },
- sample_size: Optional[int] = None,
- down_block_types: Tuple[str] = (
- "CrossAttnDownBlock2D",
- "CrossAttnDownBlock2D",
- "CrossAttnDownBlock2D",
- "DownBlock2D",
- ),
- up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
- norm_num_groups: Optional[int] = 32,
- cross_attention_dim: Union[int, Tuple[int]] = 1280,
- transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
- num_attention_heads: Optional[Union[int, Tuple[int]]] = 8,
- upcast_attention: bool = False,
- ):
- super().__init__()
-
- # 1 - Create control unet
- self.control_model = UNet2DConditionModel(
- sample_size=sample_size,
- down_block_types=down_block_types,
- up_block_types=up_block_types,
- block_out_channels=block_out_channels,
- norm_num_groups=norm_num_groups,
- cross_attention_dim=cross_attention_dim,
- transformer_layers_per_block=transformer_layers_per_block,
- attention_head_dim=num_attention_heads,
- use_linear_projection=True,
- upcast_attention=upcast_attention,
- time_embedding_dim=time_embedding_dim,
- )
-
- # 2 - Do model surgery on control model
- # 2.1 - Allow to use the same time information as the base model
- adjust_time_dims(self.control_model, time_embedding_input_dim, time_embedding_dim)
-
- # 2.2 - Allow for information infusion from base model
-
- # We concat the output of each base encoder subblocks to the input of the next control encoder subblock
- # (We ignore the 1st element, as it represents the `conv_in`.)
- extra_input_channels = [input_channels for input_channels, _ in base_model_channel_sizes["down"][1:]]
- it_extra_input_channels = iter(extra_input_channels)
-
- for b, block in enumerate(self.control_model.down_blocks):
- for r in range(len(block.resnets)):
- increase_block_input_in_encoder_resnet(
- self.control_model, block_no=b, resnet_idx=r, by=next(it_extra_input_channels)
- )
-
- if block.downsamplers:
- increase_block_input_in_encoder_downsampler(
- self.control_model, block_no=b, by=next(it_extra_input_channels)
- )
-
- increase_block_input_in_mid_resnet(self.control_model, by=extra_input_channels[-1])
-
- # 2.3 - Make group norms work with modified channel sizes
- adjust_group_norms(self.control_model)
-
- # 3 - Gather Channel Sizes
- self.ch_inout_ctrl = ControlNetXSModel._gather_subblock_sizes(self.control_model, base_or_control="control")
- self.ch_inout_base = base_model_channel_sizes
-
- # 4 - Build connections between base and control model
- self.down_zero_convs_out = nn.ModuleList([])
- self.down_zero_convs_in = nn.ModuleList([])
- self.middle_block_out = nn.ModuleList([])
- self.middle_block_in = nn.ModuleList([])
- self.up_zero_convs_out = nn.ModuleList([])
- self.up_zero_convs_in = nn.ModuleList([])
-
- for ch_io_base in self.ch_inout_base["down"]:
- self.down_zero_convs_in.append(self._make_zero_conv(in_channels=ch_io_base[1], out_channels=ch_io_base[1]))
- for i in range(len(self.ch_inout_ctrl["down"])):
- self.down_zero_convs_out.append(
- self._make_zero_conv(self.ch_inout_ctrl["down"][i][1], self.ch_inout_base["down"][i][1])
- )
-
- self.middle_block_out = self._make_zero_conv(
- self.ch_inout_ctrl["mid"][-1][1], self.ch_inout_base["mid"][-1][1]
- )
-
- self.up_zero_convs_out.append(
- self._make_zero_conv(self.ch_inout_ctrl["down"][-1][1], self.ch_inout_base["mid"][-1][1])
- )
- for i in range(1, len(self.ch_inout_ctrl["down"])):
- self.up_zero_convs_out.append(
- self._make_zero_conv(self.ch_inout_ctrl["down"][-(i + 1)][1], self.ch_inout_base["up"][i - 1][1])
- )
-
- # 5 - Create conditioning hint embedding
- self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
- conditioning_embedding_channels=block_out_channels[0],
- block_out_channels=conditioning_embedding_out_channels,
- conditioning_channels=conditioning_channels,
- )
-
- # In the mininal implementation setting, we only need the control model up to the mid block
- del self.control_model.up_blocks
- del self.control_model.conv_norm_out
- del self.control_model.conv_out
-
- @classmethod
- def from_unet(
- cls,
- unet: UNet2DConditionModel,
- conditioning_channels: int = 3,
- conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
- controlnet_conditioning_channel_order: str = "rgb",
- learn_embedding: bool = False,
- time_embedding_mix: float = 1.0,
- block_out_channels: Optional[Tuple[int]] = None,
- size_ratio: Optional[float] = None,
- num_attention_heads: Optional[Union[int, Tuple[int]]] = 8,
- norm_num_groups: Optional[int] = None,
- ):
- r"""
- Instantiate a [`ControlNetXSModel`] from [`UNet2DConditionModel`].
-
- Parameters:
- unet (`UNet2DConditionModel`):
- The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it.
- conditioning_channels (`int`, defaults to 3):
- Number of channels of conditioning input (e.g. an image)
- conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
- The tuple of output channel for each block in the `controlnet_cond_embedding` layer.
- controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
- The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
- learn_embedding (`bool`, defaults to `False`):
- Wether to use time embedding of the control model. If yes, the time embedding is a linear interpolation
- of the time embeddings of the control and base model with interpolation parameter
- `time_embedding_mix**3`.
- time_embedding_mix (`float`, defaults to 1.0):
- Linear interpolation parameter used if `learn_embedding` is `True`.
- block_out_channels (`Tuple[int]`, *optional*):
- Down blocks output channels in control model. Either this or `size_ratio` must be given.
- size_ratio (float, *optional*):
- When given, block_out_channels is set to a relative fraction of the base model's block_out_channels.
- Either this or `block_out_channels` must be given.
- num_attention_heads (`Union[int, Tuple[int]]`, *optional*):
- The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
- norm_num_groups (int, *optional*, defaults to `None`):
- The number of groups to use for the normalization of the control unet. If `None`,
- `int(unet.config.norm_num_groups * size_ratio)` is taken.
- """
-
- # Check input
- fixed_size = block_out_channels is not None
- relative_size = size_ratio is not None
- if not (fixed_size ^ relative_size):
- raise ValueError(
- "Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)."
- )
-
- # Create model
- if block_out_channels is None:
- block_out_channels = [int(size_ratio * c) for c in unet.config.block_out_channels]
-
- # Check that attention heads and group norms match channel sizes
- # - attention heads
- def attn_heads_match_channel_sizes(attn_heads, channel_sizes):
- if isinstance(attn_heads, (tuple, list)):
- return all(c % a == 0 for a, c in zip(attn_heads, channel_sizes))
- else:
- return all(c % attn_heads == 0 for c in channel_sizes)
-
- num_attention_heads = num_attention_heads or unet.config.attention_head_dim
- if not attn_heads_match_channel_sizes(num_attention_heads, block_out_channels):
- raise ValueError(
- f"The dimension of attention heads ({num_attention_heads}) must divide `block_out_channels` ({block_out_channels}). If you didn't set `num_attention_heads` the default settings don't match your model. Set `num_attention_heads` manually."
- )
-
- # - group norms
- def group_norms_match_channel_sizes(num_groups, channel_sizes):
- return all(c % num_groups == 0 for c in channel_sizes)
-
- if norm_num_groups is None:
- if group_norms_match_channel_sizes(unet.config.norm_num_groups, block_out_channels):
- norm_num_groups = unet.config.norm_num_groups
- else:
- norm_num_groups = min(block_out_channels)
-
- if group_norms_match_channel_sizes(norm_num_groups, block_out_channels):
- print(
- f"`norm_num_groups` was set to `min(block_out_channels)` (={norm_num_groups}) so it divides all block_out_channels` ({block_out_channels}). Set it explicitly to remove this information."
- )
- else:
- raise ValueError(
- f"`block_out_channels` ({block_out_channels}) don't match the base models `norm_num_groups` ({unet.config.norm_num_groups}). Setting `norm_num_groups` to `min(block_out_channels)` ({norm_num_groups}) didn't fix this. Pass `norm_num_groups` explicitly so it divides all block_out_channels."
- )
-
- def get_time_emb_input_dim(unet: UNet2DConditionModel):
- return unet.time_embedding.linear_1.in_features
-
- def get_time_emb_dim(unet: UNet2DConditionModel):
- return unet.time_embedding.linear_2.out_features
-
- # Clone params from base unet if
- # (i) it's required to build SD or SDXL, and
- # (ii) it's not used for the time embedding (as time embedding of control model is never used), and
- # (iii) it's not set further below anyway
- to_keep = [
- "cross_attention_dim",
- "down_block_types",
- "sample_size",
- "transformer_layers_per_block",
- "up_block_types",
- "upcast_attention",
- ]
- kwargs = {k: v for k, v in dict(unet.config).items() if k in to_keep}
- kwargs.update(block_out_channels=block_out_channels)
- kwargs.update(num_attention_heads=num_attention_heads)
- kwargs.update(norm_num_groups=norm_num_groups)
-
- # Add controlnetxs-specific params
- kwargs.update(
- conditioning_channels=conditioning_channels,
- controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
- time_embedding_input_dim=get_time_emb_input_dim(unet),
- time_embedding_dim=get_time_emb_dim(unet),
- time_embedding_mix=time_embedding_mix,
- learn_embedding=learn_embedding,
- base_model_channel_sizes=ControlNetXSModel._gather_subblock_sizes(unet, base_or_control="base"),
- conditioning_embedding_out_channels=conditioning_embedding_out_channels,
- )
-
- return cls(**kwargs)
-
- @property
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- return self.control_model.attn_processors
-
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- self.control_model.set_attn_processor(processor)
-
- def set_default_attn_processor(self):
- """
- Disables custom attention processors and sets the default attention implementation.
- """
- self.control_model.set_default_attn_processor()
-
- def set_attention_slice(self, slice_size):
- r"""
- Enable sliced attention computation.
-
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
-
- Args:
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
- must be a multiple of `slice_size`.
- """
- self.control_model.set_attention_slice(slice_size)
-
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, (UNet2DConditionModel)):
- if value:
- module.enable_gradient_checkpointing()
- else:
- module.disable_gradient_checkpointing()
-
- def forward(
- self,
- base_model: UNet2DConditionModel,
- sample: torch.FloatTensor,
- timestep: Union[torch.Tensor, float, int],
- encoder_hidden_states: torch.Tensor,
- controlnet_cond: torch.Tensor,
- conditioning_scale: float = 1.0,
- class_labels: Optional[torch.Tensor] = None,
- timestep_cond: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
- return_dict: bool = True,
- ) -> Union[ControlNetXSOutput, Tuple]:
- """
- The [`ControlNetModel`] forward method.
-
- Args:
- base_model (`UNet2DConditionModel`):
- The base unet model we want to control.
- sample (`torch.FloatTensor`):
- The noisy input tensor.
- timestep (`Union[torch.Tensor, float, int]`):
- The number of timesteps to denoise an input.
- encoder_hidden_states (`torch.Tensor`):
- The encoder hidden states.
- controlnet_cond (`torch.FloatTensor`):
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
- conditioning_scale (`float`, defaults to `1.0`):
- How much the control model affects the base model outputs.
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
- timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
- Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
- timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
- embeddings.
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
- negative values to the attention scores corresponding to "discard" tokens.
- added_cond_kwargs (`dict`):
- Additional conditions for the Stable Diffusion XL UNet.
- cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
- A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
- return_dict (`bool`, defaults to `True`):
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
-
- Returns:
- [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`:
- If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a
- tuple is returned where the first element is the sample tensor.
- """
- # check channel order
- channel_order = self.config.controlnet_conditioning_channel_order
-
- if channel_order == "rgb":
- # in rgb order by default
- ...
- elif channel_order == "bgr":
- controlnet_cond = torch.flip(controlnet_cond, dims=[1])
- else:
- raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
-
- # scale control strength
- n_connections = len(self.down_zero_convs_out) + 1 + len(self.up_zero_convs_out)
- scale_list = torch.full((n_connections,), conditioning_scale)
-
- # prepare attention_mask
- if attention_mask is not None:
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
- attention_mask = attention_mask.unsqueeze(1)
-
- # 1. time
- timesteps = timestep
- if not torch.is_tensor(timesteps):
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
- # This would be a good case for the `match` statement (Python 3.10+)
- is_mps = sample.device.type == "mps"
- if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
- else:
- dtype = torch.int32 if is_mps else torch.int64
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
- elif len(timesteps.shape) == 0:
- timesteps = timesteps[None].to(sample.device)
-
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timesteps = timesteps.expand(sample.shape[0])
-
- t_emb = base_model.time_proj(timesteps)
-
- # timesteps does not contain any weights and will always return f32 tensors
- # but time_embedding might actually be running in fp16. so we need to cast here.
- # there might be better ways to encapsulate this.
- t_emb = t_emb.to(dtype=sample.dtype)
-
- if self.config.learn_embedding:
- ctrl_temb = self.control_model.time_embedding(t_emb, timestep_cond)
- base_temb = base_model.time_embedding(t_emb, timestep_cond)
- interpolation_param = self.config.time_embedding_mix**0.3
-
- temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param)
- else:
- temb = base_model.time_embedding(t_emb)
-
- # added time & text embeddings
- aug_emb = None
-
- if base_model.class_embedding is not None:
- if class_labels is None:
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
-
- if base_model.config.class_embed_type == "timestep":
- class_labels = base_model.time_proj(class_labels)
-
- class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype)
- temb = temb + class_emb
-
- if base_model.config.addition_embed_type is not None:
- if base_model.config.addition_embed_type == "text":
- aug_emb = base_model.add_embedding(encoder_hidden_states)
- elif base_model.config.addition_embed_type == "text_image":
- raise NotImplementedError()
- elif base_model.config.addition_embed_type == "text_time":
- # SDXL - style
- if "text_embeds" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
- )
- text_embeds = added_cond_kwargs.get("text_embeds")
- if "time_ids" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
- )
- time_ids = added_cond_kwargs.get("time_ids")
- time_embeds = base_model.add_time_proj(time_ids.flatten())
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
- add_embeds = add_embeds.to(temb.dtype)
- aug_emb = base_model.add_embedding(add_embeds)
- elif base_model.config.addition_embed_type == "image":
- raise NotImplementedError()
- elif base_model.config.addition_embed_type == "image_hint":
- raise NotImplementedError()
-
- temb = temb + aug_emb if aug_emb is not None else temb
-
- # text embeddings
- cemb = encoder_hidden_states
-
- # Preparation
- guided_hint = self.controlnet_cond_embedding(controlnet_cond)
-
- h_ctrl = h_base = sample
- hs_base, hs_ctrl = [], []
- it_down_convs_in, it_down_convs_out, it_dec_convs_in, it_up_convs_out = map(
- iter, (self.down_zero_convs_in, self.down_zero_convs_out, self.up_zero_convs_in, self.up_zero_convs_out)
- )
- scales = iter(scale_list)
-
- base_down_subblocks = to_sub_blocks(base_model.down_blocks)
- ctrl_down_subblocks = to_sub_blocks(self.control_model.down_blocks)
- base_mid_subblocks = to_sub_blocks([base_model.mid_block])
- ctrl_mid_subblocks = to_sub_blocks([self.control_model.mid_block])
- base_up_subblocks = to_sub_blocks(base_model.up_blocks)
-
- # Cross Control
- # 0 - conv in
- h_base = base_model.conv_in(h_base)
- h_ctrl = self.control_model.conv_in(h_ctrl)
- if guided_hint is not None:
- h_ctrl += guided_hint
- h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base
-
- hs_base.append(h_base)
- hs_ctrl.append(h_ctrl)
-
- # 1 - down
- for m_base, m_ctrl in zip(base_down_subblocks, ctrl_down_subblocks):
- h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl
- h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock
- h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock
- h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base
- hs_base.append(h_base)
- hs_ctrl.append(h_ctrl)
-
- # 2 - mid
- h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl
- for m_base, m_ctrl in zip(base_mid_subblocks, ctrl_mid_subblocks):
- h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock
- h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock
- h_base = h_base + self.middle_block_out(h_ctrl) * next(scales) # D - add ctrl -> base
-
- # 3 - up
- for i, m_base in enumerate(base_up_subblocks):
- h_base = h_base + next(it_up_convs_out)(hs_ctrl.pop()) * next(scales) # add info from ctrl encoder
- h_base = torch.cat([h_base, hs_base.pop()], dim=1) # concat info from base encoder+ctrl encoder
- h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs)
-
- h_base = base_model.conv_norm_out(h_base)
- h_base = base_model.conv_act(h_base)
- h_base = base_model.conv_out(h_base)
-
- if not return_dict:
- return h_base
-
- return ControlNetXSOutput(sample=h_base)
-
- def _make_zero_conv(self, in_channels, out_channels=None):
- # keep running track of channels sizes
- self.in_channels = in_channels
- self.out_channels = out_channels or in_channels
-
- return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))
-
- @torch.no_grad()
- def _check_if_vae_compatible(self, vae: AutoencoderKL):
- condition_downscale_factor = 2 ** (len(self.config.conditioning_embedding_out_channels) - 1)
- vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
- compatible = condition_downscale_factor == vae_downscale_factor
- return compatible, condition_downscale_factor, vae_downscale_factor
-
-
-class SubBlock(nn.ModuleList):
- """A SubBlock is the largest piece of either base or control model, that is executed independently of the other model respectively.
- Before each subblock, information is concatted from base to control. And after each subblock, information is added from control to base.
- """
-
- def __init__(self, ms, *args, **kwargs):
- if not is_iterable(ms):
- ms = [ms]
- super().__init__(ms, *args, **kwargs)
-
- def forward(
- self,
- x: torch.Tensor,
- temb: torch.Tensor,
- cemb: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- ):
- """Iterate through children and pass correct information to each."""
- for m in self:
- if isinstance(m, ResnetBlock2D):
- x = m(x, temb)
- elif isinstance(m, Transformer2DModel):
- x = m(x, cemb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs).sample
- elif isinstance(m, Downsample2D):
- x = m(x)
- elif isinstance(m, Upsample2D):
- x = m(x)
- else:
- raise ValueError(
- f"Type of m is {type(m)} but should be `ResnetBlock2D`, `Transformer2DModel`, `Downsample2D` or `Upsample2D`"
- )
-
- return x
-
-
-def adjust_time_dims(unet: UNet2DConditionModel, in_dim: int, out_dim: int):
- unet.time_embedding.linear_1 = nn.Linear(in_dim, out_dim)
-
-
-def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no, resnet_idx, by):
- """Increase channels sizes to allow for additional concatted information from base model"""
- r = unet.down_blocks[block_no].resnets[resnet_idx]
- old_norm1, old_conv1 = r.norm1, r.conv1
- # norm
- norm_args = "num_groups num_channels eps affine".split(" ")
- for a in norm_args:
- assert hasattr(old_norm1, a)
- norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
- norm_kwargs["num_channels"] += by # surgery done here
- # conv1
- conv1_args = [
- "in_channels",
- "out_channels",
- "kernel_size",
- "stride",
- "padding",
- "dilation",
- "groups",
- "bias",
- "padding_mode",
- ]
- if not USE_PEFT_BACKEND:
- conv1_args.append("lora_layer")
-
- for a in conv1_args:
- assert hasattr(old_conv1, a)
-
- conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
- conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
- conv1_kwargs["in_channels"] += by # surgery done here
- # conv_shortcut
- # as we changed the input size of the block, the input and output sizes are likely different,
- # therefore we need a conv_shortcut (simply adding won't work)
- conv_shortcut_args_kwargs = {
- "in_channels": conv1_kwargs["in_channels"],
- "out_channels": conv1_kwargs["out_channels"],
- # default arguments from resnet.__init__
- "kernel_size": 1,
- "stride": 1,
- "padding": 0,
- "bias": True,
- }
- # swap old with new modules
- unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs)
- unet.down_blocks[block_no].resnets[resnet_idx].conv1 = (
- nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
- )
- unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = (
- nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
- )
- unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here
-
-
-def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by):
- """Increase channels sizes to allow for additional concatted information from base model"""
- old_down = unet.down_blocks[block_no].downsamplers[0].conv
-
- args = [
- "in_channels",
- "out_channels",
- "kernel_size",
- "stride",
- "padding",
- "dilation",
- "groups",
- "bias",
- "padding_mode",
- ]
- if not USE_PEFT_BACKEND:
- args.append("lora_layer")
-
- for a in args:
- assert hasattr(old_down, a)
- kwargs = {a: getattr(old_down, a) for a in args}
- kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor.
- kwargs["in_channels"] += by # surgery done here
- # swap old with new modules
- unet.down_blocks[block_no].downsamplers[0].conv = (
- nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs)
- )
- unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here
-
-
-def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
- """Increase channels sizes to allow for additional concatted information from base model"""
- m = unet.mid_block.resnets[0]
- old_norm1, old_conv1 = m.norm1, m.conv1
- # norm
- norm_args = "num_groups num_channels eps affine".split(" ")
- for a in norm_args:
- assert hasattr(old_norm1, a)
- norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
- norm_kwargs["num_channels"] += by # surgery done here
- conv1_args = [
- "in_channels",
- "out_channels",
- "kernel_size",
- "stride",
- "padding",
- "dilation",
- "groups",
- "bias",
- "padding_mode",
- ]
- if not USE_PEFT_BACKEND:
- conv1_args.append("lora_layer")
-
- conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
- conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
- conv1_kwargs["in_channels"] += by # surgery done here
- # conv_shortcut
- # as we changed the input size of the block, the input and output sizes are likely different,
- # therefore we need a conv_shortcut (simply adding won't work)
- conv_shortcut_args_kwargs = {
- "in_channels": conv1_kwargs["in_channels"],
- "out_channels": conv1_kwargs["out_channels"],
- # default arguments from resnet.__init__
- "kernel_size": 1,
- "stride": 1,
- "padding": 0,
- "bias": True,
- }
- # swap old with new modules
- unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs)
- unet.mid_block.resnets[0].conv1 = (
- nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
- )
- unet.mid_block.resnets[0].conv_shortcut = (
- nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
- )
- unet.mid_block.resnets[0].in_channels += by # surgery done here
-
-
-def adjust_group_norms(unet: UNet2DConditionModel, max_num_group: int = 32):
- def find_denominator(number, start):
- if start >= number:
- return number
- while start != 0:
- residual = number % start
- if residual == 0:
- return start
- start -= 1
-
- for block in [*unet.down_blocks, unet.mid_block]:
- # resnets
- for r in block.resnets:
- if r.norm1.num_groups < max_num_group:
- r.norm1.num_groups = find_denominator(r.norm1.num_channels, start=max_num_group)
-
- if r.norm2.num_groups < max_num_group:
- r.norm2.num_groups = find_denominator(r.norm2.num_channels, start=max_num_group)
-
- # transformers
- if hasattr(block, "attentions"):
- for a in block.attentions:
- if a.norm.num_groups < max_num_group:
- a.norm.num_groups = find_denominator(a.norm.num_channels, start=max_num_group)
-
-
-def is_iterable(o):
- if isinstance(o, str):
- return False
- try:
- iter(o)
- return True
- except TypeError:
- return False
-
-
-def to_sub_blocks(blocks):
- if not is_iterable(blocks):
- blocks = [blocks]
-
- sub_blocks = []
-
- for b in blocks:
- if hasattr(b, "resnets"):
- if hasattr(b, "attentions") and b.attentions is not None:
- for r, a in zip(b.resnets, b.attentions):
- sub_blocks.append([r, a])
-
- num_resnets = len(b.resnets)
- num_attns = len(b.attentions)
-
- if num_resnets > num_attns:
- # we can have more resnets than attentions, so add each resnet as separate subblock
- for i in range(num_attns, num_resnets):
- sub_blocks.append([b.resnets[i]])
- else:
- for r in b.resnets:
- sub_blocks.append([r])
-
- # upsamplers are part of the same subblock
- if hasattr(b, "upsamplers") and b.upsamplers is not None:
- for u in b.upsamplers:
- sub_blocks[-1].extend([u])
-
- # downsamplers are own subblock
- if hasattr(b, "downsamplers") and b.downsamplers is not None:
- for d in b.downsamplers:
- sub_blocks.append([d])
-
- return list(map(SubBlock, sub_blocks))
-
-
-def zero_module(module):
- for p in module.parameters():
- nn.init.zeros_(p)
- return module
diff --git a/examples/research_projects/controlnetxs/infer_sd_controlnetxs.py b/examples/research_projects/controlnetxs/infer_sd_controlnetxs.py
deleted file mode 100644
index 722b282a3251..000000000000
--- a/examples/research_projects/controlnetxs/infer_sd_controlnetxs.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# !pip install opencv-python transformers accelerate
-import argparse
-
-import cv2
-import numpy as np
-import torch
-from controlnetxs import ControlNetXSModel
-from PIL import Image
-from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
-
-from diffusers.utils import load_image
-
-
-parser = argparse.ArgumentParser()
-parser.add_argument(
- "--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
-)
-parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
-parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
-parser.add_argument(
- "--image_path",
- type=str,
- default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
-)
-parser.add_argument("--num_inference_steps", type=int, default=50)
-
-args = parser.parse_args()
-
-prompt = args.prompt
-negative_prompt = args.negative_prompt
-# download an image
-image = load_image(args.image_path)
-
-# initialize the models and pipeline
-controlnet_conditioning_scale = args.controlnet_conditioning_scale
-controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16)
-pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16
-)
-pipe.enable_model_cpu_offload()
-
-# get canny image
-image = np.array(image)
-image = cv2.Canny(image, 100, 200)
-image = image[:, :, None]
-image = np.concatenate([image, image, image], axis=2)
-canny_image = Image.fromarray(image)
-
-num_inference_steps = args.num_inference_steps
-
-# generate image
-image = pipe(
- prompt,
- controlnet_conditioning_scale=controlnet_conditioning_scale,
- image=canny_image,
- num_inference_steps=num_inference_steps,
-).images[0]
-image.save("cnxs_sd.canny.png")
diff --git a/examples/research_projects/controlnetxs/infer_sdxl_controlnetxs.py b/examples/research_projects/controlnetxs/infer_sdxl_controlnetxs.py
deleted file mode 100644
index e5b8cfd88223..000000000000
--- a/examples/research_projects/controlnetxs/infer_sdxl_controlnetxs.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# !pip install opencv-python transformers accelerate
-import argparse
-
-import cv2
-import numpy as np
-import torch
-from controlnetxs import ControlNetXSModel
-from PIL import Image
-from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
-
-from diffusers.utils import load_image
-
-
-parser = argparse.ArgumentParser()
-parser.add_argument(
- "--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
-)
-parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
-parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
-parser.add_argument(
- "--image_path",
- type=str,
- default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
-)
-parser.add_argument("--num_inference_steps", type=int, default=50)
-
-args = parser.parse_args()
-
-prompt = args.prompt
-negative_prompt = args.negative_prompt
-# download an image
-image = load_image(args.image_path)
-# initialize the models and pipeline
-controlnet_conditioning_scale = args.controlnet_conditioning_scale
-controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16)
-pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
-)
-pipe.enable_model_cpu_offload()
-
-# get canny image
-image = np.array(image)
-image = cv2.Canny(image, 100, 200)
-image = image[:, :, None]
-image = np.concatenate([image, image, image], axis=2)
-canny_image = Image.fromarray(image)
-
-num_inference_steps = args.num_inference_steps
-
-# generate image
-image = pipe(
- prompt,
- controlnet_conditioning_scale=controlnet_conditioning_scale,
- image=canny_image,
- num_inference_steps=num_inference_steps,
-).images[0]
-image.save("cnxs_sdxl.canny.png")
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 770045923d5d..5d6761663938 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -80,6 +80,7 @@
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
+ "ControlNetXSAdapter",
"I2VGenXLUNet",
"Kandinsky3UNet",
"ModelMixin",
@@ -94,6 +95,7 @@
"UNet2DConditionModel",
"UNet2DModel",
"UNet3DConditionModel",
+ "UNetControlNetXSModel",
"UNetMotionModel",
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
@@ -270,6 +272,7 @@
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionControlNetPipeline",
+ "StableDiffusionControlNetXSPipeline",
"StableDiffusionDepth2ImgPipeline",
"StableDiffusionDiffEditPipeline",
"StableDiffusionGLIGENPipeline",
@@ -293,6 +296,7 @@
"StableDiffusionXLControlNetImg2ImgPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
"StableDiffusionXLControlNetPipeline",
+ "StableDiffusionXLControlNetXSPipeline",
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
@@ -474,6 +478,7 @@
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
+ ControlNetXSAdapter,
I2VGenXLUNet,
Kandinsky3UNet,
ModelMixin,
@@ -487,6 +492,7 @@
UNet2DConditionModel,
UNet2DModel,
UNet3DConditionModel,
+ UNetControlNetXSModel,
UNetMotionModel,
UNetSpatioTemporalConditionModel,
UVit2DModel,
@@ -642,6 +648,7 @@
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
+ StableDiffusionControlNetXSPipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline,
StableDiffusionGLIGENPipeline,
@@ -665,6 +672,7 @@
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
+ StableDiffusionXLControlNetXSPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index da77e4450c86..78b0efff921d 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -32,6 +32,7 @@
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
+ _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
@@ -68,6 +69,7 @@
ConsistencyDecoderVAE,
)
from .controlnet import ControlNetModel
+ from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import (
diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py
new file mode 100644
index 000000000000..4bbe1dd4dc25
--- /dev/null
+++ b/src/diffusers/models/controlnet_xs.py
@@ -0,0 +1,1892 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from math import gcd
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import FloatTensor, nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, is_torch_version, logging
+from ..utils.torch_utils import apply_freeu
+from .attention_processor import Attention, AttentionProcessor
+from .controlnet import ControlNetConditioningEmbedding
+from .embeddings import TimestepEmbedding, Timesteps
+from .modeling_utils import ModelMixin
+from .unets.unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ CrossAttnUpBlock2D,
+ Downsample2D,
+ ResnetBlock2D,
+ Transformer2DModel,
+ UNetMidBlock2DCrossAttn,
+ Upsample2D,
+)
+from .unets.unet_2d_condition import UNet2DConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class ControlNetXSOutput(BaseOutput):
+ """
+ The output of [`UNetControlNetXSModel`].
+
+ Args:
+ sample (`FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base
+ model output, but is already the final output.
+ """
+
+ sample: FloatTensor = None
+
+
+class DownBlockControlNetXSAdapter(nn.Module):
+ """Components that together with corresponding components from the base model will form a
+ `ControlNetXSCrossAttnDownBlock2D`"""
+
+ def __init__(
+ self,
+ resnets: nn.ModuleList,
+ base_to_ctrl: nn.ModuleList,
+ ctrl_to_base: nn.ModuleList,
+ attentions: Optional[nn.ModuleList] = None,
+ downsampler: Optional[nn.Conv2d] = None,
+ ):
+ super().__init__()
+ self.resnets = resnets
+ self.base_to_ctrl = base_to_ctrl
+ self.ctrl_to_base = ctrl_to_base
+ self.attentions = attentions
+ self.downsamplers = downsampler
+
+
+class MidBlockControlNetXSAdapter(nn.Module):
+ """Components that together with corresponding components from the base model will form a
+ `ControlNetXSCrossAttnMidBlock2D`"""
+
+ def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl: nn.ModuleList, ctrl_to_base: nn.ModuleList):
+ super().__init__()
+ self.midblock = midblock
+ self.base_to_ctrl = base_to_ctrl
+ self.ctrl_to_base = ctrl_to_base
+
+
+class UpBlockControlNetXSAdapter(nn.Module):
+ """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnUpBlock2D`"""
+
+ def __init__(self, ctrl_to_base: nn.ModuleList):
+ super().__init__()
+ self.ctrl_to_base = ctrl_to_base
+
+
+def get_down_block_adapter(
+ base_in_channels: int,
+ base_out_channels: int,
+ ctrl_in_channels: int,
+ ctrl_out_channels: int,
+ temb_channels: int,
+ max_norm_num_groups: Optional[int] = 32,
+ has_crossattn=True,
+ transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1,
+ num_attention_heads: Optional[int] = 1,
+ cross_attention_dim: Optional[int] = 1024,
+ add_downsample: bool = True,
+ upcast_attention: Optional[bool] = False,
+):
+ num_layers = 2 # only support sd + sdxl
+
+ resnets = []
+ attentions = []
+ ctrl_to_base = []
+ base_to_ctrl = []
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ base_in_channels = base_in_channels if i == 0 else base_out_channels
+ ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels
+
+ # Before the resnet/attention application, information is concatted from base to control.
+ # Concat doesn't require change in number of channels
+ base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels))
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl
+ out_channels=ctrl_out_channels,
+ temb_channels=temb_channels,
+ groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups),
+ groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
+ eps=1e-5,
+ )
+ )
+
+ if has_crossattn:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ ctrl_out_channels // num_attention_heads,
+ in_channels=ctrl_out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ use_linear_projection=True,
+ upcast_attention=upcast_attention,
+ norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
+ )
+ )
+
+ # After the resnet/attention application, information is added from control to base
+ # Addition requires change in number of channels
+ ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
+
+ if add_downsample:
+ # Before the downsampler application, information is concatted from base to control
+ # Concat doesn't require change in number of channels
+ base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels))
+
+ downsamplers = Downsample2D(
+ ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op"
+ )
+
+ # After the downsampler application, information is added from control to base
+ # Addition requires change in number of channels
+ ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
+ else:
+ downsamplers = None
+
+ down_block_components = DownBlockControlNetXSAdapter(
+ resnets=nn.ModuleList(resnets),
+ base_to_ctrl=nn.ModuleList(base_to_ctrl),
+ ctrl_to_base=nn.ModuleList(ctrl_to_base),
+ )
+
+ if has_crossattn:
+ down_block_components.attentions = nn.ModuleList(attentions)
+ if downsamplers is not None:
+ down_block_components.downsamplers = downsamplers
+
+ return down_block_components
+
+
+def get_mid_block_adapter(
+ base_channels: int,
+ ctrl_channels: int,
+ temb_channels: Optional[int] = None,
+ max_norm_num_groups: Optional[int] = 32,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: Optional[int] = 1,
+ cross_attention_dim: Optional[int] = 1024,
+ upcast_attention: bool = False,
+):
+ # Before the midblock application, information is concatted from base to control.
+ # Concat doesn't require change in number of channels
+ base_to_ctrl = make_zero_conv(base_channels, base_channels)
+
+ midblock = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=ctrl_channels + base_channels,
+ out_channels=ctrl_channels,
+ temb_channels=temb_channels,
+ # number or norm groups must divide both in_channels and out_channels
+ resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups),
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ use_linear_projection=True,
+ upcast_attention=upcast_attention,
+ )
+
+ # After the midblock application, information is added from control to base
+ # Addition requires change in number of channels
+ ctrl_to_base = make_zero_conv(ctrl_channels, base_channels)
+
+ return MidBlockControlNetXSAdapter(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base)
+
+
+def get_up_block_adapter(
+ out_channels: int,
+ prev_output_channel: int,
+ ctrl_skip_channels: List[int],
+):
+ ctrl_to_base = []
+ num_layers = 3 # only support sd + sdxl
+ for i in range(num_layers):
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+ ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels))
+
+ return UpBlockControlNetXSAdapter(ctrl_to_base=nn.ModuleList(ctrl_to_base))
+
+
+class ControlNetXSAdapter(ModelMixin, ConfigMixin):
+ r"""
+ A `ControlNetXSAdapter` model. To use it, pass it into a `UNetControlNetXSModel` (together with a
+ `UNet2DConditionModel` base model).
+
+ This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic
+ methods implemented for all models (such as downloading or saving).
+
+ Like `UNetControlNetXSModel`, `ControlNetXSAdapter` is compatible with StableDiffusion and StableDiffusion-XL. It's
+ default parameters are compatible with StableDiffusion.
+
+ Parameters:
+ conditioning_channels (`int`, defaults to 3):
+ Number of channels of conditioning input (e.g. an image)
+ conditioning_channel_order (`str`, defaults to `"rgb"`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channels for each block in the `controlnet_cond_embedding` layer.
+ time_embedding_mix (`float`, defaults to 1.0):
+ If 0, then only the control adapters's time embedding is used. If 1, then only the base unet's time
+ embedding is used. Otherwise, both are combined.
+ learn_time_embedding (`bool`, defaults to `False`):
+ Whether a time embedding should be learned. If yes, `UNetControlNetXSModel` will combine the time
+ embeddings of the base model and the control adapter. If no, `UNetControlNetXSModel` will use the base
+ model's time embedding.
+ num_attention_heads (`list[int]`, defaults to `[4]`):
+ The number of attention heads.
+ block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`):
+ The tuple of output channels for each block.
+ base_block_out_channels (`list[int]`, defaults to `[320, 640, 1280, 1280]`):
+ The tuple of output channels for each block in the base unet.
+ cross_attention_dim (`int`, defaults to 1024):
+ The dimension of the cross attention features.
+ down_block_types (`list[str]`, defaults to `["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]`):
+ The tuple of downsample blocks to use.
+ sample_size (`int`, defaults to 96):
+ Height and width of input/output sample.
+ transformer_layers_per_block (`Union[int, Tuple[int]]`, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ upcast_attention (`bool`, defaults to `True`):
+ Whether the attention computation should always be upcasted.
+ max_norm_num_groups (`int`, defaults to 32):
+ Maximum number of groups in group normal. The actual number will the the largest divisor of the respective
+ channels, that is <= max_norm_num_groups.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ conditioning_channels: int = 3,
+ conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
+ time_embedding_mix: float = 1.0,
+ learn_time_embedding: bool = False,
+ num_attention_heads: Union[int, Tuple[int]] = 4,
+ block_out_channels: Tuple[int] = (4, 8, 16, 16),
+ base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ cross_attention_dim: int = 1024,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ sample_size: Optional[int] = 96,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ upcast_attention: bool = True,
+ max_norm_num_groups: int = 32,
+ ):
+ super().__init__()
+
+ time_embedding_input_dim = base_block_out_channels[0]
+ time_embedding_dim = base_block_out_channels[0] * 4
+
+ # Check inputs
+ if conditioning_channel_order not in ["rgb", "bgr"]:
+ raise ValueError(f"unknown `conditioning_channel_order`: {conditioning_channel_order}")
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(transformer_layers_per_block, (list, tuple)):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+ if not isinstance(cross_attention_dim, (list, tuple)):
+ cross_attention_dim = [cross_attention_dim] * len(down_block_types)
+ # see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why `ControlNetXSAdapter` takes `num_attention_heads` instead of `attention_head_dim`
+ if not isinstance(num_attention_heads, (list, tuple)):
+ num_attention_heads = [num_attention_heads] * len(down_block_types)
+
+ if len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ # 5 - Create conditioning hint embedding
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ # time
+ if learn_time_embedding:
+ self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embedding_dim)
+ else:
+ self.time_embedding = None
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_connections = nn.ModuleList([])
+
+ # input
+ self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1)
+ self.control_to_base_for_conv_in = make_zero_conv(block_out_channels[0], base_block_out_channels[0])
+
+ # down
+ base_out_channels = base_block_out_channels[0]
+ ctrl_out_channels = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ base_in_channels = base_out_channels
+ base_out_channels = base_block_out_channels[i]
+ ctrl_in_channels = ctrl_out_channels
+ ctrl_out_channels = block_out_channels[i]
+ has_crossattn = "CrossAttn" in down_block_type
+ is_final_block = i == len(down_block_types) - 1
+
+ self.down_blocks.append(
+ get_down_block_adapter(
+ base_in_channels=base_in_channels,
+ base_out_channels=base_out_channels,
+ ctrl_in_channels=ctrl_in_channels,
+ ctrl_out_channels=ctrl_out_channels,
+ temb_channels=time_embedding_dim,
+ max_norm_num_groups=max_norm_num_groups,
+ has_crossattn=has_crossattn,
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ num_attention_heads=num_attention_heads[i],
+ cross_attention_dim=cross_attention_dim[i],
+ add_downsample=not is_final_block,
+ upcast_attention=upcast_attention,
+ )
+ )
+
+ # mid
+ self.mid_block = get_mid_block_adapter(
+ base_channels=base_block_out_channels[-1],
+ ctrl_channels=block_out_channels[-1],
+ temb_channels=time_embedding_dim,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ num_attention_heads=num_attention_heads[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ upcast_attention=upcast_attention,
+ )
+
+ # up
+ # The skip connection channels are the output of the conv_in and of all the down subblocks
+ ctrl_skip_channels = [block_out_channels[0]]
+ for i, out_channels in enumerate(block_out_channels):
+ number_of_subblocks = (
+ 3 if i < len(block_out_channels) - 1 else 2
+ ) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler
+ ctrl_skip_channels.extend([out_channels] * number_of_subblocks)
+
+ reversed_base_block_out_channels = list(reversed(base_block_out_channels))
+
+ base_out_channels = reversed_base_block_out_channels[0]
+ for i in range(len(down_block_types)):
+ prev_base_output_channel = base_out_channels
+ base_out_channels = reversed_base_block_out_channels[i]
+ ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)]
+
+ self.up_connections.append(
+ get_up_block_adapter(
+ out_channels=base_out_channels,
+ prev_output_channel=prev_base_output_channel,
+ ctrl_skip_channels=ctrl_skip_channels_,
+ )
+ )
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ size_ratio: Optional[float] = None,
+ block_out_channels: Optional[List[int]] = None,
+ num_attention_heads: Optional[List[int]] = None,
+ learn_time_embedding: bool = False,
+ time_embedding_mix: int = 1.0,
+ conditioning_channels: int = 3,
+ conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
+ ):
+ r"""
+ Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model we want to control. The dimensions of the ControlNetXSAdapter will be adapted to it.
+ size_ratio (float, *optional*, defaults to `None`):
+ When given, block_out_channels is set to a fraction of the base model's block_out_channels. Either this
+ or `block_out_channels` must be given.
+ block_out_channels (`List[int]`, *optional*, defaults to `None`):
+ Down blocks output channels in control model. Either this or `size_ratio` must be given.
+ num_attention_heads (`List[int]`, *optional*, defaults to `None`):
+ The dimension of the attention heads. The naming seems a bit confusing and it is, see
+ https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
+ learn_time_embedding (`bool`, defaults to `False`):
+ Whether the `ControlNetXSAdapter` should learn a time embedding.
+ time_embedding_mix (`float`, defaults to 1.0):
+ If 0, then only the control adapter's time embedding is used. If 1, then only the base unet's time
+ embedding is used. Otherwise, both are combined.
+ conditioning_channels (`int`, defaults to 3):
+ Number of channels of conditioning input (e.g. an image)
+ conditioning_channel_order (`str`, defaults to `"rgb"`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `controlnet_cond_embedding` layer.
+ """
+
+ # Check input
+ fixed_size = block_out_channels is not None
+ relative_size = size_ratio is not None
+ if not (fixed_size ^ relative_size):
+ raise ValueError(
+ "Pass exactly one of `block_out_channels` (for absolute sizing) or `size_ratio` (for relative sizing)."
+ )
+
+ # Create model
+ block_out_channels = block_out_channels or [int(b * size_ratio) for b in unet.config.block_out_channels]
+ if num_attention_heads is None:
+ # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
+ num_attention_heads = unet.config.attention_head_dim
+
+ model = cls(
+ conditioning_channels=conditioning_channels,
+ conditioning_channel_order=conditioning_channel_order,
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ time_embedding_mix=time_embedding_mix,
+ learn_time_embedding=learn_time_embedding,
+ num_attention_heads=num_attention_heads,
+ block_out_channels=block_out_channels,
+ base_block_out_channels=unet.config.block_out_channels,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ down_block_types=unet.config.down_block_types,
+ sample_size=unet.config.sample_size,
+ transformer_layers_per_block=unet.config.transformer_layers_per_block,
+ upcast_attention=unet.config.upcast_attention,
+ max_norm_num_groups=unet.config.norm_num_groups,
+ )
+
+ # ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel
+ model.to(unet.dtype)
+
+ return model
+
+ def forward(self, *args, **kwargs):
+ raise ValueError(
+ "A ControlNetXSAdapter cannot be run by itself. Use it together with a UNet2DConditionModel to instantiate a UNetControlNetXSModel."
+ )
+
+
+class UNetControlNetXSModel(ModelMixin, ConfigMixin):
+ r"""
+ A UNet fused with a ControlNet-XS adapter model
+
+ This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic
+ methods implemented for all models (such as downloading or saving).
+
+ `UNetControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are
+ compatible with StableDiffusion.
+
+ It's parameters are either passed to the underlying `UNet2DConditionModel` or used exactly like in
+ `ControlNetXSAdapter` . See their documentation for details.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ # unet configs
+ sample_size: Optional[int] = 96,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ norm_num_groups: Optional[int] = 32,
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ num_attention_heads: Union[int, Tuple[int]] = 8,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ upcast_attention: bool = True,
+ time_cond_proj_dim: Optional[int] = None,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ # additional controlnet configs
+ time_embedding_mix: float = 1.0,
+ ctrl_conditioning_channels: int = 3,
+ ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
+ ctrl_conditioning_channel_order: str = "rgb",
+ ctrl_learn_time_embedding: bool = False,
+ ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16),
+ ctrl_num_attention_heads: Union[int, Tuple[int]] = 4,
+ ctrl_max_norm_num_groups: int = 32,
+ ):
+ super().__init__()
+
+ if time_embedding_mix < 0 or time_embedding_mix > 1:
+ raise ValueError("`time_embedding_mix` needs to be between 0 and 1.")
+ if time_embedding_mix < 1 and not ctrl_learn_time_embedding:
+ raise ValueError("To use `time_embedding_mix` < 1, `ctrl_learn_time_embedding` must be `True`")
+
+ if addition_embed_type is not None and addition_embed_type != "text_time":
+ raise ValueError(
+ "As `UNetControlNetXSModel` currently only supports StableDiffusion and StableDiffusion-XL, `addition_embed_type` must be `None` or `'text_time'`."
+ )
+
+ if not isinstance(transformer_layers_per_block, (list, tuple)):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+ if not isinstance(cross_attention_dim, (list, tuple)):
+ cross_attention_dim = [cross_attention_dim] * len(down_block_types)
+ if not isinstance(num_attention_heads, (list, tuple)):
+ num_attention_heads = [num_attention_heads] * len(down_block_types)
+ if not isinstance(ctrl_num_attention_heads, (list, tuple)):
+ ctrl_num_attention_heads = [ctrl_num_attention_heads] * len(down_block_types)
+
+ base_num_attention_heads = num_attention_heads
+
+ self.in_channels = 4
+
+ # # Input
+ self.base_conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1)
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
+ conditioning_embedding_channels=ctrl_block_out_channels[0],
+ block_out_channels=ctrl_conditioning_embedding_out_channels,
+ conditioning_channels=ctrl_conditioning_channels,
+ )
+ self.ctrl_conv_in = nn.Conv2d(4, ctrl_block_out_channels[0], kernel_size=3, padding=1)
+ self.control_to_base_for_conv_in = make_zero_conv(ctrl_block_out_channels[0], block_out_channels[0])
+
+ # # Time
+ time_embed_input_dim = block_out_channels[0]
+ time_embed_dim = block_out_channels[0] * 4
+
+ self.base_time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.base_time_embedding = TimestepEmbedding(
+ time_embed_input_dim,
+ time_embed_dim,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+ self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim)
+
+ if addition_embed_type is None:
+ self.base_add_time_proj = None
+ self.base_add_embedding = None
+ else:
+ self.base_add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.base_add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ # # Create down blocks
+ down_blocks = []
+ base_out_channels = block_out_channels[0]
+ ctrl_out_channels = ctrl_block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ base_in_channels = base_out_channels
+ base_out_channels = block_out_channels[i]
+ ctrl_in_channels = ctrl_out_channels
+ ctrl_out_channels = ctrl_block_out_channels[i]
+ has_crossattn = "CrossAttn" in down_block_type
+ is_final_block = i == len(down_block_types) - 1
+
+ down_blocks.append(
+ ControlNetXSCrossAttnDownBlock2D(
+ base_in_channels=base_in_channels,
+ base_out_channels=base_out_channels,
+ ctrl_in_channels=ctrl_in_channels,
+ ctrl_out_channels=ctrl_out_channels,
+ temb_channels=time_embed_dim,
+ norm_num_groups=norm_num_groups,
+ ctrl_max_norm_num_groups=ctrl_max_norm_num_groups,
+ has_crossattn=has_crossattn,
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ base_num_attention_heads=base_num_attention_heads[i],
+ ctrl_num_attention_heads=ctrl_num_attention_heads[i],
+ cross_attention_dim=cross_attention_dim[i],
+ add_downsample=not is_final_block,
+ upcast_attention=upcast_attention,
+ )
+ )
+
+ # # Create mid block
+ self.mid_block = ControlNetXSCrossAttnMidBlock2D(
+ base_channels=block_out_channels[-1],
+ ctrl_channels=ctrl_block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ norm_num_groups=norm_num_groups,
+ ctrl_max_norm_num_groups=ctrl_max_norm_num_groups,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ base_num_attention_heads=base_num_attention_heads[-1],
+ ctrl_num_attention_heads=ctrl_num_attention_heads[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ upcast_attention=upcast_attention,
+ )
+
+ # # Create up blocks
+ up_blocks = []
+ rev_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+ rev_num_attention_heads = list(reversed(base_num_attention_heads))
+ rev_cross_attention_dim = list(reversed(cross_attention_dim))
+
+ # The skip connection channels are the output of the conv_in and of all the down subblocks
+ ctrl_skip_channels = [ctrl_block_out_channels[0]]
+ for i, out_channels in enumerate(ctrl_block_out_channels):
+ number_of_subblocks = (
+ 3 if i < len(ctrl_block_out_channels) - 1 else 2
+ ) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler
+ ctrl_skip_channels.extend([out_channels] * number_of_subblocks)
+
+ reversed_block_out_channels = list(reversed(block_out_channels))
+
+ out_channels = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = out_channels
+ out_channels = reversed_block_out_channels[i]
+ in_channels = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+ ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)]
+
+ has_crossattn = "CrossAttn" in up_block_type
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_blocks.append(
+ ControlNetXSCrossAttnUpBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ ctrl_skip_channels=ctrl_skip_channels_,
+ temb_channels=time_embed_dim,
+ resolution_idx=i,
+ has_crossattn=has_crossattn,
+ transformer_layers_per_block=rev_transformer_layers_per_block[i],
+ num_attention_heads=rev_num_attention_heads[i],
+ cross_attention_dim=rev_cross_attention_dim[i],
+ add_upsample=not is_final_block,
+ upcast_attention=upcast_attention,
+ norm_num_groups=norm_num_groups,
+ )
+ )
+
+ self.down_blocks = nn.ModuleList(down_blocks)
+ self.up_blocks = nn.ModuleList(up_blocks)
+
+ self.base_conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups)
+ self.base_conv_act = nn.SiLU()
+ self.base_conv_out = nn.Conv2d(block_out_channels[0], 4, kernel_size=3, padding=1)
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ controlnet: Optional[ControlNetXSAdapter] = None,
+ size_ratio: Optional[float] = None,
+ ctrl_block_out_channels: Optional[List[float]] = None,
+ time_embedding_mix: Optional[float] = None,
+ ctrl_optional_kwargs: Optional[Dict] = None,
+ ):
+ r"""
+ Instantiate a [`UNetControlNetXSModel`] from a [`UNet2DConditionModel`] and an optional [`ControlNetXSAdapter`]
+ .
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model we want to control.
+ controlnet (`ControlNetXSAdapter`):
+ The ConntrolNet-XS adapter with which the UNet will be fused. If none is given, a new ConntrolNet-XS
+ adapter will be created.
+ size_ratio (float, *optional*, defaults to `None`):
+ Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
+ ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`):
+ Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details,
+ where this parameter is called `block_out_channels`.
+ time_embedding_mix (`float`, *optional*, defaults to None):
+ Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
+ ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`):
+ Passed to the `init` of the new controlent if no controlent was given.
+ """
+ if controlnet is None:
+ controlnet = ControlNetXSAdapter.from_unet(
+ unet, size_ratio, ctrl_block_out_channels, **ctrl_optional_kwargs
+ )
+ else:
+ if any(
+ o is not None for o in (size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs)
+ ):
+ raise ValueError(
+ "When a controlnet is passed, none of these parameters should be passed: size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs."
+ )
+
+ # # get params
+ params_for_unet = [
+ "sample_size",
+ "down_block_types",
+ "up_block_types",
+ "block_out_channels",
+ "norm_num_groups",
+ "cross_attention_dim",
+ "transformer_layers_per_block",
+ "addition_embed_type",
+ "addition_time_embed_dim",
+ "upcast_attention",
+ "time_cond_proj_dim",
+ "projection_class_embeddings_input_dim",
+ ]
+ params_for_unet = {k: v for k, v in unet.config.items() if k in params_for_unet}
+ # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
+ params_for_unet["num_attention_heads"] = unet.config.attention_head_dim
+
+ params_for_controlnet = [
+ "conditioning_channels",
+ "conditioning_embedding_out_channels",
+ "conditioning_channel_order",
+ "learn_time_embedding",
+ "block_out_channels",
+ "num_attention_heads",
+ "max_norm_num_groups",
+ ]
+ params_for_controlnet = {"ctrl_" + k: v for k, v in controlnet.config.items() if k in params_for_controlnet}
+ params_for_controlnet["time_embedding_mix"] = controlnet.config.time_embedding_mix
+
+ # # create model
+ model = cls.from_config({**params_for_unet, **params_for_controlnet})
+
+ # # load weights
+ # from unet
+ modules_from_unet = [
+ "time_embedding",
+ "conv_in",
+ "conv_norm_out",
+ "conv_out",
+ ]
+ for m in modules_from_unet:
+ getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict())
+
+ optional_modules_from_unet = [
+ "add_time_proj",
+ "add_embedding",
+ ]
+ for m in optional_modules_from_unet:
+ if hasattr(unet, m) and getattr(unet, m) is not None:
+ getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict())
+
+ # from controlnet
+ model.controlnet_cond_embedding.load_state_dict(controlnet.controlnet_cond_embedding.state_dict())
+ model.ctrl_conv_in.load_state_dict(controlnet.conv_in.state_dict())
+ if controlnet.time_embedding is not None:
+ model.ctrl_time_embedding.load_state_dict(controlnet.time_embedding.state_dict())
+ model.control_to_base_for_conv_in.load_state_dict(controlnet.control_to_base_for_conv_in.state_dict())
+
+ # from both
+ model.down_blocks = nn.ModuleList(
+ ControlNetXSCrossAttnDownBlock2D.from_modules(b, c)
+ for b, c in zip(unet.down_blocks, controlnet.down_blocks)
+ )
+ model.mid_block = ControlNetXSCrossAttnMidBlock2D.from_modules(unet.mid_block, controlnet.mid_block)
+ model.up_blocks = nn.ModuleList(
+ ControlNetXSCrossAttnUpBlock2D.from_modules(b, c)
+ for b, c in zip(unet.up_blocks, controlnet.up_connections)
+ )
+
+ # ensure that the UNetControlNetXSModel is the same dtype as the UNet2DConditionModel
+ model.to(unet.dtype)
+
+ return model
+
+ def freeze_unet_params(self) -> None:
+ """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
+ tuning."""
+ # Freeze everything
+ for param in self.parameters():
+ param.requires_grad = True
+
+ # Unfreeze ControlNetXSAdapter
+ base_parts = [
+ "base_time_proj",
+ "base_time_embedding",
+ "base_add_time_proj",
+ "base_add_embedding",
+ "base_conv_in",
+ "base_conv_norm_out",
+ "base_conv_act",
+ "base_conv_out",
+ ]
+ base_parts = [getattr(self, part) for part in base_parts if getattr(self, part) is not None]
+ for part in base_parts:
+ for param in part.parameters():
+ param.requires_grad = False
+
+ for d in self.down_blocks:
+ d.freeze_base_params()
+ self.mid_block.freeze_base_params()
+ for u in self.up_blocks:
+ u.freeze_base_params()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ setattr(upsample_block, k, None)
+
+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ sample: FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: Optional[torch.Tensor] = None,
+ conditioning_scale: Optional[float] = 1.0,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ return_dict: bool = True,
+ apply_control: bool = True,
+ ) -> Union[ControlNetXSOutput, Tuple]:
+ """
+ The [`ControlNetXSModel`] forward method.
+
+ Args:
+ sample (`FloatTensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ controlnet_cond (`FloatTensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ How much the control model affects the base model outputs.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
+ embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
+ added_cond_kwargs (`dict`):
+ Additional conditions for the Stable Diffusion XL UNet.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
+ apply_control (`bool`, defaults to `True`):
+ If `False`, the input is run only through the base model.
+
+ Returns:
+ [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+ """
+
+ # check channel order
+ if self.config.ctrl_conditioning_channel_order == "bgr":
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.base_time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ if self.config.ctrl_learn_time_embedding and apply_control:
+ ctrl_temb = self.ctrl_time_embedding(t_emb, timestep_cond)
+ base_temb = self.base_time_embedding(t_emb, timestep_cond)
+ interpolation_param = self.config.time_embedding_mix**0.3
+
+ temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param)
+ else:
+ temb = self.base_time_embedding(t_emb)
+
+ # added time & text embeddings
+ aug_emb = None
+
+ if self.config.addition_embed_type is None:
+ pass
+ elif self.config.addition_embed_type == "text_time":
+ # SDXL - style
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.base_add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(temb.dtype)
+ aug_emb = self.base_add_embedding(add_embeds)
+ else:
+ raise ValueError(
+ f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.config.addition_embed_type} is currently not supported."
+ )
+
+ temb = temb + aug_emb if aug_emb is not None else temb
+
+ # text embeddings
+ cemb = encoder_hidden_states
+
+ # Preparation
+ h_ctrl = h_base = sample
+ hs_base, hs_ctrl = [], []
+
+ # Cross Control
+ guided_hint = self.controlnet_cond_embedding(controlnet_cond)
+
+ # 1 - conv in & down
+
+ h_base = self.base_conv_in(h_base)
+ h_ctrl = self.ctrl_conv_in(h_ctrl)
+ if guided_hint is not None:
+ h_ctrl += guided_hint
+ if apply_control:
+ h_base = h_base + self.control_to_base_for_conv_in(h_ctrl) * conditioning_scale # add ctrl -> base
+
+ hs_base.append(h_base)
+ hs_ctrl.append(h_ctrl)
+
+ for down in self.down_blocks:
+ h_base, h_ctrl, residual_hb, residual_hc = down(
+ hidden_states_base=h_base,
+ hidden_states_ctrl=h_ctrl,
+ temb=temb,
+ encoder_hidden_states=cemb,
+ conditioning_scale=conditioning_scale,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ apply_control=apply_control,
+ )
+ hs_base.extend(residual_hb)
+ hs_ctrl.extend(residual_hc)
+
+ # 2 - mid
+ h_base, h_ctrl = self.mid_block(
+ hidden_states_base=h_base,
+ hidden_states_ctrl=h_ctrl,
+ temb=temb,
+ encoder_hidden_states=cemb,
+ conditioning_scale=conditioning_scale,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ apply_control=apply_control,
+ )
+
+ # 3 - up
+ for up in self.up_blocks:
+ n_resnets = len(up.resnets)
+ skips_hb = hs_base[-n_resnets:]
+ skips_hc = hs_ctrl[-n_resnets:]
+ hs_base = hs_base[:-n_resnets]
+ hs_ctrl = hs_ctrl[:-n_resnets]
+ h_base = up(
+ hidden_states=h_base,
+ res_hidden_states_tuple_base=skips_hb,
+ res_hidden_states_tuple_ctrl=skips_hc,
+ temb=temb,
+ encoder_hidden_states=cemb,
+ conditioning_scale=conditioning_scale,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ apply_control=apply_control,
+ )
+
+ # 4 - conv out
+ h_base = self.base_conv_norm_out(h_base)
+ h_base = self.base_conv_act(h_base)
+ h_base = self.base_conv_out(h_base)
+
+ if not return_dict:
+ return (h_base,)
+
+ return ControlNetXSOutput(sample=h_base)
+
+
+class ControlNetXSCrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ base_in_channels: int,
+ base_out_channels: int,
+ ctrl_in_channels: int,
+ ctrl_out_channels: int,
+ temb_channels: int,
+ norm_num_groups: int = 32,
+ ctrl_max_norm_num_groups: int = 32,
+ has_crossattn=True,
+ transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1,
+ base_num_attention_heads: Optional[int] = 1,
+ ctrl_num_attention_heads: Optional[int] = 1,
+ cross_attention_dim: Optional[int] = 1024,
+ add_downsample: bool = True,
+ upcast_attention: Optional[bool] = False,
+ ):
+ super().__init__()
+ base_resnets = []
+ base_attentions = []
+ ctrl_resnets = []
+ ctrl_attentions = []
+ ctrl_to_base = []
+ base_to_ctrl = []
+
+ num_layers = 2 # only support sd + sdxl
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ base_in_channels = base_in_channels if i == 0 else base_out_channels
+ ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels
+
+ # Before the resnet/attention application, information is concatted from base to control.
+ # Concat doesn't require change in number of channels
+ base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels))
+
+ base_resnets.append(
+ ResnetBlock2D(
+ in_channels=base_in_channels,
+ out_channels=base_out_channels,
+ temb_channels=temb_channels,
+ groups=norm_num_groups,
+ )
+ )
+ ctrl_resnets.append(
+ ResnetBlock2D(
+ in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl
+ out_channels=ctrl_out_channels,
+ temb_channels=temb_channels,
+ groups=find_largest_factor(
+ ctrl_in_channels + base_in_channels, max_factor=ctrl_max_norm_num_groups
+ ),
+ groups_out=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
+ eps=1e-5,
+ )
+ )
+
+ if has_crossattn:
+ base_attentions.append(
+ Transformer2DModel(
+ base_num_attention_heads,
+ base_out_channels // base_num_attention_heads,
+ in_channels=base_out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ use_linear_projection=True,
+ upcast_attention=upcast_attention,
+ norm_num_groups=norm_num_groups,
+ )
+ )
+ ctrl_attentions.append(
+ Transformer2DModel(
+ ctrl_num_attention_heads,
+ ctrl_out_channels // ctrl_num_attention_heads,
+ in_channels=ctrl_out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ use_linear_projection=True,
+ upcast_attention=upcast_attention,
+ norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
+ )
+ )
+
+ # After the resnet/attention application, information is added from control to base
+ # Addition requires change in number of channels
+ ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
+
+ if add_downsample:
+ # Before the downsampler application, information is concatted from base to control
+ # Concat doesn't require change in number of channels
+ base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels))
+
+ self.base_downsamplers = Downsample2D(
+ base_out_channels, use_conv=True, out_channels=base_out_channels, name="op"
+ )
+ self.ctrl_downsamplers = Downsample2D(
+ ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op"
+ )
+
+ # After the downsampler application, information is added from control to base
+ # Addition requires change in number of channels
+ ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
+ else:
+ self.base_downsamplers = None
+ self.ctrl_downsamplers = None
+
+ self.base_resnets = nn.ModuleList(base_resnets)
+ self.ctrl_resnets = nn.ModuleList(ctrl_resnets)
+ self.base_attentions = nn.ModuleList(base_attentions) if has_crossattn else [None] * num_layers
+ self.ctrl_attentions = nn.ModuleList(ctrl_attentions) if has_crossattn else [None] * num_layers
+ self.base_to_ctrl = nn.ModuleList(base_to_ctrl)
+ self.ctrl_to_base = nn.ModuleList(ctrl_to_base)
+
+ self.gradient_checkpointing = False
+
+ @classmethod
+ def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: DownBlockControlNetXSAdapter):
+ # get params
+ def get_first_cross_attention(block):
+ return block.attentions[0].transformer_blocks[0].attn2
+
+ base_in_channels = base_downblock.resnets[0].in_channels
+ base_out_channels = base_downblock.resnets[0].out_channels
+ ctrl_in_channels = (
+ ctrl_downblock.resnets[0].in_channels - base_in_channels
+ ) # base channels are concatted to ctrl channels in init
+ ctrl_out_channels = ctrl_downblock.resnets[0].out_channels
+ temb_channels = base_downblock.resnets[0].time_emb_proj.in_features
+ num_groups = base_downblock.resnets[0].norm1.num_groups
+ ctrl_num_groups = ctrl_downblock.resnets[0].norm1.num_groups
+ if hasattr(base_downblock, "attentions"):
+ has_crossattn = True
+ transformer_layers_per_block = len(base_downblock.attentions[0].transformer_blocks)
+ base_num_attention_heads = get_first_cross_attention(base_downblock).heads
+ ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads
+ cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim
+ upcast_attention = get_first_cross_attention(base_downblock).upcast_attention
+ else:
+ has_crossattn = False
+ transformer_layers_per_block = None
+ base_num_attention_heads = None
+ ctrl_num_attention_heads = None
+ cross_attention_dim = None
+ upcast_attention = None
+ add_downsample = base_downblock.downsamplers is not None
+
+ # create model
+ model = cls(
+ base_in_channels=base_in_channels,
+ base_out_channels=base_out_channels,
+ ctrl_in_channels=ctrl_in_channels,
+ ctrl_out_channels=ctrl_out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=num_groups,
+ ctrl_max_norm_num_groups=ctrl_num_groups,
+ has_crossattn=has_crossattn,
+ transformer_layers_per_block=transformer_layers_per_block,
+ base_num_attention_heads=base_num_attention_heads,
+ ctrl_num_attention_heads=ctrl_num_attention_heads,
+ cross_attention_dim=cross_attention_dim,
+ add_downsample=add_downsample,
+ upcast_attention=upcast_attention,
+ )
+
+ # # load weights
+ model.base_resnets.load_state_dict(base_downblock.resnets.state_dict())
+ model.ctrl_resnets.load_state_dict(ctrl_downblock.resnets.state_dict())
+ if has_crossattn:
+ model.base_attentions.load_state_dict(base_downblock.attentions.state_dict())
+ model.ctrl_attentions.load_state_dict(ctrl_downblock.attentions.state_dict())
+ if add_downsample:
+ model.base_downsamplers.load_state_dict(base_downblock.downsamplers[0].state_dict())
+ model.ctrl_downsamplers.load_state_dict(ctrl_downblock.downsamplers.state_dict())
+ model.base_to_ctrl.load_state_dict(ctrl_downblock.base_to_ctrl.state_dict())
+ model.ctrl_to_base.load_state_dict(ctrl_downblock.ctrl_to_base.state_dict())
+
+ return model
+
+ def freeze_base_params(self) -> None:
+ """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
+ tuning."""
+ # Unfreeze everything
+ for param in self.parameters():
+ param.requires_grad = True
+
+ # Freeze base part
+ base_parts = [self.base_resnets]
+ if isinstance(self.base_attentions, nn.ModuleList): # attentions can be a list of Nones
+ base_parts.append(self.base_attentions)
+ if self.base_downsamplers is not None:
+ base_parts.append(self.base_downsamplers)
+ for part in base_parts:
+ for param in part.parameters():
+ param.requires_grad = False
+
+ def forward(
+ self,
+ hidden_states_base: FloatTensor,
+ temb: FloatTensor,
+ encoder_hidden_states: Optional[FloatTensor] = None,
+ hidden_states_ctrl: Optional[FloatTensor] = None,
+ conditioning_scale: Optional[float] = 1.0,
+ attention_mask: Optional[FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[FloatTensor] = None,
+ apply_control: bool = True,
+ ) -> Tuple[FloatTensor, FloatTensor, Tuple[FloatTensor, ...], Tuple[FloatTensor, ...]]:
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ h_base = hidden_states_base
+ h_ctrl = hidden_states_ctrl
+
+ base_output_states = ()
+ ctrl_output_states = ()
+
+ base_blocks = list(zip(self.base_resnets, self.base_attentions))
+ ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions))
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip(
+ base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base
+ ):
+ # concat base -> ctrl
+ if apply_control:
+ h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)
+
+ # apply base subblock
+ if self.training and self.gradient_checkpointing:
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ h_base = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(b_res),
+ h_base,
+ temb,
+ **ckpt_kwargs,
+ )
+ else:
+ h_base = b_res(h_base, temb)
+
+ if b_attn is not None:
+ h_base = b_attn(
+ h_base,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ # apply ctrl subblock
+ if apply_control:
+ if self.training and self.gradient_checkpointing:
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ h_ctrl = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(c_res),
+ h_ctrl,
+ temb,
+ **ckpt_kwargs,
+ )
+ else:
+ h_ctrl = c_res(h_ctrl, temb)
+ if c_attn is not None:
+ h_ctrl = c_attn(
+ h_ctrl,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ # add ctrl -> base
+ if apply_control:
+ h_base = h_base + c2b(h_ctrl) * conditioning_scale
+
+ base_output_states = base_output_states + (h_base,)
+ ctrl_output_states = ctrl_output_states + (h_ctrl,)
+
+ if self.base_downsamplers is not None: # if we have a base_downsampler, then also a ctrl_downsampler
+ b2c = self.base_to_ctrl[-1]
+ c2b = self.ctrl_to_base[-1]
+
+ # concat base -> ctrl
+ if apply_control:
+ h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)
+ # apply base subblock
+ h_base = self.base_downsamplers(h_base)
+ # apply ctrl subblock
+ if apply_control:
+ h_ctrl = self.ctrl_downsamplers(h_ctrl)
+ # add ctrl -> base
+ if apply_control:
+ h_base = h_base + c2b(h_ctrl) * conditioning_scale
+
+ base_output_states = base_output_states + (h_base,)
+ ctrl_output_states = ctrl_output_states + (h_ctrl,)
+
+ return h_base, h_ctrl, base_output_states, ctrl_output_states
+
+
+class ControlNetXSCrossAttnMidBlock2D(nn.Module):
+ def __init__(
+ self,
+ base_channels: int,
+ ctrl_channels: int,
+ temb_channels: Optional[int] = None,
+ norm_num_groups: int = 32,
+ ctrl_max_norm_num_groups: int = 32,
+ transformer_layers_per_block: int = 1,
+ base_num_attention_heads: Optional[int] = 1,
+ ctrl_num_attention_heads: Optional[int] = 1,
+ cross_attention_dim: Optional[int] = 1024,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+
+ # Before the midblock application, information is concatted from base to control.
+ # Concat doesn't require change in number of channels
+ self.base_to_ctrl = make_zero_conv(base_channels, base_channels)
+
+ self.base_midblock = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=base_channels,
+ temb_channels=temb_channels,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=base_num_attention_heads,
+ use_linear_projection=True,
+ upcast_attention=upcast_attention,
+ )
+
+ self.ctrl_midblock = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=ctrl_channels + base_channels,
+ out_channels=ctrl_channels,
+ temb_channels=temb_channels,
+ # number or norm groups must divide both in_channels and out_channels
+ resnet_groups=find_largest_factor(
+ gcd(ctrl_channels, ctrl_channels + base_channels), ctrl_max_norm_num_groups
+ ),
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=ctrl_num_attention_heads,
+ use_linear_projection=True,
+ upcast_attention=upcast_attention,
+ )
+
+ # After the midblock application, information is added from control to base
+ # Addition requires change in number of channels
+ self.ctrl_to_base = make_zero_conv(ctrl_channels, base_channels)
+
+ self.gradient_checkpointing = False
+
+ @classmethod
+ def from_modules(
+ cls,
+ base_midblock: UNetMidBlock2DCrossAttn,
+ ctrl_midblock: MidBlockControlNetXSAdapter,
+ ):
+ base_to_ctrl = ctrl_midblock.base_to_ctrl
+ ctrl_to_base = ctrl_midblock.ctrl_to_base
+ ctrl_midblock = ctrl_midblock.midblock
+
+ # get params
+ def get_first_cross_attention(midblock):
+ return midblock.attentions[0].transformer_blocks[0].attn2
+
+ base_channels = ctrl_to_base.out_channels
+ ctrl_channels = ctrl_to_base.in_channels
+ transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks)
+ temb_channels = base_midblock.resnets[0].time_emb_proj.in_features
+ num_groups = base_midblock.resnets[0].norm1.num_groups
+ ctrl_num_groups = ctrl_midblock.resnets[0].norm1.num_groups
+ base_num_attention_heads = get_first_cross_attention(base_midblock).heads
+ ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads
+ cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim
+ upcast_attention = get_first_cross_attention(base_midblock).upcast_attention
+
+ # create model
+ model = cls(
+ base_channels=base_channels,
+ ctrl_channels=ctrl_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=num_groups,
+ ctrl_max_norm_num_groups=ctrl_num_groups,
+ transformer_layers_per_block=transformer_layers_per_block,
+ base_num_attention_heads=base_num_attention_heads,
+ ctrl_num_attention_heads=ctrl_num_attention_heads,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ )
+
+ # load weights
+ model.base_to_ctrl.load_state_dict(base_to_ctrl.state_dict())
+ model.base_midblock.load_state_dict(base_midblock.state_dict())
+ model.ctrl_midblock.load_state_dict(ctrl_midblock.state_dict())
+ model.ctrl_to_base.load_state_dict(ctrl_to_base.state_dict())
+
+ return model
+
+ def freeze_base_params(self) -> None:
+ """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
+ tuning."""
+ # Unfreeze everything
+ for param in self.parameters():
+ param.requires_grad = True
+
+ # Freeze base part
+ for param in self.base_midblock.parameters():
+ param.requires_grad = False
+
+ def forward(
+ self,
+ hidden_states_base: FloatTensor,
+ temb: FloatTensor,
+ encoder_hidden_states: FloatTensor,
+ hidden_states_ctrl: Optional[FloatTensor] = None,
+ conditioning_scale: Optional[float] = 1.0,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ attention_mask: Optional[FloatTensor] = None,
+ encoder_attention_mask: Optional[FloatTensor] = None,
+ apply_control: bool = True,
+ ) -> Tuple[FloatTensor, FloatTensor]:
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ h_base = hidden_states_base
+ h_ctrl = hidden_states_ctrl
+
+ joint_args = {
+ "temb": temb,
+ "encoder_hidden_states": encoder_hidden_states,
+ "attention_mask": attention_mask,
+ "cross_attention_kwargs": cross_attention_kwargs,
+ "encoder_attention_mask": encoder_attention_mask,
+ }
+
+ if apply_control:
+ h_ctrl = torch.cat([h_ctrl, self.base_to_ctrl(h_base)], dim=1) # concat base -> ctrl
+ h_base = self.base_midblock(h_base, **joint_args) # apply base mid block
+ if apply_control:
+ h_ctrl = self.ctrl_midblock(h_ctrl, **joint_args) # apply ctrl mid block
+ h_base = h_base + self.ctrl_to_base(h_ctrl) * conditioning_scale # add ctrl -> base
+
+ return h_base, h_ctrl
+
+
+class ControlNetXSCrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ ctrl_skip_channels: List[int],
+ temb_channels: int,
+ norm_num_groups: int = 32,
+ resolution_idx: Optional[int] = None,
+ has_crossattn=True,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1024,
+ add_upsample: bool = True,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ ctrl_to_base = []
+
+ num_layers = 3 # only support sd + sdxl
+
+ self.has_cross_attention = has_crossattn
+ self.num_attention_heads = num_attention_heads
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels))
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ groups=norm_num_groups,
+ )
+ )
+
+ if has_crossattn:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ use_linear_projection=True,
+ upcast_attention=upcast_attention,
+ norm_num_groups=norm_num_groups,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions) if has_crossattn else [None] * num_layers
+ self.ctrl_to_base = nn.ModuleList(ctrl_to_base)
+
+ if add_upsample:
+ self.upsamplers = Upsample2D(out_channels, use_conv=True, out_channels=out_channels)
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ @classmethod
+ def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_upblock: UpBlockControlNetXSAdapter):
+ ctrl_to_base_skip_connections = ctrl_upblock.ctrl_to_base
+
+ # get params
+ def get_first_cross_attention(block):
+ return block.attentions[0].transformer_blocks[0].attn2
+
+ out_channels = base_upblock.resnets[0].out_channels
+ in_channels = base_upblock.resnets[-1].in_channels - out_channels
+ prev_output_channels = base_upblock.resnets[0].in_channels - out_channels
+ ctrl_skip_channelss = [c.in_channels for c in ctrl_to_base_skip_connections]
+ temb_channels = base_upblock.resnets[0].time_emb_proj.in_features
+ num_groups = base_upblock.resnets[0].norm1.num_groups
+ resolution_idx = base_upblock.resolution_idx
+ if hasattr(base_upblock, "attentions"):
+ has_crossattn = True
+ transformer_layers_per_block = len(base_upblock.attentions[0].transformer_blocks)
+ num_attention_heads = get_first_cross_attention(base_upblock).heads
+ cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim
+ upcast_attention = get_first_cross_attention(base_upblock).upcast_attention
+ else:
+ has_crossattn = False
+ transformer_layers_per_block = None
+ num_attention_heads = None
+ cross_attention_dim = None
+ upcast_attention = None
+ add_upsample = base_upblock.upsamplers is not None
+
+ # create model
+ model = cls(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channels,
+ ctrl_skip_channels=ctrl_skip_channelss,
+ temb_channels=temb_channels,
+ norm_num_groups=num_groups,
+ resolution_idx=resolution_idx,
+ has_crossattn=has_crossattn,
+ transformer_layers_per_block=transformer_layers_per_block,
+ num_attention_heads=num_attention_heads,
+ cross_attention_dim=cross_attention_dim,
+ add_upsample=add_upsample,
+ upcast_attention=upcast_attention,
+ )
+
+ # load weights
+ model.resnets.load_state_dict(base_upblock.resnets.state_dict())
+ if has_crossattn:
+ model.attentions.load_state_dict(base_upblock.attentions.state_dict())
+ if add_upsample:
+ model.upsamplers.load_state_dict(base_upblock.upsamplers[0].state_dict())
+ model.ctrl_to_base.load_state_dict(ctrl_to_base_skip_connections.state_dict())
+
+ return model
+
+ def freeze_base_params(self) -> None:
+ """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
+ tuning."""
+ # Unfreeze everything
+ for param in self.parameters():
+ param.requires_grad = True
+
+ # Freeze base part
+ base_parts = [self.resnets]
+ if isinstance(self.attentions, nn.ModuleList): # attentions can be a list of Nones
+ base_parts.append(self.attentions)
+ if self.upsamplers is not None:
+ base_parts.append(self.upsamplers)
+ for part in base_parts:
+ for param in part.parameters():
+ param.requires_grad = False
+
+ def forward(
+ self,
+ hidden_states: FloatTensor,
+ res_hidden_states_tuple_base: Tuple[FloatTensor, ...],
+ res_hidden_states_tuple_ctrl: Tuple[FloatTensor, ...],
+ temb: FloatTensor,
+ encoder_hidden_states: Optional[FloatTensor] = None,
+ conditioning_scale: Optional[float] = 1.0,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ attention_mask: Optional[FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ encoder_attention_mask: Optional[FloatTensor] = None,
+ apply_control: bool = True,
+ ) -> FloatTensor:
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ return apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_h_base,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+ else:
+ return hidden_states, res_h_base
+
+ for resnet, attn, c2b, res_h_base, res_h_ctrl in zip(
+ self.resnets,
+ self.attentions,
+ self.ctrl_to_base,
+ reversed(res_hidden_states_tuple_base),
+ reversed(res_hidden_states_tuple_ctrl),
+ ):
+ if apply_control:
+ hidden_states += c2b(res_h_ctrl) * conditioning_scale
+
+ hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base)
+ hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if attn is not None:
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if self.upsamplers is not None:
+ hidden_states = self.upsamplers(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+def make_zero_conv(in_channels, out_channels=None):
+ return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
+
+
+def find_largest_factor(number, max_factor):
+ factor = max_factor
+ if factor >= number:
+ return number
+ while factor != 0:
+ residual = number % factor
+ if residual == 0:
+ return factor
+ factor -= 1
diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py
index d54630376961..ef75fad25e44 100644
--- a/src/diffusers/models/unets/unet_2d_blocks.py
+++ b/src/diffusers/models/unets/unet_2d_blocks.py
@@ -746,6 +746,7 @@ def __init__(
self,
in_channels: int,
temb_channels: int,
+ out_channels: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
@@ -753,6 +754,7 @@ def __init__(
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
+ resnet_groups_out: Optional[int] = None,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
@@ -764,6 +766,10 @@ def __init__(
):
super().__init__()
+ out_channels = out_channels or in_channels
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
@@ -772,14 +778,17 @@ def __init__(
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+ resnet_groups_out = resnet_groups_out or resnet_groups
+
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
- out_channels=in_channels,
+ out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
+ groups_out=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
@@ -794,11 +803,11 @@ def __init__(
attentions.append(
Transformer2DModel(
num_attention_heads,
- in_channels // num_attention_heads,
- in_channels=in_channels,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
+ norm_num_groups=resnet_groups_out,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
@@ -808,8 +817,8 @@ def __init__(
attentions.append(
DualTransformer2DModel(
num_attention_heads,
- in_channels // num_attention_heads,
- in_channels=in_channels,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
@@ -817,11 +826,11 @@ def __init__(
)
resnets.append(
ResnetBlock2D(
- in_channels=in_channels,
- out_channels=in_channels,
+ in_channels=out_channels,
+ out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
- groups=resnet_groups,
+ groups=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 2b2277809b38..ab7c13b56eb8 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -134,6 +134,12 @@
"StableDiffusionXLControlNetPipeline",
]
)
+ _import_structure["controlnet_xs"].extend(
+ [
+ "StableDiffusionControlNetXSPipeline",
+ "StableDiffusionXLControlNetXSPipeline",
+ ]
+ )
_import_structure["deepfloyd_if"] = [
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
@@ -378,6 +384,10 @@
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
)
+ from .controlnet_xs import (
+ StableDiffusionControlNetXSPipeline,
+ StableDiffusionXLControlNetXSPipeline,
+ )
from .deepfloyd_if import (
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
diff --git a/src/diffusers/pipelines/controlnet_xs/__init__.py b/src/diffusers/pipelines/controlnet_xs/__init__.py
new file mode 100644
index 000000000000..978278b184f9
--- /dev/null
+++ b/src/diffusers/pipelines/controlnet_xs/__init__.py
@@ -0,0 +1,68 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_flax_available,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"]
+ _import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"]
+try:
+ if not (is_transformers_available() and is_flax_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_flax_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
+else:
+ pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
+
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
+ from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline
+
+ try:
+ if not (is_transformers_available() and is_flax_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
+ else:
+ pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/examples/research_projects/controlnetxs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
similarity index 82%
rename from examples/research_projects/controlnetxs/pipeline_controlnet_xs.py
rename to src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
index 88a586e9271d..2f450b9c2cea 100644
--- a/examples/research_projects/controlnetxs/pipeline_controlnet_xs.py
+++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
@@ -19,30 +19,75 @@
import PIL.Image
import torch
import torch.nn.functional as F
-from controlnetxs import ControlNetXSModel
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
-from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
-from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
-from diffusers.models import AutoencoderKL, UNet2DConditionModel
-from diffusers.models.lora import adjust_lora_scale_text_encoder
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
-from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
-from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
-from diffusers.schedulers import KarrasDiffusionSchedulers
-from diffusers.utils import (
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
+ replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
-from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install opencv-python transformers accelerate
+ >>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSAdapter
+ >>> from diffusers.utils import load_image
+ >>> import numpy as np
+ >>> import torch
+
+ >>> import cv2
+ >>> from PIL import Image
+
+ >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
+ >>> negative_prompt = "low quality, bad quality, sketches"
+
+ >>> # download an image
+ >>> image = load_image(
+ ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
+ ... )
+
+ >>> # initialize the models and pipeline
+ >>> controlnet_conditioning_scale = 0.5
+
+ >>> controlnet = ControlNetXSAdapter.from_pretrained(
+ ... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
+ ... )
+ >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # get canny image
+ >>> image = np.array(image)
+ >>> image = cv2.Canny(image, 100, 200)
+ >>> image = image[:, :, None]
+ >>> image = np.concatenate([image, image, image], axis=2)
+ >>> canny_image = Image.fromarray(image)
+ >>> # generate image
+ >>> image = pipe(
+ ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
+ ... ).images[0]
+ ```
+"""
+
+
class StableDiffusionControlNetXSPipeline(
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
@@ -56,7 +101,7 @@ class StableDiffusionControlNetXSPipeline(
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
@@ -66,9 +111,9 @@ class StableDiffusionControlNetXSPipeline(
tokenizer ([`~transformers.CLIPTokenizer`]):
A `CLIPTokenizer` to tokenize text.
unet ([`UNet2DConditionModel`]):
- A `UNet2DConditionModel` to denoise the encoded image latents.
- controlnet ([`ControlNetXSModel`]):
- Provides additional conditioning to the `unet` during the denoising process.
+ A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents.
+ controlnet ([`ControlNetXSAdapter`]):
+ A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
@@ -80,17 +125,18 @@ class StableDiffusionControlNetXSPipeline(
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
- model_cpu_offload_seq = "text_encoder->unet->vae>controlnet"
+ model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
- unet: UNet2DConditionModel,
- controlnet: ControlNetXSModel,
+ unet: Union[UNet2DConditionModel, UNetControlNetXSModel],
+ controlnet: ControlNetXSAdapter,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
@@ -98,6 +144,9 @@ def __init__(
):
super().__init__()
+ if isinstance(unet, UNet2DConditionModel):
+ unet = UNetControlNetXSModel.from_unet(unet, controlnet)
+
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
@@ -114,14 +163,6 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible(
- vae
- )
- if not vae_compatible:
- raise ValueError(
- f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`."
- )
-
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -403,20 +444,19 @@ def check_inputs(
self,
prompt,
image,
- callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
+ callback_on_step_end_tensor_inputs=None,
):
- if (callback_steps is None) or (
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
- f" {type(callback_steps)}."
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
@@ -445,25 +485,16 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)
- # Check `image`
+ # Check `image` and `controlnet_conditioning_scale`
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ self.unet, torch._dynamo.eval_frame.OptimizedModule
)
if (
- isinstance(self.controlnet, ControlNetXSModel)
+ isinstance(self.unet, UNetControlNetXSModel)
or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
+ and isinstance(self.unet._orig_mod, UNetControlNetXSModel)
):
self.check_image(image, prompt, prompt_embeds)
- else:
- assert False
-
- # Check `controlnet_conditioning_scale`
- if (
- isinstance(self.controlnet, ControlNetXSModel)
- or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
- ):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
else:
@@ -563,7 +594,33 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = latents * self.scheduler.init_noise_sigma
return latents
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
+ def clip_skip(self):
+ return self._clip_skip
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
+ def num_timesteps(self):
+ return self._num_timesteps
+
@torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
@@ -581,13 +638,13 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
- callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
control_guidance_start: float = 0.0,
control_guidance_end: float = 1.0,
clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
r"""
The call function to the pipeline for generation.
@@ -595,7 +652,7 @@ def __call__(
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
- image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
@@ -639,12 +696,6 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
- callback (`Callable`, *optional*):
- A function that calls every `callback_steps` steps during inference. The function is called with the
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
- callback_steps (`int`, *optional*, defaults to 1):
- The frequency at which the `callback` function is called. If not specified, the callback is called at
- every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -659,7 +710,15 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
-
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeine class.
Examples:
Returns:
@@ -669,21 +728,27 @@ def __call__(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
- controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
image,
- callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
+ callback_on_step_end_tensor_inputs,
)
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -713,6 +778,7 @@ def __call__(
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
)
+
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
@@ -720,27 +786,24 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare image
- if isinstance(controlnet, ControlNetXSModel):
- image = self.prepare_image(
- image=image,
- width=width,
- height=height,
- batch_size=batch_size * num_images_per_prompt,
- num_images_per_prompt=num_images_per_prompt,
- device=device,
- dtype=controlnet.dtype,
- do_classifier_free_guidance=do_classifier_free_guidance,
- )
- height, width = image.shape[-2:]
- else:
- assert False
+ image = self.prepare_image(
+ image=image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=unet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+ height, width = image.shape[-2:]
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
- num_channels_latents = self.unet.config.in_channels
+ num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
@@ -757,42 +820,33 @@ def __call__(
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
- is_unet_compiled = is_compiled_module(self.unet)
- is_controlnet_compiled = is_compiled_module(self.controlnet)
+ self._num_timesteps = len(timesteps)
+ is_controlnet_compiled = is_compiled_module(self.unet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
- if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+ if is_controlnet_compiled and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
- dont_control = (
- i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end
+ apply_control = (
+ i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end
)
- if dont_control:
- noise_pred = self.unet(
- sample=latent_model_input,
- timestep=t,
- encoder_hidden_states=prompt_embeds,
- cross_attention_kwargs=cross_attention_kwargs,
- return_dict=True,
- ).sample
- else:
- noise_pred = self.controlnet(
- base_model=self.unet,
- sample=latent_model_input,
- timestep=t,
- encoder_hidden_states=prompt_embeds,
- controlnet_cond=image,
- conditioning_scale=controlnet_conditioning_scale,
- cross_attention_kwargs=cross_attention_kwargs,
- return_dict=True,
- ).sample
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=prompt_embeds,
+ controlnet_cond=image,
+ conditioning_scale=controlnet_conditioning_scale,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=True,
+ apply_control=apply_control,
+ ).sample
# perform guidance
if do_classifier_free_guidance:
@@ -801,12 +855,18 @@ def __call__(
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
- if callback is not None and i % callback_steps == 0:
- step_idx = i // getattr(self.scheduler, "order", 1)
- callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
diff --git a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
similarity index 83%
rename from examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py
rename to src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
index d0186573fa9c..ff270d20d11e 100644
--- a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
@@ -19,41 +19,94 @@
import PIL.Image
import torch
import torch.nn.functional as F
-from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+)
+
+from diffusers.utils.import_utils import is_invisible_watermark_available
-from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
-from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
-from diffusers.models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel
-from diffusers.models.attention_processor import (
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
+from ...models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
-from diffusers.models.lora import adjust_lora_scale_text_encoder
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
-from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
-from diffusers.schedulers import KarrasDiffusionSchedulers
-from diffusers.utils import (
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
logging,
+ replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
-from diffusers.utils.import_utils import is_invisible_watermark_available
-from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available():
- from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+ from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install opencv-python transformers accelerate
+ >>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSAdapter, AutoencoderKL
+ >>> from diffusers.utils import load_image
+ >>> import numpy as np
+ >>> import torch
+
+ >>> import cv2
+ >>> from PIL import Image
+
+ >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
+ >>> negative_prompt = "low quality, bad quality, sketches"
+
+ >>> # download an image
+ >>> image = load_image(
+ ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
+ ... )
+
+ >>> # initialize the models and pipeline
+ >>> controlnet_conditioning_scale = 0.5
+ >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
+ >>> controlnet = ControlNetXSAdapter.from_pretrained(
+ ... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16
+ ... )
+ >>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # get canny image
+ >>> image = np.array(image)
+ >>> image = cv2.Canny(image, 100, 200)
+ >>> image = image[:, :, None]
+ >>> image = np.concatenate([image, image, image], axis=2)
+ >>> canny_image = Image.fromarray(image)
+
+ >>> # generate image
+ >>> image = pipe(
+ ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
+ ... ).images[0]
+ ```
+"""
+
+
class StableDiffusionXLControlNetXSPipeline(
DiffusionPipeline,
- StableDiffusionMixin,
TextualInversionLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
FromSingleFileMixin,
@@ -66,9 +119,8 @@ class StableDiffusionXLControlNetXSPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
@@ -83,9 +135,9 @@ class StableDiffusionXLControlNetXSPipeline(
tokenizer_2 ([`~transformers.CLIPTokenizer`]):
A `CLIPTokenizer` to tokenize text.
unet ([`UNet2DConditionModel`]):
- A `UNet2DConditionModel` to denoise the encoded image latents.
- controlnet ([`ControlNetXSModel`]:
- Provides additional conditioning to the `unet` during the denoising process.
+ A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents.
+ controlnet ([`ControlNetXSAdapter`]):
+ A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
@@ -98,9 +150,15 @@ class StableDiffusionXLControlNetXSPipeline(
watermarker is used.
"""
- # leave controlnet out on purpose because it iterates with unet
- model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae->controlnet"
- _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "feature_extractor",
+ ]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -109,21 +167,17 @@ def __init__(
text_encoder_2: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
- unet: UNet2DConditionModel,
- controlnet: ControlNetXSModel,
+ unet: Union[UNet2DConditionModel, UNetControlNetXSModel],
+ controlnet: ControlNetXSAdapter,
scheduler: KarrasDiffusionSchedulers,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
+ feature_extractor: CLIPImageProcessor = None,
):
super().__init__()
- vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible(
- vae
- )
- if not vae_compatible:
- raise ValueError(
- f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`."
- )
+ if isinstance(unet, UNet2DConditionModel):
+ unet = UNetControlNetXSModel.from_unet(unet, controlnet)
self.register_modules(
vae=vae,
@@ -134,6 +188,7 @@ def __init__(
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
+ feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
@@ -417,15 +472,21 @@ def check_inputs(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
+ callback_on_step_end_tensor_inputs=None,
):
- if (callback_steps is None) or (
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
- ):
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -474,25 +535,16 @@ def check_inputs(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
- # Check `image`
+ # Check `image` and ``controlnet_conditioning_scale``
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ self.unet, torch._dynamo.eval_frame.OptimizedModule
)
if (
- isinstance(self.controlnet, ControlNetXSModel)
+ isinstance(self.unet, UNetControlNetXSModel)
or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
+ and isinstance(self.unet._orig_mod, UNetControlNetXSModel)
):
self.check_image(image, prompt, prompt_embeds)
- else:
- assert False
-
- # Check `controlnet_conditioning_scale`
- if (
- isinstance(self.controlnet, ControlNetXSModel)
- or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
- ):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
else:
@@ -593,7 +645,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = latents * self.scheduler.init_noise_sigma
return latents
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
def _get_add_time_ids(
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
):
@@ -602,7 +653,7 @@ def _get_add_time_ids(
passed_add_embed_dim = (
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
)
- expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+ expected_add_embed_dim = self.unet.base_add_embedding.linear_1.in_features
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
@@ -632,7 +683,33 @@ def upcast_vae(self):
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
+ def clip_skip(self):
+ return self._clip_skip
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
+ def num_timesteps(self):
+ return self._num_timesteps
+
@torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
@@ -654,8 +731,6 @@ def __call__(
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
- callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
control_guidance_start: float = 0.0,
@@ -667,6 +742,9 @@ def __call__(
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -677,7 +755,7 @@ def __call__(
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders.
- image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
@@ -735,12 +813,6 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
- callback (`Callable`, *optional*):
- A function that calls every `callback_steps` steps during inference. The function is called with the
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
- callback_steps (`int`, *optional*, defaults to 1):
- The frequency at which the `callback` function is called. If not specified, the callback is called at
- every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -783,6 +855,15 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeine class.
Examples:
@@ -791,7 +872,24 @@ def __call__(
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is
returned, otherwise a `tuple` is returned containing the output images.
"""
- controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
# 1. Check inputs. Raise error if not correct
self.check_inputs(
@@ -808,8 +906,14 @@ def __call__(
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
+ callback_on_step_end_tensor_inputs,
)
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -850,7 +954,7 @@ def __call__(
)
# 4. Prepare image
- if isinstance(controlnet, ControlNetXSModel):
+ if isinstance(unet, UNetControlNetXSModel):
image = self.prepare_image(
image=image,
width=width,
@@ -858,7 +962,7 @@ def __call__(
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
- dtype=controlnet.dtype,
+ dtype=unet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
height, width = image.shape[-2:]
@@ -870,7 +974,7 @@ def __call__(
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
- num_channels_latents = self.unet.config.in_channels
+ num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
@@ -928,14 +1032,14 @@ def __call__(
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
- is_unet_compiled = is_compiled_module(self.unet)
- is_controlnet_compiled = is_compiled_module(self.controlnet)
+ self._num_timesteps = len(timesteps)
+ is_controlnet_compiled = is_compiled_module(self.unet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
- if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+ if is_controlnet_compiled and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@@ -944,30 +1048,20 @@ def __call__(
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# predict the noise residual
- dont_control = (
- i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end
+ apply_control = (
+ i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end
)
- if dont_control:
- noise_pred = self.unet(
- sample=latent_model_input,
- timestep=t,
- encoder_hidden_states=prompt_embeds,
- cross_attention_kwargs=cross_attention_kwargs,
- added_cond_kwargs=added_cond_kwargs,
- return_dict=True,
- ).sample
- else:
- noise_pred = self.controlnet(
- base_model=self.unet,
- sample=latent_model_input,
- timestep=t,
- encoder_hidden_states=prompt_embeds,
- controlnet_cond=image,
- conditioning_scale=controlnet_conditioning_scale,
- cross_attention_kwargs=cross_attention_kwargs,
- added_cond_kwargs=added_cond_kwargs,
- return_dict=True,
- ).sample
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=prompt_embeds,
+ controlnet_cond=image,
+ conditioning_scale=controlnet_conditioning_scale,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=True,
+ apply_control=apply_control,
+ ).sample
# perform guidance
if do_classifier_free_guidance:
@@ -977,6 +1071,16 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -984,6 +1088,11 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ # manually for max memory savings
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
index 1c55d088aa0a..3c3bd526692d 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
@@ -2238,6 +2238,7 @@ def __init__(
self,
in_channels: int,
temb_channels: int,
+ out_channels: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
@@ -2245,6 +2246,7 @@ def __init__(
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
+ resnet_groups_out: Optional[int] = None,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
@@ -2256,6 +2258,10 @@ def __init__(
):
super().__init__()
+ out_channels = out_channels or in_channels
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
@@ -2264,14 +2270,17 @@ def __init__(
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+ resnet_groups_out = resnet_groups_out or resnet_groups
+
# there is always at least one resnet
resnets = [
ResnetBlockFlat(
in_channels=in_channels,
- out_channels=in_channels,
+ out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
+ groups_out=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
@@ -2286,11 +2295,11 @@ def __init__(
attentions.append(
Transformer2DModel(
num_attention_heads,
- in_channels // num_attention_heads,
- in_channels=in_channels,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
+ norm_num_groups=resnet_groups_out,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
@@ -2300,8 +2309,8 @@ def __init__(
attentions.append(
DualTransformer2DModel(
num_attention_heads,
- in_channels // num_attention_heads,
- in_channels=in_channels,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
@@ -2309,11 +2318,11 @@ def __init__(
)
resnets.append(
ResnetBlockFlat(
- in_channels=in_channels,
- out_channels=in_channels,
+ in_channels=out_channels,
+ out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
- groups=resnet_groups,
+ groups=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 14947848a43f..b04006cb5ee6 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -92,6 +92,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class ControlNetXSAdapter(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class I2VGenXLUNet(metaclass=DummyObject):
_backends = ["torch"]
@@ -287,6 +302,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class UNetControlNetXSModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class UNetMotionModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index f64c15702087..8ad2f4b4876d 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -902,6 +902,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class StableDiffusionControlNetXSPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1247,6 +1262,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py
new file mode 100644
index 000000000000..8c9b43a20ad6
--- /dev/null
+++ b/tests/models/unets/test_models_unet_controlnetxs.py
@@ -0,0 +1,352 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import unittest
+
+import numpy as np
+import torch
+from torch import nn
+
+from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
+from diffusers.utils import logging
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+logger = logging.get_logger(__name__)
+
+enable_full_determinism()
+
+
+class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = UNetControlNetXSModel
+ main_input_name = "sample"
+
+ @property
+ def dummy_input(self):
+ batch_size = 4
+ num_channels = 4
+ sizes = (16, 16)
+ conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning
+
+ noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
+ time_step = torch.tensor([10]).to(torch_device)
+ encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
+ controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device)
+ conditioning_scale = 1
+
+ return {
+ "sample": noise,
+ "timestep": time_step,
+ "encoder_hidden_states": encoder_hidden_states,
+ "controlnet_cond": controlnet_cond,
+ "conditioning_scale": conditioning_scale,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "sample_size": 16,
+ "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
+ "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
+ "block_out_channels": (4, 8),
+ "cross_attention_dim": 8,
+ "transformer_layers_per_block": 1,
+ "num_attention_heads": 2,
+ "norm_num_groups": 4,
+ "upcast_attention": False,
+ "ctrl_block_out_channels": [2, 4],
+ "ctrl_num_attention_heads": 4,
+ "ctrl_max_norm_num_groups": 2,
+ "ctrl_conditioning_embedding_out_channels": (2, 2),
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def get_dummy_unet(self):
+ """For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
+ return UNet2DConditionModel(
+ block_out_channels=(4, 8),
+ layers_per_block=2,
+ sample_size=16,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=8,
+ norm_num_groups=4,
+ use_linear_projection=True,
+ )
+
+ def get_dummy_controlnet_from_unet(self, unet, **kwargs):
+ """For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
+ # size_ratio and conditioning_embedding_out_channels chosen to keep model small
+ return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs)
+
+ def test_from_unet(self):
+ unet = self.get_dummy_unet()
+ controlnet = self.get_dummy_controlnet_from_unet(unet)
+
+ model = UNetControlNetXSModel.from_unet(unet, controlnet)
+ model_state_dict = model.state_dict()
+
+ def assert_equal_weights(module, weight_dict_prefix):
+ for param_name, param_value in module.named_parameters():
+ assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value)
+
+ # # check unet
+ # everything expect down,mid,up blocks
+ modules_from_unet = [
+ "time_embedding",
+ "conv_in",
+ "conv_norm_out",
+ "conv_out",
+ ]
+ for p in modules_from_unet:
+ assert_equal_weights(getattr(unet, p), "base_" + p)
+ optional_modules_from_unet = [
+ "class_embedding",
+ "add_time_proj",
+ "add_embedding",
+ ]
+ for p in optional_modules_from_unet:
+ if hasattr(unet, p) and getattr(unet, p) is not None:
+ assert_equal_weights(getattr(unet, p), "base_" + p)
+ # down blocks
+ assert len(unet.down_blocks) == len(model.down_blocks)
+ for i, d in enumerate(unet.down_blocks):
+ assert_equal_weights(d.resnets, f"down_blocks.{i}.base_resnets")
+ if hasattr(d, "attentions"):
+ assert_equal_weights(d.attentions, f"down_blocks.{i}.base_attentions")
+ if hasattr(d, "downsamplers") and getattr(d, "downsamplers") is not None:
+ assert_equal_weights(d.downsamplers[0], f"down_blocks.{i}.base_downsamplers")
+ # mid block
+ assert_equal_weights(unet.mid_block, "mid_block.base_midblock")
+ # up blocks
+ assert len(unet.up_blocks) == len(model.up_blocks)
+ for i, u in enumerate(unet.up_blocks):
+ assert_equal_weights(u.resnets, f"up_blocks.{i}.resnets")
+ if hasattr(u, "attentions"):
+ assert_equal_weights(u.attentions, f"up_blocks.{i}.attentions")
+ if hasattr(u, "upsamplers") and getattr(u, "upsamplers") is not None:
+ assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers")
+
+ # # check controlnet
+ # everything expect down,mid,up blocks
+ modules_from_controlnet = {
+ "controlnet_cond_embedding": "controlnet_cond_embedding",
+ "conv_in": "ctrl_conv_in",
+ "control_to_base_for_conv_in": "control_to_base_for_conv_in",
+ }
+ optional_modules_from_controlnet = {"time_embedding": "ctrl_time_embedding"}
+ for name_in_controlnet, name_in_unetcnxs in modules_from_controlnet.items():
+ assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs)
+
+ for name_in_controlnet, name_in_unetcnxs in optional_modules_from_controlnet.items():
+ if hasattr(controlnet, name_in_controlnet) and getattr(controlnet, name_in_controlnet) is not None:
+ assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs)
+ # down blocks
+ assert len(controlnet.down_blocks) == len(model.down_blocks)
+ for i, d in enumerate(controlnet.down_blocks):
+ assert_equal_weights(d.resnets, f"down_blocks.{i}.ctrl_resnets")
+ assert_equal_weights(d.base_to_ctrl, f"down_blocks.{i}.base_to_ctrl")
+ assert_equal_weights(d.ctrl_to_base, f"down_blocks.{i}.ctrl_to_base")
+ if d.attentions is not None:
+ assert_equal_weights(d.attentions, f"down_blocks.{i}.ctrl_attentions")
+ if d.downsamplers is not None:
+ assert_equal_weights(d.downsamplers, f"down_blocks.{i}.ctrl_downsamplers")
+ # mid block
+ assert_equal_weights(controlnet.mid_block.base_to_ctrl, "mid_block.base_to_ctrl")
+ assert_equal_weights(controlnet.mid_block.midblock, "mid_block.ctrl_midblock")
+ assert_equal_weights(controlnet.mid_block.ctrl_to_base, "mid_block.ctrl_to_base")
+ # up blocks
+ assert len(controlnet.up_connections) == len(model.up_blocks)
+ for i, u in enumerate(controlnet.up_connections):
+ assert_equal_weights(u.ctrl_to_base, f"up_blocks.{i}.ctrl_to_base")
+
+ def test_freeze_unet(self):
+ def assert_frozen(module):
+ for p in module.parameters():
+ assert not p.requires_grad
+
+ def assert_unfrozen(module):
+ for p in module.parameters():
+ assert p.requires_grad
+
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+ model = UNetControlNetXSModel(**init_dict)
+ model.freeze_unet_params()
+
+ # # check unet
+ # everything expect down,mid,up blocks
+ modules_from_unet = [
+ model.base_time_embedding,
+ model.base_conv_in,
+ model.base_conv_norm_out,
+ model.base_conv_out,
+ ]
+ for m in modules_from_unet:
+ assert_frozen(m)
+
+ optional_modules_from_unet = [
+ model.base_add_time_proj,
+ model.base_add_embedding,
+ ]
+ for m in optional_modules_from_unet:
+ if m is not None:
+ assert_frozen(m)
+
+ # down blocks
+ for i, d in enumerate(model.down_blocks):
+ assert_frozen(d.base_resnets)
+ if isinstance(d.base_attentions, nn.ModuleList): # attentions can be list of Nones
+ assert_frozen(d.base_attentions)
+ if d.base_downsamplers is not None:
+ assert_frozen(d.base_downsamplers)
+
+ # mid block
+ assert_frozen(model.mid_block.base_midblock)
+
+ # up blocks
+ for i, u in enumerate(model.up_blocks):
+ assert_frozen(u.resnets)
+ if isinstance(u.attentions, nn.ModuleList): # attentions can be list of Nones
+ assert_frozen(u.attentions)
+ if u.upsamplers is not None:
+ assert_frozen(u.upsamplers)
+
+ # # check controlnet
+ # everything expect down,mid,up blocks
+ modules_from_controlnet = [
+ model.controlnet_cond_embedding,
+ model.ctrl_conv_in,
+ model.control_to_base_for_conv_in,
+ ]
+ optional_modules_from_controlnet = [model.ctrl_time_embedding]
+
+ for m in modules_from_controlnet:
+ assert_unfrozen(m)
+ for m in optional_modules_from_controlnet:
+ if m is not None:
+ assert_unfrozen(m)
+
+ # down blocks
+ for d in model.down_blocks:
+ assert_unfrozen(d.ctrl_resnets)
+ assert_unfrozen(d.base_to_ctrl)
+ assert_unfrozen(d.ctrl_to_base)
+ if isinstance(d.ctrl_attentions, nn.ModuleList): # attentions can be list of Nones
+ assert_unfrozen(d.ctrl_attentions)
+ if d.ctrl_downsamplers is not None:
+ assert_unfrozen(d.ctrl_downsamplers)
+ # mid block
+ assert_unfrozen(model.mid_block.base_to_ctrl)
+ assert_unfrozen(model.mid_block.ctrl_midblock)
+ assert_unfrozen(model.mid_block.ctrl_to_base)
+ # up blocks
+ for u in model.up_blocks:
+ assert_unfrozen(u.ctrl_to_base)
+
+ def test_gradient_checkpointing_is_applied(self):
+ model_class_copy = copy.copy(UNetControlNetXSModel)
+
+ modules_with_gc_enabled = {}
+
+ # now monkey patch the following function:
+ # def _set_gradient_checkpointing(self, module, value=False):
+ # if hasattr(module, "gradient_checkpointing"):
+ # module.gradient_checkpointing = value
+
+ def _set_gradient_checkpointing_new(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+ modules_with_gc_enabled[module.__class__.__name__] = True
+
+ model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
+
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+ model = model_class_copy(**init_dict)
+
+ model.enable_gradient_checkpointing()
+
+ EXPECTED_SET = {
+ "Transformer2DModel",
+ "UNetMidBlock2DCrossAttn",
+ "ControlNetXSCrossAttnDownBlock2D",
+ "ControlNetXSCrossAttnMidBlock2D",
+ "ControlNetXSCrossAttnUpBlock2D",
+ }
+
+ assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
+ assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
+
+ def test_forward_no_control(self):
+ unet = self.get_dummy_unet()
+ controlnet = self.get_dummy_controlnet_from_unet(unet)
+
+ model = UNetControlNetXSModel.from_unet(unet, controlnet)
+
+ unet = unet.to(torch_device)
+ model = model.to(torch_device)
+
+ input_ = self.dummy_input
+
+ control_specific_input = ["controlnet_cond", "conditioning_scale"]
+ input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input}
+
+ with torch.no_grad():
+ unet_output = unet(**input_for_unet).sample.cpu()
+ unet_controlnet_output = model(**input_, apply_control=False).sample.cpu()
+
+ assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 3e-4
+
+ def test_time_embedding_mixing(self):
+ unet = self.get_dummy_unet()
+ controlnet = self.get_dummy_controlnet_from_unet(unet)
+ controlnet_mix_time = self.get_dummy_controlnet_from_unet(
+ unet, time_embedding_mix=0.5, learn_time_embedding=True
+ )
+
+ model = UNetControlNetXSModel.from_unet(unet, controlnet)
+ model_mix_time = UNetControlNetXSModel.from_unet(unet, controlnet_mix_time)
+
+ unet = unet.to(torch_device)
+ model = model.to(torch_device)
+ model_mix_time = model_mix_time.to(torch_device)
+
+ input_ = self.dummy_input
+
+ with torch.no_grad():
+ output = model(**input_).sample
+ output_mix_time = model_mix_time(**input_).sample
+
+ assert output.shape == output_mix_time.shape
+
+ def test_forward_with_norm_groups(self):
+ # UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
+ pass
diff --git a/tests/pipelines/controlnet_xs/__init__.py b/tests/pipelines/controlnet_xs/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py
new file mode 100644
index 000000000000..5ac78129ef34
--- /dev/null
+++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py
@@ -0,0 +1,366 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import traceback
+import unittest
+
+import numpy as np
+import torch
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AsymmetricAutoencoderKL,
+ AutoencoderKL,
+ AutoencoderTiny,
+ ConsistencyDecoderVAE,
+ ControlNetXSAdapter,
+ DDIMScheduler,
+ LCMScheduler,
+ StableDiffusionControlNetXSPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ load_image,
+ load_numpy,
+ require_python39_or_higher,
+ require_torch_2,
+ require_torch_gpu,
+ run_test_in_subprocess,
+ slow,
+ torch_device,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...models.autoencoders.test_models_vae import (
+ get_asym_autoencoder_kl_config,
+ get_autoencoder_kl_config,
+ get_autoencoder_tiny_config,
+ get_consistency_vae_config,
+)
+from ..pipeline_params import (
+ IMAGE_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_TO_IMAGE_BATCH_PARAMS,
+ TEXT_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_TO_IMAGE_PARAMS,
+)
+from ..test_pipelines_common import (
+ PipelineKarrasSchedulerTesterMixin,
+ PipelineLatentTesterMixin,
+ PipelineTesterMixin,
+ SDFunctionTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+# Will be run via run_test_in_subprocess
+def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
+ error = None
+ try:
+ _ = in_queue.get(timeout=timeout)
+
+ controlnet = ControlNetXSAdapter.from_pretrained(
+ "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
+ )
+ pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-1-base",
+ controlnet=controlnet,
+ safety_checker=None,
+ torch_dtype=torch.float16,
+ )
+ pipe.to("cuda")
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe.unet.to(memory_format=torch.channels_last)
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+ prompt = "bird"
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
+ ).resize((512, 512))
+
+ output = pipe(prompt, image, num_inference_steps=10, generator=generator, output_type="np")
+ image = output.images[0]
+
+ assert image.shape == (512, 512, 3)
+
+ expected_image = load_numpy(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
+ )
+ expected_image = np.resize(expected_image, (512, 512, 3))
+
+ assert np.abs(expected_image - image).max() < 1.0
+
+ except Exception:
+ error = f"{traceback.format_exc()}"
+
+ results = {"error": error}
+ out_queue.put(results, timeout=timeout)
+ out_queue.join()
+
+
+class ControlNetXSPipelineFastTests(
+ PipelineLatentTesterMixin,
+ PipelineKarrasSchedulerTesterMixin,
+ PipelineTesterMixin,
+ SDFunctionTesterMixin,
+ unittest.TestCase,
+):
+ pipeline_class = StableDiffusionControlNetXSPipeline
+ params = TEXT_TO_IMAGE_PARAMS
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+
+ test_attention_slicing = False
+
+ def get_dummy_components(self, time_cond_proj_dim=None):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(4, 8),
+ layers_per_block=2,
+ sample_size=16,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=8,
+ norm_num_groups=4,
+ time_cond_proj_dim=time_cond_proj_dim,
+ use_linear_projection=True,
+ )
+ torch.manual_seed(0)
+ controlnet = ControlNetXSAdapter.from_unet(
+ unet=unet,
+ size_ratio=1,
+ learn_time_embedding=True,
+ conditioning_embedding_out_channels=(2, 2),
+ )
+ torch.manual_seed(0)
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[4, 8],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ norm_num_groups=2,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=8,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "unet": unet,
+ "controlnet": controlnet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "safety_checker": None,
+ "feature_extractor": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ controlnet_embedder_scale_factor = 2
+ image = randn_tensor(
+ (1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor),
+ generator=generator,
+ device=torch.device(device),
+ )
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "output_type": "numpy",
+ "image": image,
+ }
+
+ return inputs
+
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_xformers_available(),
+ reason="XFormers attention is only available with CUDA and `xformers` installed",
+ )
+ def test_xformers_attention_forwardGenerator_pass(self):
+ self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(expected_max_diff=2e-3)
+
+ def test_controlnet_lcm(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+
+ components = self.get_dummy_components(time_cond_proj_dim=8)
+ sd_pipe = StableDiffusionControlNetXSPipeline(**components)
+ sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ output = sd_pipe(**inputs)
+ image = output.images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 16, 16, 3)
+ expected_slice = np.array([0.745, 0.753, 0.767, 0.543, 0.523, 0.502, 0.314, 0.521, 0.478])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ def test_to_dtype(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
+
+ pipe.to(dtype=torch.float16)
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
+
+ def test_multi_vae(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ block_out_channels = pipe.vae.config.block_out_channels
+ norm_num_groups = pipe.vae.config.norm_num_groups
+
+ vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
+ configs = [
+ get_autoencoder_kl_config(block_out_channels, norm_num_groups),
+ get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
+ get_consistency_vae_config(block_out_channels, norm_num_groups),
+ get_autoencoder_tiny_config(block_out_channels),
+ ]
+
+ out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
+
+ for vae_cls, config in zip(vae_classes, configs):
+ vae = vae_cls(**config)
+ vae = vae.to(torch_device)
+ components["vae"] = vae
+ vae_pipe = self.pipeline_class(**components)
+
+ # pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device.
+ # So we need to move the new pipe to device.
+ vae_pipe.to(torch_device)
+ vae_pipe.set_progress_bar_config(disable=None)
+
+ out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
+
+ assert out_vae_np.shape == out_np.shape
+
+
+@slow
+@require_torch_gpu
+class ControlNetXSPipelineSlowTests(unittest.TestCase):
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_canny(self):
+ controlnet = ControlNetXSAdapter.from_pretrained(
+ "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
+ )
+ pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
+ )
+ pipe.enable_model_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+ prompt = "bird"
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
+ )
+
+ output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
+
+ image = output.images[0]
+
+ assert image.shape == (768, 512, 3)
+
+ original_image = image[-3:, -3:, -1].flatten()
+ expected_image = np.array([0.1963, 0.229, 0.2659, 0.2109, 0.2332, 0.2827, 0.2534, 0.2422, 0.2808])
+ assert np.allclose(original_image, expected_image, atol=1e-04)
+
+ def test_depth(self):
+ controlnet = ControlNetXSAdapter.from_pretrained(
+ "UmerHA/Testing-ConrolNetXS-SD2.1-depth", torch_dtype=torch.float16
+ )
+ pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
+ )
+ pipe.enable_model_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+ prompt = "Stormtrooper's lecture"
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
+ )
+
+ output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
+
+ image = output.images[0]
+
+ assert image.shape == (512, 512, 3)
+
+ original_image = image[-3:, -3:, -1].flatten()
+ expected_image = np.array([0.4844, 0.4937, 0.4956, 0.4663, 0.5039, 0.5044, 0.4565, 0.4883, 0.4941])
+ assert np.allclose(original_image, expected_image, atol=1e-04)
+
+ @require_python39_or_higher
+ @require_torch_2
+ def test_stable_diffusion_compile(self):
+ run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)
diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py
new file mode 100644
index 000000000000..ee0d15ec3472
--- /dev/null
+++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py
@@ -0,0 +1,425 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import numpy as np
+import torch
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers import (
+ AsymmetricAutoencoderKL,
+ AutoencoderKL,
+ AutoencoderTiny,
+ ConsistencyDecoderVAE,
+ ControlNetXSAdapter,
+ EulerDiscreteScheduler,
+ StableDiffusionXLControlNetXSPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...models.autoencoders.test_models_vae import (
+ get_asym_autoencoder_kl_config,
+ get_autoencoder_kl_config,
+ get_autoencoder_tiny_config,
+ get_consistency_vae_config,
+)
+from ..pipeline_params import (
+ IMAGE_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_TO_IMAGE_BATCH_PARAMS,
+ TEXT_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_TO_IMAGE_PARAMS,
+)
+from ..test_pipelines_common import (
+ PipelineKarrasSchedulerTesterMixin,
+ PipelineLatentTesterMixin,
+ PipelineTesterMixin,
+ SDXLOptionalComponentsTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class StableDiffusionXLControlNetXSPipelineFastTests(
+ PipelineLatentTesterMixin,
+ PipelineKarrasSchedulerTesterMixin,
+ PipelineTesterMixin,
+ SDXLOptionalComponentsTesterMixin,
+ unittest.TestCase,
+):
+ pipeline_class = StableDiffusionXLControlNetXSPipeline
+ params = TEXT_TO_IMAGE_PARAMS
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+
+ test_attention_slicing = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(4, 8),
+ layers_per_block=2,
+ sample_size=16,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ use_linear_projection=True,
+ norm_num_groups=4,
+ # SD2-specific config below
+ attention_head_dim=(2, 4),
+ addition_embed_type="text_time",
+ addition_time_embed_dim=8,
+ transformer_layers_per_block=(1, 2),
+ projection_class_embeddings_input_dim=56, # 6 * 8 (addition_time_embed_dim) + 8 (cross_attention_dim)
+ cross_attention_dim=8,
+ )
+ torch.manual_seed(0)
+ controlnet = ControlNetXSAdapter.from_unet(
+ unet=unet,
+ size_ratio=0.5,
+ learn_time_embedding=True,
+ conditioning_embedding_out_channels=(2, 2),
+ )
+ torch.manual_seed(0)
+ scheduler = EulerDiscreteScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ steps_offset=1,
+ beta_schedule="scaled_linear",
+ timestep_spacing="leading",
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[4, 8],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ norm_num_groups=2,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=4,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ # SD2-specific config below
+ hidden_act="gelu",
+ projection_dim=8,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "unet": unet,
+ "controlnet": controlnet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer_2": tokenizer_2,
+ "feature_extractor": None,
+ }
+ return components
+
+ # copied from test_controlnet_sdxl.py
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ controlnet_embedder_scale_factor = 2
+ image = randn_tensor(
+ (1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor),
+ generator=generator,
+ device=torch.device(device),
+ )
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "output_type": "np",
+ "image": image,
+ }
+
+ return inputs
+
+ # copied from test_controlnet_sdxl.py
+ def test_attention_slicing_forward_pass(self):
+ return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
+
+ # copied from test_controlnet_sdxl.py
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_xformers_available(),
+ reason="XFormers attention is only available with CUDA and `xformers` installed",
+ )
+ def test_xformers_attention_forwardGenerator_pass(self):
+ self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
+
+ # copied from test_controlnet_sdxl.py
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(expected_max_diff=2e-3)
+
+ # copied from test_controlnet_sdxl.py
+ @require_torch_gpu
+ def test_stable_diffusion_xl_offloads(self):
+ pipes = []
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components).to(torch_device)
+ pipes.append(sd_pipe)
+
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components)
+ sd_pipe.enable_model_cpu_offload()
+ pipes.append(sd_pipe)
+
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components)
+ sd_pipe.enable_sequential_cpu_offload()
+ pipes.append(sd_pipe)
+
+ image_slices = []
+ for pipe in pipes:
+ pipe.unet.set_default_attn_processor()
+
+ inputs = self.get_dummy_inputs(torch_device)
+ image = pipe(**inputs).images
+
+ image_slices.append(image[0, -3:, -3:, -1].flatten())
+
+ assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
+ assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
+
+ # copied from test_controlnet_sdxl.py
+ def test_stable_diffusion_xl_multi_prompts(self):
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components).to(torch_device)
+
+ # forward with single prompt
+ inputs = self.get_dummy_inputs(torch_device)
+ output = sd_pipe(**inputs)
+ image_slice_1 = output.images[0, -3:, -3:, -1]
+
+ # forward with same prompt duplicated
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt_2"] = inputs["prompt"]
+ output = sd_pipe(**inputs)
+ image_slice_2 = output.images[0, -3:, -3:, -1]
+
+ # ensure the results are equal
+ assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
+
+ # forward with different prompt
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt_2"] = "different prompt"
+ output = sd_pipe(**inputs)
+ image_slice_3 = output.images[0, -3:, -3:, -1]
+
+ # ensure the results are not equal
+ assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
+
+ # manually set a negative_prompt
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["negative_prompt"] = "negative prompt"
+ output = sd_pipe(**inputs)
+ image_slice_1 = output.images[0, -3:, -3:, -1]
+
+ # forward with same negative_prompt duplicated
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["negative_prompt"] = "negative prompt"
+ inputs["negative_prompt_2"] = inputs["negative_prompt"]
+ output = sd_pipe(**inputs)
+ image_slice_2 = output.images[0, -3:, -3:, -1]
+
+ # ensure the results are equal
+ assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
+
+ # forward with different negative_prompt
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["negative_prompt"] = "negative prompt"
+ inputs["negative_prompt_2"] = "different negative prompt"
+ output = sd_pipe(**inputs)
+ image_slice_3 = output.images[0, -3:, -3:, -1]
+
+ # ensure the results are not equal
+ assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
+
+ # copied from test_stable_diffusion_xl.py
+ def test_stable_diffusion_xl_prompt_embeds(self):
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ # forward without prompt embeds
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt"] = 2 * [inputs["prompt"]]
+ inputs["num_images_per_prompt"] = 2
+
+ output = sd_pipe(**inputs)
+ image_slice_1 = output.images[0, -3:, -3:, -1]
+
+ # forward with prompt embeds
+ inputs = self.get_dummy_inputs(torch_device)
+ prompt = 2 * [inputs.pop("prompt")]
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = sd_pipe.encode_prompt(prompt)
+
+ output = sd_pipe(
+ **inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ )
+ image_slice_2 = output.images[0, -3:, -3:, -1]
+
+ # make sure that it's equal
+ assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1.1e-4
+
+ # copied from test_stable_diffusion_xl.py
+ def test_save_load_optional_components(self):
+ self._test_save_load_optional_components()
+
+ # copied from test_controlnetxs.py
+ def test_to_dtype(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
+
+ pipe.to(dtype=torch.float16)
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
+
+ def test_multi_vae(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ block_out_channels = pipe.vae.config.block_out_channels
+ norm_num_groups = pipe.vae.config.norm_num_groups
+
+ vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
+ configs = [
+ get_autoencoder_kl_config(block_out_channels, norm_num_groups),
+ get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
+ get_consistency_vae_config(block_out_channels, norm_num_groups),
+ get_autoencoder_tiny_config(block_out_channels),
+ ]
+
+ out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
+
+ for vae_cls, config in zip(vae_classes, configs):
+ vae = vae_cls(**config)
+ vae = vae.to(torch_device)
+ components["vae"] = vae
+ vae_pipe = self.pipeline_class(**components)
+
+ # pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device.
+ # So we need to move the new pipe to device.
+ vae_pipe.to(torch_device)
+ vae_pipe.set_progress_bar_config(disable=None)
+
+ out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
+
+ assert out_vae_np.shape == out_np.shape
+
+
+@slow
+@require_torch_gpu
+class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase):
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_canny(self):
+ controlnet = ControlNetXSAdapter.from_pretrained(
+ "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16
+ )
+ pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
+ )
+ pipe.enable_sequential_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+ prompt = "bird"
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
+ )
+
+ images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
+
+ assert images[0].shape == (768, 512, 3)
+
+ original_image = images[0, -3:, -3:, -1].flatten()
+ expected_image = np.array([0.3202, 0.3151, 0.3328, 0.3172, 0.337, 0.3381, 0.3378, 0.3389, 0.3224])
+ assert np.allclose(original_image, expected_image, atol=1e-04)
+
+ def test_depth(self):
+ controlnet = ControlNetXSAdapter.from_pretrained(
+ "UmerHA/Testing-ConrolNetXS-SDXL-depth", torch_dtype=torch.float16
+ )
+ pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
+ )
+ pipe.enable_sequential_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+ prompt = "Stormtrooper's lecture"
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
+ )
+
+ images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
+
+ assert images[0].shape == (512, 512, 3)
+
+ original_image = images[0, -3:, -3:, -1].flatten()
+ expected_image = np.array([0.5448, 0.5437, 0.5426, 0.5543, 0.553, 0.5475, 0.5595, 0.5602, 0.5529])
+ assert np.allclose(original_image, expected_image, atol=1e-04)
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index acff5f2cdf8f..2ff53374be56 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -32,6 +32,7 @@
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import IPAdapterMixin
from diffusers.models.attention_processor import AttnProcessor
+from diffusers.models.controlnet_xs import UNetControlNetXSModel
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet
from diffusers.models.unets.unet_motion_model import UNetMotionModel
@@ -1685,7 +1686,10 @@ def test_StableDiffusionMixin_component(self):
self.assertTrue(hasattr(pipe, "vae") and isinstance(pipe.vae, (AutoencoderKL, AutoencoderTiny)))
self.assertTrue(
hasattr(pipe, "unet")
- and isinstance(pipe.unet, (UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel))
+ and isinstance(
+ pipe.unet,
+ (UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel, UNetControlNetXSModel),
+ )
)