Skip to content

Commit

Permalink
Update results using Zipformer-large on multi-hans-zh (k2-fsa#1679)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuekaizhang authored and Your Name committed Aug 9, 2024
1 parent 7dcb965 commit 9f38679
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 575 deletions.
55 changes: 55 additions & 0 deletions egs/multi_zh-hans/ASR/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,61 @@ Fine-tuned models, training logs, decoding logs, tensorboard and decoding result
are available at
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>

### Multi Chinese datasets char-based training results (streaming) on zipformer large model

#### Streaming (with CTC head)

The training command for large model (num of params : ~160M):

Please use the [script](https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/prepare.sh) to prepare fbank features.

```
./zipformer/train.py \
--world-size 8 \
--num-epochs 20 \
--use-fp16 1 \
--max-duration 1200 \
--num-workers 8 \
--use-ctc 1 \
--exp-dir zipformer/exp-large \
--causal 1 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 768,1024,1536,2048,1536,768 \
--encoder-dim 256,384,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192
```

The decoding command for transducer greedy search:

```
./zipformer/decode.py \
--epoch 999 \
--avg 1 \
--causal 1 \
--use-averaged-model False \
--chunk_size -1
--left-context-frames -1 \
--use-ctc 1 \
--exp-dir zipformer/exp-large \
--max-duration 1200 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 768,1024,1536,2048,1536,768 \
--encoder-dim 256,384,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192
```

Character Error Rates (CERs) listed below are produced by the checkpoint of the 18th epoch using BPE model ( # tokens is 2000, byte fallback enabled).

| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| CTC Greedy Streaming | 26.50 | 28.10| 1.71 | 1.97| 3.89| 4.06 | 17.23 | 3.69 | 2.87 | 8.14 | 3.61 |9.51 | 6.11 | 8.13 | 10.62 |
| CTC Greedy Offline | 23.47 | 25.02 | 1.39 | 1.50 | 3.15 | 3.41 | 15.14 | 3.07 | 2.37 | 6.06 | 2.90 | 7.13 | 5.40 | 6.52 | 9.64 |
| Transducer Greedy Offline | 23.16 | 24.78 | 1.33 | 1.38 | 3.06 | 3.23 | 15.36 | 2.54 | 2.09 | 5.24 | 2.28 | 6.26 | 4.87 | 6.26 | 7.07 |
| Transducer Greedy Streaming | 26.83|28.74 | 1.75 | 1.91 | 3.84 | 4.12 | 17.83 | 3.23 | 2.71 | 7.31 | 3.16 | 8.69 | 5.71 | 7.91 | 8.54 |

Pre-trained model can be found here : https://huggingface.co/yuekai/icefall-asr-multi-zh-hans-zipformer-large

### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model

Expand Down
247 changes: 0 additions & 247 deletions egs/multi_zh-hans/ASR/whisper/multi_dataset.py

This file was deleted.

1 change: 1 addition & 0 deletions egs/multi_zh-hans/ASR/whisper/multi_dataset.py
13 changes: 5 additions & 8 deletions egs/multi_zh-hans/ASR/zipformer/ctc_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from asr_datamodule import AsrDataModule
from lhotse.cut import Cut
from multi_dataset import MultiDataset
from train import add_model_arguments, get_model, get_params
from train import add_model_arguments, get_model, get_params, normalize_text_alimeeting

from icefall.checkpoint import (
average_checkpoints,
Expand Down Expand Up @@ -367,21 +367,18 @@ def decode_dataset(
hyps_dict = decode_one_batch(
params=params,
model=model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch,
word_table=word_table,
G=G,
)

for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = list(ref_text.replace(" ", ""))
hyp_words = list("".join(hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
ref_text = normalize_text_alimeeting(ref_text)
hyp_text = "".join(hyp_words)
this_batch.append((cut_id, ref_text, hyp_text))

results[name].extend(this_batch)

Expand Down Expand Up @@ -583,7 +580,7 @@ def main():
data_module = AsrDataModule(args)
multi_dataset = MultiDataset(args.manifest_dir)

test_sets_cuts = multi_dataset.test_cuts()
test_sets_cuts = {**multi_dataset.test_cuts(), **multi_dataset.speechio_test_cuts()}

def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2
Expand Down
6 changes: 3 additions & 3 deletions egs/multi_zh-hans/ASR/zipformer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@
)
from lhotse.cut import Cut
from multi_dataset import MultiDataset
from train import add_model_arguments, get_model, get_params
from train import add_model_arguments, get_model, get_params, normalize_text_alimeeting

from icefall.checkpoint import (
average_checkpoints,
Expand Down Expand Up @@ -532,7 +532,6 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]

hyps_dict = decode_one_batch(
Expand All @@ -548,6 +547,7 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text)
hyp_text = "".join(hyp_words)
this_batch.append((cut_id, ref_text, hyp_text))

Expand Down Expand Up @@ -795,7 +795,7 @@ def remove_short_utt(c: Cut):
)
return T > 0

test_sets_cuts = multi_dataset.test_cuts()
test_sets_cuts = {**multi_dataset.test_cuts(), **multi_dataset.speechio_test_cuts()}

test_sets = test_sets_cuts.keys()
test_dl = [
Expand Down
Loading

0 comments on commit 9f38679

Please sign in to comment.