Skip to content

Commit

Permalink
fix usages of returned losses after adding attention-decoder in zipfo…
Browse files Browse the repository at this point in the history
…rmer
  • Loading branch information
yaozengwei committed Jul 12, 2024
1 parent f6febd6 commit c58cf28
Show file tree
Hide file tree
Showing 20 changed files with 42 additions and 62 deletions.
3 changes: 2 additions & 1 deletion egs/aishell/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,14 +758,15 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss = losses[:2]

s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
Expand Down
3 changes: 2 additions & 1 deletion egs/aishell/ASR/zipformer/train_bbpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,15 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss = losses[:2]

s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
Expand Down
3 changes: 2 additions & 1 deletion egs/commonvoice/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,14 +814,15 @@ def compute_loss(
y = k2.RaggedTensor(y)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]

loss = 0.0

Expand Down
3 changes: 2 additions & 1 deletion egs/commonvoice/ASR/zipformer/train_char.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,14 +449,15 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]

loss = 0.0

Expand Down
3 changes: 2 additions & 1 deletion egs/gigaspeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,14 +803,15 @@ def compute_loss(
y = k2.RaggedTensor(y)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]

loss = 0.0

Expand Down
3 changes: 2 additions & 1 deletion egs/gigaspeech/KWS/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,14 +806,15 @@ def compute_loss(
y = k2.RaggedTensor(y)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]

loss = 0.0

Expand Down
3 changes: 2 additions & 1 deletion egs/ksponspeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,14 +787,15 @@ def compute_loss(
y = k2.RaggedTensor(y)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]

loss = 0.0

Expand Down
4 changes: 2 additions & 2 deletions egs/libriheavy/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
import argparse
import copy
import logging
import random
import warnings
from pathlib import Path
from shutil import copyfile
Expand Down Expand Up @@ -804,14 +803,15 @@ def compute_loss(
y = k2.RaggedTensor(y)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]

loss = 0.0

Expand Down
3 changes: 2 additions & 1 deletion egs/librispeech/ASR/zipformer/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,14 +893,15 @@ def compute_loss(
y = k2.RaggedTensor(y)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]

loss = 0.0

Expand Down
3 changes: 2 additions & 1 deletion egs/librispeech/ASR/zipformer_adapter/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,14 +890,15 @@ def compute_loss(
y = k2.RaggedTensor(y)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]

loss = 0.0

Expand Down
3 changes: 2 additions & 1 deletion egs/librispeech/ASR/zipformer_lora/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,14 +903,15 @@ def compute_loss(
y = k2.RaggedTensor(y)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]

loss = 0.0

Expand Down
3 changes: 2 additions & 1 deletion egs/librispeech/ASR/zipformer_lora/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,14 +792,15 @@ def compute_loss(
y = k2.RaggedTensor(y)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]

loss = 0.0

Expand Down
3 changes: 2 additions & 1 deletion egs/mdcc/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,14 +754,15 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss = losses[:2]

s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
Expand Down
3 changes: 2 additions & 1 deletion egs/multi_zh-hans/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,14 +832,15 @@ def compute_loss(
y = k2.RaggedTensor(y)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]

loss = 0.0

Expand Down
3 changes: 2 additions & 1 deletion egs/multi_zh_en/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,14 +814,15 @@ def compute_loss(
y = k2.RaggedTensor(y)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]

loss = 0.0

Expand Down
4 changes: 2 additions & 2 deletions egs/reazonspeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@

import k2
import optim
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
Expand Down Expand Up @@ -791,14 +790,15 @@ def compute_loss(
y = k2.RaggedTensor(y)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]

loss = 0.0

Expand Down
4 changes: 2 additions & 2 deletions egs/spgispeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
from asr_datamodule import SPGISpeechAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import AsrModel
Expand Down Expand Up @@ -792,14 +791,15 @@ def compute_loss(
y = k2.RaggedTensor(y)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]

loss = 0.0

Expand Down
3 changes: 2 additions & 1 deletion egs/wenetspeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,14 +758,15 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss = losses[:2]

s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
Expand Down
Loading

0 comments on commit c58cf28

Please sign in to comment.