Skip to content

Commit

Permalink
feat: implement block swapping for FLUX.1 LoRA (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 11, 2024
1 parent 7feaae5 commit cde90b8
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 5 deletions.
2 changes: 1 addition & 1 deletion flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def grad_hook(parameter: torch.Tensor):
num_parameters_per_group[opt_idx] += 1

# add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook
if is_swapping_blocks:
if False: # is_swapping_blocks:
import library.custom_offloading_utils as custom_offloading_utils

num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
Expand Down
33 changes: 33 additions & 0 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False

def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
Expand Down Expand Up @@ -78,6 +79,12 @@ def load_target_model(self, args, weight_dtype, accelerator):
if args.split_mode:
model = self.prepare_split_model(model, weight_dtype, accelerator)

self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)

clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
clip_l.eval()

Expand Down Expand Up @@ -285,6 +292,8 @@ def sample_images(self, accelerator, args, epoch, global_step, device, ae, token
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)

if not args.split_mode:
if self.is_swapping_blocks:
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
Expand Down Expand Up @@ -539,6 +548,19 @@ def forward(hidden_states):
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)

def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)

# if we doesn't swap blocks, we can move the model to device
flux: flux_models.Flux = unet
flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage

return flux


def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
Expand All @@ -550,6 +572,17 @@ def setup_parser() -> argparse.ArgumentParser:
help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
+ "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
)

parser.add_argument(
"--blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of blocks to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップするブロックの数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
)
return parser


Expand Down
40 changes: 39 additions & 1 deletion library/custom_offloading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,47 @@ class ModelOffloader(Offloader):
supports forward offloading
"""

def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)

# register backward hooks
self.remove_handles = []
for i, block in enumerate(blocks):
hook = self.create_backward_hook(blocks, i)
if hook is not None:
handle = block.register_full_backward_hook(hook)
self.remove_handles.append(handle)

def __del__(self):
for handle in self.remove_handles:
handle.remove()

def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
# -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
waiting = block_index > 0 and block_index <= self.blocks_to_swap

if not swapping and not waiting:
return None

# create hook
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1

def backward_hook(module, grad_input, grad_output):
if self.debug:
print(f"Backward hook for block {block_index}")

if swapping:
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
if waiting:
self._wait_blocks_move(block_idx_to_wait)
return None

return backward_hook

def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
Expand Down
8 changes: 6 additions & 2 deletions library/flux_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,8 +970,12 @@ def enable_block_swap(self, num_blocks: int, device: torch.device):
double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2

self.offloader_double = custom_offloading_utils.ModelOffloader(self.num_double_blocks, double_blocks_to_swap, device)
self.offloader_single = custom_offloading_utils.ModelOffloader(self.num_single_blocks, single_blocks_to_swap, device)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device #, debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device #, debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
)
Expand Down
9 changes: 8 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
init_ipex()

from accelerate.utils import set_seed
from accelerate import Accelerator
from diffusers import DDPMScheduler
from library import deepspeed_utils, model_util, strategy_base, strategy_sd

Expand Down Expand Up @@ -272,6 +273,11 @@ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
text_encoder.text_model.embeddings.to(dtype=weight_dtype)

def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
return accelerator.prepare(unet)

def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
pass

Expand Down Expand Up @@ -627,7 +633,8 @@ def train(self, args):
training_model = ds_model
else:
if train_unet:
unet = accelerator.prepare(unet)
# default implementation is: unet = accelerator.prepare(unet)
unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here
else:
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
Expand Down

0 comments on commit cde90b8

Please sign in to comment.