diff --git a/egs/librispeech/ASR/zipformer/test_cope.py b/egs/librispeech/ASR/zipformer/test_cope.py new file mode 100755 index 0000000000..5eb6ccfd98 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/test_cope.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +from zipformer import ContextualPositionalEncoding + + +def test(): + embed_dim = 5 + npos_max = 10 + cope = ContextualPositionalEncoding(embed_dim=embed_dim, npos_max=npos_max) + q = torch.rand(2, 3, 4, embed_dim) + qk = torch.rand(2, 3, 4, 6) + + p = cope(q=q, qk=qk) + print(p.shape) + + +def main(): + test() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 69059287b4..6a94e3ab00 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -95,6 +95,7 @@ class Zipformer2(EncoderInterface): context chunks for causal training; will be rounded to a number of chunks. Must not be less than cnn_module_kernel (after factoring in rounding and downsampling); an error will be thrown if this is violated. + use_cope (bool): If true, use contextual positional encoding """ def __init__( @@ -116,6 +117,7 @@ def __init__( causal: bool = False, chunk_size: Tuple[int] = [-1], left_context_frames: Tuple[int] = [-1], + use_cope: bool = False, ) -> None: super(Zipformer2, self).__init__() @@ -183,6 +185,7 @@ def _to_tuple(x): warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + use_cope=use_cope, ) if downsampling_factor[i] != 1: @@ -1017,6 +1020,7 @@ def __init__( warmup_end: float, initial_layerdrop_rate: float = 0.5, final_layerdrop_rate: float = 0.05, + use_cope: bool = False, ) -> None: super().__init__() self.encoder_pos = CompactRelPositionalEncoding( @@ -1372,6 +1376,54 @@ def forward(self, src: Tensor) -> Tensor: return src +class ContextualPositionalEncoding(torch.nn.Module): + """ + This class implements the following paper: + Contextual Position Encoding: Learning to Count What's Important + https://arxiv.org/abs/2405.18719 + + Args: + embed_dim: Embedding dimension. + npos_max: The maximum context size. + """ + + def __init__(self, embed_dim: int, npos_max: int): + super().__init__() + self.npos_max = npos_max + self.embedding = nn.Embedding( + num_embeddings=npos_max, + embedding_dim=embed_dim, + ) + + def forward(self, q: torch.Tensor, qk: torch.Tensor) -> torch.Tensor: + """ + Args: + q (torch.Tensor): A tensor of shape (head, batch, time1, query_head_dim) + qk (torch.Tensor): A tensor of shape (head, batch, time1, time2) + Returns: + Return a tensor of shape (head, batch, time1, npos_max) + """ + gates = torch.sigmoid(qk) + pos = gates.sum(dim=-1, keepdim=True) # (head, batch, dim1, 1) + # Note: We don't use cumulative sum here for non-streaming + # speech recognition + + pos = pos.clamp(max=self.npos_max - 1) + pos_ceil = pos.ceil().long() + pos_floor = pos.floor().long() + logits_int = torch.matmul( + q, self.embedding.weight.t() + ) # (head, batch, time1, npos_max) + logits_cell = logits_int.gather(-1, pos_ceil.expand(*logits_int.shape)) + logits_floor = logits_int.gather(-1, pos_floor.expand(*logits_int.shape)) + + w = pos - pos_floor + return logits_cell * w + logits_floor * (1 - w) + + def streaming_forward(self): + raise RuntimeError("To be implemented") + + class CompactRelPositionalEncoding(torch.nn.Module): """ Relative positional encoding module. This version is "compact" meaning it is able to encode @@ -1609,7 +1661,11 @@ def forward( k = x[..., query_dim : 2 * query_dim] # p is the position-encoding query p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim, (p.shape[-1], num_heads, pos_head_dim) + assert p.shape[-1] == num_heads * pos_head_dim, ( + p.shape[-1], + num_heads, + pos_head_dim, + ) q = self.copy_query(q) # for diagnostics only, does nothing. k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.