Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support xformers #105

Merged
merged 1 commit into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@ RUN sh -c "$(wget -O- https://github.com/deluan/zsh-in-docker/releases/download/
# Install python package.
WORKDIR /diffengine
COPY ./ /diffengine
RUN pip install --upgrade pip && \
pip install --no-cache-dir openmim==0.3.9 && \
pip install . && \
pip install pre-commit
RUN pip install --upgrade pip

# Install xformers
# RUN export TORCH_CUDA_ARCH_LIST="9.0+PTX" MAX_JOBS=1 && \
# pip install -v -U git+https://github.com/facebookresearch/[email protected]#egg=xformers
RUN pip install ninja
RUN export TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6 9.0+PTX" MAX_JOBS=1 && \
pip install -v -U git+https://github.com/facebookresearch/[email protected]#egg=xformers

# Install DiffEngine
RUN pip install --no-cache-dir openmim==0.3.9 && \
pip install . && \
pip install pre-commit

# Language settings
ENV LANG C.UTF-8
Expand Down
9 changes: 5 additions & 4 deletions configs/stable_diffusion_xl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ Settings:

- 1epoch training.

| Model | total time |
| :-------------------------------------: | :--------: |
| stable_diffusion_xl_pokemon_blip (fp16) | 12 m 37 s |
| stable_diffusion_xl_pokemon_blip_fast | 9 m 47 s |
| Model | total time |
| :---------------------------------------: | :--------: |
| stable_diffusion_xl_pokemon_blip (fp16) | 12 m 37 s |
| stable_diffusion_xl_pokemon_blip_xformers | 10 m 6 s |
| stable_diffusion_xl_pokemon_blip_fast | 9 m 47 s |

Note that `stable_diffusion_xl_pokemon_blip_fast` took a few minutes to compile. We will disregard it.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
_base_ = [
"../_base_/models/stable_diffusion_xl.py",
"../_base_/datasets/pokemon_blip_xl.py",
"../_base_/schedules/stable_diffusion_xl_50e.py",
"../_base_/default_runtime.py",
]

model = dict(
enable_xformers=True,
gradient_checkpointing=False)

train_dataloader = dict(batch_size=1)

optim_wrapper = dict(
dtype="float16",
accumulative_counts=4)

env_cfg = dict(
cudnn_benchmark=True,
)

custom_hooks = [
dict(
type="VisualizationHook",
prompt=["yoda pokemon"] * 4,
height=1024,
width=1024),
dict(type="SDCheckpointHook"),
dict(type="FastNormHook", fuse_unet_ln=True, fuse_gn=True),
dict(type="CompileHook", compile_unet=False),
]
17 changes: 17 additions & 0 deletions diffengine/models/editors/deepfloyd_if/deepfloyd_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class DeepFloydIF(BaseModel):
gradient_checkpointing (bool): Whether or not to use gradient
checkpointing to save memory at the expense of slower backward
pass. Defaults to False.
enable_xformers (bool): Whether or not to enable memory efficient
attention. Defaults to False.
"""

def __init__(
Expand All @@ -78,6 +80,7 @@ def __init__(
*,
finetune_text_encoder: bool = False,
gradient_checkpointing: bool = False,
enable_xformers: bool = False,
) -> None:
if data_preprocessor is None:
data_preprocessor = {"type": "SDDataPreprocessor"}
Expand Down Expand Up @@ -117,6 +120,7 @@ def __init__(
self.gradient_checkpointing = gradient_checkpointing
self.tokenizer_max_length = tokenizer_max_length
self.input_perturbation_gamma = input_perturbation_gamma
self.enable_xformers = enable_xformers

if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
Expand All @@ -138,6 +142,7 @@ def __init__(
self.timesteps_generator = MODELS.build(timesteps_generator)
self.prepare_model()
self.set_lora()
self.set_xformers()

def set_lora(self) -> None:
"""Set LORA for model."""
Expand Down Expand Up @@ -166,6 +171,18 @@ def prepare_model(self) -> None:
self.text_encoder.requires_grad_(requires_grad=False)
print_log("Set Text Encoder untrainable.", "current")

def set_xformers(self) -> None:
"""Set xformers for model."""
if self.enable_xformers:
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available():
self.unet.enable_xformers_memory_efficient_attention()
else:
msg = "Please install xformers to enable memory efficient attention."
raise ImportError(
msg,
)

@property
def device(self) -> torch.device:
"""Get device information.
Expand Down
13 changes: 13 additions & 0 deletions diffengine/models/editors/distill_sd/distill_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,19 @@ def hook(model, input, output): # noqa
get_activation(
self.student_feats, "u" + str(i), residuals_present=False))

