Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: nithinraok <[email protected]>
  • Loading branch information
nithinraok authored and Nithin Rao Koluguri committed Oct 25, 2024
1 parent dbc4b66 commit 3908b79
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def transcribe(
)

decoding_cfg = self.cfg.aux_ctc.decoding if self.cur_decoder == "ctc" else self.cfg.decoding

logging.info(
"Timestamps requested, setting decoding timestamps to True. Capture them in Hypothesis object, with output[idx].timestep['word'/'segment'/'char']"
)
Expand Down
8 changes: 6 additions & 2 deletions tests/collections/asr/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

from dataclasses import dataclass
from typing import Optional, Type
from nemo.collections.asr.models import ASRModel

import numpy as np
import pytest
import torch

from nemo.collections.asr.models import ASRModel


class RNNTTestHelper:
@staticmethod
Expand Down Expand Up @@ -355,14 +356,17 @@ def rnnt_test_helper() -> Type[RNNTTestHelper]:
def rnn_loss_sample_data() -> Type[RnntLossSampleData]:
return RnntLossSampleData


@pytest.fixture(scope='session')
def fast_conformer_transducer_model():
return ASRModel.from_pretrained("stt_en_fastconformer_transducer_large")


@pytest.fixture(scope='session')
def fast_conformer_ctc_model():
return ASRModel.from_pretrained("stt_en_fastconformer_ctc_large")


@pytest.fixture(scope='session')
def fast_conformer_hybrid_model():
return ASRModel.from_pretrained("parakeet-tdt_ctc-110m")
return ASRModel.from_pretrained("parakeet-tdt_ctc-110m")
32 changes: 15 additions & 17 deletions tests/collections/asr/mixins/test_transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def audio_files(test_data_dir):
Returns a list of audio files for testing.
"""
import soundfile as sf

audio_file1 = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav")
audio_file2 = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an104-mrcb-b.wav")

Expand Down Expand Up @@ -310,6 +311,7 @@ class OverrideConfig(TranscribeConfig):
assert outputs[0][2] == 3.0

pytest.mark.with_downloads()

@pytest.mark.unit
def test_transcribe_return_hypothesis(self, test_data_dir, fast_conformer_ctc_model):
audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav")
Expand All @@ -327,7 +329,7 @@ def test_transcribe_return_hypothesis(self, test_data_dir, fast_conformer_ctc_mo
@pytest.mark.with_downloads()
@pytest.mark.unit
def test_transcribe_tensor(self, audio_files, fast_conformer_ctc_model):

audio, _ = audio_files
# Numpy array test
outputs = fast_conformer_ctc_model.transcribe(audio, batch_size=1)
Expand All @@ -353,7 +355,7 @@ def test_transcribe_multiple_tensor(self, audio_files, fast_conformer_ctc_model)
def test_transcribe_dataloader(self, audio_files, fast_conformer_ctc_model):

audio, audio2 = audio_files

dataset = DummyDataset([audio, audio2])
collate_fn = lambda x: _speech_collate_fn(x, pad_id=0)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0, collate_fn=collate_fn)
Expand All @@ -369,43 +371,39 @@ def test_transcribe_dataloader(self, audio_files, fast_conformer_ctc_model):
def test_timestamps_with_transcribe(self, audio_files, fast_conformer_ctc_model):
audio1, audio2 = audio_files

output = fast_conformer_ctc_model.transcribe([audio1,audio2], timestamps=True)
output = fast_conformer_ctc_model.transcribe([audio1, audio2], timestamps=True)

# check len of output
assert len(output)==2
# check len of output
assert len(output) == 2

# check hypothesis object
# check hypothesis object
assert isinstance(output[0], Hypothesis)
# check transcript
# check transcript
assert output[0].text == 'stop'
assert output[1].text == 'start'

# check timestamp
assert output[0].timestep['segment'][0]['start'] == pytest.approx(0.4)
assert output[0].timestep['segment'][0]['end'] == pytest.approx(0.48)


@pytest.mark.with_downloads()
@pytest.mark.unit
def test_timestamps_with_transcribe_hybrid(self, audio_files, fast_conformer_hybrid_model):
audio1, audio2 = audio_files

output = fast_conformer_hybrid_model.transcribe([audio1,audio2], timestamps=True)
output = fast_conformer_hybrid_model.transcribe([audio1, audio2], timestamps=True)

# check len of output
assert len(output)==2
# check len of output
assert len(output) == 2

output = output[1] # Transducer returns tuple
output = output[1] # Transducer returns tuple

# check hypothesis object
# check hypothesis object
assert isinstance(output[0], Hypothesis)
# check transcript
# check transcript
assert output[0].text == 'Stop?'
assert output[1].text == 'Start.'

# check timestamp
assert output[0].timestep['segment'][0]['start'] == pytest.approx(0.48)
assert output[0].timestep['segment'][0]['end'] == pytest.approx(0.72)



0 comments on commit 3908b79

Please sign in to comment.