From d42577c5c2e3b06b03a1c6aa04f51c85a9a59488 Mon Sep 17 00:00:00 2001 From: okotaku Date: Sat, 2 Dec 2023 09:22:03 +0000 Subject: [PATCH] Support xformers --- Dockerfile | 15 +++++---- configs/stable_diffusion_xl/README.md | 9 +++--- ...able_diffusion_xl_pokemon_blip_xformers.py | 31 +++++++++++++++++++ .../editors/deepfloyd_if/deepfloyd_if.py | 17 ++++++++++ .../editors/distill_sd/distill_sd_xl.py | 13 ++++++++ diffengine/models/editors/esd/esd_xl.py | 14 +++++++++ diffengine/models/editors/lcm/lcm_xl.py | 15 +++++++++ diffengine/models/editors/ssd_1b/ssd_1b.py | 18 +++++++++++ .../stable_diffusion/stable_diffusion.py | 17 ++++++++++ .../stable_diffusion_controlnet.py | 13 ++++++++ .../stable_diffusion_xl.py | 19 +++++++++++- .../stable_diffusion_xl_controlnet.py | 13 ++++++++ .../stable_diffusion_xl_t2i_adapter.py | 12 +++++++ 13 files changed, 195 insertions(+), 11 deletions(-) create mode 100644 configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_xformers.py diff --git a/Dockerfile b/Dockerfile index a1a10c1..b782e90 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,14 +15,17 @@ RUN sh -c "$(wget -O- https://github.com/deluan/zsh-in-docker/releases/download/ # Install python package. WORKDIR /diffengine COPY ./ /diffengine -RUN pip install --upgrade pip && \ - pip install --no-cache-dir openmim==0.3.9 && \ - pip install . && \ - pip install pre-commit +RUN pip install --upgrade pip # Install xformers -# RUN export TORCH_CUDA_ARCH_LIST="9.0+PTX" MAX_JOBS=1 && \ -# pip install -v -U git+https://github.com/facebookresearch/xformers.git@v0.0.22.post7#egg=xformers +RUN pip install ninja +RUN export TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6 9.0+PTX" MAX_JOBS=1 && \ + pip install -v -U git+https://github.com/facebookresearch/xformers.git@v0.0.20#egg=xformers + +# Install DiffEngine +RUN pip install --no-cache-dir openmim==0.3.9 && \ + pip install . && \ + pip install pre-commit # Language settings ENV LANG C.UTF-8 diff --git a/configs/stable_diffusion_xl/README.md b/configs/stable_diffusion_xl/README.md index 5549985..ee0e88a 100644 --- a/configs/stable_diffusion_xl/README.md +++ b/configs/stable_diffusion_xl/README.md @@ -40,10 +40,11 @@ Settings: - 1epoch training. -| Model | total time | -| :-------------------------------------: | :--------: | -| stable_diffusion_xl_pokemon_blip (fp16) | 12 m 37 s | -| stable_diffusion_xl_pokemon_blip_fast | 9 m 47 s | +| Model | total time | +| :---------------------------------------: | :--------: | +| stable_diffusion_xl_pokemon_blip (fp16) | 12 m 37 s | +| stable_diffusion_xl_pokemon_blip_xformers | 10 m 6 s | +| stable_diffusion_xl_pokemon_blip_fast | 9 m 47 s | Note that `stable_diffusion_xl_pokemon_blip_fast` took a few minutes to compile. We will disregard it. diff --git a/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_xformers.py b/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_xformers.py new file mode 100644 index 0000000..3bedd1f --- /dev/null +++ b/configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_xformers.py @@ -0,0 +1,31 @@ +_base_ = [ + "../_base_/models/stable_diffusion_xl.py", + "../_base_/datasets/pokemon_blip_xl.py", + "../_base_/schedules/stable_diffusion_xl_50e.py", + "../_base_/default_runtime.py", +] + +model = dict( + enable_xformers=True, + gradient_checkpointing=False) + +train_dataloader = dict(batch_size=1) + +optim_wrapper = dict( + dtype="float16", + accumulative_counts=4) + +env_cfg = dict( + cudnn_benchmark=True, +) + +custom_hooks = [ + dict( + type="VisualizationHook", + prompt=["yoda pokemon"] * 4, + height=1024, + width=1024), + dict(type="SDCheckpointHook"), + dict(type="FastNormHook", fuse_unet_ln=True, fuse_gn=True), + dict(type="CompileHook", compile_unet=False), +] diff --git a/diffengine/models/editors/deepfloyd_if/deepfloyd_if.py b/diffengine/models/editors/deepfloyd_if/deepfloyd_if.py index 034ef99..015c671 100644 --- a/diffengine/models/editors/deepfloyd_if/deepfloyd_if.py +++ b/diffengine/models/editors/deepfloyd_if/deepfloyd_if.py @@ -60,6 +60,8 @@ class DeepFloydIF(BaseModel): gradient_checkpointing (bool): Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. Defaults to False. + enable_xformers (bool): Whether or not to enable memory efficient + attention. Defaults to False. """ def __init__( @@ -78,6 +80,7 @@ def __init__( *, finetune_text_encoder: bool = False, gradient_checkpointing: bool = False, + enable_xformers: bool = False, ) -> None: if data_preprocessor is None: data_preprocessor = {"type": "SDDataPreprocessor"} @@ -117,6 +120,7 @@ def __init__( self.gradient_checkpointing = gradient_checkpointing self.tokenizer_max_length = tokenizer_max_length self.input_perturbation_gamma = input_perturbation_gamma + self.enable_xformers = enable_xformers if not isinstance(loss, nn.Module): loss = MODELS.build(loss) @@ -138,6 +142,7 @@ def __init__( self.timesteps_generator = MODELS.build(timesteps_generator) self.prepare_model() self.set_lora() + self.set_xformers() def set_lora(self) -> None: """Set LORA for model.""" @@ -166,6 +171,18 @@ def prepare_model(self) -> None: self.text_encoder.requires_grad_(requires_grad=False) print_log("Set Text Encoder untrainable.", "current") + def set_xformers(self) -> None: + """Set xformers for model.""" + if self.enable_xformers: + from diffusers.utils.import_utils import is_xformers_available + if is_xformers_available(): + self.unet.enable_xformers_memory_efficient_attention() + else: + msg = "Please install xformers to enable memory efficient attention." + raise ImportError( + msg, + ) + @property def device(self) -> torch.device: """Get device information. diff --git a/diffengine/models/editors/distill_sd/distill_sd_xl.py b/diffengine/models/editors/distill_sd/distill_sd_xl.py index 2bd1f3a..b0b1447 100644 --- a/diffengine/models/editors/distill_sd/distill_sd_xl.py +++ b/diffengine/models/editors/distill_sd/distill_sd_xl.py @@ -146,6 +146,19 @@ def hook(model, input, output): # noqa get_activation( self.student_feats, "u" + str(i), residuals_present=False)) + def set_xformers(self) -> None: + """Set xformers for model.""" + if self.enable_xformers: + from diffusers.utils.import_utils import is_xformers_available + if is_xformers_available(): + self.unet.enable_xformers_memory_efficient_attention() + self.orig_unet.enable_xformers_memory_efficient_attention() + else: + msg = "Please install xformers to enable memory efficient attention." + raise ImportError( + msg, + ) + def forward( self, inputs: dict, diff --git a/diffengine/models/editors/esd/esd_xl.py b/diffengine/models/editors/esd/esd_xl.py index 203bfb8..10b204e 100644 --- a/diffengine/models/editors/esd/esd_xl.py +++ b/diffengine/models/editors/esd/esd_xl.py @@ -81,6 +81,20 @@ def _freeze_unet(self) -> None: name.startswith("out.")): module.eval() + def set_xformers(self) -> None: + """Set xformers for model.""" + if self.enable_xformers: + from diffusers.utils.import_utils import is_xformers_available + if is_xformers_available(): + self.unet.enable_xformers_memory_efficient_attention() + if self.unet_lora_config is None: + self.orig_unet.enable_xformers_memory_efficient_attention() + else: + msg = "Please install xformers to enable memory efficient attention." + raise ImportError( + msg, + ) + def train(self, *, mode=True) -> None: """Convert the model into training mode.""" super().train(mode) diff --git a/diffengine/models/editors/lcm/lcm_xl.py b/diffengine/models/editors/lcm/lcm_xl.py index 5c67dd2..bda0e75 100644 --- a/diffengine/models/editors/lcm/lcm_xl.py +++ b/diffengine/models/editors/lcm/lcm_xl.py @@ -103,6 +103,21 @@ def prepare_model(self) -> None: super().prepare_model() + def set_xformers(self) -> None: + """Set xformers for model.""" + if self.enable_xformers: + from diffusers.utils.import_utils import is_xformers_available + if is_xformers_available(): + self.unet.enable_xformers_memory_efficient_attention() + self.teacher_unet.enable_xformers_memory_efficient_attention() + if self.unet_lora_config is None: + self.target_unet.enable_xformers_memory_efficient_attention() + else: + msg = "Please install xformers to enable memory efficient attention." + raise ImportError( + msg, + ) + @torch.no_grad() def infer(self, prompt: list[str], diff --git a/diffengine/models/editors/ssd_1b/ssd_1b.py b/diffengine/models/editors/ssd_1b/ssd_1b.py index 923d1b2..2ed3dae 100644 --- a/diffengine/models/editors/ssd_1b/ssd_1b.py +++ b/diffengine/models/editors/ssd_1b/ssd_1b.py @@ -72,6 +72,8 @@ class SSD1B(StableDiffusionXL): pass. Defaults to False. pre_compute_text_embeddings(bool): Whether or not to pre-compute text embeddings to save memory. Defaults to False. + enable_xformers (bool): Whether or not to enable memory efficient + attention. Defaults to False. """ def __init__( @@ -93,6 +95,7 @@ def __init__( finetune_text_encoder: bool = False, gradient_checkpointing: bool = False, pre_compute_text_embeddings: bool = False, + enable_xformers: bool = False, ) -> None: assert unet_lora_config is None, \ "`unet_lora_config` should be None when training SSD1B" @@ -118,6 +121,7 @@ def __init__( self.gradient_checkpointing = gradient_checkpointing self.pre_compute_text_embeddings = pre_compute_text_embeddings self.input_perturbation_gamma = input_perturbation_gamma + self.enable_xformers = enable_xformers if not isinstance(loss, nn.Module): loss = MODELS.build(loss) @@ -163,6 +167,7 @@ def __init__( self.timesteps_generator = MODELS.build(timesteps_generator) self.prepare_model() self.set_lora() + self.set_xformers() def set_lora(self) -> None: """Set LORA for model.""" @@ -242,6 +247,19 @@ def hook(model, input, output): # noqa get_activation(self.student_feats,f"u{nb}a{i}", residuals_present=True)) + def set_xformers(self) -> None: + """Set xformers for model.""" + if self.enable_xformers: + from diffusers.utils.import_utils import is_xformers_available + if is_xformers_available(): + self.unet.enable_xformers_memory_efficient_attention() + self.orig_unet.enable_xformers_memory_efficient_attention() + else: + msg = "Please install xformers to enable memory efficient attention." + raise ImportError( + msg, + ) + def forward( self, inputs: dict, diff --git a/diffengine/models/editors/stable_diffusion/stable_diffusion.py b/diffengine/models/editors/stable_diffusion/stable_diffusion.py index 5dd4803..940425e 100644 --- a/diffengine/models/editors/stable_diffusion/stable_diffusion.py +++ b/diffengine/models/editors/stable_diffusion/stable_diffusion.py @@ -58,6 +58,8 @@ class StableDiffusion(BaseModel): gradient_checkpointing (bool): Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. Defaults to False. + enable_xformers (bool): Whether or not to enable memory efficient + attention. Defaults to False. """ def __init__( @@ -75,6 +77,7 @@ def __init__( *, finetune_text_encoder: bool = False, gradient_checkpointing: bool = False, + enable_xformers: bool = False, ) -> None: if data_preprocessor is None: data_preprocessor = {"type": "SDDataPreprocessor"} @@ -112,6 +115,7 @@ def __init__( self.prior_loss_weight = prior_loss_weight self.gradient_checkpointing = gradient_checkpointing self.input_perturbation_gamma = input_perturbation_gamma + self.enable_xformers = enable_xformers if not isinstance(loss, nn.Module): loss = MODELS.build(loss) @@ -134,6 +138,7 @@ def __init__( self.timesteps_generator = MODELS.build(timesteps_generator) self.prepare_model() self.set_lora() + self.set_xformers() def set_lora(self) -> None: """Set LORA for model.""" @@ -164,6 +169,18 @@ def prepare_model(self) -> None: self.text_encoder.requires_grad_(requires_grad=False) print_log("Set Text Encoder untrainable.", "current") + def set_xformers(self) -> None: + """Set xformers for model.""" + if self.enable_xformers: + from diffusers.utils.import_utils import is_xformers_available + if is_xformers_available(): + self.unet.enable_xformers_memory_efficient_attention() + else: + msg = "Please install xformers to enable memory efficient attention." + raise ImportError( + msg, + ) + @property def device(self) -> torch.device: """Get device information. diff --git a/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py b/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py index 1e672fe..70a84ef 100644 --- a/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py +++ b/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py @@ -112,6 +112,19 @@ def prepare_model(self) -> None: self.unet.requires_grad_(requires_grad=False) print_log("Set Unet untrainable.", "current") + def set_xformers(self) -> None: + """Set xformers for model.""" + if self.enable_xformers: + from diffusers.utils.import_utils import is_xformers_available + if is_xformers_available(): + self.unet.enable_xformers_memory_efficient_attention() + self.controlnet.enable_xformers_memory_efficient_attention() + else: + msg = "Please install xformers to enable memory efficient attention." + raise ImportError( + msg, + ) + @torch.no_grad() def infer(self, prompt: list[str], diff --git a/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py b/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py index 8972d3e..576c05e 100644 --- a/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py +++ b/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py @@ -86,8 +86,10 @@ class StableDiffusionXL(BaseModel): gradient_checkpointing (bool): Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. Defaults to False. - pre_compute_text_embeddings(bool): Whether or not to pre-compute text + pre_compute_text_embeddings (bool): Whether or not to pre-compute text embeddings to save memory. Defaults to False. + enable_xformers (bool): Whether or not to enable memory efficient + attention. Defaults to False. """ def __init__( # noqa: C901 @@ -107,6 +109,7 @@ def __init__( # noqa: C901 finetune_text_encoder: bool = False, gradient_checkpointing: bool = False, pre_compute_text_embeddings: bool = False, + enable_xformers: bool = False, ) -> None: if data_preprocessor is None: data_preprocessor = {"type": "SDXLDataPreprocessor"} @@ -148,6 +151,7 @@ def __init__( # noqa: C901 self.gradient_checkpointing = gradient_checkpointing self.pre_compute_text_embeddings = pre_compute_text_embeddings self.input_perturbation_gamma = input_perturbation_gamma + self.enable_xformers = enable_xformers if not isinstance(loss, nn.Module): loss = MODELS.build(loss) @@ -183,6 +187,7 @@ def __init__( # noqa: C901 self.timesteps_generator = MODELS.build(timesteps_generator) self.prepare_model() self.set_lora() + self.set_xformers() def set_lora(self) -> None: """Set LORA for model.""" @@ -219,6 +224,18 @@ def prepare_model(self) -> None: self.text_encoder_two.requires_grad_(requires_grad=False) print_log("Set Text Encoder untrainable.", "current") + def set_xformers(self) -> None: + """Set xformers for model.""" + if self.enable_xformers: + from diffusers.utils.import_utils import is_xformers_available + if is_xformers_available(): + self.unet.enable_xformers_memory_efficient_attention() + else: + msg = "Please install xformers to enable memory efficient attention." + raise ImportError( + msg, + ) + @property def device(self) -> torch.device: """Get device information. diff --git a/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py b/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py index 37630ae..41912fa 100644 --- a/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py +++ b/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py @@ -113,6 +113,19 @@ def prepare_model(self) -> None: self.unet.requires_grad_(requires_grad=False) print_log("Set Unet untrainable.", "current") + def set_xformers(self) -> None: + """Set xformers for model.""" + if self.enable_xformers: + from diffusers.utils.import_utils import is_xformers_available + if is_xformers_available(): + self.unet.enable_xformers_memory_efficient_attention() + self.controlnet.enable_xformers_memory_efficient_attention() + else: + msg = "Please install xformers to enable memory efficient attention." + raise ImportError( + msg, + ) + @torch.no_grad() def infer(self, prompt: list[str], diff --git a/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py b/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py index a761282..085ac39 100644 --- a/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py +++ b/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py @@ -110,6 +110,18 @@ def prepare_model(self) -> None: self.unet.requires_grad_(requires_grad=False) print_log("Set Unet untrainable.", "current") + def set_xformers(self) -> None: + """Set xformers for model.""" + if self.enable_xformers: + from diffusers.utils.import_utils import is_xformers_available + if is_xformers_available(): + self.unet.enable_xformers_memory_efficient_attention() + else: + msg = "Please install xformers to enable memory efficient attention." + raise ImportError( + msg, + ) + @torch.no_grad() def infer(self, prompt: list[str],