Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add chunk_length parameter to Whisper #1909

Closed
wants to merge 6 commits into from

Conversation

MahmoudAshraf97
Copy link
Contributor

@MahmoudAshraf97 MahmoudAshraf97 commented Jul 7, 2024

distil-whisper models perform the best with chunk sizes less than 30s that the original whisper models use, this PR introduces the option to build the engine with a different chunk length

Summary of the changes in this PR:

  • Whisper encoder now supports changing chunk_size
  • Example has been updated to support remove_input_padding in the decoder
  • conv1d now supports input with more than 1 dynamic shape
  • Whisper decoder should now support inflight batching when built with paged_kv_cache using the executor, although there is no clear way to feed the encoder input and the prompt to the tensorrt_llm.bindings.Request class as it only accepts list of tokens in all inputs, and the encoder output is a float tensor

enabling remove_input_padding in the encoder wasn't as easy as I thought, all of my trials failed at the step where the positional embeddings are added to the conv output. chunk size is not defined at build time, this didn't work because the positional embeddings tensor first dim is 1500 which corresponds to 30s inputs. When the chunk_size is known at build time it's easy to slice the positional embeddings tensor to the correct size and add it to the conv output, but when the chunk size is unknown, the build fails at fetching the correct indices, for example:

import tensorrt_llm.functional as F

positional_embeddings = F.gather(
    positional_embedding,
    dim=0,
    indices=F.concat(
        [F.arange(0, input_length, "int32") for input_length in input_lengths.unbind()]
    ),
)

## only for padded input
positional_embeddings = F.view(positional_embeddings,[-1, chunk_size, hidden_size])
##
x = x + positional_embeddings

input_lengths.unbind() fails because input_lengths shape is [-1]

removing input padding from the encoder isn't that much important TBH as we expect encoder inputs to be of the same shape and size except for the last window in an audio, it will be beneficial in scenarios where we expect the requests to be multiple audio files which all of them are less than 30s and vary a lot in length

on the other side, remove_input_padding is important on the decoder side because it's required to enable inflight batching, from a quick trial on a 30 min audio file, the larger the batch size, the slower the generation

efficiency = generation loops needed / actual generation loops (calculated by the longest seq in the output)
# time is for decoding only, the whole 30 mins
# efficiency:  1.0 / batch_size=1
# 626 ms ± 17.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# efficiency:  0.87 / batch_size=2
# 716 ms ± 50.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# efficiency:  0.80 / batch_size=4
# 755 ms ± 31.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# efficiency:  0.67 / batch_size=8
# 1.05 s ± 21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# efficiency:  0.56 / batch_size=16
# 1.33 s ± 38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# efficiency:  0.45 / batch_size=32
# 1.85 s ± 34.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

as we notice, the time taken increases with batch size which is counter productive for large workloads, hence the need for inflight batching

@yuekaizhang
Copy link

@MahmoudAshraf97 Hi, thanks for your effort. I would take this PR into our internal gitlab. Also, we would add your name into the co-author list and credit your work on the release notes for whisper IFB feature.

@yuekaizhang
Copy link

@MahmoudAshraf97 Hi, I just tried the more than 1 dynamic shape conv1d solution by setting codes below:

    x = Tensor(name="x",
               dtype=self._dtype,
               shape=[-1, self.config.n_mels, -1],
               dim_range=OrderedDict([
                   ("batch_size", [bs_range]),
                   ("feature_dim", [self.config.n_mels]),
                   ("feature_len_range", [1, 1000, 3000]),
               ]))

However, the build process failed. Seems the slice operator would need to know the value of x.shape[1].

I was wondering why you set fixed config.chunk_length here rather than let it be dynamic.

@MahmoudAshraf97
Copy link
Contributor Author

@MahmoudAshraf97 Hi, I just tried the more than 1 dynamic shape conv1d solution by setting codes below:

    x = Tensor(name="x",
               dtype=self._dtype,
               shape=[-1, self.config.n_mels, -1],
               dim_range=OrderedDict([
                   ("batch_size", [bs_range]),
                   ("feature_dim", [self.config.n_mels]),
                   ("feature_len_range", [1, 1000, 3000]),
               ]))

However, the build process failed. Seems the slice operator would need to know the value of x.shape[1].

I was wondering why you set fixed config.chunk_length here rather than let it be dynamic.

as I mentioned in my trials in the PR, this was a step to make it work but I couldn't complete it because of the slice operator or other operators that aim to add the positional embeddings to x. before this change the build failed at the first conv layer, now it passes the conv layers and fails at a later stage, so I guess we are half way there

@yuekaizhang
Copy link

as I mentioned in my trials in the PR, this was a step to make it work but I couldn't complete it because of the slice operator or other operators that aim to add the positional embeddings to x. before this change the build failed at the first conv layer, now it passes the conv layers and fails at a later stage, so I guess we are half way there

@MahmoudAshraf97 I see. Thanks. Btw, the remove_input_padding for decoder issue has been fixed. The code would sync to github one week later.

cross_attention_mask = torch.ones(
[encoder_outputs.shape[0], 1,
encoder_outputs.shape[1]]).int().cuda()
cross_attention_mask = (
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @MahmoudAshraf97 if I understand correctly, you are making this change because distil-whisper can work on dynamic chunk sizes, unlike whisper which must use fixed 30 second chunks. Am I understanding correctly? Thank you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @galv , this PR contains 2 main changes

  1. Encoder is no longer restricted to 30s inputs, this helps in case of distil-whisper as you mentioned
  2. Decoder now supports remove_input_padding and accepts packed inputs to save memory

@kaiyux kaiyux mentioned this pull request Jul 23, 2024
@MahmoudAshraf97
Copy link
Contributor Author

closing this since it was merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants