Skip to content

Commit

Permalink
Small fix (#1686)
Browse files Browse the repository at this point in the history
  • Loading branch information
yfyeung authored Jul 11, 2024
1 parent 785f3f0 commit d65187e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
5 changes: 3 additions & 2 deletions egs/librispeech/ASR/zipformer/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,8 +636,9 @@ def __init__(
)

def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
"""
Forward function. Args:
"""Forward function.
Args:
x: a Tensor of shape (batch_size, channels, seq_len)
chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
"""
Expand Down
6 changes: 3 additions & 3 deletions egs/librispeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)

parser.add_argument(
Expand All @@ -429,7 +429,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" "part.",
help="The scale to smooth the loss with am (output of encoder network) part.",
)

parser.add_argument(
Expand Down Expand Up @@ -848,7 +848,7 @@ def compute_loss(
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
warmup: a floating point value which increases throughout training;
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
Expand Down

0 comments on commit d65187e

Please sign in to comment.