def set_xformers(self) -> None:
"""Set xformers for model."""
if self.enable_xformers:
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available():
self.unet.enable_xformers_memory_efficient_attention()
self.orig_unet.enable_xformers_memory_efficient_attention()
else:
msg = "Please install xformers to enable memory efficient attention."
raise ImportError(
msg,
)

def forward(
self,
inputs: dict,
Expand Down
14 changes: 14 additions & 0 deletions diffengine/models/editors/esd/esd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,20 @@ def _freeze_unet(self) -> None:
name.startswith("out.")):
module.eval()

def set_xformers(self) -> None:
"""Set xformers for model."""
if self.enable_xformers:
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available():
self.unet.enable_xformers_memory_efficient_attention()
if self.unet_lora_config is None:
self.orig_unet.enable_xformers_memory_efficient_attention()
else:
msg = "Please install xformers to enable memory efficient attention."
raise ImportError(
msg,
)

def train(self, *, mode=True) -> None:
"""Convert the model into training mode."""
super().train(mode)
Expand Down
15 changes: 15 additions & 0 deletions diffengine/models/editors/lcm/lcm_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,21 @@ def prepare_model(self) -> None:

super().prepare_model()

def set_xformers(self) -> None:
"""Set xformers for model."""
if self.enable_xformers:
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available():
self.unet.enable_xformers_memory_efficient_attention()
self.teacher_unet.enable_xformers_memory_efficient_attention()
if self.unet_lora_config is None:
self.target_unet.enable_xformers_memory_efficient_attention()
else:
msg = "Please install xformers to enable memory efficient attention."
raise ImportError(
msg,
)

