Skip to content

Commit

Permalink
add Deep Shrink
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 23, 2023
1 parent d0923d6 commit 6d6d862
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 3 deletions.
76 changes: 76 additions & 0 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2501,6 +2501,10 @@ def __getattr__(self, item):
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()

# Deep Shrink
if args.ds_depth_1 is not None:
unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio)

# Extended Textual Inversion および Textual Inversionを処理する
if args.XTI_embeddings:
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
Expand Down Expand Up @@ -3085,6 +3089,13 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
clip_prompt = None
network_muls = None

# Deep Shrink
ds_depth_1 = None # means no override
ds_timesteps_1 = args.ds_timesteps_1
ds_depth_2 = args.ds_depth_2
ds_timesteps_2 = args.ds_timesteps_2
ds_ratio = args.ds_ratio

prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
Expand Down Expand Up @@ -3156,10 +3167,51 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
print(f"network mul: {network_muls}")
continue

# Deep Shrink
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 1
ds_depth_1 = int(m.group(1))
print(f"deep shrink depth 1: {ds_depth_1}")
continue

m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 1
ds_timesteps_1 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
continue

m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 2
ds_depth_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink depth 2: {ds_depth_2}")
continue

m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 2
ds_timesteps_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
continue

m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink ratio
ds_ratio = float(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink ratio: {ds_ratio}")
continue

except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)

# override Deep Shrink
if ds_depth_1 is not None:
if ds_depth_1 < 0:
ds_depth_1 = args.ds_depth_1 or 3
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)

# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う
Expand Down Expand Up @@ -3509,6 +3561,30 @@ def setup_parser() -> argparse.ArgumentParser:
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# )

# Deep Shrink
parser.add_argument(
"--ds_depth_1",
type=int,
default=None,
help="Enable Deep Shrink with this depth 1, valid values are 0 to 3 / Deep Shrinkをこのdepthで有効にする",
)
parser.add_argument(
"--ds_timesteps_1",
type=int,
default=650,
help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps",
)
parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2")
parser.add_argument(
"--ds_timesteps_2",
type=int,
default=650,
help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps",
)
parser.add_argument(
"--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率"
)

return parser


Expand Down
68 changes: 66 additions & 2 deletions library/original_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,23 @@ def get_timestep_embedding(
return emb


# Deep Shrink: We do not common this function, because minimize dependencies.
def resize_like(x, target, mode="bicubic", align_corners=False):
org_dtype = x.dtype
if org_dtype == torch.bfloat16:
x = x.to(torch.float32)

if x.shape[-2:] != target.shape[-2:]:
if mode == "nearest":
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
else:
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)

if org_dtype == torch.bfloat16:
x = x.to(org_dtype)
return x


class SampleOutput:
def __init__(self, sample):
self.sample = sample
Expand Down Expand Up @@ -1130,6 +1147,11 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]

# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)

hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

if self.training and self.gradient_checkpointing:
Expand Down Expand Up @@ -1221,6 +1243,11 @@ def forward(
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]

# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)

hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

if self.training and self.gradient_checkpointing:
Expand Down Expand Up @@ -1417,6 +1444,31 @@ def __init__(
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)

# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None

def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio

# region diffusers compatibility
def prepare_config(self):
self.config = SimpleNamespace()
Expand Down Expand Up @@ -1519,9 +1571,21 @@ def forward(
# 2. pre-process
sample = self.conv_in(sample)

# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
for depth, downsample_block in enumerate(self.down_blocks):
# Deep Shrink
if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
self.ds_depth_2 is not None
and depth == self.ds_depth_2
and timesteps[0] < self.ds_timesteps_1
and timesteps[0] >= self.ds_timesteps_2
):
org_dtype = sample.dtype
if org_dtype == torch.bfloat16:
sample = sample.to(torch.float32)
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)

# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
# まあこちらのほうがわかりやすいかもしれない
if downsample_block.has_cross_attention:
Expand Down
70 changes: 69 additions & 1 deletion library/sdxl_original_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,23 @@ def get_timestep_embedding(
return emb


# Deep Shrink: We do not common this function, because minimize dependencies.
def resize_like(x, target, mode="bicubic", align_corners=False):
org_dtype = x.dtype
if org_dtype == torch.bfloat16:
x = x.to(torch.float32)

if x.shape[-2:] != target.shape[-2:]:
if mode == "nearest":
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
else:
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)

if org_dtype == torch.bfloat16:
x = x.to(org_dtype)
return x


class GroupNorm32(nn.GroupNorm):
def forward(self, x):
if self.weight.dtype != torch.float32:
Expand Down Expand Up @@ -996,6 +1013,31 @@ def __init__(
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
)

# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None

def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio

# region diffusers compatibility
def prepare_config(self):
self.config = SimpleNamespace()
Expand Down Expand Up @@ -1077,16 +1119,42 @@ def call_module(module, h, emb, context):

# h = x.type(self.dtype)
h = x
for module in self.input_blocks:

for depth, module in enumerate(self.input_blocks):
# Deep Shrink
if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
self.ds_depth_2 is not None
and depth == self.ds_depth_2
and timesteps[0] < self.ds_timesteps_1
and timesteps[0] >= self.ds_timesteps_2
):
# print("downsample", h.shape, self.ds_ratio)
org_dtype = h.dtype
if org_dtype == torch.bfloat16:
h = h.to(torch.float32)
h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)

h = call_module(module, h, emb, context)
hs.append(h)

h = call_module(self.middle_block, h, emb, context)

for module in self.output_blocks:
# Deep Shrink
if self.ds_depth_1 is not None:
if hs[-1].shape[-2:] != h.shape[-2:]:
# print("upsample", h.shape, hs[-1].shape)
h = resize_like(h, hs[-1])

h = torch.cat([h, hs.pop()], dim=1)
h = call_module(module, h, emb, context)

# Deep Shrink: in case of depth 0
if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]:
# print("upsample", h.shape, x.shape)
h = resize_like(h, x)

h = h.type(x.dtype)
h = call_module(self.out, h, emb, context)

Expand Down
Loading

0 comments on commit 6d6d862

Please sign in to comment.