Skip to content

Commit

Permalink
Add support for NATTEN 0.17.0 (fused neighborhood attention)
Browse files Browse the repository at this point in the history
  • Loading branch information
crowsonkb committed May 6, 2024
1 parent 6ab5146 commit 21d12c9
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions k_diffusion/models/image_transformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,18 +407,28 @@ def forward(self, x, pos, cond):
skip = x
x = self.norm(x, cond)
qkv = self.qkv_proj(x)
q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head)
q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6)
theta = self.pos_emb(pos).movedim(-2, -4)
q = apply_rotary_emb_(q, theta)
k = apply_rotary_emb_(k, theta)
if natten is None:
raise ModuleNotFoundError("natten is required for neighborhood attention")
flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size)
qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1)
a = torch.softmax(qk, dim=-1).to(v.dtype)
x = natten.functional.natten2dav(a, v, self.kernel_size, 1)
x = rearrange(x, "n nh h w e -> n h w (nh e)")
if natten.has_fused_na():
q, k, v = rearrange(qkv, "n h w (t nh e) -> t n h w nh e", t=3, e=self.d_head)
q, k = scale_for_cosine_sim(q, k, self.scale[:, None], 1e-6)
theta = self.pos_emb(pos)
q = apply_rotary_emb_(q, theta)
k = apply_rotary_emb_(k, theta)
flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size)
x = natten.functional.na2d(q, k, v, self.kernel_size, scale=1.0)
x = rearrange(x, "n h w nh e -> n h w (nh e)")
else:
q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head)
q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6)
theta = self.pos_emb(pos).movedim(-2, -4)
q = apply_rotary_emb_(q, theta)
k = apply_rotary_emb_(k, theta)
flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size)
qk = natten.functional.na2d_qk(q, k, self.kernel_size)
a = torch.softmax(qk, dim=-1).to(v.dtype)
x = natten.functional.na2d_av(a, v, self.kernel_size)
x = rearrange(x, "n nh h w e -> n h w (nh e)")
x = self.dropout(x)
x = self.out_proj(x)
return x + skip
Expand Down

0 comments on commit 21d12c9

Please sign in to comment.