Skip to content

Commit

Permalink
Cleaned up the diffusers patch and added to news. Closes #1.
Browse files Browse the repository at this point in the history
  • Loading branch information
dbolya committed Apr 1, 2023
1 parent d31063f commit 3680a98
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 31 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
- Initial release.

### 2023.03.31
- Added support for more resolutions than multiples of 16. (Fixes #8)
- Added support for more resolutions than multiples of 16. (Fixes #8)
- Added support for diffusers (thanks @JunnYu and @ExponentialML)! (Fixes #1)
37 changes: 18 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Token Merging (**ToMe**) speeds up transformers by _merging redundant tokens_, w
Even with more than half of the tokens merged (60%!), ToMe for SD still produces images close to the originals, while being _**2x** faster_ and using _**~5.7x** less memory_. Moreover, ToMe is not another efficient reimplementation of transformer modules. Instead, it actually _reduces_ the total work necessary to generate an image, so it can function _in conjunction_ with efficient implementations (see [Usage](#tome--xformers--flash-attn--torch-20)).

## News

- **[2023.03.31]** ToMe for SD now supports [Diffusers](https://github.com/huggingface/diffusers). Thanks @JunnYu and @ExponentialML!
- **[2023.03.30]** Initial release.

See the [changelog](CHANGELOG.md) for more details.
Expand All @@ -54,6 +54,7 @@ This repo includes code to patch an existing Stable Diffusion environment. Curre
- [x] [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion)
- [x] [Stable Diffusion v1](https://github.com/runwayml/stable-diffusion)
- [x] [Latent Diffusion](https://github.com/CompVis/latent-diffusion)
- [x] [Diffusers](https://github.com/huggingface/diffusers)
- [ ] And potentially others

**Note:** This also supports most downstream UIs that use these repositories.
Expand Down Expand Up @@ -93,24 +94,6 @@ tomesd.apply_patch(model, ratio=0.9, sx=4, sy=4, max_downsample=2) # Extreme mer
See above for what speeds and memory savings you can expect with different ratios.
If you want to remove the patch later, simply use `tomesd.remove_patch(model)`.

Apply ToMe for SD to 🤗 Diffusers Stable Diffusion model with
```py
import torch
import tomesd
# pip install diffusers==0.14.0
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("CompVis-stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe.to("cuda")
# Patch a Diffusers Pipeline with ToMe for SD using a 50% merging ratio.
tomesd.apply_patch(pipe, ratio=0.5)
# or Patch a Diffusers Unet with ToMe for SD using a 50% merging ratio.
# tomesd.apply_patch(pipe.unet, ratio=0.5)
image = pipe("a photo of an astronaut riding a horse on mars", guidance_scale=7.5, height=512, width=512, num_inference_steps=50).images[0]
image.save("astronaut.png")
```
See above for what speeds and memory savings you can expect with different ratios.
If you want to remove the patch later, simply use `tomesd.remove_patch(pipe)`.

### Example
To apply ToMe to the txt2img script of SDv2 or SDv1 for instance, add the following to [this line](https://github.com/Stability-AI/stablediffusion/blob/fc1488421a2761937b9d54784194157882cbc3b1/scripts/txt2img.py#L220) (SDv2) or [this line](https://github.com/runwayml/stable-diffusion/blob/08ab4d326c96854026c4eb3454cd3b02109ee982/scripts/txt2img.py#L241) (SDv1):
```py
Expand All @@ -120,6 +103,22 @@ tomesd.apply_patch(model, ratio=0.5)
That's it! More examples and demos coming soon (_hopefully_).
**Note:** You may not see the full speed-up for the first image generated (as pytorch sets up the graph). Since ToMe for SD uses random processes, you might need to set the seed every batch if you want consistent results.

### Diffusers
ToMe can also be used to patch a 🤗 Diffusers Stable Diffusion pipeline:
```py
import torch, tomesd
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")

# Apply ToMe with a 50% merging ratio
tomesd.apply_patch(pipe, ratio=0.5) # Can also use pipe.unet in place of pipe here

image = pipe("a photo of an astronaut riding a horse on mars").images[0]
image.save("astronaut.png")
```
You can remove the patch with `tomesd.remove_patch(pipe)`.

### ToMe + xformers / flash attn / torch 2.0
Since ToMe only affects the forward function of the block, it should support most efficient transformer implementations out of the box. Just apply the patch as normal!

Expand Down
26 changes: 15 additions & 11 deletions tomesd/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,14 @@ def _forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tenso

return ToMeBlock






def make_diffusers_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
"""
Make a patched class on the fly so we don't have to import any specific modules.
Make a patched class for a diffusers model.
This patch applies ToMe to the forward function of the block.
"""
class ToMeBlock(block_class):
Expand Down Expand Up @@ -110,8 +115,6 @@ def forward(
)
# (4) ToMe m_c
norm_hidden_states = m_c(norm_hidden_states)
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
# prepare attention mask here

# 2. Cross-Attention
attn_output = self.attn2(
Expand Down Expand Up @@ -146,6 +149,9 @@ def forward(






def make_tome_model(model_class: Type[torch.nn.Module]):
"""
Make a patched class on the fly so we don't have to import any specific modules.
Expand Down Expand Up @@ -217,7 +223,7 @@ def apply_patch(
raise RuntimeError("Provided model was not a Stable Diffusion / Latent Diffusion model, as expected.")
diffusion_model = model.model.diffusion_model
else:
# support "pipe.unet" and "unet"
# Supports "pipe.unet" and "unet"
diffusion_model = model.unet if hasattr(model, "unet") else model

diffusion_model._tome_info = {
Expand All @@ -241,11 +247,9 @@ def apply_patch(
module.__class__ = make_tome_block_fn(module.__class__)
module._tome_info = diffusion_model._tome_info

# diffusers not need this
if not is_diffusers:
# Something introduced in SD 2.0
if not hasattr(module, "disable_self_attn"):
module.disable_self_attn = False
# Something introduced in SD 2.0 (LDM only)
if not hasattr(module, "disable_self_attn") and not is_diffusers:
module.disable_self_attn = False

return model

Expand All @@ -254,8 +258,8 @@ def apply_patch(


def remove_patch(model: torch.nn.Module):
""" Removes a patch from a ToMe Diffusion module if it was already patched. """\

""" Removes a patch from a ToMe Diffusion module if it was already patched. """
# For diffusers
model = model.unet if hasattr(model, "unet") else model

for _, module in model.named_modules():
Expand Down

0 comments on commit 3680a98

Please sign in to comment.