Skip to content

Commit

Permalink
Merge pull request #43 from drhead/patch-1
Browse files Browse the repository at this point in the history
Optimizations to PAG and t2i-zero
  • Loading branch information
v0xie authored May 18, 2024
2 parents 1b3d17c + ca60bff commit 1dd3b2b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
8 changes: 4 additions & 4 deletions scripts/pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,11 +333,11 @@ def pag_pre_hook(module, input, kwargs, output):
# oops we forgot to unhook
return

batch_size, seq_len, inner_dim = output.shape
identity = torch.eye(seq_len).expand(batch_size, -1, -1).to(shared.device)

# get the last to_v output and save it
last_to_v = getattr(module, 'pag_last_to_v', None)

batch_size, seq_len, inner_dim = output.shape
identity = torch.eye(seq_len, dtype=last_to_v.dtype, device=shared.device).expand(batch_size, -1, -1)
if last_to_v is not None:
new_output = torch.einsum('bij,bjk->bik', identity, last_to_v)
return new_output
Expand Down Expand Up @@ -836,4 +836,4 @@ def _remove_child_hooks(
_remove_child_hooks(module, hook_fn_name)

# Remove hooks from the target module
_remove_hooks(module, hook_fn_name)
_remove_hooks(module, hook_fn_name)
19 changes: 10 additions & 9 deletions scripts/t2i_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,15 +407,16 @@ def ready_hijack_forward(self, alpha, width, height, ema_factor, step_start, ste
plot_num = 0
for module in cross_attn_modules:
self.add_field_cross_attn_modules(module, 't2i0_last_attn_map', None)
self.add_field_cross_attn_modules(module, 't2i0_step', torch.tensor([-1]).to(device=shared.device))
self.add_field_cross_attn_modules(module, 't2i0_step_start', torch.tensor([step_start]).to(device=shared.device))
self.add_field_cross_attn_modules(module, 't2i0_step_end', torch.tensor([step_end]).to(device=shared.device))
self.add_field_cross_attn_modules(module, 't2i0_step', int(-1))
self.add_field_cross_attn_modules(module, 't2i0_step_start', int(step_start))
self.add_field_cross_attn_modules(module, 't2i0_step_end', int(step_end))
self.add_field_cross_attn_modules(module, 't2i0_ema', None)
self.add_field_cross_attn_modules(module, 't2i0_ema_factor', torch.tensor([ema_factor]).to(device=shared.device, dtype=torch.float16))
self.add_field_cross_attn_modules(module, 'plot_num', torch.tensor([plot_num]).to(device=shared.device))
self.add_field_cross_attn_modules(module, 't2i0_ema_factor', float(ema_factor))
self.add_field_cross_attn_modules(module, 'plot_num', int(plot_num))
self.add_field_cross_attn_modules(module, 't2i0_to_v_map', None)
self.add_field_cross_attn_modules(module.to_v, 't2i0_parent_module', [module])
self.add_field_cross_attn_modules(module, 't2i0_token_count', torch.tensor(token_count).to(device=shared.device, dtype=torch.int64))
self.add_field_cross_attn_modules(module, 't2i0_token_count', int(token_count))
self.add_field_cross_attn_modules(module, 'gaussian_blur', GaussianBlur(kernel_size=3, sigma=1).to(device=shared.device))
if tokens is not None:
self.add_field_cross_attn_modules(module, 't2i0_tokens', torch.tensor(tokens).to(device=shared.device, dtype=torch.int64))
else:
Expand Down Expand Up @@ -476,9 +477,9 @@ def cross_token_non_maximum_suppression(module, input, kwargs, output):
attention_map = output.view(batch_size, downscale_height, downscale_width, inner_dim)

if token_indices is None:
selected_tokens = torch.tensor(list(range(1, token_count.item())))
selected_tokens = torch.arange(1, token_count, device=output.device)
elif len(token_indices) == 0:
selected_tokens = torch.tensor(list(range(1, token_count.item())))
selected_tokens = torch.arange(1, token_count, device=output.device)
else:
selected_tokens = module.t2i0_tokens

Expand All @@ -490,7 +491,7 @@ def cross_token_non_maximum_suppression(module, input, kwargs, output):

# Extract and process the selected attention maps
# GaussianBlur expects the input [..., C, H, W]
gaussian_blur = GaussianBlur(kernel_size=3, sigma=1)
gaussian_blur = module.gaussian_blur
AC = AC.permute(0, 3, 1, 2)
AC = gaussian_blur(AC) # Applying Gaussian smoothing
AC = AC.permute(0, 2, 3, 1)
Expand Down

0 comments on commit 1dd3b2b

Please sign in to comment.