Skip to content

Commit

Permalink
Added support for diffusers
Browse files Browse the repository at this point in the history
  • Loading branch information
JunnYu authored Apr 1, 2023
1 parent 8bdcdfa commit d31063f
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 9 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,24 @@ tomesd.apply_patch(model, ratio=0.9, sx=4, sy=4, max_downsample=2) # Extreme mer
See above for what speeds and memory savings you can expect with different ratios.
If you want to remove the patch later, simply use `tomesd.remove_patch(model)`.

Apply ToMe for SD to 🤗 Diffusers Stable Diffusion model with
```py
import torch
import tomesd
# pip install diffusers==0.14.0
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("CompVis-stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe.to("cuda")
# Patch a Diffusers Pipeline with ToMe for SD using a 50% merging ratio.
tomesd.apply_patch(pipe, ratio=0.5)
# or Patch a Diffusers Unet with ToMe for SD using a 50% merging ratio.
# tomesd.apply_patch(pipe.unet, ratio=0.5)
image = pipe("a photo of an astronaut riding a horse on mars", guidance_scale=7.5, height=512, width=512, num_inference_steps=50).images[0]
image.save("astronaut.png")
```
See above for what speeds and memory savings you can expect with different ratios.
If you want to remove the patch later, simply use `tomesd.remove_patch(pipe)`.

### Example
To apply ToMe to the txt2img script of SDv2 or SDv1 for instance, add the following to [this line](https://github.com/Stability-AI/stablediffusion/blob/fc1488421a2761937b9d54784194157882cbc3b1/scripts/txt2img.py#L220) (SDv2) or [this line](https://github.com/runwayml/stable-diffusion/blob/08ab4d326c96854026c4eb3454cd3b02109ee982/scripts/txt2img.py#L241) (SDv1):
```py
Expand Down
115 changes: 106 additions & 9 deletions tomesd/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,93 @@ def _forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tenso

return ToMeBlock

def make_diffusers_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
"""
Make a patched class on the fly so we don't have to import any specific modules.
This patch applies ToMe to the forward function of the block.
"""
class ToMeBlock(block_class):
# Save for unpatching later
_parent = block_class

def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
cross_attention_kwargs=None,
class_labels=None,
) -> torch.Tensor:
# (1) ToMe
m_a, m_c, m_m, u_a, u_c, u_m = compute_merge(hidden_states, self._tome_info)

if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)

# (2) ToMe m_a
norm_hidden_states = m_a(norm_hidden_states)

# 1. Self-Attention
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output

# (3) ToMe u_a
hidden_states = u_a(attn_output) + hidden_states

if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
# (4) ToMe m_c
norm_hidden_states = m_c(norm_hidden_states)
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
# prepare attention mask here

# 2. Cross-Attention
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
# (5) ToMe u_c
hidden_states = u_c(attn_output) + hidden_states

# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)

if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

# (6) ToMe m_m
norm_hidden_states = m_m(norm_hidden_states)

ff_output = self.ff(norm_hidden_states)

if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output

# (7) ToMe u_m
hidden_states = u_m(ff_output) + hidden_states

return hidden_states

return ToMeBlock



Expand Down Expand Up @@ -123,11 +209,17 @@ def apply_patch(
# Make sure the module is not currently patched
remove_patch(model)

if not hasattr(model, "model") or not hasattr(model.model, "diffusion_model"):
# Provided model not supported
raise RuntimeError("Provided model was not a Stable Diffusion / Latent Diffusion model, as expected.")
is_diffusers = isinstance_str(model, "DiffusionPipeline") or isinstance_str(model, "ModelMixin")

if not is_diffusers:
if not hasattr(model, "model") or not hasattr(model.model, "diffusion_model"):
# Provided model not supported
raise RuntimeError("Provided model was not a Stable Diffusion / Latent Diffusion model, as expected.")
diffusion_model = model.model.diffusion_model
else:
# support "pipe.unet" and "unet"
diffusion_model = model.unet if hasattr(model, "unet") else model

diffusion_model = model.model.diffusion_model
diffusion_model._tome_info = {
"size": None,
"args": {
Expand All @@ -145,12 +237,15 @@ def apply_patch(
for _, module in diffusion_model.named_modules():
# If for some reason this has a different name, create an issue and I'll fix it
if isinstance_str(module, "BasicTransformerBlock"):
module.__class__ = make_tome_block(module.__class__)
make_tome_block_fn = make_diffusers_tome_block if is_diffusers else make_tome_block
module.__class__ = make_tome_block_fn(module.__class__)
module._tome_info = diffusion_model._tome_info

# Something introduced in SD 2.0
if not hasattr(module, "disable_self_attn"):
module.disable_self_attn = False
# diffusers not need this
if not is_diffusers:
# Something introduced in SD 2.0
if not hasattr(module, "disable_self_attn"):
module.disable_self_attn = False

return model

Expand All @@ -160,7 +255,9 @@ def apply_patch(

def remove_patch(model: torch.nn.Module):
""" Removes a patch from a ToMe Diffusion module if it was already patched. """\


model = model.unet if hasattr(model, "unet") else model

for _, module in model.named_modules():
if module.__class__.__name__ == "ToMeBlock":
module.__class__ = module._parent
Expand Down

0 comments on commit d31063f

Please sign in to comment.