Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ControlNet-XS support #5827

Merged
merged 100 commits into from
Dec 6, 2023
Merged

Conversation

UmerHA
Copy link
Contributor

@UmerHA UmerHA commented Nov 16, 2023

What does this PR do?

Adds ControlNet-XS support (and therefore fixes #5168).
Project page: https://vislearn.github.io/ControlNet-XS/

CleanShot 2023-11-16 at 17 37 05@2x

See here for a full working example

This PR is work in progress. Still to do:

  • Add other version of ControlNet-XS: SD canny ✅, SD depth ✅, SDXL depth
  • Add documentation
  • A few other (iiuc) minor things

Still, I would love your feedback!

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of who to tag.
Please tag fewer than 3 people.

Core library:

@sayakpaul
Copy link
Member

Cc: @DN6

@patrickvonplaten patrickvonplaten merged commit e192ae0 into huggingface:main Dec 6, 2023
14 checks passed
@patrickvonplaten
Copy link
Contributor

Amazing job @UmerHA

@UmerHA UmerHA deleted the controlnet-xs branch December 6, 2023 22:50
@universewill
Copy link

universewill commented Dec 11, 2023

@UmerHA Great work! Can you provide controlnet-xs traning example?

@UmerHA
Copy link
Contributor Author

UmerHA commented Dec 11, 2023

@universewill Sure - see https://github.com/UmerHA/diffusers/tree/cnxs-training/examples/controlnet_xs. I've tested that they they run, but haven't fully tested full training runs. When I have more time, I'll do that and open a PR.

In the meantime, let me know if you encounter any issues!

donhardman pushed a commit to donhardman/diffusers that referenced this pull request Dec 18, 2023
* Check in 23-10-05

* check-in 23-10-06

* check-in 23-10-07 2pm

* check-in 23-10-08

* check-in 231009T1200

* check-in 230109

* checkin 231010

* init + forward run

* checkin

* checkin

* ControlNetXSModel is now saveable+loadable

* Forward works

* checkin

* Pipeline works with `no_control=True`

* checkin

* debug: save intermediate outputs of resnet

* checkin

* Understood time error + fixed connection error

* checkin

* checkin 231106T1600

* turned off detailled debug prints

* time debug logs

* small fix

* Separated control_scale for connections/time

* simplified debug logging

* Full denoising works with control scale = 0

* aligned logs

* Added control_attention_head_dim param

* Passing n_heads instead of dim_head into ctrl unet

* Fixed ctrl midblock bug

* Cleanup

* Fixed time dtype bug

* checkin

* 1. from_unet, 2. base passed, 3. all unet params

* checkin

* Finished docstrings

* cleanup

* make style

* checkin

* more tests pass

* Fixed tests

* removed debug logs

* make style + quality

* make fix-copies

* fixed documentation

* added cnxs to doc toc

* added control start/end param

* Update controlnetxs_sdxl.md

* tried to fix copies..

* Fixed norm_num_groups in from_unet

* added sdxl-depth test

* created SD2.1 controlnet-xs pipeline

* re-added debug logs

* Adjusting group norm ; readded logs

* Added debug log statements

* removed debug logs ; started tests for sd2.1

* updated sd21 tests

* fixed tests

* fixed tests

* slightly increased error tolerance for 1 test

* make style & quality

* Added docs for CNXS-SD

* make fix-copies

* Fixed sd compile test ; fixed gradient ckpointing

* vae downs = cnxs conditioning downs; removed guess

* make style & quality

* Fixed tests

* fixed test

* Incorporated review feedback

* simplified control model surgery

* fixed tests & make style / quality

* Updated docs; deleted pip & cursor files

* Rolled back minimal change to resnet

* Update resnet.py

* Update resnet.py

* Update src/diffusers/models/controlnetxs.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/controlnetxs.py

Co-authored-by: Patrick von Platen <[email protected]>

* Incorporated review feedback

* Update docs/source/en/api/pipelines/controlnetxs_sdxl.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/api/pipelines/controlnetxs.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/api/pipelines/controlnetxs.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/api/pipelines/controlnetxs.md

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/models/controlnetxs.py

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/models/controlnetxs.py

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/api/pipelines/controlnetxs.md

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py

Co-authored-by: Steven Liu <[email protected]>

* Incorporated doc feedback

---------

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Dhruv Nair <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Check in 23-10-05

* check-in 23-10-06

* check-in 23-10-07 2pm

* check-in 23-10-08

* check-in 231009T1200

* check-in 230109

* checkin 231010

* init + forward run

* checkin

* checkin

* ControlNetXSModel is now saveable+loadable

* Forward works

* checkin

* Pipeline works with `no_control=True`

* checkin

* debug: save intermediate outputs of resnet

* checkin

* Understood time error + fixed connection error

* checkin

* checkin 231106T1600

* turned off detailled debug prints

* time debug logs

* small fix

* Separated control_scale for connections/time

* simplified debug logging

* Full denoising works with control scale = 0

* aligned logs

* Added control_attention_head_dim param

* Passing n_heads instead of dim_head into ctrl unet

* Fixed ctrl midblock bug

* Cleanup

* Fixed time dtype bug

* checkin

* 1. from_unet, 2. base passed, 3. all unet params

* checkin

* Finished docstrings

* cleanup

* make style

* checkin

* more tests pass

* Fixed tests

* removed debug logs

* make style + quality

* make fix-copies

* fixed documentation

* added cnxs to doc toc

* added control start/end param

* Update controlnetxs_sdxl.md

* tried to fix copies..

* Fixed norm_num_groups in from_unet

* added sdxl-depth test

* created SD2.1 controlnet-xs pipeline

* re-added debug logs

* Adjusting group norm ; readded logs

* Added debug log statements

* removed debug logs ; started tests for sd2.1

* updated sd21 tests

* fixed tests

* fixed tests

* slightly increased error tolerance for 1 test

* make style & quality

* Added docs for CNXS-SD

* make fix-copies

* Fixed sd compile test ; fixed gradient ckpointing

* vae downs = cnxs conditioning downs; removed guess

* make style & quality

* Fixed tests

* fixed test

* Incorporated review feedback

* simplified control model surgery

* fixed tests & make style / quality

* Updated docs; deleted pip & cursor files

* Rolled back minimal change to resnet

* Update resnet.py

* Update resnet.py

* Update src/diffusers/models/controlnetxs.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/controlnetxs.py

Co-authored-by: Patrick von Platen <[email protected]>

* Incorporated review feedback

* Update docs/source/en/api/pipelines/controlnetxs_sdxl.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/api/pipelines/controlnetxs.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/api/pipelines/controlnetxs.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/api/pipelines/controlnetxs.md

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/models/controlnetxs.py

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/models/controlnetxs.py

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/api/pipelines/controlnetxs.md

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py

Co-authored-by: Steven Liu <[email protected]>

* Incorporated doc feedback

---------

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Dhruv Nair <[email protected]>
@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Dec 27, 2023

Sorry, we had to move the implementation to the research folder for now as the design was not in line with the usual diffusers design (e.g. the unet is forwarded into the controlnet-xs function etc...). We should have caught that when reviewing the PR, but sadly failed to do so. We still very much want to add ControlNet-XS to diffusers, but we'll need to apply the changes as suggested by the following new PR review.

Very sorry @UmerHA that we missed these things in the initial review 🙏

super().__init__()

# 1 - Create control unet
self.control_model = UNet2DConditionModel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The control_model should not be a UNet2DConditionModel if we have to apply a lot of surgery afterwards. Let's make sure we directly instantiate the correct torch.nn.Modules right away

Comment on lines +294 to +319
# 2 - Do model surgery on control model
# 2.1 - Allow to use the same time information as the base model
adjust_time_dims(self.control_model, time_embedding_input_dim, time_embedding_dim)

# 2.2 - Allow for information infusion from base model

# We concat the output of each base encoder subblocks to the input of the next control encoder subblock
# (We ignore the 1st element, as it represents the `conv_in`.)
extra_input_channels = [input_channels for input_channels, _ in base_model_channel_sizes["down"][1:]]
it_extra_input_channels = iter(extra_input_channels)

for b, block in enumerate(self.control_model.down_blocks):
for r in range(len(block.resnets)):
increase_block_input_in_encoder_resnet(
self.control_model, block_no=b, resnet_idx=r, by=next(it_extra_input_channels)
)

if block.downsamplers:
increase_block_input_in_encoder_downsampler(
self.control_model, block_no=b, by=next(it_extra_input_channels)
)

increase_block_input_in_mid_resnet(self.control_model, by=extra_input_channels[-1])

# 2.3 - Make group norms work with modified channel sizes
adjust_group_norms(self.control_model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't do surgery here, let's make sure to instead instantiate the correct classes right away

Comment on lines +359 to +363
# In the mininal implementation setting, we only need the control model up to the mid block
del self.control_model.up_blocks
del self.control_model.conv_norm_out
del self.control_model.conv_out

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They should instead never have been instantiated

Comment on lines +522 to +537
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.

When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.

Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
self.control_model.set_attention_slice(slice_size)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not provide a set_attention_slice() operation anymore since with FlashAttention it's pretty useless to slice the attention


def forward(
self,
base_model: UNet2DConditionModel,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't pass the base_model here into the forward method

encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the class_labels parameter?

Comment on lines +654 to +663
if base_model.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")

if base_model.config.class_embed_type == "timestep":
class_labels = base_model.time_proj(class_labels)

class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype)
temb = temb + class_emb

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this? I think the class labels are only required for Stable Diffusion Upsampling

Comment on lines +664 to +689
if base_model.config.addition_embed_type is not None:
if base_model.config.addition_embed_type == "text":
aug_emb = base_model.add_embedding(encoder_hidden_states)
elif base_model.config.addition_embed_type == "text_image":
raise NotImplementedError()
elif base_model.config.addition_embed_type == "text_time":
# SDXL - style
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = base_model.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(temb.dtype)
aug_emb = base_model.add_embedding(add_embeds)
elif base_model.config.addition_embed_type == "image":
raise NotImplementedError()
elif base_model.config.addition_embed_type == "image_hint":
raise NotImplementedError()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not call the base model here - instead self.add_embedding(...) should be called

return False


def to_sub_blocks(blocks):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not have to call such a method in the forward pass of the controlnet

for d in b.downsamplers:
sub_blocks.append([d])

return list(map(SubBlock, sub_blocks))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to avoid map here as this breaks torch.compile

a.norm.num_groups = find_denominator(a.norm.num_channels, start=max_num_group)


def is_iterable(o):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use such a try-except function, this breaks torch.compile

unet.time_embedding.linear_1 = nn.Linear(in_dim, out_dim)


def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no, resnet_idx, by):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not have to use such surgery methods

@UmerHA
Copy link
Contributor Author

UmerHA commented Jan 4, 2024

Hi @patrickvonplaten, no hurt feelings & fully understand. I'll start to change it to better fit diffusers, and open a new PR.

Happy New Year btw :)

