Skip to content

Commit

Permalink
fix the implementation of CoPE
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Jul 3, 2024
1 parent 06232dc commit 36808b8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 10 deletions.
8 changes: 6 additions & 2 deletions egs/librispeech/ASR/zipformer/test_cope.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
#!/usr/bin/env python3

import torch
from zipformer import ContextualPositionalEncoding


def test():
embed_dim = 5
npos_max = 10

cope = ContextualPositionalEncoding(embed_dim=embed_dim, npos_max=npos_max)
q = torch.rand(2, 3, 4, embed_dim)
qk = torch.rand(2, 3, 4, 6)
q = torch.rand(2, 3, npos_max, embed_dim)

qk = torch.rand(2, 3, npos_max, npos_max)

p = cope(q=q, qk=qk)
print(p.shape)
Expand All @@ -19,4 +22,5 @@ def main():


if __name__ == "__main__":
torch.manual_seed(20240703)
main()
49 changes: 41 additions & 8 deletions egs/librispeech/ASR/zipformer/zipformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,26 +1402,59 @@ def forward(self, q: torch.Tensor, qk: torch.Tensor) -> torch.Tensor:
qk (torch.Tensor): A tensor of shape (head, batch, time1, time2)
Returns:
Return a tensor of shape (head, batch, time1, npos_max)
Note the implementation assumes time1 == time2 and npos_max <= time2.
The implementation is reasonable for the streaming ASR encoder where
only self attention is used.
"""
# The implementation on page 13 Listing 1 from the paper does not use
# a mask to ensure that only gates[:, :, i, j] where j < i is computed.
#
# Here we fix that by introducing a mask
mask = torch.triu(
torch.full((qk.size(3), qk.size(3)), True, dtype=torch.bool),
diagonal=0,
)
#
# if qk.size(3) is 4, mask is
#
# tensor([[ True, True, True, True],
# [False, True, True, True],
# [False, False, True, True],
# [False, False, False, True]])
#
# mask[i, j] is True if i >= j
gates = torch.sigmoid(qk)
pos = gates.sum(dim=-1, keepdim=True) # (head, batch, dim1, 1)
# Note: We don't use cumulative sum here for non-streaming
# speech recognition

# We don't use an in-place operation here for the sake of autograd
gates = gates.masked_fill(mask, 0)

# cumsum() is an inclusive sum in PyTorch
pos = gates.flip(-1).cumsum(dim=-1).flip(-1) # (head, batch, time1, time2)
# pos[:, :, i, j] should be 0 for j >= i
# pos[:, :, i, j] contains the position between i and j. If gates
# is a 0-1 matrix, then pos[:, :, i, j] equals to i - j (for j < i)
# Note: The paper says on page 4 it equals to i - j + 1 instead of i - j.

pos = pos.clamp(max=self.npos_max - 1)
pos_ceil = pos.ceil().long()
pos_floor = pos.floor().long()

# We assume query_head_dim equals to embed_dim

logits_int = torch.matmul(
q, self.embedding.weight.t()
) # (head, batch, time1, npos_max)
logits_cell = logits_int.gather(-1, pos_ceil.expand(*logits_int.shape))
logits_floor = logits_int.gather(-1, pos_floor.expand(*logits_int.shape))

# We assume that npos_max <= time2
logits_cell = logits_int.gather(-1, pos_ceil)
logits_floor = logits_int.gather(-1, pos_floor)

w = pos - pos_floor
return logits_cell * w + logits_floor * (1 - w)

def streaming_forward(self):
raise RuntimeError("To be implemented")
# Note: The code in the paper on page 13 is correct
# while the description on page 4 equation (5) is wrong
return logits_cell * w + logits_floor * (1 - w)


class CompactRelPositionalEncoding(torch.nn.Module):
Expand Down

0 comments on commit 36808b8

Please sign in to comment.