Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add numpy alternative to FE using torchaudio #26339

Merged
merged 14 commits into from
Nov 8, 2023

Conversation

ylacombe
Copy link
Contributor

What does this PR do?

Following on from #26182, which ported torchaudio.compliance.kaldi.fbank to numpy in audio_utils, this PR aims to enable the use of numpy porting in previous Feature Extractors (AST and SpeechToText) that used torchaudio. It was discussed here.

This serves two purposes:

  1. to give some examples of how to use audio_utils instead of torchaudio for future Feature Extractors
  2. the possibility of removing torchaudio altogether in the future.

A next step would be to port audio_utils to torch, which might be faster (cc @sanchit-gandhi), but this is still open to discussion. Is this really relevant? And will it be really faster?

cc @ArthurZucker and @sanchit-gandhi

Comment on lines -97 to -106
fbank = ta_kaldi.fbank(
waveform,
htk_compat=True,
sample_frequency=self.sampling_rate,
use_energy=False,
window_type="hanning",
num_mel_bins=self.num_mel_bins,
dither=0.0,
frame_shift=10,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also took the opportunity to remove some unnecessary parameters here

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 22, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally think it's better to remove the torchaudio dependency entirely and align these two outliers with the rest of the numpy-based audio feature extractors! Especially since we'll probably support a torch version in audio_utils in an upcoming PR, so the speed diff will be recovered.

dither=0.0,
frame_shift=10,
)
if self.use_torchaudio:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a need to have the use_torchaudio argument. IMO we can execute the torchaudio code if_torchaudio_is_available (thus maintaining backwards comp), and the NumPy code otherwise

if_torchaudio_is_available():
    # do legacy code
else:
    # do numpy code

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also fine with removing the legacy torchaudio code altogether. I know this makes the feature extraction quite a bit slower, but I think this is fine to remove the extra dependencies to bring these models in-line with the rest of the audio library.

Personally, I would favour this approach over supporting both methods for feature extraction (torchaudio and numpy). IMO having both methods convolutes the code quite a lot, which is something we want to avoid.

Copy link
Collaborator

@ArthurZucker ArthurZucker Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with me to remove the previous code, it won’t be performance wise backward compatible 🫠

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go with the first option here then? Decorate with if_torchaudio_is_available?

@@ -198,3 +235,16 @@ def __call__(
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)

return padded_inputs

