diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 83693485d0e2..ba214f8e49f0 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -282,6 +282,10 @@ title: ControlNet - local: api/pipelines/controlnet_sdxl title: ControlNet with Stable Diffusion XL + - local: api/pipelines/controlnetxs + title: ControlNet-XS + - local: api/pipelines/controlnetxs_sdxl + title: ControlNet-XS with Stable Diffusion XL - local: api/pipelines/dance_diffusion title: Dance Diffusion - local: api/pipelines/ddim diff --git a/examples/research_projects/controlnetxs/README.md b/docs/source/en/api/pipelines/controlnetxs.md similarity index 61% rename from examples/research_projects/controlnetxs/README.md rename to docs/source/en/api/pipelines/controlnetxs.md index 72ed91c01db2..2d4ae7b8ce46 100644 --- a/examples/research_projects/controlnetxs/README.md +++ b/docs/source/en/api/pipelines/controlnetxs.md @@ -1,3 +1,15 @@ + + # ControlNet-XS ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results. @@ -12,5 +24,16 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️ + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. + + + +## StableDiffusionControlNetXSPipeline +[[autodoc]] StableDiffusionControlNetXSPipeline + - all + - __call__ -> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. \ No newline at end of file +## StableDiffusionPipelineOutput +[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput diff --git a/examples/research_projects/controlnetxs/README_sdxl.md b/docs/source/en/api/pipelines/controlnetxs_sdxl.md similarity index 56% rename from examples/research_projects/controlnetxs/README_sdxl.md rename to docs/source/en/api/pipelines/controlnetxs_sdxl.md index d401c1e76698..31075c0ef96a 100644 --- a/examples/research_projects/controlnetxs/README_sdxl.md +++ b/docs/source/en/api/pipelines/controlnetxs_sdxl.md @@ -1,3 +1,15 @@ + + # ControlNet-XS with Stable Diffusion XL ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results. @@ -12,4 +24,22 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️ -> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. \ No newline at end of file + + +🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve! + + + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. + + + +## StableDiffusionXLControlNetXSPipeline +[[autodoc]] StableDiffusionXLControlNetXSPipeline + - all + - __call__ + +## StableDiffusionPipelineOutput +[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput diff --git a/examples/research_projects/controlnetxs/controlnetxs.py b/examples/research_projects/controlnetxs/controlnetxs.py deleted file mode 100644 index 14ad1d8a3af9..000000000000 --- a/examples/research_projects/controlnetxs/controlnetxs.py +++ /dev/null @@ -1,1014 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import functional as F -from torch.nn.modules.normalization import GroupNorm - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProcessor -from diffusers.models.autoencoders import AutoencoderKL -from diffusers.models.lora import LoRACompatibleConv -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.unets.unet_2d_blocks import ( - CrossAttnDownBlock2D, - CrossAttnUpBlock2D, - DownBlock2D, - Downsample2D, - ResnetBlock2D, - Transformer2DModel, - UpBlock2D, - Upsample2D, -) -from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel -from diffusers.utils import BaseOutput, logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class ControlNetXSOutput(BaseOutput): - """ - The output of [`ControlNetXSModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - The output of the `ControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model - output, but is already the final output. - """ - - sample: torch.FloatTensor = None - - -# copied from diffusers.models.controlnet.ControlNetConditioningEmbedding -class ControlNetConditioningEmbedding(nn.Module): - """ - Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN - [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized - training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the - convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides - (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full - model) to encode image-space conditions ... into feature maps ..." - """ - - def __init__( - self, - conditioning_embedding_channels: int, - conditioning_channels: int = 3, - block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), - ): - super().__init__() - - self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) - - self.blocks = nn.ModuleList([]) - - for i in range(len(block_out_channels) - 1): - channel_in = block_out_channels[i] - channel_out = block_out_channels[i + 1] - self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) - self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) - - self.conv_out = zero_module( - nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) - ) - - def forward(self, conditioning): - embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) - - for block in self.blocks: - embedding = block(embedding) - embedding = F.silu(embedding) - - embedding = self.conv_out(embedding) - - return embedding - - -class ControlNetXSModel(ModelMixin, ConfigMixin): - r""" - A ControlNet-XS model - - This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic - methods implemented for all models (such as downloading or saving). - - Most of parameters for this model are passed into the [`UNet2DConditionModel`] it creates. Check the documentation - of [`UNet2DConditionModel`] for them. - - Parameters: - conditioning_channels (`int`, defaults to 3): - Number of channels of conditioning input (e.g. an image) - controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): - The channel order of conditional image. Will convert to `rgb` if it's `bgr`. - conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in the `controlnet_cond_embedding` layer. - time_embedding_input_dim (`int`, defaults to 320): - Dimension of input into time embedding. Needs to be same as in the base model. - time_embedding_dim (`int`, defaults to 1280): - Dimension of output from time embedding. Needs to be same as in the base model. - learn_embedding (`bool`, defaults to `False`): - Whether to use time embedding of the control model. If yes, the time embedding is a linear interpolation of - the time embeddings of the control and base model with interpolation parameter `time_embedding_mix**3`. - time_embedding_mix (`float`, defaults to 1.0): - Linear interpolation parameter used if `learn_embedding` is `True`. A value of 1.0 means only the - control model's time embedding will be used. A value of 0.0 means only the base model's time embedding will be used. - base_model_channel_sizes (`Dict[str, List[Tuple[int]]]`): - Channel sizes of each subblock of base model. Use `gather_subblock_sizes` on your base model to compute it. - """ - - @classmethod - def init_original(cls, base_model: UNet2DConditionModel, is_sdxl=True): - """ - Create a ControlNetXS model with the same parameters as in the original paper (https://github.com/vislearn/ControlNet-XS). - - Parameters: - base_model (`UNet2DConditionModel`): - Base UNet model. Needs to be either StableDiffusion or StableDiffusion-XL. - is_sdxl (`bool`, defaults to `True`): - Whether passed `base_model` is a StableDiffusion-XL model. - """ - - def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_attn_heads: int): - """ - Currently, diffusers can only set the dimension of attention heads (see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why). - The original ControlNet-XS model, however, define the number of attention heads. - That's why compute the dimensions needed to get the correct number of attention heads. - """ - block_out_channels = [int(size_ratio * c) for c in base_model.config.block_out_channels] - dim_attn_heads = [math.ceil(c / num_attn_heads) for c in block_out_channels] - return dim_attn_heads - - if is_sdxl: - return ControlNetXSModel.from_unet( - base_model, - time_embedding_mix=0.95, - learn_embedding=True, - size_ratio=0.1, - conditioning_embedding_out_channels=(16, 32, 96, 256), - num_attention_heads=get_dim_attn_heads(base_model, 0.1, 64), - ) - else: - return ControlNetXSModel.from_unet( - base_model, - time_embedding_mix=1.0, - learn_embedding=True, - size_ratio=0.0125, - conditioning_embedding_out_channels=(16, 32, 96, 256), - num_attention_heads=get_dim_attn_heads(base_model, 0.0125, 8), - ) - - @classmethod - def _gather_subblock_sizes(cls, unet: UNet2DConditionModel, base_or_control: str): - """To create correctly sized connections between base and control model, we need to know - the input and output channels of each subblock. - - Parameters: - unet (`UNet2DConditionModel`): - Unet of which the subblock channels sizes are to be gathered. - base_or_control (`str`): - Needs to be either "base" or "control". If "base", decoder is also considered. - """ - if base_or_control not in ["base", "control"]: - raise ValueError("`base_or_control` needs to be either `base` or `control`") - - channel_sizes = {"down": [], "mid": [], "up": []} - - # input convolution - channel_sizes["down"].append((unet.conv_in.in_channels, unet.conv_in.out_channels)) - - # encoder blocks - for module in unet.down_blocks: - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): - for r in module.resnets: - channel_sizes["down"].append((r.in_channels, r.out_channels)) - if module.downsamplers: - channel_sizes["down"].append( - (module.downsamplers[0].channels, module.downsamplers[0].out_channels) - ) - else: - raise ValueError(f"Encountered unknown module of type {type(module)} while creating ControlNet-XS.") - - # middle block - channel_sizes["mid"].append((unet.mid_block.resnets[0].in_channels, unet.mid_block.resnets[0].out_channels)) - - # decoder blocks - if base_or_control == "base": - for module in unet.up_blocks: - if isinstance(module, (CrossAttnUpBlock2D, UpBlock2D)): - for r in module.resnets: - channel_sizes["up"].append((r.in_channels, r.out_channels)) - else: - raise ValueError( - f"Encountered unknown module of type {type(module)} while creating ControlNet-XS." - ) - - return channel_sizes - - @register_to_config - def __init__( - self, - conditioning_channels: int = 3, - conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), - controlnet_conditioning_channel_order: str = "rgb", - time_embedding_input_dim: int = 320, - time_embedding_dim: int = 1280, - time_embedding_mix: float = 1.0, - learn_embedding: bool = False, - base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { - "down": [ - (4, 320), - (320, 320), - (320, 320), - (320, 320), - (320, 640), - (640, 640), - (640, 640), - (640, 1280), - (1280, 1280), - ], - "mid": [(1280, 1280)], - "up": [ - (2560, 1280), - (2560, 1280), - (1920, 1280), - (1920, 640), - (1280, 640), - (960, 640), - (960, 320), - (640, 320), - (640, 320), - ], - }, - sample_size: Optional[int] = None, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - norm_num_groups: Optional[int] = 32, - cross_attention_dim: Union[int, Tuple[int]] = 1280, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, - num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, - upcast_attention: bool = False, - ): - super().__init__() - - # 1 - Create control unet - self.control_model = UNet2DConditionModel( - sample_size=sample_size, - down_block_types=down_block_types, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - transformer_layers_per_block=transformer_layers_per_block, - attention_head_dim=num_attention_heads, - use_linear_projection=True, - upcast_attention=upcast_attention, - time_embedding_dim=time_embedding_dim, - ) - - # 2 - Do model surgery on control model - # 2.1 - Allow to use the same time information as the base model - adjust_time_dims(self.control_model, time_embedding_input_dim, time_embedding_dim) - - # 2.2 - Allow for information infusion from base model - - # We concat the output of each base encoder subblocks to the input of the next control encoder subblock - # (We ignore the 1st element, as it represents the `conv_in`.) - extra_input_channels = [input_channels for input_channels, _ in base_model_channel_sizes["down"][1:]] - it_extra_input_channels = iter(extra_input_channels) - - for b, block in enumerate(self.control_model.down_blocks): - for r in range(len(block.resnets)): - increase_block_input_in_encoder_resnet( - self.control_model, block_no=b, resnet_idx=r, by=next(it_extra_input_channels) - ) - - if block.downsamplers: - increase_block_input_in_encoder_downsampler( - self.control_model, block_no=b, by=next(it_extra_input_channels) - ) - - increase_block_input_in_mid_resnet(self.control_model, by=extra_input_channels[-1]) - - # 2.3 - Make group norms work with modified channel sizes - adjust_group_norms(self.control_model) - - # 3 - Gather Channel Sizes - self.ch_inout_ctrl = ControlNetXSModel._gather_subblock_sizes(self.control_model, base_or_control="control") - self.ch_inout_base = base_model_channel_sizes - - # 4 - Build connections between base and control model - self.down_zero_convs_out = nn.ModuleList([]) - self.down_zero_convs_in = nn.ModuleList([]) - self.middle_block_out = nn.ModuleList([]) - self.middle_block_in = nn.ModuleList([]) - self.up_zero_convs_out = nn.ModuleList([]) - self.up_zero_convs_in = nn.ModuleList([]) - - for ch_io_base in self.ch_inout_base["down"]: - self.down_zero_convs_in.append(self._make_zero_conv(in_channels=ch_io_base[1], out_channels=ch_io_base[1])) - for i in range(len(self.ch_inout_ctrl["down"])): - self.down_zero_convs_out.append( - self._make_zero_conv(self.ch_inout_ctrl["down"][i][1], self.ch_inout_base["down"][i][1]) - ) - - self.middle_block_out = self._make_zero_conv( - self.ch_inout_ctrl["mid"][-1][1], self.ch_inout_base["mid"][-1][1] - ) - - self.up_zero_convs_out.append( - self._make_zero_conv(self.ch_inout_ctrl["down"][-1][1], self.ch_inout_base["mid"][-1][1]) - ) - for i in range(1, len(self.ch_inout_ctrl["down"])): - self.up_zero_convs_out.append( - self._make_zero_conv(self.ch_inout_ctrl["down"][-(i + 1)][1], self.ch_inout_base["up"][i - 1][1]) - ) - - # 5 - Create conditioning hint embedding - self.controlnet_cond_embedding = ControlNetConditioningEmbedding( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, - ) - - # In the mininal implementation setting, we only need the control model up to the mid block - del self.control_model.up_blocks - del self.control_model.conv_norm_out - del self.control_model.conv_out - - @classmethod - def from_unet( - cls, - unet: UNet2DConditionModel, - conditioning_channels: int = 3, - conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), - controlnet_conditioning_channel_order: str = "rgb", - learn_embedding: bool = False, - time_embedding_mix: float = 1.0, - block_out_channels: Optional[Tuple[int]] = None, - size_ratio: Optional[float] = None, - num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, - norm_num_groups: Optional[int] = None, - ): - r""" - Instantiate a [`ControlNetXSModel`] from [`UNet2DConditionModel`]. - - Parameters: - unet (`UNet2DConditionModel`): - The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it. - conditioning_channels (`int`, defaults to 3): - Number of channels of conditioning input (e.g. an image) - conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in the `controlnet_cond_embedding` layer. - controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): - The channel order of conditional image. Will convert to `rgb` if it's `bgr`. - learn_embedding (`bool`, defaults to `False`): - Wether to use time embedding of the control model. If yes, the time embedding is a linear interpolation - of the time embeddings of the control and base model with interpolation parameter - `time_embedding_mix**3`. - time_embedding_mix (`float`, defaults to 1.0): - Linear interpolation parameter used if `learn_embedding` is `True`. - block_out_channels (`Tuple[int]`, *optional*): - Down blocks output channels in control model. Either this or `size_ratio` must be given. - size_ratio (float, *optional*): - When given, block_out_channels is set to a relative fraction of the base model's block_out_channels. - Either this or `block_out_channels` must be given. - num_attention_heads (`Union[int, Tuple[int]]`, *optional*): - The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. - norm_num_groups (int, *optional*, defaults to `None`): - The number of groups to use for the normalization of the control unet. If `None`, - `int(unet.config.norm_num_groups * size_ratio)` is taken. - """ - - # Check input - fixed_size = block_out_channels is not None - relative_size = size_ratio is not None - if not (fixed_size ^ relative_size): - raise ValueError( - "Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)." - ) - - # Create model - if block_out_channels is None: - block_out_channels = [int(size_ratio * c) for c in unet.config.block_out_channels] - - # Check that attention heads and group norms match channel sizes - # - attention heads - def attn_heads_match_channel_sizes(attn_heads, channel_sizes): - if isinstance(attn_heads, (tuple, list)): - return all(c % a == 0 for a, c in zip(attn_heads, channel_sizes)) - else: - return all(c % attn_heads == 0 for c in channel_sizes) - - num_attention_heads = num_attention_heads or unet.config.attention_head_dim - if not attn_heads_match_channel_sizes(num_attention_heads, block_out_channels): - raise ValueError( - f"The dimension of attention heads ({num_attention_heads}) must divide `block_out_channels` ({block_out_channels}). If you didn't set `num_attention_heads` the default settings don't match your model. Set `num_attention_heads` manually." - ) - - # - group norms - def group_norms_match_channel_sizes(num_groups, channel_sizes): - return all(c % num_groups == 0 for c in channel_sizes) - - if norm_num_groups is None: - if group_norms_match_channel_sizes(unet.config.norm_num_groups, block_out_channels): - norm_num_groups = unet.config.norm_num_groups - else: - norm_num_groups = min(block_out_channels) - - if group_norms_match_channel_sizes(norm_num_groups, block_out_channels): - print( - f"`norm_num_groups` was set to `min(block_out_channels)` (={norm_num_groups}) so it divides all block_out_channels` ({block_out_channels}). Set it explicitly to remove this information." - ) - else: - raise ValueError( - f"`block_out_channels` ({block_out_channels}) don't match the base models `norm_num_groups` ({unet.config.norm_num_groups}). Setting `norm_num_groups` to `min(block_out_channels)` ({norm_num_groups}) didn't fix this. Pass `norm_num_groups` explicitly so it divides all block_out_channels." - ) - - def get_time_emb_input_dim(unet: UNet2DConditionModel): - return unet.time_embedding.linear_1.in_features - - def get_time_emb_dim(unet: UNet2DConditionModel): - return unet.time_embedding.linear_2.out_features - - # Clone params from base unet if - # (i) it's required to build SD or SDXL, and - # (ii) it's not used for the time embedding (as time embedding of control model is never used), and - # (iii) it's not set further below anyway - to_keep = [ - "cross_attention_dim", - "down_block_types", - "sample_size", - "transformer_layers_per_block", - "up_block_types", - "upcast_attention", - ] - kwargs = {k: v for k, v in dict(unet.config).items() if k in to_keep} - kwargs.update(block_out_channels=block_out_channels) - kwargs.update(num_attention_heads=num_attention_heads) - kwargs.update(norm_num_groups=norm_num_groups) - - # Add controlnetxs-specific params - kwargs.update( - conditioning_channels=conditioning_channels, - controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, - time_embedding_input_dim=get_time_emb_input_dim(unet), - time_embedding_dim=get_time_emb_dim(unet), - time_embedding_mix=time_embedding_mix, - learn_embedding=learn_embedding, - base_model_channel_sizes=ControlNetXSModel._gather_subblock_sizes(unet, base_or_control="base"), - conditioning_embedding_out_channels=conditioning_embedding_out_channels, - ) - - return cls(**kwargs) - - @property - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - return self.control_model.attn_processors - - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - self.control_model.set_attn_processor(processor) - - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - self.control_model.set_default_attn_processor() - - def set_attention_slice(self, slice_size): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module splits the input tensor in slices to compute attention in - several steps. This is useful for saving some memory in exchange for a small decrease in speed. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - self.control_model.set_attention_slice(slice_size) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (UNet2DConditionModel)): - if value: - module.enable_gradient_checkpointing() - else: - module.disable_gradient_checkpointing() - - def forward( - self, - base_model: UNet2DConditionModel, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - controlnet_cond: torch.Tensor, - conditioning_scale: float = 1.0, - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - return_dict: bool = True, - ) -> Union[ControlNetXSOutput, Tuple]: - """ - The [`ControlNetModel`] forward method. - - Args: - base_model (`UNet2DConditionModel`): - The base unet model we want to control. - sample (`torch.FloatTensor`): - The noisy input tensor. - timestep (`Union[torch.Tensor, float, int]`): - The number of timesteps to denoise an input. - encoder_hidden_states (`torch.Tensor`): - The encoder hidden states. - controlnet_cond (`torch.FloatTensor`): - The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. - conditioning_scale (`float`, defaults to `1.0`): - How much the control model affects the base model outputs. - class_labels (`torch.Tensor`, *optional*, defaults to `None`): - Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. - timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): - Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the - timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep - embeddings. - attention_mask (`torch.Tensor`, *optional*, defaults to `None`): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - added_cond_kwargs (`dict`): - Additional conditions for the Stable Diffusion XL UNet. - cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): - A kwargs dictionary that if specified is passed along to the `AttnProcessor`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. - - Returns: - [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: - If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. - """ - # check channel order - channel_order = self.config.controlnet_conditioning_channel_order - - if channel_order == "rgb": - # in rgb order by default - ... - elif channel_order == "bgr": - controlnet_cond = torch.flip(controlnet_cond, dims=[1]) - else: - raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") - - # scale control strength - n_connections = len(self.down_zero_convs_out) + 1 + len(self.up_zero_convs_out) - scale_list = torch.full((n_connections,), conditioning_scale) - - # prepare attention_mask - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = base_model.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - if self.config.learn_embedding: - ctrl_temb = self.control_model.time_embedding(t_emb, timestep_cond) - base_temb = base_model.time_embedding(t_emb, timestep_cond) - interpolation_param = self.config.time_embedding_mix**0.3 - - temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) - else: - temb = base_model.time_embedding(t_emb) - - # added time & text embeddings - aug_emb = None - - if base_model.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if base_model.config.class_embed_type == "timestep": - class_labels = base_model.time_proj(class_labels) - - class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype) - temb = temb + class_emb - - if base_model.config.addition_embed_type is not None: - if base_model.config.addition_embed_type == "text": - aug_emb = base_model.add_embedding(encoder_hidden_states) - elif base_model.config.addition_embed_type == "text_image": - raise NotImplementedError() - elif base_model.config.addition_embed_type == "text_time": - # SDXL - style - if "text_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" - ) - text_embeds = added_cond_kwargs.get("text_embeds") - if "time_ids" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" - ) - time_ids = added_cond_kwargs.get("time_ids") - time_embeds = base_model.add_time_proj(time_ids.flatten()) - time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) - add_embeds = add_embeds.to(temb.dtype) - aug_emb = base_model.add_embedding(add_embeds) - elif base_model.config.addition_embed_type == "image": - raise NotImplementedError() - elif base_model.config.addition_embed_type == "image_hint": - raise NotImplementedError() - - temb = temb + aug_emb if aug_emb is not None else temb - - # text embeddings - cemb = encoder_hidden_states - - # Preparation - guided_hint = self.controlnet_cond_embedding(controlnet_cond) - - h_ctrl = h_base = sample - hs_base, hs_ctrl = [], [] - it_down_convs_in, it_down_convs_out, it_dec_convs_in, it_up_convs_out = map( - iter, (self.down_zero_convs_in, self.down_zero_convs_out, self.up_zero_convs_in, self.up_zero_convs_out) - ) - scales = iter(scale_list) - - base_down_subblocks = to_sub_blocks(base_model.down_blocks) - ctrl_down_subblocks = to_sub_blocks(self.control_model.down_blocks) - base_mid_subblocks = to_sub_blocks([base_model.mid_block]) - ctrl_mid_subblocks = to_sub_blocks([self.control_model.mid_block]) - base_up_subblocks = to_sub_blocks(base_model.up_blocks) - - # Cross Control - # 0 - conv in - h_base = base_model.conv_in(h_base) - h_ctrl = self.control_model.conv_in(h_ctrl) - if guided_hint is not None: - h_ctrl += guided_hint - h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base - - hs_base.append(h_base) - hs_ctrl.append(h_ctrl) - - # 1 - down - for m_base, m_ctrl in zip(base_down_subblocks, ctrl_down_subblocks): - h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl - h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock - h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock - h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base - hs_base.append(h_base) - hs_ctrl.append(h_ctrl) - - # 2 - mid - h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl - for m_base, m_ctrl in zip(base_mid_subblocks, ctrl_mid_subblocks): - h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock - h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock - h_base = h_base + self.middle_block_out(h_ctrl) * next(scales) # D - add ctrl -> base - - # 3 - up - for i, m_base in enumerate(base_up_subblocks): - h_base = h_base + next(it_up_convs_out)(hs_ctrl.pop()) * next(scales) # add info from ctrl encoder - h_base = torch.cat([h_base, hs_base.pop()], dim=1) # concat info from base encoder+ctrl encoder - h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) - - h_base = base_model.conv_norm_out(h_base) - h_base = base_model.conv_act(h_base) - h_base = base_model.conv_out(h_base) - - if not return_dict: - return h_base - - return ControlNetXSOutput(sample=h_base) - - def _make_zero_conv(self, in_channels, out_channels=None): - # keep running track of channels sizes - self.in_channels = in_channels - self.out_channels = out_channels or in_channels - - return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) - - @torch.no_grad() - def _check_if_vae_compatible(self, vae: AutoencoderKL): - condition_downscale_factor = 2 ** (len(self.config.conditioning_embedding_out_channels) - 1) - vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1) - compatible = condition_downscale_factor == vae_downscale_factor - return compatible, condition_downscale_factor, vae_downscale_factor - - -class SubBlock(nn.ModuleList): - """A SubBlock is the largest piece of either base or control model, that is executed independently of the other model respectively. - Before each subblock, information is concatted from base to control. And after each subblock, information is added from control to base. - """ - - def __init__(self, ms, *args, **kwargs): - if not is_iterable(ms): - ms = [ms] - super().__init__(ms, *args, **kwargs) - - def forward( - self, - x: torch.Tensor, - temb: torch.Tensor, - cemb: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - ): - """Iterate through children and pass correct information to each.""" - for m in self: - if isinstance(m, ResnetBlock2D): - x = m(x, temb) - elif isinstance(m, Transformer2DModel): - x = m(x, cemb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs).sample - elif isinstance(m, Downsample2D): - x = m(x) - elif isinstance(m, Upsample2D): - x = m(x) - else: - raise ValueError( - f"Type of m is {type(m)} but should be `ResnetBlock2D`, `Transformer2DModel`, `Downsample2D` or `Upsample2D`" - ) - - return x - - -def adjust_time_dims(unet: UNet2DConditionModel, in_dim: int, out_dim: int): - unet.time_embedding.linear_1 = nn.Linear(in_dim, out_dim) - - -def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no, resnet_idx, by): - """Increase channels sizes to allow for additional concatted information from base model""" - r = unet.down_blocks[block_no].resnets[resnet_idx] - old_norm1, old_conv1 = r.norm1, r.conv1 - # norm - norm_args = "num_groups num_channels eps affine".split(" ") - for a in norm_args: - assert hasattr(old_norm1, a) - norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} - norm_kwargs["num_channels"] += by # surgery done here - # conv1 - conv1_args = [ - "in_channels", - "out_channels", - "kernel_size", - "stride", - "padding", - "dilation", - "groups", - "bias", - "padding_mode", - ] - if not USE_PEFT_BACKEND: - conv1_args.append("lora_layer") - - for a in conv1_args: - assert hasattr(old_conv1, a) - - conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} - conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. - conv1_kwargs["in_channels"] += by # surgery done here - # conv_shortcut - # as we changed the input size of the block, the input and output sizes are likely different, - # therefore we need a conv_shortcut (simply adding won't work) - conv_shortcut_args_kwargs = { - "in_channels": conv1_kwargs["in_channels"], - "out_channels": conv1_kwargs["out_channels"], - # default arguments from resnet.__init__ - "kernel_size": 1, - "stride": 1, - "padding": 0, - "bias": True, - } - # swap old with new modules - unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs) - unet.down_blocks[block_no].resnets[resnet_idx].conv1 = ( - nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) - ) - unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = ( - nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) - ) - unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here - - -def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by): - """Increase channels sizes to allow for additional concatted information from base model""" - old_down = unet.down_blocks[block_no].downsamplers[0].conv - - args = [ - "in_channels", - "out_channels", - "kernel_size", - "stride", - "padding", - "dilation", - "groups", - "bias", - "padding_mode", - ] - if not USE_PEFT_BACKEND: - args.append("lora_layer") - - for a in args: - assert hasattr(old_down, a) - kwargs = {a: getattr(old_down, a) for a in args} - kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor. - kwargs["in_channels"] += by # surgery done here - # swap old with new modules - unet.down_blocks[block_no].downsamplers[0].conv = ( - nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs) - ) - unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here - - -def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by): - """Increase channels sizes to allow for additional concatted information from base model""" - m = unet.mid_block.resnets[0] - old_norm1, old_conv1 = m.norm1, m.conv1 - # norm - norm_args = "num_groups num_channels eps affine".split(" ") - for a in norm_args: - assert hasattr(old_norm1, a) - norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} - norm_kwargs["num_channels"] += by # surgery done here - conv1_args = [ - "in_channels", - "out_channels", - "kernel_size", - "stride", - "padding", - "dilation", - "groups", - "bias", - "padding_mode", - ] - if not USE_PEFT_BACKEND: - conv1_args.append("lora_layer") - - conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} - conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. - conv1_kwargs["in_channels"] += by # surgery done here - # conv_shortcut - # as we changed the input size of the block, the input and output sizes are likely different, - # therefore we need a conv_shortcut (simply adding won't work) - conv_shortcut_args_kwargs = { - "in_channels": conv1_kwargs["in_channels"], - "out_channels": conv1_kwargs["out_channels"], - # default arguments from resnet.__init__ - "kernel_size": 1, - "stride": 1, - "padding": 0, - "bias": True, - } - # swap old with new modules - unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs) - unet.mid_block.resnets[0].conv1 = ( - nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) - ) - unet.mid_block.resnets[0].conv_shortcut = ( - nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) - ) - unet.mid_block.resnets[0].in_channels += by # surgery done here - - -def adjust_group_norms(unet: UNet2DConditionModel, max_num_group: int = 32): - def find_denominator(number, start): - if start >= number: - return number - while start != 0: - residual = number % start - if residual == 0: - return start - start -= 1 - - for block in [*unet.down_blocks, unet.mid_block]: - # resnets - for r in block.resnets: - if r.norm1.num_groups < max_num_group: - r.norm1.num_groups = find_denominator(r.norm1.num_channels, start=max_num_group) - - if r.norm2.num_groups < max_num_group: - r.norm2.num_groups = find_denominator(r.norm2.num_channels, start=max_num_group) - - # transformers - if hasattr(block, "attentions"): - for a in block.attentions: - if a.norm.num_groups < max_num_group: - a.norm.num_groups = find_denominator(a.norm.num_channels, start=max_num_group) - - -def is_iterable(o): - if isinstance(o, str): - return False - try: - iter(o) - return True - except TypeError: - return False - - -def to_sub_blocks(blocks): - if not is_iterable(blocks): - blocks = [blocks] - - sub_blocks = [] - - for b in blocks: - if hasattr(b, "resnets"): - if hasattr(b, "attentions") and b.attentions is not None: - for r, a in zip(b.resnets, b.attentions): - sub_blocks.append([r, a]) - - num_resnets = len(b.resnets) - num_attns = len(b.attentions) - - if num_resnets > num_attns: - # we can have more resnets than attentions, so add each resnet as separate subblock - for i in range(num_attns, num_resnets): - sub_blocks.append([b.resnets[i]]) - else: - for r in b.resnets: - sub_blocks.append([r]) - - # upsamplers are part of the same subblock - if hasattr(b, "upsamplers") and b.upsamplers is not None: - for u in b.upsamplers: - sub_blocks[-1].extend([u]) - - # downsamplers are own subblock - if hasattr(b, "downsamplers") and b.downsamplers is not None: - for d in b.downsamplers: - sub_blocks.append([d]) - - return list(map(SubBlock, sub_blocks)) - - -def zero_module(module): - for p in module.parameters(): - nn.init.zeros_(p) - return module diff --git a/examples/research_projects/controlnetxs/infer_sd_controlnetxs.py b/examples/research_projects/controlnetxs/infer_sd_controlnetxs.py deleted file mode 100644 index 722b282a3251..000000000000 --- a/examples/research_projects/controlnetxs/infer_sd_controlnetxs.py +++ /dev/null @@ -1,58 +0,0 @@ -# !pip install opencv-python transformers accelerate -import argparse - -import cv2 -import numpy as np -import torch -from controlnetxs import ControlNetXSModel -from PIL import Image -from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline - -from diffusers.utils import load_image - - -parser = argparse.ArgumentParser() -parser.add_argument( - "--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" -) -parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches") -parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7) -parser.add_argument( - "--image_path", - type=str, - default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png", -) -parser.add_argument("--num_inference_steps", type=int, default=50) - -args = parser.parse_args() - -prompt = args.prompt -negative_prompt = args.negative_prompt -# download an image -image = load_image(args.image_path) - -# initialize the models and pipeline -controlnet_conditioning_scale = args.controlnet_conditioning_scale -controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16) -pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16 -) -pipe.enable_model_cpu_offload() - -# get canny image -image = np.array(image) -image = cv2.Canny(image, 100, 200) -image = image[:, :, None] -image = np.concatenate([image, image, image], axis=2) -canny_image = Image.fromarray(image) - -num_inference_steps = args.num_inference_steps - -# generate image -image = pipe( - prompt, - controlnet_conditioning_scale=controlnet_conditioning_scale, - image=canny_image, - num_inference_steps=num_inference_steps, -).images[0] -image.save("cnxs_sd.canny.png") diff --git a/examples/research_projects/controlnetxs/infer_sdxl_controlnetxs.py b/examples/research_projects/controlnetxs/infer_sdxl_controlnetxs.py deleted file mode 100644 index e5b8cfd88223..000000000000 --- a/examples/research_projects/controlnetxs/infer_sdxl_controlnetxs.py +++ /dev/null @@ -1,57 +0,0 @@ -# !pip install opencv-python transformers accelerate -import argparse - -import cv2 -import numpy as np -import torch -from controlnetxs import ControlNetXSModel -from PIL import Image -from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline - -from diffusers.utils import load_image - - -parser = argparse.ArgumentParser() -parser.add_argument( - "--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" -) -parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches") -parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7) -parser.add_argument( - "--image_path", - type=str, - default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png", -) -parser.add_argument("--num_inference_steps", type=int, default=50) - -args = parser.parse_args() - -prompt = args.prompt -negative_prompt = args.negative_prompt -# download an image -image = load_image(args.image_path) -# initialize the models and pipeline -controlnet_conditioning_scale = args.controlnet_conditioning_scale -controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16) -pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 -) -pipe.enable_model_cpu_offload() - -# get canny image -image = np.array(image) -image = cv2.Canny(image, 100, 200) -image = image[:, :, None] -image = np.concatenate([image, image, image], axis=2) -canny_image = Image.fromarray(image) - -num_inference_steps = args.num_inference_steps - -# generate image -image = pipe( - prompt, - controlnet_conditioning_scale=controlnet_conditioning_scale, - image=canny_image, - num_inference_steps=num_inference_steps, -).images[0] -image.save("cnxs_sdxl.canny.png") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 770045923d5d..5d6761663938 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -80,6 +80,7 @@ "AutoencoderTiny", "ConsistencyDecoderVAE", "ControlNetModel", + "ControlNetXSAdapter", "I2VGenXLUNet", "Kandinsky3UNet", "ModelMixin", @@ -94,6 +95,7 @@ "UNet2DConditionModel", "UNet2DModel", "UNet3DConditionModel", + "UNetControlNetXSModel", "UNetMotionModel", "UNetSpatioTemporalConditionModel", "UVit2DModel", @@ -270,6 +272,7 @@ "StableDiffusionControlNetImg2ImgPipeline", "StableDiffusionControlNetInpaintPipeline", "StableDiffusionControlNetPipeline", + "StableDiffusionControlNetXSPipeline", "StableDiffusionDepth2ImgPipeline", "StableDiffusionDiffEditPipeline", "StableDiffusionGLIGENPipeline", @@ -293,6 +296,7 @@ "StableDiffusionXLControlNetImg2ImgPipeline", "StableDiffusionXLControlNetInpaintPipeline", "StableDiffusionXLControlNetPipeline", + "StableDiffusionXLControlNetXSPipeline", "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", @@ -474,6 +478,7 @@ AutoencoderTiny, ConsistencyDecoderVAE, ControlNetModel, + ControlNetXSAdapter, I2VGenXLUNet, Kandinsky3UNet, ModelMixin, @@ -487,6 +492,7 @@ UNet2DConditionModel, UNet2DModel, UNet3DConditionModel, + UNetControlNetXSModel, UNetMotionModel, UNetSpatioTemporalConditionModel, UVit2DModel, @@ -642,6 +648,7 @@ StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, + StableDiffusionControlNetXSPipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionDiffEditPipeline, StableDiffusionGLIGENPipeline, @@ -665,6 +672,7 @@ StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetPipeline, + StableDiffusionXLControlNetXSPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index da77e4450c86..78b0efff921d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -32,6 +32,7 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["controlnet"] = ["ControlNetModel"] + _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] @@ -68,6 +69,7 @@ ConsistencyDecoderVAE, ) from .controlnet import ControlNetModel + from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py new file mode 100644 index 000000000000..4bbe1dd4dc25 --- /dev/null +++ b/src/diffusers/models/controlnet_xs.py @@ -0,0 +1,1892 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from math import gcd +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import FloatTensor, nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, is_torch_version, logging +from ..utils.torch_utils import apply_freeu +from .attention_processor import Attention, AttentionProcessor +from .controlnet import ControlNetConditioningEmbedding +from .embeddings import TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .unets.unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + Downsample2D, + ResnetBlock2D, + Transformer2DModel, + UNetMidBlock2DCrossAttn, + Upsample2D, +) +from .unets.unet_2d_condition import UNet2DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetXSOutput(BaseOutput): + """ + The output of [`UNetControlNetXSModel`]. + + Args: + sample (`FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base + model output, but is already the final output. + """ + + sample: FloatTensor = None + + +class DownBlockControlNetXSAdapter(nn.Module): + """Components that together with corresponding components from the base model will form a + `ControlNetXSCrossAttnDownBlock2D`""" + + def __init__( + self, + resnets: nn.ModuleList, + base_to_ctrl: nn.ModuleList, + ctrl_to_base: nn.ModuleList, + attentions: Optional[nn.ModuleList] = None, + downsampler: Optional[nn.Conv2d] = None, + ): + super().__init__() + self.resnets = resnets + self.base_to_ctrl = base_to_ctrl + self.ctrl_to_base = ctrl_to_base + self.attentions = attentions + self.downsamplers = downsampler + + +class MidBlockControlNetXSAdapter(nn.Module): + """Components that together with corresponding components from the base model will form a + `ControlNetXSCrossAttnMidBlock2D`""" + + def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl: nn.ModuleList, ctrl_to_base: nn.ModuleList): + super().__init__() + self.midblock = midblock + self.base_to_ctrl = base_to_ctrl + self.ctrl_to_base = ctrl_to_base + + +class UpBlockControlNetXSAdapter(nn.Module): + """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnUpBlock2D`""" + + def __init__(self, ctrl_to_base: nn.ModuleList): + super().__init__() + self.ctrl_to_base = ctrl_to_base + + +def get_down_block_adapter( + base_in_channels: int, + base_out_channels: int, + ctrl_in_channels: int, + ctrl_out_channels: int, + temb_channels: int, + max_norm_num_groups: Optional[int] = 32, + has_crossattn=True, + transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, + num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + add_downsample: bool = True, + upcast_attention: Optional[bool] = False, +): + num_layers = 2 # only support sd + sdxl + + resnets = [] + attentions = [] + ctrl_to_base = [] + base_to_ctrl = [] + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + base_in_channels = base_in_channels if i == 0 else base_out_channels + ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels + + # Before the resnet/attention application, information is concatted from base to control. + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) + + resnets.append( + ResnetBlock2D( + in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl + out_channels=ctrl_out_channels, + temb_channels=temb_channels, + groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups), + groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), + eps=1e-5, + ) + ) + + if has_crossattn: + attentions.append( + Transformer2DModel( + num_attention_heads, + ctrl_out_channels // num_attention_heads, + in_channels=ctrl_out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=True, + upcast_attention=upcast_attention, + norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), + ) + ) + + # After the resnet/attention application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + + if add_downsample: + # Before the downsampler application, information is concatted from base to control + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) + + downsamplers = Downsample2D( + ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" + ) + + # After the downsampler application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + else: + downsamplers = None + + down_block_components = DownBlockControlNetXSAdapter( + resnets=nn.ModuleList(resnets), + base_to_ctrl=nn.ModuleList(base_to_ctrl), + ctrl_to_base=nn.ModuleList(ctrl_to_base), + ) + + if has_crossattn: + down_block_components.attentions = nn.ModuleList(attentions) + if downsamplers is not None: + down_block_components.downsamplers = downsamplers + + return down_block_components + + +def get_mid_block_adapter( + base_channels: int, + ctrl_channels: int, + temb_channels: Optional[int] = None, + max_norm_num_groups: Optional[int] = 32, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + upcast_attention: bool = False, +): + # Before the midblock application, information is concatted from base to control. + # Concat doesn't require change in number of channels + base_to_ctrl = make_zero_conv(base_channels, base_channels) + + midblock = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=ctrl_channels + base_channels, + out_channels=ctrl_channels, + temb_channels=temb_channels, + # number or norm groups must divide both in_channels and out_channels + resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups), + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + use_linear_projection=True, + upcast_attention=upcast_attention, + ) + + # After the midblock application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) + + return MidBlockControlNetXSAdapter(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base) + + +def get_up_block_adapter( + out_channels: int, + prev_output_channel: int, + ctrl_skip_channels: List[int], +): + ctrl_to_base = [] + num_layers = 3 # only support sd + sdxl + for i in range(num_layers): + resnet_in_channels = prev_output_channel if i == 0 else out_channels + ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) + + return UpBlockControlNetXSAdapter(ctrl_to_base=nn.ModuleList(ctrl_to_base)) + + +class ControlNetXSAdapter(ModelMixin, ConfigMixin): + r""" + A `ControlNetXSAdapter` model. To use it, pass it into a `UNetControlNetXSModel` (together with a + `UNet2DConditionModel` base model). + + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic + methods implemented for all models (such as downloading or saving). + + Like `UNetControlNetXSModel`, `ControlNetXSAdapter` is compatible with StableDiffusion and StableDiffusion-XL. It's + default parameters are compatible with StableDiffusion. + + Parameters: + conditioning_channels (`int`, defaults to 3): + Number of channels of conditioning input (e.g. an image) + conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channels for each block in the `controlnet_cond_embedding` layer. + time_embedding_mix (`float`, defaults to 1.0): + If 0, then only the control adapters's time embedding is used. If 1, then only the base unet's time + embedding is used. Otherwise, both are combined. + learn_time_embedding (`bool`, defaults to `False`): + Whether a time embedding should be learned. If yes, `UNetControlNetXSModel` will combine the time + embeddings of the base model and the control adapter. If no, `UNetControlNetXSModel` will use the base + model's time embedding. + num_attention_heads (`list[int]`, defaults to `[4]`): + The number of attention heads. + block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`): + The tuple of output channels for each block. + base_block_out_channels (`list[int]`, defaults to `[320, 640, 1280, 1280]`): + The tuple of output channels for each block in the base unet. + cross_attention_dim (`int`, defaults to 1024): + The dimension of the cross attention features. + down_block_types (`list[str]`, defaults to `["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]`): + The tuple of downsample blocks to use. + sample_size (`int`, defaults to 96): + Height and width of input/output sample. + transformer_layers_per_block (`Union[int, Tuple[int]]`, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + upcast_attention (`bool`, defaults to `True`): + Whether the attention computation should always be upcasted. + max_norm_num_groups (`int`, defaults to 32): + Maximum number of groups in group normal. The actual number will the the largest divisor of the respective + channels, that is <= max_norm_num_groups. + """ + + @register_to_config + def __init__( + self, + conditioning_channels: int = 3, + conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + time_embedding_mix: float = 1.0, + learn_time_embedding: bool = False, + num_attention_heads: Union[int, Tuple[int]] = 4, + block_out_channels: Tuple[int] = (4, 8, 16, 16), + base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + cross_attention_dim: int = 1024, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + sample_size: Optional[int] = 96, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + upcast_attention: bool = True, + max_norm_num_groups: int = 32, + ): + super().__init__() + + time_embedding_input_dim = base_block_out_channels[0] + time_embedding_dim = base_block_out_channels[0] * 4 + + # Check inputs + if conditioning_channel_order not in ["rgb", "bgr"]: + raise ValueError(f"unknown `conditioning_channel_order`: {conditioning_channel_order}") + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(transformer_layers_per_block, (list, tuple)): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + if not isinstance(cross_attention_dim, (list, tuple)): + cross_attention_dim = [cross_attention_dim] * len(down_block_types) + # see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why `ControlNetXSAdapter` takes `num_attention_heads` instead of `attention_head_dim` + if not isinstance(num_attention_heads, (list, tuple)): + num_attention_heads = [num_attention_heads] * len(down_block_types) + + if len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # 5 - Create conditioning hint embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + # time + if learn_time_embedding: + self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embedding_dim) + else: + self.time_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.up_connections = nn.ModuleList([]) + + # input + self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) + self.control_to_base_for_conv_in = make_zero_conv(block_out_channels[0], base_block_out_channels[0]) + + # down + base_out_channels = base_block_out_channels[0] + ctrl_out_channels = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + base_in_channels = base_out_channels + base_out_channels = base_block_out_channels[i] + ctrl_in_channels = ctrl_out_channels + ctrl_out_channels = block_out_channels[i] + has_crossattn = "CrossAttn" in down_block_type + is_final_block = i == len(down_block_types) - 1 + + self.down_blocks.append( + get_down_block_adapter( + base_in_channels=base_in_channels, + base_out_channels=base_out_channels, + ctrl_in_channels=ctrl_in_channels, + ctrl_out_channels=ctrl_out_channels, + temb_channels=time_embedding_dim, + max_norm_num_groups=max_norm_num_groups, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block[i], + num_attention_heads=num_attention_heads[i], + cross_attention_dim=cross_attention_dim[i], + add_downsample=not is_final_block, + upcast_attention=upcast_attention, + ) + ) + + # mid + self.mid_block = get_mid_block_adapter( + base_channels=base_block_out_channels[-1], + ctrl_channels=block_out_channels[-1], + temb_channels=time_embedding_dim, + transformer_layers_per_block=transformer_layers_per_block[-1], + num_attention_heads=num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + upcast_attention=upcast_attention, + ) + + # up + # The skip connection channels are the output of the conv_in and of all the down subblocks + ctrl_skip_channels = [block_out_channels[0]] + for i, out_channels in enumerate(block_out_channels): + number_of_subblocks = ( + 3 if i < len(block_out_channels) - 1 else 2 + ) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler + ctrl_skip_channels.extend([out_channels] * number_of_subblocks) + + reversed_base_block_out_channels = list(reversed(base_block_out_channels)) + + base_out_channels = reversed_base_block_out_channels[0] + for i in range(len(down_block_types)): + prev_base_output_channel = base_out_channels + base_out_channels = reversed_base_block_out_channels[i] + ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] + + self.up_connections.append( + get_up_block_adapter( + out_channels=base_out_channels, + prev_output_channel=prev_base_output_channel, + ctrl_skip_channels=ctrl_skip_channels_, + ) + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + size_ratio: Optional[float] = None, + block_out_channels: Optional[List[int]] = None, + num_attention_heads: Optional[List[int]] = None, + learn_time_embedding: bool = False, + time_embedding_mix: int = 1.0, + conditioning_channels: int = 3, + conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + ): + r""" + Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model we want to control. The dimensions of the ControlNetXSAdapter will be adapted to it. + size_ratio (float, *optional*, defaults to `None`): + When given, block_out_channels is set to a fraction of the base model's block_out_channels. Either this + or `block_out_channels` must be given. + block_out_channels (`List[int]`, *optional*, defaults to `None`): + Down blocks output channels in control model. Either this or `size_ratio` must be given. + num_attention_heads (`List[int]`, *optional*, defaults to `None`): + The dimension of the attention heads. The naming seems a bit confusing and it is, see + https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + learn_time_embedding (`bool`, defaults to `False`): + Whether the `ControlNetXSAdapter` should learn a time embedding. + time_embedding_mix (`float`, defaults to 1.0): + If 0, then only the control adapter's time embedding is used. If 1, then only the base unet's time + embedding is used. Otherwise, both are combined. + conditioning_channels (`int`, defaults to 3): + Number of channels of conditioning input (e.g. an image) + conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `controlnet_cond_embedding` layer. + """ + + # Check input + fixed_size = block_out_channels is not None + relative_size = size_ratio is not None + if not (fixed_size ^ relative_size): + raise ValueError( + "Pass exactly one of `block_out_channels` (for absolute sizing) or `size_ratio` (for relative sizing)." + ) + + # Create model + block_out_channels = block_out_channels or [int(b * size_ratio) for b in unet.config.block_out_channels] + if num_attention_heads is None: + # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + num_attention_heads = unet.config.attention_head_dim + + model = cls( + conditioning_channels=conditioning_channels, + conditioning_channel_order=conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + time_embedding_mix=time_embedding_mix, + learn_time_embedding=learn_time_embedding, + num_attention_heads=num_attention_heads, + block_out_channels=block_out_channels, + base_block_out_channels=unet.config.block_out_channels, + cross_attention_dim=unet.config.cross_attention_dim, + down_block_types=unet.config.down_block_types, + sample_size=unet.config.sample_size, + transformer_layers_per_block=unet.config.transformer_layers_per_block, + upcast_attention=unet.config.upcast_attention, + max_norm_num_groups=unet.config.norm_num_groups, + ) + + # ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel + model.to(unet.dtype) + + return model + + def forward(self, *args, **kwargs): + raise ValueError( + "A ControlNetXSAdapter cannot be run by itself. Use it together with a UNet2DConditionModel to instantiate a UNetControlNetXSModel." + ) + + +class UNetControlNetXSModel(ModelMixin, ConfigMixin): + r""" + A UNet fused with a ControlNet-XS adapter model + + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic + methods implemented for all models (such as downloading or saving). + + `UNetControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are + compatible with StableDiffusion. + + It's parameters are either passed to the underlying `UNet2DConditionModel` or used exactly like in + `ControlNetXSAdapter` . See their documentation for details. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + # unet configs + sample_size: Optional[int] = 96, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + norm_num_groups: Optional[int] = 32, + cross_attention_dim: Union[int, Tuple[int]] = 1024, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + num_attention_heads: Union[int, Tuple[int]] = 8, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + upcast_attention: bool = True, + time_cond_proj_dim: Optional[int] = None, + projection_class_embeddings_input_dim: Optional[int] = None, + # additional controlnet configs + time_embedding_mix: float = 1.0, + ctrl_conditioning_channels: int = 3, + ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + ctrl_conditioning_channel_order: str = "rgb", + ctrl_learn_time_embedding: bool = False, + ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16), + ctrl_num_attention_heads: Union[int, Tuple[int]] = 4, + ctrl_max_norm_num_groups: int = 32, + ): + super().__init__() + + if time_embedding_mix < 0 or time_embedding_mix > 1: + raise ValueError("`time_embedding_mix` needs to be between 0 and 1.") + if time_embedding_mix < 1 and not ctrl_learn_time_embedding: + raise ValueError("To use `time_embedding_mix` < 1, `ctrl_learn_time_embedding` must be `True`") + + if addition_embed_type is not None and addition_embed_type != "text_time": + raise ValueError( + "As `UNetControlNetXSModel` currently only supports StableDiffusion and StableDiffusion-XL, `addition_embed_type` must be `None` or `'text_time'`." + ) + + if not isinstance(transformer_layers_per_block, (list, tuple)): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + if not isinstance(cross_attention_dim, (list, tuple)): + cross_attention_dim = [cross_attention_dim] * len(down_block_types) + if not isinstance(num_attention_heads, (list, tuple)): + num_attention_heads = [num_attention_heads] * len(down_block_types) + if not isinstance(ctrl_num_attention_heads, (list, tuple)): + ctrl_num_attention_heads = [ctrl_num_attention_heads] * len(down_block_types) + + base_num_attention_heads = num_attention_heads + + self.in_channels = 4 + + # # Input + self.base_conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=ctrl_block_out_channels[0], + block_out_channels=ctrl_conditioning_embedding_out_channels, + conditioning_channels=ctrl_conditioning_channels, + ) + self.ctrl_conv_in = nn.Conv2d(4, ctrl_block_out_channels[0], kernel_size=3, padding=1) + self.control_to_base_for_conv_in = make_zero_conv(ctrl_block_out_channels[0], block_out_channels[0]) + + # # Time + time_embed_input_dim = block_out_channels[0] + time_embed_dim = block_out_channels[0] * 4 + + self.base_time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0) + self.base_time_embedding = TimestepEmbedding( + time_embed_input_dim, + time_embed_dim, + cond_proj_dim=time_cond_proj_dim, + ) + self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim) + + if addition_embed_type is None: + self.base_add_time_proj = None + self.base_add_embedding = None + else: + self.base_add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.base_add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + # # Create down blocks + down_blocks = [] + base_out_channels = block_out_channels[0] + ctrl_out_channels = ctrl_block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + base_in_channels = base_out_channels + base_out_channels = block_out_channels[i] + ctrl_in_channels = ctrl_out_channels + ctrl_out_channels = ctrl_block_out_channels[i] + has_crossattn = "CrossAttn" in down_block_type + is_final_block = i == len(down_block_types) - 1 + + down_blocks.append( + ControlNetXSCrossAttnDownBlock2D( + base_in_channels=base_in_channels, + base_out_channels=base_out_channels, + ctrl_in_channels=ctrl_in_channels, + ctrl_out_channels=ctrl_out_channels, + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + ctrl_max_norm_num_groups=ctrl_max_norm_num_groups, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block[i], + base_num_attention_heads=base_num_attention_heads[i], + ctrl_num_attention_heads=ctrl_num_attention_heads[i], + cross_attention_dim=cross_attention_dim[i], + add_downsample=not is_final_block, + upcast_attention=upcast_attention, + ) + ) + + # # Create mid block + self.mid_block = ControlNetXSCrossAttnMidBlock2D( + base_channels=block_out_channels[-1], + ctrl_channels=ctrl_block_out_channels[-1], + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + ctrl_max_norm_num_groups=ctrl_max_norm_num_groups, + transformer_layers_per_block=transformer_layers_per_block[-1], + base_num_attention_heads=base_num_attention_heads[-1], + ctrl_num_attention_heads=ctrl_num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + upcast_attention=upcast_attention, + ) + + # # Create up blocks + up_blocks = [] + rev_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + rev_num_attention_heads = list(reversed(base_num_attention_heads)) + rev_cross_attention_dim = list(reversed(cross_attention_dim)) + + # The skip connection channels are the output of the conv_in and of all the down subblocks + ctrl_skip_channels = [ctrl_block_out_channels[0]] + for i, out_channels in enumerate(ctrl_block_out_channels): + number_of_subblocks = ( + 3 if i < len(ctrl_block_out_channels) - 1 else 2 + ) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler + ctrl_skip_channels.extend([out_channels] * number_of_subblocks) + + reversed_block_out_channels = list(reversed(block_out_channels)) + + out_channels = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = out_channels + out_channels = reversed_block_out_channels[i] + in_channels = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] + + has_crossattn = "CrossAttn" in up_block_type + is_final_block = i == len(block_out_channels) - 1 + + up_blocks.append( + ControlNetXSCrossAttnUpBlock2D( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + ctrl_skip_channels=ctrl_skip_channels_, + temb_channels=time_embed_dim, + resolution_idx=i, + has_crossattn=has_crossattn, + transformer_layers_per_block=rev_transformer_layers_per_block[i], + num_attention_heads=rev_num_attention_heads[i], + cross_attention_dim=rev_cross_attention_dim[i], + add_upsample=not is_final_block, + upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, + ) + ) + + self.down_blocks = nn.ModuleList(down_blocks) + self.up_blocks = nn.ModuleList(up_blocks) + + self.base_conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups) + self.base_conv_act = nn.SiLU() + self.base_conv_out = nn.Conv2d(block_out_channels[0], 4, kernel_size=3, padding=1) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet: Optional[ControlNetXSAdapter] = None, + size_ratio: Optional[float] = None, + ctrl_block_out_channels: Optional[List[float]] = None, + time_embedding_mix: Optional[float] = None, + ctrl_optional_kwargs: Optional[Dict] = None, + ): + r""" + Instantiate a [`UNetControlNetXSModel`] from a [`UNet2DConditionModel`] and an optional [`ControlNetXSAdapter`] + . + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model we want to control. + controlnet (`ControlNetXSAdapter`): + The ConntrolNet-XS adapter with which the UNet will be fused. If none is given, a new ConntrolNet-XS + adapter will be created. + size_ratio (float, *optional*, defaults to `None`): + Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. + ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`): + Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details, + where this parameter is called `block_out_channels`. + time_embedding_mix (`float`, *optional*, defaults to None): + Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. + ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`): + Passed to the `init` of the new controlent if no controlent was given. + """ + if controlnet is None: + controlnet = ControlNetXSAdapter.from_unet( + unet, size_ratio, ctrl_block_out_channels, **ctrl_optional_kwargs + ) + else: + if any( + o is not None for o in (size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs) + ): + raise ValueError( + "When a controlnet is passed, none of these parameters should be passed: size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs." + ) + + # # get params + params_for_unet = [ + "sample_size", + "down_block_types", + "up_block_types", + "block_out_channels", + "norm_num_groups", + "cross_attention_dim", + "transformer_layers_per_block", + "addition_embed_type", + "addition_time_embed_dim", + "upcast_attention", + "time_cond_proj_dim", + "projection_class_embeddings_input_dim", + ] + params_for_unet = {k: v for k, v in unet.config.items() if k in params_for_unet} + # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + params_for_unet["num_attention_heads"] = unet.config.attention_head_dim + + params_for_controlnet = [ + "conditioning_channels", + "conditioning_embedding_out_channels", + "conditioning_channel_order", + "learn_time_embedding", + "block_out_channels", + "num_attention_heads", + "max_norm_num_groups", + ] + params_for_controlnet = {"ctrl_" + k: v for k, v in controlnet.config.items() if k in params_for_controlnet} + params_for_controlnet["time_embedding_mix"] = controlnet.config.time_embedding_mix + + # # create model + model = cls.from_config({**params_for_unet, **params_for_controlnet}) + + # # load weights + # from unet + modules_from_unet = [ + "time_embedding", + "conv_in", + "conv_norm_out", + "conv_out", + ] + for m in modules_from_unet: + getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) + + optional_modules_from_unet = [ + "add_time_proj", + "add_embedding", + ] + for m in optional_modules_from_unet: + if hasattr(unet, m) and getattr(unet, m) is not None: + getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) + + # from controlnet + model.controlnet_cond_embedding.load_state_dict(controlnet.controlnet_cond_embedding.state_dict()) + model.ctrl_conv_in.load_state_dict(controlnet.conv_in.state_dict()) + if controlnet.time_embedding is not None: + model.ctrl_time_embedding.load_state_dict(controlnet.time_embedding.state_dict()) + model.control_to_base_for_conv_in.load_state_dict(controlnet.control_to_base_for_conv_in.state_dict()) + + # from both + model.down_blocks = nn.ModuleList( + ControlNetXSCrossAttnDownBlock2D.from_modules(b, c) + for b, c in zip(unet.down_blocks, controlnet.down_blocks) + ) + model.mid_block = ControlNetXSCrossAttnMidBlock2D.from_modules(unet.mid_block, controlnet.mid_block) + model.up_blocks = nn.ModuleList( + ControlNetXSCrossAttnUpBlock2D.from_modules(b, c) + for b, c in zip(unet.up_blocks, controlnet.up_connections) + ) + + # ensure that the UNetControlNetXSModel is the same dtype as the UNet2DConditionModel + model.to(unet.dtype) + + return model + + def freeze_unet_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Freeze everything + for param in self.parameters(): + param.requires_grad = True + + # Unfreeze ControlNetXSAdapter + base_parts = [ + "base_time_proj", + "base_time_embedding", + "base_add_time_proj", + "base_add_embedding", + "base_conv_in", + "base_conv_norm_out", + "base_conv_act", + "base_conv_out", + ] + base_parts = [getattr(self, part) for part in base_parts if getattr(self, part) is not None] + for part in base_parts: + for param in part.parameters(): + param.requires_grad = False + + for d in self.down_blocks: + d.freeze_base_params() + self.mid_block.freeze_base_params() + for u in self.up_blocks: + u.freeze_base_params() + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def forward( + self, + sample: FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: Optional[torch.Tensor] = None, + conditioning_scale: Optional[float] = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + return_dict: bool = True, + apply_control: bool = True, + ) -> Union[ControlNetXSOutput, Tuple]: + """ + The [`ControlNetXSModel`] forward method. + + Args: + sample (`FloatTensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`FloatTensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + How much the control model affects the base model outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + apply_control (`bool`, defaults to `True`): + If `False`, the input is run only through the base model. + + Returns: + [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + + # check channel order + if self.config.ctrl_conditioning_channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.base_time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + if self.config.ctrl_learn_time_embedding and apply_control: + ctrl_temb = self.ctrl_time_embedding(t_emb, timestep_cond) + base_temb = self.base_time_embedding(t_emb, timestep_cond) + interpolation_param = self.config.time_embedding_mix**0.3 + + temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) + else: + temb = self.base_time_embedding(t_emb) + + # added time & text embeddings + aug_emb = None + + if self.config.addition_embed_type is None: + pass + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.base_add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(temb.dtype) + aug_emb = self.base_add_embedding(add_embeds) + else: + raise ValueError( + f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.config.addition_embed_type} is currently not supported." + ) + + temb = temb + aug_emb if aug_emb is not None else temb + + # text embeddings + cemb = encoder_hidden_states + + # Preparation + h_ctrl = h_base = sample + hs_base, hs_ctrl = [], [] + + # Cross Control + guided_hint = self.controlnet_cond_embedding(controlnet_cond) + + # 1 - conv in & down + + h_base = self.base_conv_in(h_base) + h_ctrl = self.ctrl_conv_in(h_ctrl) + if guided_hint is not None: + h_ctrl += guided_hint + if apply_control: + h_base = h_base + self.control_to_base_for_conv_in(h_ctrl) * conditioning_scale # add ctrl -> base + + hs_base.append(h_base) + hs_ctrl.append(h_ctrl) + + for down in self.down_blocks: + h_base, h_ctrl, residual_hb, residual_hc = down( + hidden_states_base=h_base, + hidden_states_ctrl=h_ctrl, + temb=temb, + encoder_hidden_states=cemb, + conditioning_scale=conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + apply_control=apply_control, + ) + hs_base.extend(residual_hb) + hs_ctrl.extend(residual_hc) + + # 2 - mid + h_base, h_ctrl = self.mid_block( + hidden_states_base=h_base, + hidden_states_ctrl=h_ctrl, + temb=temb, + encoder_hidden_states=cemb, + conditioning_scale=conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + apply_control=apply_control, + ) + + # 3 - up + for up in self.up_blocks: + n_resnets = len(up.resnets) + skips_hb = hs_base[-n_resnets:] + skips_hc = hs_ctrl[-n_resnets:] + hs_base = hs_base[:-n_resnets] + hs_ctrl = hs_ctrl[:-n_resnets] + h_base = up( + hidden_states=h_base, + res_hidden_states_tuple_base=skips_hb, + res_hidden_states_tuple_ctrl=skips_hc, + temb=temb, + encoder_hidden_states=cemb, + conditioning_scale=conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + apply_control=apply_control, + ) + + # 4 - conv out + h_base = self.base_conv_norm_out(h_base) + h_base = self.base_conv_act(h_base) + h_base = self.base_conv_out(h_base) + + if not return_dict: + return (h_base,) + + return ControlNetXSOutput(sample=h_base) + + +class ControlNetXSCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + base_in_channels: int, + base_out_channels: int, + ctrl_in_channels: int, + ctrl_out_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + ctrl_max_norm_num_groups: int = 32, + has_crossattn=True, + transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, + base_num_attention_heads: Optional[int] = 1, + ctrl_num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + add_downsample: bool = True, + upcast_attention: Optional[bool] = False, + ): + super().__init__() + base_resnets = [] + base_attentions = [] + ctrl_resnets = [] + ctrl_attentions = [] + ctrl_to_base = [] + base_to_ctrl = [] + + num_layers = 2 # only support sd + sdxl + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + base_in_channels = base_in_channels if i == 0 else base_out_channels + ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels + + # Before the resnet/attention application, information is concatted from base to control. + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) + + base_resnets.append( + ResnetBlock2D( + in_channels=base_in_channels, + out_channels=base_out_channels, + temb_channels=temb_channels, + groups=norm_num_groups, + ) + ) + ctrl_resnets.append( + ResnetBlock2D( + in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl + out_channels=ctrl_out_channels, + temb_channels=temb_channels, + groups=find_largest_factor( + ctrl_in_channels + base_in_channels, max_factor=ctrl_max_norm_num_groups + ), + groups_out=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), + eps=1e-5, + ) + ) + + if has_crossattn: + base_attentions.append( + Transformer2DModel( + base_num_attention_heads, + base_out_channels // base_num_attention_heads, + in_channels=base_out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=True, + upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, + ) + ) + ctrl_attentions.append( + Transformer2DModel( + ctrl_num_attention_heads, + ctrl_out_channels // ctrl_num_attention_heads, + in_channels=ctrl_out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=True, + upcast_attention=upcast_attention, + norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), + ) + ) + + # After the resnet/attention application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + + if add_downsample: + # Before the downsampler application, information is concatted from base to control + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) + + self.base_downsamplers = Downsample2D( + base_out_channels, use_conv=True, out_channels=base_out_channels, name="op" + ) + self.ctrl_downsamplers = Downsample2D( + ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" + ) + + # After the downsampler application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + else: + self.base_downsamplers = None + self.ctrl_downsamplers = None + + self.base_resnets = nn.ModuleList(base_resnets) + self.ctrl_resnets = nn.ModuleList(ctrl_resnets) + self.base_attentions = nn.ModuleList(base_attentions) if has_crossattn else [None] * num_layers + self.ctrl_attentions = nn.ModuleList(ctrl_attentions) if has_crossattn else [None] * num_layers + self.base_to_ctrl = nn.ModuleList(base_to_ctrl) + self.ctrl_to_base = nn.ModuleList(ctrl_to_base) + + self.gradient_checkpointing = False + + @classmethod + def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: DownBlockControlNetXSAdapter): + # get params + def get_first_cross_attention(block): + return block.attentions[0].transformer_blocks[0].attn2 + + base_in_channels = base_downblock.resnets[0].in_channels + base_out_channels = base_downblock.resnets[0].out_channels + ctrl_in_channels = ( + ctrl_downblock.resnets[0].in_channels - base_in_channels + ) # base channels are concatted to ctrl channels in init + ctrl_out_channels = ctrl_downblock.resnets[0].out_channels + temb_channels = base_downblock.resnets[0].time_emb_proj.in_features + num_groups = base_downblock.resnets[0].norm1.num_groups + ctrl_num_groups = ctrl_downblock.resnets[0].norm1.num_groups + if hasattr(base_downblock, "attentions"): + has_crossattn = True + transformer_layers_per_block = len(base_downblock.attentions[0].transformer_blocks) + base_num_attention_heads = get_first_cross_attention(base_downblock).heads + ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads + cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim + upcast_attention = get_first_cross_attention(base_downblock).upcast_attention + else: + has_crossattn = False + transformer_layers_per_block = None + base_num_attention_heads = None + ctrl_num_attention_heads = None + cross_attention_dim = None + upcast_attention = None + add_downsample = base_downblock.downsamplers is not None + + # create model + model = cls( + base_in_channels=base_in_channels, + base_out_channels=base_out_channels, + ctrl_in_channels=ctrl_in_channels, + ctrl_out_channels=ctrl_out_channels, + temb_channels=temb_channels, + norm_num_groups=num_groups, + ctrl_max_norm_num_groups=ctrl_num_groups, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block, + base_num_attention_heads=base_num_attention_heads, + ctrl_num_attention_heads=ctrl_num_attention_heads, + cross_attention_dim=cross_attention_dim, + add_downsample=add_downsample, + upcast_attention=upcast_attention, + ) + + # # load weights + model.base_resnets.load_state_dict(base_downblock.resnets.state_dict()) + model.ctrl_resnets.load_state_dict(ctrl_downblock.resnets.state_dict()) + if has_crossattn: + model.base_attentions.load_state_dict(base_downblock.attentions.state_dict()) + model.ctrl_attentions.load_state_dict(ctrl_downblock.attentions.state_dict()) + if add_downsample: + model.base_downsamplers.load_state_dict(base_downblock.downsamplers[0].state_dict()) + model.ctrl_downsamplers.load_state_dict(ctrl_downblock.downsamplers.state_dict()) + model.base_to_ctrl.load_state_dict(ctrl_downblock.base_to_ctrl.state_dict()) + model.ctrl_to_base.load_state_dict(ctrl_downblock.ctrl_to_base.state_dict()) + + return model + + def freeze_base_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Unfreeze everything + for param in self.parameters(): + param.requires_grad = True + + # Freeze base part + base_parts = [self.base_resnets] + if isinstance(self.base_attentions, nn.ModuleList): # attentions can be a list of Nones + base_parts.append(self.base_attentions) + if self.base_downsamplers is not None: + base_parts.append(self.base_downsamplers) + for part in base_parts: + for param in part.parameters(): + param.requires_grad = False + + def forward( + self, + hidden_states_base: FloatTensor, + temb: FloatTensor, + encoder_hidden_states: Optional[FloatTensor] = None, + hidden_states_ctrl: Optional[FloatTensor] = None, + conditioning_scale: Optional[float] = 1.0, + attention_mask: Optional[FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[FloatTensor] = None, + apply_control: bool = True, + ) -> Tuple[FloatTensor, FloatTensor, Tuple[FloatTensor, ...], Tuple[FloatTensor, ...]]: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + h_base = hidden_states_base + h_ctrl = hidden_states_ctrl + + base_output_states = () + ctrl_output_states = () + + base_blocks = list(zip(self.base_resnets, self.base_attentions)) + ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions)) + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip( + base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base + ): + # concat base -> ctrl + if apply_control: + h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) + + # apply base subblock + if self.training and self.gradient_checkpointing: + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + h_base = torch.utils.checkpoint.checkpoint( + create_custom_forward(b_res), + h_base, + temb, + **ckpt_kwargs, + ) + else: + h_base = b_res(h_base, temb) + + if b_attn is not None: + h_base = b_attn( + h_base, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply ctrl subblock + if apply_control: + if self.training and self.gradient_checkpointing: + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + h_ctrl = torch.utils.checkpoint.checkpoint( + create_custom_forward(c_res), + h_ctrl, + temb, + **ckpt_kwargs, + ) + else: + h_ctrl = c_res(h_ctrl, temb) + if c_attn is not None: + h_ctrl = c_attn( + h_ctrl, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # add ctrl -> base + if apply_control: + h_base = h_base + c2b(h_ctrl) * conditioning_scale + + base_output_states = base_output_states + (h_base,) + ctrl_output_states = ctrl_output_states + (h_ctrl,) + + if self.base_downsamplers is not None: # if we have a base_downsampler, then also a ctrl_downsampler + b2c = self.base_to_ctrl[-1] + c2b = self.ctrl_to_base[-1] + + # concat base -> ctrl + if apply_control: + h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) + # apply base subblock + h_base = self.base_downsamplers(h_base) + # apply ctrl subblock + if apply_control: + h_ctrl = self.ctrl_downsamplers(h_ctrl) + # add ctrl -> base + if apply_control: + h_base = h_base + c2b(h_ctrl) * conditioning_scale + + base_output_states = base_output_states + (h_base,) + ctrl_output_states = ctrl_output_states + (h_ctrl,) + + return h_base, h_ctrl, base_output_states, ctrl_output_states + + +class ControlNetXSCrossAttnMidBlock2D(nn.Module): + def __init__( + self, + base_channels: int, + ctrl_channels: int, + temb_channels: Optional[int] = None, + norm_num_groups: int = 32, + ctrl_max_norm_num_groups: int = 32, + transformer_layers_per_block: int = 1, + base_num_attention_heads: Optional[int] = 1, + ctrl_num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + upcast_attention: bool = False, + ): + super().__init__() + + # Before the midblock application, information is concatted from base to control. + # Concat doesn't require change in number of channels + self.base_to_ctrl = make_zero_conv(base_channels, base_channels) + + self.base_midblock = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=base_channels, + temb_channels=temb_channels, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=base_num_attention_heads, + use_linear_projection=True, + upcast_attention=upcast_attention, + ) + + self.ctrl_midblock = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=ctrl_channels + base_channels, + out_channels=ctrl_channels, + temb_channels=temb_channels, + # number or norm groups must divide both in_channels and out_channels + resnet_groups=find_largest_factor( + gcd(ctrl_channels, ctrl_channels + base_channels), ctrl_max_norm_num_groups + ), + cross_attention_dim=cross_attention_dim, + num_attention_heads=ctrl_num_attention_heads, + use_linear_projection=True, + upcast_attention=upcast_attention, + ) + + # After the midblock application, information is added from control to base + # Addition requires change in number of channels + self.ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) + + self.gradient_checkpointing = False + + @classmethod + def from_modules( + cls, + base_midblock: UNetMidBlock2DCrossAttn, + ctrl_midblock: MidBlockControlNetXSAdapter, + ): + base_to_ctrl = ctrl_midblock.base_to_ctrl + ctrl_to_base = ctrl_midblock.ctrl_to_base + ctrl_midblock = ctrl_midblock.midblock + + # get params + def get_first_cross_attention(midblock): + return midblock.attentions[0].transformer_blocks[0].attn2 + + base_channels = ctrl_to_base.out_channels + ctrl_channels = ctrl_to_base.in_channels + transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks) + temb_channels = base_midblock.resnets[0].time_emb_proj.in_features + num_groups = base_midblock.resnets[0].norm1.num_groups + ctrl_num_groups = ctrl_midblock.resnets[0].norm1.num_groups + base_num_attention_heads = get_first_cross_attention(base_midblock).heads + ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads + cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim + upcast_attention = get_first_cross_attention(base_midblock).upcast_attention + + # create model + model = cls( + base_channels=base_channels, + ctrl_channels=ctrl_channels, + temb_channels=temb_channels, + norm_num_groups=num_groups, + ctrl_max_norm_num_groups=ctrl_num_groups, + transformer_layers_per_block=transformer_layers_per_block, + base_num_attention_heads=base_num_attention_heads, + ctrl_num_attention_heads=ctrl_num_attention_heads, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + + # load weights + model.base_to_ctrl.load_state_dict(base_to_ctrl.state_dict()) + model.base_midblock.load_state_dict(base_midblock.state_dict()) + model.ctrl_midblock.load_state_dict(ctrl_midblock.state_dict()) + model.ctrl_to_base.load_state_dict(ctrl_to_base.state_dict()) + + return model + + def freeze_base_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Unfreeze everything + for param in self.parameters(): + param.requires_grad = True + + # Freeze base part + for param in self.base_midblock.parameters(): + param.requires_grad = False + + def forward( + self, + hidden_states_base: FloatTensor, + temb: FloatTensor, + encoder_hidden_states: FloatTensor, + hidden_states_ctrl: Optional[FloatTensor] = None, + conditioning_scale: Optional[float] = 1.0, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_mask: Optional[FloatTensor] = None, + encoder_attention_mask: Optional[FloatTensor] = None, + apply_control: bool = True, + ) -> Tuple[FloatTensor, FloatTensor]: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + h_base = hidden_states_base + h_ctrl = hidden_states_ctrl + + joint_args = { + "temb": temb, + "encoder_hidden_states": encoder_hidden_states, + "attention_mask": attention_mask, + "cross_attention_kwargs": cross_attention_kwargs, + "encoder_attention_mask": encoder_attention_mask, + } + + if apply_control: + h_ctrl = torch.cat([h_ctrl, self.base_to_ctrl(h_base)], dim=1) # concat base -> ctrl + h_base = self.base_midblock(h_base, **joint_args) # apply base mid block + if apply_control: + h_ctrl = self.ctrl_midblock(h_ctrl, **joint_args) # apply ctrl mid block + h_base = h_base + self.ctrl_to_base(h_ctrl) * conditioning_scale # add ctrl -> base + + return h_base, h_ctrl + + +class ControlNetXSCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + ctrl_skip_channels: List[int], + temb_channels: int, + norm_num_groups: int = 32, + resolution_idx: Optional[int] = None, + has_crossattn=True, + transformer_layers_per_block: int = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1024, + add_upsample: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + ctrl_to_base = [] + + num_layers = 3 # only support sd + sdxl + + self.has_cross_attention = has_crossattn + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + groups=norm_num_groups, + ) + ) + + if has_crossattn: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=True, + upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) if has_crossattn else [None] * num_layers + self.ctrl_to_base = nn.ModuleList(ctrl_to_base) + + if add_upsample: + self.upsamplers = Upsample2D(out_channels, use_conv=True, out_channels=out_channels) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + @classmethod + def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_upblock: UpBlockControlNetXSAdapter): + ctrl_to_base_skip_connections = ctrl_upblock.ctrl_to_base + + # get params + def get_first_cross_attention(block): + return block.attentions[0].transformer_blocks[0].attn2 + + out_channels = base_upblock.resnets[0].out_channels + in_channels = base_upblock.resnets[-1].in_channels - out_channels + prev_output_channels = base_upblock.resnets[0].in_channels - out_channels + ctrl_skip_channelss = [c.in_channels for c in ctrl_to_base_skip_connections] + temb_channels = base_upblock.resnets[0].time_emb_proj.in_features + num_groups = base_upblock.resnets[0].norm1.num_groups + resolution_idx = base_upblock.resolution_idx + if hasattr(base_upblock, "attentions"): + has_crossattn = True + transformer_layers_per_block = len(base_upblock.attentions[0].transformer_blocks) + num_attention_heads = get_first_cross_attention(base_upblock).heads + cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim + upcast_attention = get_first_cross_attention(base_upblock).upcast_attention + else: + has_crossattn = False + transformer_layers_per_block = None + num_attention_heads = None + cross_attention_dim = None + upcast_attention = None + add_upsample = base_upblock.upsamplers is not None + + # create model + model = cls( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channels, + ctrl_skip_channels=ctrl_skip_channelss, + temb_channels=temb_channels, + norm_num_groups=num_groups, + resolution_idx=resolution_idx, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block, + num_attention_heads=num_attention_heads, + cross_attention_dim=cross_attention_dim, + add_upsample=add_upsample, + upcast_attention=upcast_attention, + ) + + # load weights + model.resnets.load_state_dict(base_upblock.resnets.state_dict()) + if has_crossattn: + model.attentions.load_state_dict(base_upblock.attentions.state_dict()) + if add_upsample: + model.upsamplers.load_state_dict(base_upblock.upsamplers[0].state_dict()) + model.ctrl_to_base.load_state_dict(ctrl_to_base_skip_connections.state_dict()) + + return model + + def freeze_base_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Unfreeze everything + for param in self.parameters(): + param.requires_grad = True + + # Freeze base part + base_parts = [self.resnets] + if isinstance(self.attentions, nn.ModuleList): # attentions can be a list of Nones + base_parts.append(self.attentions) + if self.upsamplers is not None: + base_parts.append(self.upsamplers) + for part in base_parts: + for param in part.parameters(): + param.requires_grad = False + + def forward( + self, + hidden_states: FloatTensor, + res_hidden_states_tuple_base: Tuple[FloatTensor, ...], + res_hidden_states_tuple_ctrl: Tuple[FloatTensor, ...], + temb: FloatTensor, + encoder_hidden_states: Optional[FloatTensor] = None, + conditioning_scale: Optional[float] = 1.0, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_mask: Optional[FloatTensor] = None, + upsample_size: Optional[int] = None, + encoder_attention_mask: Optional[FloatTensor] = None, + apply_control: bool = True, + ) -> FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + return apply_freeu( + self.resolution_idx, + hidden_states, + res_h_base, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + else: + return hidden_states, res_h_base + + for resnet, attn, c2b, res_h_base, res_h_ctrl in zip( + self.resnets, + self.attentions, + self.ctrl_to_base, + reversed(res_hidden_states_tuple_base), + reversed(res_hidden_states_tuple_ctrl), + ): + if apply_control: + hidden_states += c2b(res_h_ctrl) * conditioning_scale + + hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base) + hidden_states = torch.cat([hidden_states, res_h_base], dim=1) + + if self.training and self.gradient_checkpointing: + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + + if attn is not None: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + hidden_states = self.upsamplers(hidden_states, upsample_size) + + return hidden_states + + +def make_zero_conv(in_channels, out_channels=None): + return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +def find_largest_factor(number, max_factor): + factor = max_factor + if factor >= number: + return number + while factor != 0: + residual = number % factor + if residual == 0: + return factor + factor -= 1 diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index d54630376961..ef75fad25e44 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -746,6 +746,7 @@ def __init__( self, in_channels: int, temb_channels: int, + out_channels: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, @@ -753,6 +754,7 @@ def __init__( resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, + resnet_groups_out: Optional[int] = None, resnet_pre_norm: bool = True, num_attention_heads: int = 1, output_scale_factor: float = 1.0, @@ -764,6 +766,10 @@ def __init__( ): super().__init__() + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) @@ -772,14 +778,17 @@ def __init__( if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers + resnet_groups_out = resnet_groups_out or resnet_groups + # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, - out_channels=in_channels, + out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, + groups_out=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, @@ -794,11 +803,11 @@ def __init__( attentions.append( Transformer2DModel( num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, + out_channels // num_attention_heads, + in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, + norm_num_groups=resnet_groups_out, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, @@ -808,8 +817,8 @@ def __init__( attentions.append( DualTransformer2DModel( num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, + out_channels // num_attention_heads, + in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, @@ -817,11 +826,11 @@ def __init__( ) resnets.append( ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, + in_channels=out_channels, + out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, - groups=resnet_groups, + groups=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 2b2277809b38..ab7c13b56eb8 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -134,6 +134,12 @@ "StableDiffusionXLControlNetPipeline", ] ) + _import_structure["controlnet_xs"].extend( + [ + "StableDiffusionControlNetXSPipeline", + "StableDiffusionXLControlNetXSPipeline", + ] + ) _import_structure["deepfloyd_if"] = [ "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -378,6 +384,10 @@ StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetPipeline, ) + from .controlnet_xs import ( + StableDiffusionControlNetXSPipeline, + StableDiffusionXLControlNetXSPipeline, + ) from .deepfloyd_if import ( IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, diff --git a/src/diffusers/pipelines/controlnet_xs/__init__.py b/src/diffusers/pipelines/controlnet_xs/__init__.py new file mode 100644 index 000000000000..978278b184f9 --- /dev/null +++ b/src/diffusers/pipelines/controlnet_xs/__init__.py @@ -0,0 +1,68 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"] + _import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"] +try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_flax_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) +else: + pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline + from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/examples/research_projects/controlnetxs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py similarity index 82% rename from examples/research_projects/controlnetxs/pipeline_controlnet_xs.py rename to src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 88a586e9271d..2f450b9c2cea 100644 --- a/examples/research_projects/controlnetxs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -19,30 +19,75 @@ import PIL.Image import torch import torch.nn.functional as F -from controlnetxs import ControlNetXSModel from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import ( +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, + replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) -from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSAdapter + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 + + >>> controlnet = ControlNetXSAdapter.from_pretrained( + ... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] + ``` +""" + + class StableDiffusionControlNetXSPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): @@ -56,7 +101,7 @@ class StableDiffusionControlNetXSPipeline( - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files Args: vae ([`AutoencoderKL`]): @@ -66,9 +111,9 @@ class StableDiffusionControlNetXSPipeline( tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): - A `UNet2DConditionModel` to denoise the encoded image latents. - controlnet ([`ControlNetXSModel`]): - Provides additional conditioning to the `unet` during the denoising process. + A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents. + controlnet ([`ControlNetXSAdapter`]): + A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. @@ -80,17 +125,18 @@ class StableDiffusionControlNetXSPipeline( A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->unet->vae>controlnet" + model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - controlnet: ControlNetXSModel, + unet: Union[UNet2DConditionModel, UNetControlNetXSModel], + controlnet: ControlNetXSAdapter, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, @@ -98,6 +144,9 @@ def __init__( ): super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetControlNetXSModel.from_unet(unet, controlnet) + if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" @@ -114,14 +163,6 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible( - vae - ) - if not vae_compatible: - raise ValueError( - f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." - ) - self.register_modules( vae=vae, text_encoder=text_encoder, @@ -403,20 +444,19 @@ def check_inputs( self, prompt, image, - callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: @@ -445,25 +485,16 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - # Check `image` + # Check `image` and `controlnet_conditioning_scale` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule + self.unet, torch._dynamo.eval_frame.OptimizedModule ) if ( - isinstance(self.controlnet, ControlNetXSModel) + isinstance(self.unet, UNetControlNetXSModel) or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + and isinstance(self.unet._orig_mod, UNetControlNetXSModel) ): self.check_image(image, prompt, prompt_embeds) - else: - assert False - - # Check `controlnet_conditioning_scale` - if ( - isinstance(self.controlnet, ControlNetXSModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) - ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") else: @@ -563,7 +594,33 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale + def guidance_scale(self): + return self._guidance_scale + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip + def clip_skip(self): + return self._clip_skip + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -581,13 +638,13 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, control_guidance_start: float = 0.0, control_guidance_end: float = 1.0, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): r""" The call function to the pipeline for generation. @@ -595,7 +652,7 @@ def __call__( Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be @@ -639,12 +696,6 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -659,7 +710,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: Returns: @@ -669,21 +728,27 @@ def __call__( second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, image, - callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -713,6 +778,7 @@ def __call__( lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes @@ -720,27 +786,24 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare image - if isinstance(controlnet, ControlNetXSModel): - image = self.prepare_image( - image=image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - ) - height, width = image.shape[-2:] - else: - assert False + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=unet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare latent variables - num_channels_latents = self.unet.config.in_channels + num_channels_latents = self.unet.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -757,42 +820,33 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - is_unet_compiled = is_compiled_module(self.unet) - is_controlnet_compiled = is_compiled_module(self.controlnet) + self._num_timesteps = len(timesteps) + is_controlnet_compiled = is_compiled_module(self.unet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + if is_controlnet_compiled and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - dont_control = ( - i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end + apply_control = ( + i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end ) - if dont_control: - noise_pred = self.unet( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=True, - ).sample - else: - noise_pred = self.controlnet( - base_model=self.unet, - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - controlnet_cond=image, - conditioning_scale=controlnet_conditioning_scale, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=True, - ).sample + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=True, + apply_control=apply_control, + ).sample # perform guidance if do_classifier_free_guidance: @@ -801,12 +855,18 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings diff --git a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py similarity index 83% rename from examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py rename to src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index d0186573fa9c..ff270d20d11e 100644 --- a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -19,41 +19,94 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available -from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel -from diffusers.models.attention_processor import ( +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel +from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) -from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput -from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import ( +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( USE_PEFT_BACKEND, + deprecate, logging, + replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) -from diffusers.utils.import_utils import is_invisible_watermark_available -from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput if is_invisible_watermark_available(): - from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSAdapter, AutoencoderKL + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> controlnet = ControlNetXSAdapter.from_pretrained( + ... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] + ``` +""" + + class StableDiffusionXLControlNetXSPipeline( DiffusionPipeline, - StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, @@ -66,9 +119,8 @@ class StableDiffusionXLControlNetXSPipeline( The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files Args: vae ([`AutoencoderKL`]): @@ -83,9 +135,9 @@ class StableDiffusionXLControlNetXSPipeline( tokenizer_2 ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): - A `UNet2DConditionModel` to denoise the encoded image latents. - controlnet ([`ControlNetXSModel`]: - Provides additional conditioning to the `unet` during the denoising process. + A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents. + controlnet ([`ControlNetXSAdapter`]): + A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. @@ -98,9 +150,15 @@ class StableDiffusionXLControlNetXSPipeline( watermarker is used. """ - # leave controlnet out on purpose because it iterates with unet - model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae->controlnet" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + ] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -109,21 +167,17 @@ def __init__( text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, - unet: UNet2DConditionModel, - controlnet: ControlNetXSModel, + unet: Union[UNet2DConditionModel, UNetControlNetXSModel], + controlnet: ControlNetXSAdapter, scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, ): super().__init__() - vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible( - vae - ) - if not vae_compatible: - raise ValueError( - f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." - ) + if isinstance(unet, UNet2DConditionModel): + unet = UNetControlNetXSModel.from_unet(unet, controlnet) self.register_modules( vae=vae, @@ -134,6 +188,7 @@ def __init__( unet=unet, controlnet=controlnet, scheduler=scheduler, + feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) @@ -417,15 +472,21 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -474,25 +535,16 @@ def check_inputs( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) - # Check `image` + # Check `image` and ``controlnet_conditioning_scale`` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule + self.unet, torch._dynamo.eval_frame.OptimizedModule ) if ( - isinstance(self.controlnet, ControlNetXSModel) + isinstance(self.unet, UNetControlNetXSModel) or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + and isinstance(self.unet._orig_mod, UNetControlNetXSModel) ): self.check_image(image, prompt, prompt_embeds) - else: - assert False - - # Check `controlnet_conditioning_scale` - if ( - isinstance(self.controlnet, ControlNetXSModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) - ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") else: @@ -593,7 +645,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): @@ -602,7 +653,7 @@ def _get_add_time_ids( passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + expected_add_embed_dim = self.unet.base_add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( @@ -632,7 +683,33 @@ def upcast_vae(self): self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale + def guidance_scale(self): + return self._guidance_scale + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip + def clip_skip(self): + return self._clip_skip + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -654,8 +731,6 @@ def __call__( negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, control_guidance_start: float = 0.0, @@ -667,6 +742,9 @@ def __call__( negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" The call function to the pipeline for generation. @@ -677,7 +755,7 @@ def __call__( prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders. - image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be @@ -735,12 +813,6 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -783,6 +855,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: @@ -791,7 +872,24 @@ def __call__( If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is returned, otherwise a `tuple` is returned containing the output images. """ - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -808,8 +906,14 @@ def __call__( controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -850,7 +954,7 @@ def __call__( ) # 4. Prepare image - if isinstance(controlnet, ControlNetXSModel): + if isinstance(unet, UNetControlNetXSModel): image = self.prepare_image( image=image, width=width, @@ -858,7 +962,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=controlnet.dtype, + dtype=unet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, ) height, width = image.shape[-2:] @@ -870,7 +974,7 @@ def __call__( timesteps = self.scheduler.timesteps # 6. Prepare latent variables - num_channels_latents = self.unet.config.in_channels + num_channels_latents = self.unet.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -928,14 +1032,14 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - is_unet_compiled = is_compiled_module(self.unet) - is_controlnet_compiled = is_compiled_module(self.controlnet) + self._num_timesteps = len(timesteps) + is_controlnet_compiled = is_compiled_module(self.unet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + if is_controlnet_compiled and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -944,30 +1048,20 @@ def __call__( added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # predict the noise residual - dont_control = ( - i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end + apply_control = ( + i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end ) - if dont_control: - noise_pred = self.unet( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=True, - ).sample - else: - noise_pred = self.controlnet( - base_model=self.unet, - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - controlnet_cond=image, - conditioning_scale=controlnet_conditioning_scale, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=True, - ).sample + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=True, + apply_control=apply_control, + ).sample # perform guidance if do_classifier_free_guidance: @@ -977,6 +1071,16 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() @@ -984,6 +1088,11 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + # manually for max memory savings + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 1c55d088aa0a..3c3bd526692d 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -2238,6 +2238,7 @@ def __init__( self, in_channels: int, temb_channels: int, + out_channels: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, @@ -2245,6 +2246,7 @@ def __init__( resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, + resnet_groups_out: Optional[int] = None, resnet_pre_norm: bool = True, num_attention_heads: int = 1, output_scale_factor: float = 1.0, @@ -2256,6 +2258,10 @@ def __init__( ): super().__init__() + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) @@ -2264,14 +2270,17 @@ def __init__( if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers + resnet_groups_out = resnet_groups_out or resnet_groups + # there is always at least one resnet resnets = [ ResnetBlockFlat( in_channels=in_channels, - out_channels=in_channels, + out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, + groups_out=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, @@ -2286,11 +2295,11 @@ def __init__( attentions.append( Transformer2DModel( num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, + out_channels // num_attention_heads, + in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, + norm_num_groups=resnet_groups_out, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, @@ -2300,8 +2309,8 @@ def __init__( attentions.append( DualTransformer2DModel( num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, + out_channels // num_attention_heads, + in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, @@ -2309,11 +2318,11 @@ def __init__( ) resnets.append( ResnetBlockFlat( - in_channels=in_channels, - out_channels=in_channels, + in_channels=out_channels, + out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, - groups=resnet_groups, + groups=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 14947848a43f..b04006cb5ee6 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -92,6 +92,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ControlNetXSAdapter(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class I2VGenXLUNet(metaclass=DummyObject): _backends = ["torch"] @@ -287,6 +302,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class UNetControlNetXSModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class UNetMotionModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index f64c15702087..8ad2f4b4876d 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -902,6 +902,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionControlNetXSPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1247,6 +1262,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py new file mode 100644 index 000000000000..8c9b43a20ad6 --- /dev/null +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -0,0 +1,352 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import unittest + +import numpy as np +import torch +from torch import nn + +from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel +from diffusers.utils import logging +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin + + +logger = logging.get_logger(__name__) + +enable_full_determinism() + + +class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = UNetControlNetXSModel + main_input_name = "sample" + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + sizes = (16, 16) + conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) + controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device) + conditioning_scale = 1 + + return { + "sample": noise, + "timestep": time_step, + "encoder_hidden_states": encoder_hidden_states, + "controlnet_cond": controlnet_cond, + "conditioning_scale": conditioning_scale, + } + + @property + def input_shape(self): + return (4, 16, 16) + + @property + def output_shape(self): + return (4, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "sample_size": 16, + "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"), + "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"), + "block_out_channels": (4, 8), + "cross_attention_dim": 8, + "transformer_layers_per_block": 1, + "num_attention_heads": 2, + "norm_num_groups": 4, + "upcast_attention": False, + "ctrl_block_out_channels": [2, 4], + "ctrl_num_attention_heads": 4, + "ctrl_max_norm_num_groups": 2, + "ctrl_conditioning_embedding_out_channels": (2, 2), + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def get_dummy_unet(self): + """For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter""" + return UNet2DConditionModel( + block_out_channels=(4, 8), + layers_per_block=2, + sample_size=16, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=8, + norm_num_groups=4, + use_linear_projection=True, + ) + + def get_dummy_controlnet_from_unet(self, unet, **kwargs): + """For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter""" + # size_ratio and conditioning_embedding_out_channels chosen to keep model small + return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs) + + def test_from_unet(self): + unet = self.get_dummy_unet() + controlnet = self.get_dummy_controlnet_from_unet(unet) + + model = UNetControlNetXSModel.from_unet(unet, controlnet) + model_state_dict = model.state_dict() + + def assert_equal_weights(module, weight_dict_prefix): + for param_name, param_value in module.named_parameters(): + assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value) + + # # check unet + # everything expect down,mid,up blocks + modules_from_unet = [ + "time_embedding", + "conv_in", + "conv_norm_out", + "conv_out", + ] + for p in modules_from_unet: + assert_equal_weights(getattr(unet, p), "base_" + p) + optional_modules_from_unet = [ + "class_embedding", + "add_time_proj", + "add_embedding", + ] + for p in optional_modules_from_unet: + if hasattr(unet, p) and getattr(unet, p) is not None: + assert_equal_weights(getattr(unet, p), "base_" + p) + # down blocks + assert len(unet.down_blocks) == len(model.down_blocks) + for i, d in enumerate(unet.down_blocks): + assert_equal_weights(d.resnets, f"down_blocks.{i}.base_resnets") + if hasattr(d, "attentions"): + assert_equal_weights(d.attentions, f"down_blocks.{i}.base_attentions") + if hasattr(d, "downsamplers") and getattr(d, "downsamplers") is not None: + assert_equal_weights(d.downsamplers[0], f"down_blocks.{i}.base_downsamplers") + # mid block + assert_equal_weights(unet.mid_block, "mid_block.base_midblock") + # up blocks + assert len(unet.up_blocks) == len(model.up_blocks) + for i, u in enumerate(unet.up_blocks): + assert_equal_weights(u.resnets, f"up_blocks.{i}.resnets") + if hasattr(u, "attentions"): + assert_equal_weights(u.attentions, f"up_blocks.{i}.attentions") + if hasattr(u, "upsamplers") and getattr(u, "upsamplers") is not None: + assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers") + + # # check controlnet + # everything expect down,mid,up blocks + modules_from_controlnet = { + "controlnet_cond_embedding": "controlnet_cond_embedding", + "conv_in": "ctrl_conv_in", + "control_to_base_for_conv_in": "control_to_base_for_conv_in", + } + optional_modules_from_controlnet = {"time_embedding": "ctrl_time_embedding"} + for name_in_controlnet, name_in_unetcnxs in modules_from_controlnet.items(): + assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs) + + for name_in_controlnet, name_in_unetcnxs in optional_modules_from_controlnet.items(): + if hasattr(controlnet, name_in_controlnet) and getattr(controlnet, name_in_controlnet) is not None: + assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs) + # down blocks + assert len(controlnet.down_blocks) == len(model.down_blocks) + for i, d in enumerate(controlnet.down_blocks): + assert_equal_weights(d.resnets, f"down_blocks.{i}.ctrl_resnets") + assert_equal_weights(d.base_to_ctrl, f"down_blocks.{i}.base_to_ctrl") + assert_equal_weights(d.ctrl_to_base, f"down_blocks.{i}.ctrl_to_base") + if d.attentions is not None: + assert_equal_weights(d.attentions, f"down_blocks.{i}.ctrl_attentions") + if d.downsamplers is not None: + assert_equal_weights(d.downsamplers, f"down_blocks.{i}.ctrl_downsamplers") + # mid block + assert_equal_weights(controlnet.mid_block.base_to_ctrl, "mid_block.base_to_ctrl") + assert_equal_weights(controlnet.mid_block.midblock, "mid_block.ctrl_midblock") + assert_equal_weights(controlnet.mid_block.ctrl_to_base, "mid_block.ctrl_to_base") + # up blocks + assert len(controlnet.up_connections) == len(model.up_blocks) + for i, u in enumerate(controlnet.up_connections): + assert_equal_weights(u.ctrl_to_base, f"up_blocks.{i}.ctrl_to_base") + + def test_freeze_unet(self): + def assert_frozen(module): + for p in module.parameters(): + assert not p.requires_grad + + def assert_unfrozen(module): + for p in module.parameters(): + assert p.requires_grad + + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = UNetControlNetXSModel(**init_dict) + model.freeze_unet_params() + + # # check unet + # everything expect down,mid,up blocks + modules_from_unet = [ + model.base_time_embedding, + model.base_conv_in, + model.base_conv_norm_out, + model.base_conv_out, + ] + for m in modules_from_unet: + assert_frozen(m) + + optional_modules_from_unet = [ + model.base_add_time_proj, + model.base_add_embedding, + ] + for m in optional_modules_from_unet: + if m is not None: + assert_frozen(m) + + # down blocks + for i, d in enumerate(model.down_blocks): + assert_frozen(d.base_resnets) + if isinstance(d.base_attentions, nn.ModuleList): # attentions can be list of Nones + assert_frozen(d.base_attentions) + if d.base_downsamplers is not None: + assert_frozen(d.base_downsamplers) + + # mid block + assert_frozen(model.mid_block.base_midblock) + + # up blocks + for i, u in enumerate(model.up_blocks): + assert_frozen(u.resnets) + if isinstance(u.attentions, nn.ModuleList): # attentions can be list of Nones + assert_frozen(u.attentions) + if u.upsamplers is not None: + assert_frozen(u.upsamplers) + + # # check controlnet + # everything expect down,mid,up blocks + modules_from_controlnet = [ + model.controlnet_cond_embedding, + model.ctrl_conv_in, + model.control_to_base_for_conv_in, + ] + optional_modules_from_controlnet = [model.ctrl_time_embedding] + + for m in modules_from_controlnet: + assert_unfrozen(m) + for m in optional_modules_from_controlnet: + if m is not None: + assert_unfrozen(m) + + # down blocks + for d in model.down_blocks: + assert_unfrozen(d.ctrl_resnets) + assert_unfrozen(d.base_to_ctrl) + assert_unfrozen(d.ctrl_to_base) + if isinstance(d.ctrl_attentions, nn.ModuleList): # attentions can be list of Nones + assert_unfrozen(d.ctrl_attentions) + if d.ctrl_downsamplers is not None: + assert_unfrozen(d.ctrl_downsamplers) + # mid block + assert_unfrozen(model.mid_block.base_to_ctrl) + assert_unfrozen(model.mid_block.ctrl_midblock) + assert_unfrozen(model.mid_block.ctrl_to_base) + # up blocks + for u in model.up_blocks: + assert_unfrozen(u.ctrl_to_base) + + def test_gradient_checkpointing_is_applied(self): + model_class_copy = copy.copy(UNetControlNetXSModel) + + modules_with_gc_enabled = {} + + # now monkey patch the following function: + # def _set_gradient_checkpointing(self, module, value=False): + # if hasattr(module, "gradient_checkpointing"): + # module.gradient_checkpointing = value + + def _set_gradient_checkpointing_new(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + modules_with_gc_enabled[module.__class__.__name__] = True + + model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new + + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = model_class_copy(**init_dict) + + model.enable_gradient_checkpointing() + + EXPECTED_SET = { + "Transformer2DModel", + "UNetMidBlock2DCrossAttn", + "ControlNetXSCrossAttnDownBlock2D", + "ControlNetXSCrossAttnMidBlock2D", + "ControlNetXSCrossAttnUpBlock2D", + } + + assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET + assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + + def test_forward_no_control(self): + unet = self.get_dummy_unet() + controlnet = self.get_dummy_controlnet_from_unet(unet) + + model = UNetControlNetXSModel.from_unet(unet, controlnet) + + unet = unet.to(torch_device) + model = model.to(torch_device) + + input_ = self.dummy_input + + control_specific_input = ["controlnet_cond", "conditioning_scale"] + input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input} + + with torch.no_grad(): + unet_output = unet(**input_for_unet).sample.cpu() + unet_controlnet_output = model(**input_, apply_control=False).sample.cpu() + + assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 3e-4 + + def test_time_embedding_mixing(self): + unet = self.get_dummy_unet() + controlnet = self.get_dummy_controlnet_from_unet(unet) + controlnet_mix_time = self.get_dummy_controlnet_from_unet( + unet, time_embedding_mix=0.5, learn_time_embedding=True + ) + + model = UNetControlNetXSModel.from_unet(unet, controlnet) + model_mix_time = UNetControlNetXSModel.from_unet(unet, controlnet_mix_time) + + unet = unet.to(torch_device) + model = model.to(torch_device) + model_mix_time = model_mix_time.to(torch_device) + + input_ = self.dummy_input + + with torch.no_grad(): + output = model(**input_).sample + output_mix_time = model_mix_time(**input_).sample + + assert output.shape == output_mix_time.shape + + def test_forward_with_norm_groups(self): + # UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups. + pass diff --git a/tests/pipelines/controlnet_xs/__init__.py b/tests/pipelines/controlnet_xs/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py new file mode 100644 index 000000000000..5ac78129ef34 --- /dev/null +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -0,0 +1,366 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import traceback +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AsymmetricAutoencoderKL, + AutoencoderKL, + AutoencoderTiny, + ConsistencyDecoderVAE, + ControlNetXSAdapter, + DDIMScheduler, + LCMScheduler, + StableDiffusionControlNetXSPipeline, + UNet2DConditionModel, +) +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import ( + enable_full_determinism, + load_image, + load_numpy, + require_python39_or_higher, + require_torch_2, + require_torch_gpu, + run_test_in_subprocess, + slow, + torch_device, +) +from diffusers.utils.torch_utils import randn_tensor + +from ...models.autoencoders.test_models_vae import ( + get_asym_autoencoder_kl_config, + get_autoencoder_kl_config, + get_autoencoder_tiny_config, + get_consistency_vae_config, +) +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + SDFunctionTesterMixin, +) + + +enable_full_determinism() + + +# Will be run via run_test_in_subprocess +def _test_stable_diffusion_compile(in_queue, out_queue, timeout): + error = None + try: + _ = in_queue.get(timeout=timeout) + + controlnet = ControlNetXSAdapter.from_pretrained( + "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 + ) + pipe = StableDiffusionControlNetXSPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", + controlnet=controlnet, + safety_checker=None, + torch_dtype=torch.float16, + ) + pipe.to("cuda") + pipe.set_progress_bar_config(disable=None) + + pipe.unet.to(memory_format=torch.channels_last) + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "bird" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ).resize((512, 512)) + + output = pipe(prompt, image, num_inference_steps=10, generator=generator, output_type="np") + image = output.images[0] + + assert image.shape == (512, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy" + ) + expected_image = np.resize(expected_image, (512, 512, 3)) + + assert np.abs(expected_image - image).max() < 1.0 + + except Exception: + error = f"{traceback.format_exc()}" + + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() + + +class ControlNetXSPipelineFastTests( + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + SDFunctionTesterMixin, + unittest.TestCase, +): + pipeline_class = StableDiffusionControlNetXSPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + test_attention_slicing = False + + def get_dummy_components(self, time_cond_proj_dim=None): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(4, 8), + layers_per_block=2, + sample_size=16, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=8, + norm_num_groups=4, + time_cond_proj_dim=time_cond_proj_dim, + use_linear_projection=True, + ) + torch.manual_seed(0) + controlnet = ControlNetXSAdapter.from_unet( + unet=unet, + size_ratio=1, + learn_time_embedding=True, + conditioning_embedding_out_channels=(2, 2), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[4, 8], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + norm_num_groups=2, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=8, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + image = randn_tensor( + (1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": image, + } + + return inputs + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + def test_controlnet_lcm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components(time_cond_proj_dim=8) + sd_pipe = StableDiffusionControlNetXSPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 16, 16, 3) + expected_slice = np.array([0.745, 0.753, 0.767, 0.543, 0.523, 0.502, 0.314, 0.521, 0.478]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + + pipe.to(dtype=torch.float16) + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) + + def test_multi_vae(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + block_out_channels = pipe.vae.config.block_out_channels + norm_num_groups = pipe.vae.config.norm_num_groups + + vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny] + configs = [ + get_autoencoder_kl_config(block_out_channels, norm_num_groups), + get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups), + get_consistency_vae_config(block_out_channels, norm_num_groups), + get_autoencoder_tiny_config(block_out_channels), + ] + + out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0] + + for vae_cls, config in zip(vae_classes, configs): + vae = vae_cls(**config) + vae = vae.to(torch_device) + components["vae"] = vae + vae_pipe = self.pipeline_class(**components) + + # pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device. + # So we need to move the new pipe to device. + vae_pipe.to(torch_device) + vae_pipe.set_progress_bar_config(disable=None) + + out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0] + + assert out_vae_np.shape == out_np.shape + + +@slow +@require_torch_gpu +class ControlNetXSPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_canny(self): + controlnet = ControlNetXSAdapter.from_pretrained( + "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 + ) + pipe = StableDiffusionControlNetXSPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "bird" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + + output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3) + + image = output.images[0] + + assert image.shape == (768, 512, 3) + + original_image = image[-3:, -3:, -1].flatten() + expected_image = np.array([0.1963, 0.229, 0.2659, 0.2109, 0.2332, 0.2827, 0.2534, 0.2422, 0.2808]) + assert np.allclose(original_image, expected_image, atol=1e-04) + + def test_depth(self): + controlnet = ControlNetXSAdapter.from_pretrained( + "UmerHA/Testing-ConrolNetXS-SD2.1-depth", torch_dtype=torch.float16 + ) + pipe = StableDiffusionControlNetXSPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "Stormtrooper's lecture" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png" + ) + + output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3) + + image = output.images[0] + + assert image.shape == (512, 512, 3) + + original_image = image[-3:, -3:, -1].flatten() + expected_image = np.array([0.4844, 0.4937, 0.4956, 0.4663, 0.5039, 0.5044, 0.4565, 0.4883, 0.4941]) + assert np.allclose(original_image, expected_image, atol=1e-04) + + @require_python39_or_higher + @require_torch_2 + def test_stable_diffusion_compile(self): + run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None) diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py new file mode 100644 index 000000000000..ee0d15ec3472 --- /dev/null +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -0,0 +1,425 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import ( + AsymmetricAutoencoderKL, + AutoencoderKL, + AutoencoderTiny, + ConsistencyDecoderVAE, + ControlNetXSAdapter, + EulerDiscreteScheduler, + StableDiffusionXLControlNetXSPipeline, + UNet2DConditionModel, +) +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device +from diffusers.utils.torch_utils import randn_tensor + +from ...models.autoencoders.test_models_vae import ( + get_asym_autoencoder_kl_config, + get_autoencoder_kl_config, + get_autoencoder_tiny_config, + get_consistency_vae_config, +) +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, +) + + +enable_full_determinism() + + +class StableDiffusionXLControlNetXSPipelineFastTests( + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, + unittest.TestCase, +): + pipeline_class = StableDiffusionXLControlNetXSPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + test_attention_slicing = False + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(4, 8), + layers_per_block=2, + sample_size=16, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + use_linear_projection=True, + norm_num_groups=4, + # SD2-specific config below + attention_head_dim=(2, 4), + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=56, # 6 * 8 (addition_time_embed_dim) + 8 (cross_attention_dim) + cross_attention_dim=8, + ) + torch.manual_seed(0) + controlnet = ControlNetXSAdapter.from_unet( + unet=unet, + size_ratio=0.5, + learn_time_embedding=True, + conditioning_embedding_out_channels=(2, 2), + ) + torch.manual_seed(0) + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[4, 8], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + norm_num_groups=2, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=4, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=8, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "feature_extractor": None, + } + return components + + # copied from test_controlnet_sdxl.py + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + image = randn_tensor( + (1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "np", + "image": image, + } + + return inputs + + # copied from test_controlnet_sdxl.py + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + # copied from test_controlnet_sdxl.py + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + # copied from test_controlnet_sdxl.py + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + # copied from test_controlnet_sdxl.py + @require_torch_gpu + def test_stable_diffusion_xl_offloads(self): + pipes = [] + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components).to(torch_device) + pipes.append(sd_pipe) + + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components) + sd_pipe.enable_model_cpu_offload() + pipes.append(sd_pipe) + + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components) + sd_pipe.enable_sequential_cpu_offload() + pipes.append(sd_pipe) + + image_slices = [] + for pipe in pipes: + pipe.unet.set_default_attn_processor() + + inputs = self.get_dummy_inputs(torch_device) + image = pipe(**inputs).images + + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3 + + # copied from test_controlnet_sdxl.py + def test_stable_diffusion_xl_multi_prompts(self): + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components).to(torch_device) + + # forward with single prompt + inputs = self.get_dummy_inputs(torch_device) + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + # forward with same prompt duplicated + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = inputs["prompt"] + output = sd_pipe(**inputs) + image_slice_2 = output.images[0, -3:, -3:, -1] + + # ensure the results are equal + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + + # forward with different prompt + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "different prompt" + output = sd_pipe(**inputs) + image_slice_3 = output.images[0, -3:, -3:, -1] + + # ensure the results are not equal + assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + + # manually set a negative_prompt + inputs = self.get_dummy_inputs(torch_device) + inputs["negative_prompt"] = "negative prompt" + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + # forward with same negative_prompt duplicated + inputs = self.get_dummy_inputs(torch_device) + inputs["negative_prompt"] = "negative prompt" + inputs["negative_prompt_2"] = inputs["negative_prompt"] + output = sd_pipe(**inputs) + image_slice_2 = output.images[0, -3:, -3:, -1] + + # ensure the results are equal + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + + # forward with different negative_prompt + inputs = self.get_dummy_inputs(torch_device) + inputs["negative_prompt"] = "negative prompt" + inputs["negative_prompt_2"] = "different negative prompt" + output = sd_pipe(**inputs) + image_slice_3 = output.images[0, -3:, -3:, -1] + + # ensure the results are not equal + assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + + # copied from test_stable_diffusion_xl.py + def test_stable_diffusion_xl_prompt_embeds(self): + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # forward without prompt embeds + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = 2 * [inputs["prompt"]] + inputs["num_images_per_prompt"] = 2 + + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + # forward with prompt embeds + inputs = self.get_dummy_inputs(torch_device) + prompt = 2 * [inputs.pop("prompt")] + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = sd_pipe.encode_prompt(prompt) + + output = sd_pipe( + **inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + image_slice_2 = output.images[0, -3:, -3:, -1] + + # make sure that it's equal + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1.1e-4 + + # copied from test_stable_diffusion_xl.py + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + + # copied from test_controlnetxs.py + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + + pipe.to(dtype=torch.float16) + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) + + def test_multi_vae(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + block_out_channels = pipe.vae.config.block_out_channels + norm_num_groups = pipe.vae.config.norm_num_groups + + vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny] + configs = [ + get_autoencoder_kl_config(block_out_channels, norm_num_groups), + get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups), + get_consistency_vae_config(block_out_channels, norm_num_groups), + get_autoencoder_tiny_config(block_out_channels), + ] + + out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0] + + for vae_cls, config in zip(vae_classes, configs): + vae = vae_cls(**config) + vae = vae.to(torch_device) + components["vae"] = vae + vae_pipe = self.pipeline_class(**components) + + # pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device. + # So we need to move the new pipe to device. + vae_pipe.to(torch_device) + vae_pipe.set_progress_bar_config(disable=None) + + out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0] + + assert out_vae_np.shape == out_np.shape + + +@slow +@require_torch_gpu +class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_canny(self): + controlnet = ControlNetXSAdapter.from_pretrained( + "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 + ) + pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ) + pipe.enable_sequential_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "bird" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images + + assert images[0].shape == (768, 512, 3) + + original_image = images[0, -3:, -3:, -1].flatten() + expected_image = np.array([0.3202, 0.3151, 0.3328, 0.3172, 0.337, 0.3381, 0.3378, 0.3389, 0.3224]) + assert np.allclose(original_image, expected_image, atol=1e-04) + + def test_depth(self): + controlnet = ControlNetXSAdapter.from_pretrained( + "UmerHA/Testing-ConrolNetXS-SDXL-depth", torch_dtype=torch.float16 + ) + pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ) + pipe.enable_sequential_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "Stormtrooper's lecture" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png" + ) + + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images + + assert images[0].shape == (512, 512, 3) + + original_image = images[0, -3:, -3:, -1].flatten() + expected_image = np.array([0.5448, 0.5437, 0.5426, 0.5543, 0.553, 0.5475, 0.5595, 0.5602, 0.5529]) + assert np.allclose(original_image, expected_image, atol=1e-04) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index acff5f2cdf8f..2ff53374be56 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -32,6 +32,7 @@ from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import IPAdapterMixin from diffusers.models.attention_processor import AttnProcessor +from diffusers.models.controlnet_xs import UNetControlNetXSModel from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet from diffusers.models.unets.unet_motion_model import UNetMotionModel @@ -1685,7 +1686,10 @@ def test_StableDiffusionMixin_component(self): self.assertTrue(hasattr(pipe, "vae") and isinstance(pipe.vae, (AutoencoderKL, AutoencoderTiny))) self.assertTrue( hasattr(pipe, "unet") - and isinstance(pipe.unet, (UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel)) + and isinstance( + pipe.unet, + (UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel, UNetControlNetXSModel), + ) )