@UmerHA
Copy link
Contributor Author

UmerHA commented Jan 5, 2024

Hi @patrickvonplaten, could you answer two questions before I start the new implementation:


  1. Does the following design work for you?
# UNet as it currently is. Will not be changed. Shown for completeness
class UNet2DConditionModel(...):
    # ...
# Class containing the controlnet-relevant weights
# Having a separate class ensures users can save/load the controlnet-relevant weights separately,
# and don't always have to save/load the large base model
class ControlNetXSAddonModel(...):
    def __init__(self):
        # Initiate what currently is called the control model.
        # Instead of starting with a UNet2DConditionModel and doing model surgery, build it correctly from scratch
        
    def forward(self, x):
        raise ValueError("A ControlNetXSAddonModel cannot be run by itself. Pass it into a ControlNetXSModel model instead.")
# Full ControlNetXSModel that will be compilable
class ControlNetXSModel(...):
    def __init__(self, base_model: UNet2DConditionModel, ctrl_model: ControlNetXSAddonModel):
        self.ctrl_model = ctrl_model
        self.base_model = base_model

    def forward(self, x):
        # other inputs (time, conditioning, ...) omitted for brevity 

        # decompose both models into the parts that need to be run interwovenly.
        # for brevity, let's use blocks. in actual implementation we need to use "subblocks".
        base_down_blocks = self.base_model.down_blocks
        ctrl_down_blocks = self.ctrl_model.down_blocks

        x_base = x_ctrl = x

        for b, c in zip(base_down_blocks, ctrl_down_blocks):
            x_ctrl = torch.cat([x_base, x_ctrl])  # concat information from base to ctrl
            x_ctrl = c(x_ctrl)                    # execute ctrl block
            x_base = b(x_base)                    # execute base block
            x_base += x_ctrl                      # add information from ctrl to base

        return x_base
class StableDiffusionControlNetXSPipeline(...):
    def __init__(
        ...
        unet: UNet2DConditionModel,
        controlnet_addon: ControlNetXSAddonModel,
        ...
    ):
        self.controlnet = ControlNetXSModel(
            base_model=unet,
            ctrl_model=controlnet_addon
        )
        ...
    
    def __call__(self, prompt)
        # other inputs omitted for brevity 
        ...
        latent = ...
        timesteps = ..
        ...
        for t in timesteps:
            noise_prediction = self.controlnet(sample=latent, timestep=t, ...)
        ...
        return ..
controlnet_addon = ControlNetXSAddonModel.from_pretrained(...)
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(..., controlnet_addon=controlnet_addon)
pipe("A dog")
  1. Do you want ControlNet-XS to be part of diffusers "core" or community examples? I'm asking as they have different requirements re documentation, testing, ...

Thanks!

@patrickvonplaten
Copy link
Contributor

Hey @UmerHA,

Thanks for the write-up! Yes this design makes a lot of sense to me :-)

  1. Do you want ControlNet-XS to be part of diffusers "core" or community examples? I'm asking as they have different requirements re documentation, testing, ...

I think ControlNet-XS should go into "core"

AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Check in 23-10-05

* check-in 23-10-06

* check-in 23-10-07 2pm

* check-in 23-10-08

* check-in 231009T1200

* check-in 230109

* checkin 231010

* init + forward run

* checkin

* checkin

* ControlNetXSModel is now saveable+loadable

* Forward works

* checkin

* Pipeline works with `no_control=True`

* checkin

* debug: save intermediate outputs of resnet

* checkin

* Understood time error + fixed connection error

* checkin

* checkin 231106T1600

* turned off detailled debug prints

* time debug logs

* small fix

* Separated control_scale for connections/time

* simplified debug logging

* Full denoising works with control scale = 0

* aligned logs

* Added control_attention_head_dim param

* Passing n_heads instead of dim_head into ctrl unet

* Fixed ctrl midblock bug

* Cleanup

* Fixed time dtype bug

* checkin

* 1. from_unet, 2. base passed, 3. all unet params

* checkin

* Finished docstrings

* cleanup

* make style

* checkin

* more tests pass

* Fixed tests

* removed debug logs

* make style + quality

* make fix-copies

* fixed documentation

* added cnxs to doc toc

* added control start/end param

* Update controlnetxs_sdxl.md

* tried to fix copies..

* Fixed norm_num_groups in from_unet

* added sdxl-depth test

* created SD2.1 controlnet-xs pipeline

* re-added debug logs

* Adjusting group norm ; readded logs

* Added debug log statements

* removed debug logs ; started tests for sd2.1

* updated sd21 tests

* fixed tests

* fixed tests

* slightly increased error tolerance for 1 test

* make style & quality

* Added docs for CNXS-SD

* make fix-copies

* Fixed sd compile test ; fixed gradient ckpointing

* vae downs = cnxs conditioning downs; removed guess

* make style & quality

* Fixed tests

* fixed test

* Incorporated review feedback

* simplified control model surgery

* fixed tests & make style / quality

* Updated docs; deleted pip & cursor files

* Rolled back minimal change to resnet

* Update resnet.py

* Update resnet.py

* Update src/diffusers/models/controlnetxs.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/models/controlnetxs.py

Co-authored-by: Patrick von Platen <[email protected]>

* Incorporated review feedback

* Update docs/source/en/api/pipelines/controlnetxs_sdxl.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/api/pipelines/controlnetxs.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/api/pipelines/controlnetxs.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/api/pipelines/controlnetxs.md

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/models/controlnetxs.py

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/models/controlnetxs.py

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/api/pipelines/controlnetxs.md

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py

Co-authored-by: Steven Liu <[email protected]>

* Incorporated doc feedback

---------

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Dhruv Nair <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ControlNet-XS support
8 participants