def to_dict(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this method strictly necessary? If it is, shouldn't it go in the base FeatureExtractionMixin class? Rather than copying it out for every feature extractor?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which method do you mean? to_dict?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's about time we add this to the base feature class!
(it's necessary if we support the numpy part)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, aligned with @sanchit-gandhi on not adding the np support

@@ -198,3 +235,16 @@ def __call__(
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)

return padded_inputs

def to_dict(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's about time we add this to the base feature class!
(it's necessary if we support the numpy part)

dither=0.0,
frame_shift=10,
)
if self.use_torchaudio:
Copy link
Collaborator

@ArthurZucker ArthurZucker Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with me to remove the previous code, it won’t be performance wise backward compatible 🫠

@ylacombe
Copy link
Contributor Author

Hey @ArthurZucker and @sanchit-gandhi, thanks for the review!

However, I'm not sure about what you meant here:

Looks good to me, aligned with @sanchit-gandhi on not adding the np support

And here:

I'm fine with adding a comment somewhere or a section in the doc to not lose the info on how to use numpy to get the same results as torchaudio for futur references when we'll improve or numpy port!

@sanchit-gandhi seems to be in favor of removing torchaudio support to only focus on the numpy port here, whereas @ArthurZucker seems to be in favor on not adding the numpy support.

Maybe I misunderstood the comments here! Thanks for your help!

@ArthurZucker
Copy link
Collaborator

Sorry I was confused! I agree that we should remove the old code, but worried about the performance issue, since we had to re introduce torch STFT for Whisper for example. (Performance wise backward compatible)

@ylacombe
Copy link
Contributor Author

I've made a quick benchmark, on AST, with results here:
image

Basically, torchaudio is at least 19 faster than the numpy porting. If I haven't made any mistake in my benchmark, I'll be strongly in favor of keeping torchaudio compatibility.

WDYT @ArthurZucker and @sanchit-gandhi ? Can you also take a quick look at the benchmark code to make sure that my results are correct (or redirect me to an expert at HF haha) ?

For reference, here is the benchmark code:

from datasets import load_dataset
import pytest
from transformers import ASTFeatureExtractor

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
speech_samples = ds.sort("id").select(range(64))[:64]["audio"]
speech_samples = [x["array"] for x in speech_samples]


def torchaudio_unbatch():
    fe = ASTFeatureExtractor(use_torchaudio=True)
    
    for sample in speech_samples:
        input_features = fe(sample, padding=True, return_tensors="pt")

def np_unbatch():
    fe = ASTFeatureExtractor(use_torchaudio=False)
    
    for sample in speech_samples:
        input_features = fe(sample, padding=True, return_tensors="pt")

def torchaudio_batch_8():
    fe = ASTFeatureExtractor(use_torchaudio=True)
    
    for i in range(0,len(speech_samples),8):
        samples = speech_samples[i:i+8]
        input_features = fe(samples, padding=True, return_tensors="pt")

def np_batch_8():
    fe = ASTFeatureExtractor(use_torchaudio=False)
    
    for i in range(0,len(speech_samples),8):
        samples = speech_samples[i:i+8]
        input_features = fe(samples, padding=True, return_tensors="pt")

@pytest.mark.benchmark(
    min_rounds=5, disable_gc=True, warmup=False
)
def test_torchaudio_unbatch(benchmark):
    benchmark(torchaudio_unbatch)

@pytest.mark.benchmark(
    min_rounds=5, disable_gc=True, warmup=False
)
def test_torchaudio_batch_8(benchmark):
    benchmark(torchaudio_batch_8)


@pytest.mark.benchmark(
    min_rounds=5, disable_gc=True, warmup=False
)
def test_np_unbatch(benchmark):
    benchmark(np_unbatch)

@pytest.mark.benchmark(
    min_rounds=5, disable_gc=True, warmup=False
)
def test_np_batch_8(benchmark):
    benchmark(np_batch_8)

@ylacombe
Copy link
Contributor Author

For future reference, here is the same benchmark with Speech2TextFeatureExtractor:
Previous conclusions still hold:
image

@ylacombe
Copy link
Contributor Author

It's also possible that we can optimize our audio_utils.py, WDYT?

@sanchit-gandhi
Copy link
Contributor

Alright that's quite a significant difference - this probably requires overhauling the audio_utils file as you've suggested (use torch/torchaudio if available, or see where our numpy implementation is bottlenecked and try to improve it here).

@ylacombe ylacombe mentioned this pull request Sep 28, 2023
7 tasks
Copy link
Contributor Author

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ArthurZucker and @sanchit-gandhi, thanks for your help here.

To sum it up, I removed torchaudio dependency for both FE, but those FE still use it if it installed to ensure speed.
I've also simulated torchaudio absence to make sure everything is in order.

I'm requesting your reviews again!

Comment on lines 575 to 588
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.

Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance.
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
output["feature_extractor_type"] = self.__class__.__name__

if "mel_filters" in output:
del output["mel_filters"]
if "window" in output:
del output["window"]
return output

@classmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As requested I've modified to_dict directly in feature_extraction_utils.py

Comment on lines +204 to +207
@unittest.mock.patch(
"transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer.is_speech_available",
lambda: False,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how I simulate the absence of torchaudio in the test suite

Comment on lines +214 to +219
def test_using_audio_utils(self):
# Tests that it uses audio_utils instead of torchaudio
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())

self.assertTrue(hasattr(feat_extract, "window"))
self.assertTrue(hasattr(feat_extract, "mel_filters"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also ensuring that we use the audio_utils package and torchaudio is indeed not used in this class

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This on paper LGTM and nice job on getting the tests working. My only thought is that maybe we should overhaul audio_utils.py with these changes, rather than do the if/else in the feature extraction code? This way, all the is_xxx_available logic stays in audio_utils (which is fine if it gets complex, since most people won't interact with it), and the feature extraction code can stay simple

Open do either refactoring this PR to make this change, or merging this and doing it in a follow-up (along with #26119)

Comment on lines +29 to +32
if is_speech_available():
import torchaudio.compliance.kaldi as ta_kaldi

if is_torch_available():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In speech-to-text we bundle these imports into one:

if is_speech_available():
    import torchaudio.compliance.kaldi as ta_kaldi
    import torch

Should we do the same here since we can only use torch if torchaudio is available?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, torch is also used here even when torchaudio isn't used. I can maybe refactor the code to change that, but I'm not sure it's worth the time, WDYT ?

@ylacombe
Copy link
Contributor Author

ylacombe commented Oct 2, 2023

Hey @sanchit-gandhi, thanks for the review here!

My only thought is that maybe we should overhaul audio_utils.py with these changes, rather than do the if/else in the feature extraction code?

We'd have to create a fbank method to audio_utils which would create mel_filters and window on-the-fly in that case right ? (with hindsight, it doesn't matter much since creating mel_filters and `window isn't the bottleneck here)

In any case, I'd rather refactor that in another PR, which would maybe add the torch correspondence for every possible case in audio_utils

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Looks good to me.
Aligned with @sanchit-gandhi on profiling our numpy code to see what's our huge bottlneck sometime soon!

"transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer.is_speech_available",
lambda: False,
)
class ASTFeatureExtractionWithoutTorchaudioTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice. I think you can either add # Copied from on top of some tests, or (not sure if it's possible) find a way to use parametrized to mock the absence of the package as a parameter to avoid code duplications. Not very important # Copied from will be great

Copy link
Contributor Author

@ylacombe ylacombe Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ydshieh, I'd love to have your take on how to best manage this.

In a few words, here I try to simulate that a library is missing in ASTFeatureExtractionTest. The thing is that now, I had to create another class: ASTFeatureExtractionWithoutTorchaudioTest, which is a copy of the previous one with a unittest.mock.patch decorator to simulate the library absence.

I've looked over the internet to avoid test duplication, but without success. Do you have any take on how to parametrize the library absence ?

Thanks for your help!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @ylacombe ! There is something like unittest.mock.Mock() but I never used it (yet) myself.

Search in tests/models/t5/test_modeling_t5.py

    def test_fp16_fp32_conversion(self):
        r"""
        A test to check whether the argument `keep_in_fp32_modules` correctly does its job
        """
        orig_import = __import__
        accelerate_mock = unittest.mock.Mock()

        # mock import of accelerate
        def import_accelerate_mock(name, *args, **kwargs):
            if name == "accelerate":
                if accelerate_available:
                    return accelerate_mock
                else:
                    raise ImportError
            return orig_import(name, *args, **kwargs)

and let me know how you feel. I can take a look too (good to learn anyway)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ydshieh, thanks for the quick response and for this example! I didn't know about it! I'm not sure this is the right fit for this purpose though

The idea is really to run ASTFeatureExtractionTest twice, one without context and the other with the missing library context!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I missed the above comment. So here is 2 classes instead of 2 test methods.

Is the code in the new ASTFeatureExtractionWithoutTorchaudioTest be identical to the original ASTFeatureExtractionTest? If so, maybe try make it a subclass of ASTFeatureExtractionTest but decorated with unittest.mock.patch or something similar?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this!

Making a subclass works for me, I'll try it

@@ -104,7 +103,213 @@ def _flatten(list_of_lists):
@require_torch
@require_torchaudio
class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing some copied from as well here!

Copy link

github-actions bot commented Nov 5, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Comment on lines +308 to +319
# exact same tests than before, except that we simulate that torchaudio is not available
@require_torch
@unittest.mock.patch(
"transformers.models.speech_to_text.feature_extraction_speech_to_text.is_speech_available", lambda: False
)
class Speech2TextFeatureExtractionWithoutTorchaudioTest(Speech2TextFeatureExtractionTest):
def test_using_audio_utils(self):
# Tests that it uses audio_utils instead of torchaudio
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())

self.assertTrue(hasattr(feat_extract, "window"))
self.assertTrue(hasattr(feat_extract, "mel_filters"))
Copy link
Contributor Author

@ylacombe ylacombe Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to make sure that this inheritance works with you @ArthurZucker and/or @amyeroberts, following @ydshieh suggestion!

As soon as I have approval, I'll merge!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to assert that is_speech_available is False so we are sure the patch works 😄

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(sometimes we have surprise ...)

Copy link
Contributor Author

@ylacombe ylacombe Nov 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.mel_fitlers and self.window are not defined unless is_speech_available=False but it's best to be on the safe side

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks alright with me!

@ylacombe ylacombe merged commit be74b2e into huggingface:main Nov 8, 2023
21 checks passed
@ylacombe ylacombe deleted the torchaudio-alternative branch November 8, 2023 07:39
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* add audio_utils usage in the FE of SpeechToText

* clean unecessary parameters of AudioSpectrogramTransformer FE

* add audio_utils usage in AST

* add serialization tests and function to FEs

* make style

* remove use_torchaudio and move to_dict to FE

* test audio_utils usage

* make style and fix import (remove torchaudio dependency import)

* fix torch dependency for jax and tensor tests

* fix typo

* clean tests with suggestions

* add lines to test if is_speech_availble is False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants