Skip to content

Commit

Permalink
WIP: Begin to add Contextual positional encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Jun 5, 2024
1 parent b880622 commit 06232dc
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 1 deletion.
22 changes: 22 additions & 0 deletions egs/librispeech/ASR/zipformer/test_cope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env python3

from zipformer import ContextualPositionalEncoding


def test():
embed_dim = 5
npos_max = 10
cope = ContextualPositionalEncoding(embed_dim=embed_dim, npos_max=npos_max)
q = torch.rand(2, 3, 4, embed_dim)
qk = torch.rand(2, 3, 4, 6)

p = cope(q=q, qk=qk)
print(p.shape)


def main():
test()


if __name__ == "__main__":
main()
58 changes: 57 additions & 1 deletion egs/librispeech/ASR/zipformer/zipformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class Zipformer2(EncoderInterface):
context chunks for causal training; will be rounded to a number of
chunks. Must not be less than cnn_module_kernel (after factoring in
rounding and downsampling); an error will be thrown if this is violated.
use_cope (bool): If true, use contextual positional encoding
"""

def __init__(
Expand All @@ -116,6 +117,7 @@ def __init__(
causal: bool = False,
chunk_size: Tuple[int] = [-1],
left_context_frames: Tuple[int] = [-1],
use_cope: bool = False,
) -> None:
super(Zipformer2, self).__init__()

Expand Down Expand Up @@ -183,6 +185,7 @@ def _to_tuple(x):
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
use_cope=use_cope,
)

if downsampling_factor[i] != 1:
Expand Down Expand Up @@ -1017,6 +1020,7 @@ def __init__(
warmup_end: float,
initial_layerdrop_rate: float = 0.5,
final_layerdrop_rate: float = 0.05,
use_cope: bool = False,
) -> None:
super().__init__()
self.encoder_pos = CompactRelPositionalEncoding(
Expand Down Expand Up @@ -1372,6 +1376,54 @@ def forward(self, src: Tensor) -> Tensor:
return src


class ContextualPositionalEncoding(torch.nn.Module):
"""
This class implements the following paper:
Contextual Position Encoding: Learning to Count What's Important
https://arxiv.org/abs/2405.18719
Args:
embed_dim: Embedding dimension.
npos_max: The maximum context size.
"""

def __init__(self, embed_dim: int, npos_max: int):
super().__init__()
self.npos_max = npos_max
self.embedding = nn.Embedding(
num_embeddings=npos_max,
embedding_dim=embed_dim,
)

def forward(self, q: torch.Tensor, qk: torch.Tensor) -> torch.Tensor:
"""
Args:
q (torch.Tensor): A tensor of shape (head, batch, time1, query_head_dim)
qk (torch.Tensor): A tensor of shape (head, batch, time1, time2)
Returns:
Return a tensor of shape (head, batch, time1, npos_max)
"""
gates = torch.sigmoid(qk)
pos = gates.sum(dim=-1, keepdim=True) # (head, batch, dim1, 1)
# Note: We don't use cumulative sum here for non-streaming
# speech recognition

pos = pos.clamp(max=self.npos_max - 1)
pos_ceil = pos.ceil().long()
pos_floor = pos.floor().long()
logits_int = torch.matmul(
q, self.embedding.weight.t()
) # (head, batch, time1, npos_max)
logits_cell = logits_int.gather(-1, pos_ceil.expand(*logits_int.shape))
logits_floor = logits_int.gather(-1, pos_floor.expand(*logits_int.shape))

w = pos - pos_floor
return logits_cell * w + logits_floor * (1 - w)

def streaming_forward(self):
raise RuntimeError("To be implemented")


class CompactRelPositionalEncoding(torch.nn.Module):
"""
Relative positional encoding module. This version is "compact" meaning it is able to encode
Expand Down Expand Up @@ -1609,7 +1661,11 @@ def forward(
k = x[..., query_dim : 2 * query_dim]
# p is the position-encoding query
p = x[..., 2 * query_dim :]
assert p.shape[-1] == num_heads * pos_head_dim, (p.shape[-1], num_heads, pos_head_dim)
assert p.shape[-1] == num_heads * pos_head_dim, (
p.shape[-1],
num_heads,
pos_head_dim,
)

q = self.copy_query(q) # for diagnostics only, does nothing.
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
Expand Down

0 comments on commit 06232dc

Please sign in to comment.