diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 01bf51b0e2c63..5533b50922f8f 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import random import warnings from dataclasses import dataclass from functools import partial @@ -319,6 +320,7 @@ def get_lhotse_dataloader_from_config( ReverbWithImpulseResponse( rir_recordings=RecordingSet.from_file(config.rir_path) if config.rir_path is not None else None, p=config.rir_prob, + randgen=random.Random(seed), ) ) diff --git a/tests/collections/common/test_lhotse_dataloading.py b/tests/collections/common/test_lhotse_dataloading.py index 111c00df392ac..31a8d332814e2 100644 --- a/tests/collections/common/test_lhotse_dataloading.py +++ b/tests/collections/common/test_lhotse_dataloading.py @@ -32,10 +32,6 @@ from nemo.collections.common.data.lhotse.text_adapters import TextExample from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model -requires_torchaudio = pytest.mark.skipif( - not lhotse.utils.is_torchaudio_available(), reason="Lhotse Shar format support requires torchaudio." -) - @pytest.fixture(scope="session") def cutset_path(tmp_path_factory) -> Path: @@ -348,7 +344,6 @@ def test_dataloader_from_lhotse_cuts_channel_selector(mc_cutset_path: Path): assert torch.equal(b_cs["audio"], batches[n]["audio"][:, channel_selector, :]) -@requires_torchaudio def test_dataloader_from_lhotse_shar_cuts(cutset_shar_path: Path): config = OmegaConf.create( { @@ -682,7 +677,6 @@ def test_dataloader_from_tarred_nemo_manifest_concat(nemo_tarred_manifest_path: torch.testing.assert_close(b["audio_lens"], expected_audio_lens) -@requires_torchaudio def test_dataloader_from_lhotse_shar_cuts_combine_datasets_unweighted( cutset_shar_path: Path, cutset_shar_path_other: Path ): @@ -723,19 +717,18 @@ def test_dataloader_from_lhotse_shar_cuts_combine_datasets_unweighted( assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 2 # dataset 2 b = batches[1] - assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 2 # dataset 1 - assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 1 # dataset 2 + assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 0 # dataset 1 + assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 3 # dataset 2 b = batches[2] - assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 1 # dataset 1 - assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 2 # dataset 2 + assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 2 # dataset 1 + assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 1 # dataset 2 b = batches[3] assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 1 # dataset 1 assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 2 # dataset 2 -@requires_torchaudio def test_dataloader_from_lhotse_shar_cuts_combine_datasets_weighted( cutset_shar_path: Path, cutset_shar_path_other: Path ): @@ -776,12 +769,12 @@ def test_dataloader_from_lhotse_shar_cuts_combine_datasets_weighted( assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 0 # dataset 2 b = batches[1] - assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 3 # dataset 1 - assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 0 # dataset 2 + assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 1 # dataset 1 + assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 2 # dataset 2 b = batches[2] - assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 3 # dataset 1 - assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 0 # dataset 2 + assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 2 # dataset 1 + assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 1 # dataset 2 b = batches[3] assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 3 # dataset 1 @@ -792,8 +785,8 @@ def test_dataloader_from_lhotse_shar_cuts_combine_datasets_weighted( assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 0 # dataset 2 b = batches[5] - assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 1 # dataset 1 - assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 2 # dataset 2 + assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 3 # dataset 1 + assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 0 # dataset 2 class TextDataset(torch.utils.data.Dataset):