@torch.no_grad()
def infer(self,
prompt: list[str],
Expand Down
18 changes: 18 additions & 0 deletions diffengine/models/editors/ssd_1b/ssd_1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class SSD1B(StableDiffusionXL):
pass. Defaults to False.
pre_compute_text_embeddings(bool): Whether or not to pre-compute text
embeddings to save memory. Defaults to False.
enable_xformers (bool): Whether or not to enable memory efficient
attention. Defaults to False.
"""

def __init__(
Expand All @@ -93,6 +95,7 @@ def __init__(
finetune_text_encoder: bool = False,
gradient_checkpointing: bool = False,
pre_compute_text_embeddings: bool = False,
enable_xformers: bool = False,
) -> None:
assert unet_lora_config is None, \
"`unet_lora_config` should be None when training SSD1B"
Expand All @@ -118,6 +121,7 @@ def __init__(
self.gradient_checkpointing = gradient_checkpointing
self.pre_compute_text_embeddings = pre_compute_text_embeddings
self.input_perturbation_gamma = input_perturbation_gamma
self.enable_xformers = enable_xformers

if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
Expand Down Expand Up @@ -163,6 +167,7 @@ def __init__(
self.timesteps_generator = MODELS.build(timesteps_generator)
self.prepare_model()
self.set_lora()
self.set_xformers()

def set_lora(self) -> None:
"""Set LORA for model."""
Expand Down Expand Up @@ -242,6 +247,19 @@ def hook(model, input, output): # noqa
get_activation(self.student_feats,f"u{nb}a{i}",
residuals_present=True))

def set_xformers(self) -> None:
"""Set xformers for model."""
if self.enable_xformers:
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available():
self.unet.enable_xformers_memory_efficient_attention()
self.orig_unet.enable_xformers_memory_efficient_attention()
else:
msg = "Please install xformers to enable memory efficient attention."
raise ImportError(
msg,
)

def forward(
self,
inputs: dict,
Expand Down
17 changes: 17 additions & 0 deletions diffengine/models/editors/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class StableDiffusion(BaseModel):
gradient_checkpointing (bool): Whether or not to use gradient
checkpointing to save memory at the expense of slower backward
pass. Defaults to False.
enable_xformers (bool): Whether or not to enable memory efficient
attention. Defaults to False.
"""

def __init__(
Expand All @@ -75,6 +77,7 @@ def __init__(
*,
finetune_text_encoder: bool = False,
gradient_checkpointing: bool = False,
enable_xformers: bool = False,
) -> None:
if data_preprocessor is None:
data_preprocessor = {"type": "SDDataPreprocessor"}
Expand Down Expand Up @@ -112,6 +115,7 @@ def __init__(
self.prior_loss_weight = prior_loss_weight
self.gradient_checkpointing = gradient_checkpointing
self.input_perturbation_gamma = input_perturbation_gamma
self.enable_xformers = enable_xformers

if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
Expand All @@ -134,6 +138,7 @@ def __init__(
self.timesteps_generator = MODELS.build(timesteps_generator)
self.prepare_model()
self.set_lora()
self.set_xformers()

def set_lora(self) -> None:
"""Set LORA for model."""
Expand Down Expand Up @@ -164,6 +169,18 @@ def prepare_model(self) -> None:
self.text_encoder.requires_grad_(requires_grad=False)
print_log("Set Text Encoder untrainable.", "current")

def set_xformers(self) -> None:
"""Set xformers for model."""
if self.enable_xformers:
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available():
self.unet.enable_xformers_memory_efficient_attention()
else:
msg = "Please install xformers to enable memory efficient attention."
raise ImportError(
msg,
)

@property
def device(self) -> torch.device:
"""Get device information.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,19 @@ def prepare_model(self) -> None:
self.unet.requires_grad_(requires_grad=False)
print_log("Set Unet untrainable.", "current")

def set_xformers(self) -> None:
"""Set xformers for model."""
if self.enable_xformers:
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available():
self.unet.enable_xformers_memory_efficient_attention()
self.controlnet.enable_xformers_memory_efficient_attention()
else:
msg = "Please install xformers to enable memory efficient attention."
raise ImportError(
msg,
)

@torch.no_grad()
def infer(self,
prompt: list[str],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ class StableDiffusionXL(BaseModel):
gradient_checkpointing (bool): Whether or not to use gradient
checkpointing to save memory at the expense of slower backward
pass. Defaults to False.
pre_compute_text_embeddings(bool): Whether or not to pre-compute text
pre_compute_text_embeddings (bool): Whether or not to pre-compute text
embeddings to save memory. Defaults to False.
enable_xformers (bool): Whether or not to enable memory efficient
attention. Defaults to False.
"""

def __init__( # noqa: C901
Expand All @@ -107,6 +109,7 @@ def __init__( # noqa: C901
finetune_text_encoder: bool = False,
gradient_checkpointing: bool = False,
pre_compute_text_embeddings: bool = False,
enable_xformers: bool = False,
) -> None:
if data_preprocessor is None:
data_preprocessor = {"type": "SDXLDataPreprocessor"}
Expand Down Expand Up @@ -148,6 +151,7 @@ def __init__( # noqa: C901
self.gradient_checkpointing = gradient_checkpointing
self.pre_compute_text_embeddings = pre_compute_text_embeddings
self.input_perturbation_gamma = input_perturbation_gamma
self.enable_xformers = enable_xformers

if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
Expand Down Expand Up @@ -183,6 +187,7 @@ def __init__( # noqa: C901
self.timesteps_generator = MODELS.build(timesteps_generator)
self.prepare_model()
self.set_lora()
self.set_xformers()

def set_lora(self) -> None:
"""Set LORA for model."""
Expand Down Expand Up @@ -219,6 +224,18 @@ def prepare_model(self) -> None:
self.text_encoder_two.requires_grad_(requires_grad=False)
print_log("Set Text Encoder untrainable.", "current")

def set_xformers(self) -> None:
"""Set xformers for model."""
if self.enable_xformers:
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available():
self.unet.enable_xformers_memory_efficient_attention()
else:
msg = "Please install xformers to enable memory efficient attention."
raise ImportError(
msg,
)

@property
def device(self) -> torch.device:
"""Get device information.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ def prepare_model(self) -> None:
self.unet.requires_grad_(requires_grad=False)
print_log("Set Unet untrainable.", "current")

def set_xformers(self) -> None:
"""Set xformers for model."""
if self.enable_xformers:
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available():
self.unet.enable_xformers_memory_efficient_attention()
self.controlnet.enable_xformers_memory_efficient_attention()
else:
msg = "Please install xformers to enable memory efficient attention."
raise ImportError(
msg,
)

@torch.no_grad()
def infer(self,
prompt: list[str],
Expand Down
Loading