Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static cache #89

Merged
merged 69 commits into from
Aug 7, 2024
Merged

Add static cache #89

merged 69 commits into from
Aug 7, 2024

Conversation

eustlb
Copy link
Contributor

@eustlb eustlb commented Jul 22, 2024

This PR enables compilation of the forward path of Parler-TTS.

Dynamic caching of keys and values during the auto-regressive decoding makes tensors in past_key_values of changing shape and stride, causing recompilation at each pass of the forward method. In this PR, we implement a keys and values static cache, enabling a single compilation of the model when generating. Work done here is inspired from this PR already done on Whisper.

Env notes

Build your environment as such:

conda create -n parler-tts python=3.11 --yes
pip install git+https://github.com/eustlb/parler-tts.git@add-static-cached

Usage

import soundfile as sf
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import torch
import torch._dynamo.config
import torch._inductor.config

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True

# debugging
torch._logging.set_logs(graph_breaks=True, recompiles=True)

# reproducibility
torch.manual_seed(0)

# set-up device args
torch_device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# attn_implementation = "eager"
attn_implementation = "sdpa"

class Timer:
    def __init__(self, name):
        self.name = name
        self.start_event = torch.cuda.Event(enable_timing=True)
        self.end_event = torch.cuda.Event(enable_timing=True)

    def __enter__(self):
        torch.manual_seed(0)
        torch.cuda.synchronize()
        self.start_event.record()
        self.start = time.perf_counter()

    def __exit__(self, exc_type, exc_value, traceback):
        self.end_event.record()
        torch.cuda.synchronize()
        elapsed_time = self.start_event.elapsed_time(self.end_event) * 1.0e-3

        print('Execution time:', elapsed_time, 'seconds')

# model
model_name = "parler-tts/parler_tts_mini_v0.1"
model = ParlerTTSForConditionalGeneration.from_pretrained(
    model_name,
    attn_implementation=attn_implementation
).to(torch_device, dtype=torch_dtype)
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

# tokenizers
padding_side = "left"
description_tokenizer = AutoTokenizer.from_pretrained(model_name) 
prompt_tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=padding_side)

def tokenize_inputs(description, prompt):
    tokenized_description = description_tokenizer(description, return_tensors="pt", padding='max_length', max_length=50)
    input_ids = tokenized_description.input_ids.to(torch_device)
    attention_mask = tokenized_description.attention_mask.to(torch_device)

    tokenized_prompt = prompt_tokenizer(prompt, return_tensors="pt", padding='max_length', max_length=50)
    prompt_input_ids = tokenized_prompt.input_ids.to(torch_device)
    prompt_attention_mask = tokenized_prompt.attention_mask.to(torch_device)

    return input_ids, prompt_input_ids, attention_mask, prompt_attention_mask 

# first generation
prompt = "Hey, how are you doing today?"
description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."
input_ids, prompt_input_ids, attention_mask, prompt_attention_mask = tokenize_inputs(description, prompt)

with Timer("First run"):
    generation = model.generate(
        input_ids=input_ids, 
        prompt_input_ids=prompt_input_ids,
        attention_mask=attention_mask,
        prompt_attention_mask=prompt_attention_mask
    )

audio_arr = generation.to(torch.float32).cpu().numpy().squeeze()
sf.write("./output_1.wav", audio_arr, model.config.sampling_rate)

# second generation, debugging parameters will show us if recompilation happens
prompt = "Hey, how are you doing?"
description = "A male speaker with a slightly low-pitched voice delivers his words quite expressively, in a very confined sounding environment with clear audio quality. He speaks very fast."
input_ids, prompt_input_ids, attention_mask, prompt_attention_mask = tokenize_inputs(description, prompt)

with Timer("Second run"):
    generation = model.generate(
        input_ids=input_ids, 
        prompt_input_ids=prompt_input_ids,
        attention_mask=attention_mask,
        prompt_attention_mask=prompt_attention_mask
    )

audio_arr = generation.to(torch.float32).cpu().numpy().squeeze()
sf.write("./output_2.wav", audio_arr, model.config.sampling_rate)

# third generation, necessary to see speed improvements with compilation modes "reduce-overhead" and "max-autotune" because de CUDA graphs capture that occurs at second run
prompt = "Wassup"
description = "A male speaker with a slightly low-pitched voice delivers his words quite expressively." 
input_ids, prompt_input_ids, attention_mask, prompt_attention_mask = tokenize_inputs(description, prompt)

with Timer("Third run"):
    generation = model.generate(
        input_ids=input_ids, 
        prompt_input_ids=prompt_input_ids,
        attention_mask=attention_mask,
        prompt_attention_mask=prompt_attention_mask
    )

audio_arr = generation.to(torch.float32).cpu().numpy().squeeze()
sf.write("./output_3.wav", audio_arr, model.config.sampling_rate)

Benchmarks

Benchmarking code can be found here.
Reported results are best configuration (attention implementation, dtype) with compile vs. best without compile for generating 43 tokens (~ 0.5 sec of audio):

  • RTX 4090: 2.7x for mini architecture, 1.7x for large architecture (details here)
  • A100 80GB: 4.5x for mini architecture, 3.5x for large architecture (details here)
  • A100 40GB: 4.5x for mini architecture, 3.5x for large architecture (details here)

Tests

This PR has been tested for generation by comparing generation outputs for this branch and the one it was built on. Code for such tests can be found here.

sanchit-gandhi and others added 30 commits May 17, 2024 14:26
@sang-nguyen-ts
Copy link
Contributor

Hi @eustlb, Thanks for your great works, I'm trying to pre-produce the result but somehow the speed is very slow on A100 80Gb maybe due to graph recompile. Here is my full steps to pre-produce:

  1. Create a fresh env using conda:
conda create -n parler-static-cache python==3.10
  1. Clone, checkout and install parler-tts:
git clone https://github.com/eustlb/parler-tts.git && git checkout add-static-cache
  1. Upgrade to pytorch dev
pip install torchaudio torch==2.2.0.dev20240525+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121
  1. Add some timer stubs:
import soundfile as sf
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import time

import torch
import torch._dynamo.config
import torch._inductor.config


class Timer:
    def __init__(self, name):
        self.name = name
        self.start_event = torch.cuda.Event(enable_timing=True)
        self.end_event = torch.cuda.Event(enable_timing=True)

    def __enter__(self):
        torch.cuda.synchronize()
        self.start_event.record()
        self.start = time.perf_counter()

    def __exit__(self, exc_type, exc_value, traceback):
        self.end_event.record()
        torch.cuda.synchronize()
        elapsed_time = self.start_event.elapsed_time(self.end_event) * 1.0e-3

        print('Execution time:', elapsed_time, 'seconds')
...
with Timer("First run"):
    generation = model.generate(
        input_ids=input_ids,
        prompt_input_ids=prompt_input_ids,
        attention_mask=attention_mask,
        prompt_attention_mask=prompt_attention_mask,
    ).to(torch.float32)

...
with Timer("Second run"):
    generation = model.generate(
        input_ids=input_ids,
        prompt_input_ids=prompt_input_ids,
        attention_mask=attention_mask,
        prompt_attention_mask=prompt_attention_mask,
    ).to(torch.float32)
    audio_arr = generation.cpu().numpy().squeeze()
    sf.write("./output_2.wav", audio_arr, model.config.sampling_rate)
  1. Run the test code you provided:
python test.py

Logs:

python test.py
Flash attention 2 is not installed
/home/sangnguyen/.conda/envs/parler-3/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.
  WeightNorm.apply(module, name, dim)
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Using the model-agnostic default `max_length` (=2580) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour.
V0725 10:54:10.850172 1970753 torch/_dynamo/guards.py:2688] [0/1] [__recompiles] Recompiling function forward in /mnt/disk1/sangnguyen/pr/parler-static/parler_tts/modeling_parler_tts.py:2563
V0725 10:54:10.850172 1970753 torch/_dynamo/guards.py:2688] [0/1] [__recompiles]     triggered by the following guard failure(s):
V0725 10:54:10.850172 1970753 torch/_dynamo/guards.py:2688] [0/1] [__recompiles]     - 0/0: ___check_obj_id(L['past_key_values'].is_updated[0], 93874464073408)

Execution time: 63.699273437500004 seconds

Using the model-agnostic default `max_length` (=2580) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.

Execution time: 81.87266406250001 seconds

Can you help me to take a look if there is any problem in my env setup?

@eustlb
Copy link
Contributor Author

eustlb commented Jul 25, 2024

Hey @sang-nguyen-ts,

Thanks a lot for reporting this point. Looked into it and you might notice that your provided code gives expected run time results when setting compile_mode to "default". Nevertheless, compile_mode="reduce-overhead" or "max-autotune" use CUDA graphs that are recorded at second pass of the model (see here for more details) while torch compilation happens at first pass. That is why we are getting expected results with only two calls to generate using "default" mode and also why expected speed improvement will only occur at third call of generate for the two other modes.

Good point is that for benchmarking (code here), I luckily used three warmup steps without knowing this point, explaining why I did not observed this behavior before.

@Artyom17
Copy link

Artyom17 commented Jul 25, 2024

Wow, great job! I haven't looked into the code yet, but does it include the port to Transformers 4.42?

PS. Ah, it requires even newer Transformers, probably due to EncoderDecoderCache usage

@sang-nguyen-ts
Copy link
Contributor

Hey @sang-nguyen-ts,

Thanks a lot for reporting this point. Looked into it and you might notice that your provided code gives expected run time results when setting compile_mode to "default". Nevertheless, compile_mode="reduce-overhead" or "max-autotune" use CUDA graphs that are recorded at second pass of the model (see here for more details) while torch compilation happens at first pass. That is why we are getting expected results with only two calls to generate using "default" mode and also why expected speed improvement will only occur at third call of generate for the two other modes.

Good point is that for benchmarking (code here), I luckily used three warmup steps without knowing this point, explaining why I did not observed this behavior before.

I've check the benchmark code and I found it using the same prompt for every run, I tried a test case when prompt is difference over runs then it cause graph recompile:

prompt_1 = "A paragraph is defined as “a group of sentences or a single sentence that forms a unit” (Lunsford and Connors 116). Length and appearance do not determine whether a section in a paper is a paragraph."
prompt_2 = "Hey, how are you doing today?"
prompt_3 = "Hey, how are you doing mate?"
prompt_4 = "Hey, how are you doing my friend?"

logs:

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Using the model-agnostic default `max_length` (=2580) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour.
V0726 13:03:07.418674 3697053 torch/_dynamo/guards.py:2688] [0/1] [__recompiles] Recompiling function forward in /mnt/disk1/sangnguyen/pr/parler-static/parler_tts/modeling_parler_tts.py:2563
V0726 13:03:07.418674 3697053 torch/_dynamo/guards.py:2688] [0/1] [__recompiles]     triggered by the following guard failure(s):
V0726 13:03:07.418674 3697053 torch/_dynamo/guards.py:2688] [0/1] [__recompiles]     - 0/0: ___check_obj_id(L['past_key_values'].is_updated[0], 94694898878144)
Execution time: 80.577703125 seconds
Using the model-agnostic default `max_length` (=2580) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
V0726 13:03:48.552983 3697053 torch/_dynamo/guards.py:2688] [0/2] [__recompiles] Recompiling function forward in /mnt/disk1/sangnguyen/pr/parler-static/parler_tts/modeling_parler_tts.py:2563
V0726 13:03:48.552983 3697053 torch/_dynamo/guards.py:2688] [0/2] [__recompiles]     triggered by the following guard failure(s):
V0726 13:03:48.552983 3697053 torch/_dynamo/guards.py:2688] [0/2] [__recompiles]     - 0/1: ___check_obj_id(L['past_key_values'].is_updated[0], 94694898878176)
V0726 13:03:48.552983 3697053 torch/_dynamo/guards.py:2688] [0/2] [__recompiles]     - 0/0: tensor 'L['prompt_attention_mask']' stride mismatch at index 0. expected 56, actual 50
V0726 13:04:44.735695 3697053 torch/_dynamo/guards.py:2688] [0/3] [__recompiles] Recompiling function forward in /mnt/disk1/sangnguyen/pr/parler-static/parler_tts/modeling_parler_tts.py:2563
V0726 13:04:44.735695 3697053 torch/_dynamo/guards.py:2688] [0/3] [__recompiles]     triggered by the following guard failure(s):
V0726 13:04:44.735695 3697053 torch/_dynamo/guards.py:2688] [0/3] [__recompiles]     - 0/2: ___check_obj_id(L['past_key_values'].is_updated[0], 94694898878144)
V0726 13:04:44.735695 3697053 torch/_dynamo/guards.py:2688] [0/3] [__recompiles]     - 0/1: tensor 'L['prompt_attention_mask']' stride mismatch at index 0. expected 56, actual 50
V0726 13:04:44.735695 3697053 torch/_dynamo/guards.py:2688] [0/3] [__recompiles]     - 0/0: tensor 'L['prompt_attention_mask']' stride mismatch at index 0. expected 56, actual 50
Execution time: 88.9798984375 seconds
Using the model-agnostic default `max_length` (=2580) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
V0726 13:05:17.538536 3697053 torch/_dynamo/guards.py:2688] [0/4] [__recompiles] Recompiling function forward in /mnt/disk1/sangnguyen/pr/parler-static/parler_tts/modeling_parler_tts.py:2563
V0726 13:05:17.538536 3697053 torch/_dynamo/guards.py:2688] [0/4] [__recompiles]     triggered by the following guard failure(s):
V0726 13:05:17.538536 3697053 torch/_dynamo/guards.py:2688] [0/4] [__recompiles]     - 0/3: ___check_obj_id(L['past_key_values'].is_updated[0], 94694898878176)
V0726 13:05:17.538536 3697053 torch/_dynamo/guards.py:2688] [0/4] [__recompiles]     - 0/2: Ne(Mod(2636*L['prompt_hidden_states'].size()[1] + 2636, 8), 0)  # _dynamo/output_graph.py:452 in init_ambient_guards
V0726 13:05:17.538536 3697053 torch/_dynamo/guards.py:2688] [0/4] [__recompiles]     - 0/1: tensor 'L['prompt_attention_mask']' stride mismatch at index 0. expected 56, actual 55
V0726 13:05:17.538536 3697053 torch/_dynamo/guards.py:2688] [0/4] [__recompiles]     - 0/0: tensor 'L['prompt_attention_mask']' stride mismatch at index 0. expected 56, actual 55
Execution time: 64.24132421875 seconds
Using the model-agnostic default `max_length` (=2580) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
Execution time: 236.198390625 seconds

@eustlb
Copy link
Contributor Author

eustlb commented Jul 26, 2024

@sang-nguyen-ts compiling the forward pass does require every input to have a constant shape. That is why you'll find in the benchmarking code that we pad the tokenized description and prompt. The first prompt of your examples is 56 tokens long, explaining why the model is recompiled for the second prompt that is padded to 50 tokens. Consider either setting max_length parameter of the prompt tokenizer to a value >= 56, either setting truncation=True. This way you'll find that the compiled model is indeed reused without recompilation.

@ylacombe
Copy link
Collaborator

Thanks @eustlb, I'll review as soon as #65 is merged, hopefully Monday morning, let's try to merge this Monday afternoon ;)

@gante
Copy link
Member

gante commented Jul 27, 2024

Very cool use of the static caches 💛

Possibly useful information for follow-up work on this repo -- on transformers, we're working to have:

  1. compileable quantized cache (not sure if quantization is detrimental for tts)
  2. compileable cache with CPU offload (good for large batch sizes, loads into the GPU the KV cache for layer n+1 while computing layer n)

Copy link
Collaborator

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @eustlb, congrats for the great PR! LGTM!

I've left a few questions, to make sure everything's under control and to satisfy my curiosity.

A last question on my side is how can I use the model to get the previous behaviour (i.e without the Cache) ?

Also, note that I have yet to try if it's still compatible with training, will do right now

training/run_parler_tts_training.py Outdated Show resolved Hide resolved
@@ -244,7 +248,7 @@ def get_embedding(num_embeddings: int, embedding_dim: int):
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
bsz, seq_len, _ = input_ids.size()
# Create the position ids from the input token ids.
position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
position_ids = (torch.arange(seq_len, device=input_ids.device) + past_key_values_length).to(input_ids.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we cast the torch arange to input_ids.device here ? Is it because past_key_values_length is a tensor now ?
If so, can you update the type in the signature?

And also, if so, do we have to cast twice to the input_ids device ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch.
Issue here is that there is an inconsistency in past_key_values_length type: can be either 0, so an int, and a tensor (code taken from here). This inconsistency is well handled by torch when doing torch.arange(seq_len) + past_key_values_length, yet when past_key_values_length is a tensor, not on cpu device, we need to make sure that torch.arange creates tensor on the same device. I did not use past_key_values_length.device to avoid the case where past_key_values_length is int 0, and used this trick instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree that this is not very elegant, yet I don't see a more elegant way to do it. Maybe change the logic copied from whisper's code that introduces this inconsistency ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.arange(seq_len, device=input_ids.device) + past_key_values_length should be enough then, right ?

parler_tts/modeling_parler_tts.py Outdated Show resolved Hide resolved
setup.py Outdated Show resolved Hide resolved
attention_mask: Optional[torch.Tensor] = None,
cos: Optional[torch.LongTensor] = None,
sin: Optional[torch.LongTensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great refactoring!

Besides the Cache related changes, what motivated the changes in shape? Did you test speed-ups ? Or just modify to stich with Whisper's implementation?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And you previously told me about small difference between the (static KV cache + compile) vs the previous implementation, are you sure there not coming from the difference in shape you've diffused here ?

E.g you do attn_output = torch.matmul(attn_probs, value_states), with 4D tensors. We previously did attn_output = torch.bmm(attn_probs, value_states) with 3D tensors, which might results in small difference

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified to stich with Whisper's implementation that is allegedly faster (see whisper's static cache PR, yet I have not benchmarked it myself.
Concerning the change from torch.bmm to torch.matmul, the tests I've run showed exact same results for every dtype.

parler_tts/modeling_parler_tts.py Outdated Show resolved Hide resolved
parler_tts/modeling_parler_tts.py Outdated Show resolved Hide resolved
@@ -1948,10 +2030,11 @@ def generate(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = input_ids.shape[0] // self.num_codebooks
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though we don't ever use ParlerTTSForCausalLM, have you tested it still ? No worries if not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not tested at all, I have not propagated the changes to ParlerTTSForCausalLM.

parler_tts/modeling_parler_tts.py Show resolved Hide resolved
Comment on lines +2909 to +2920
prompt_hidden_states = model_kwargs["prompt_hidden_states"]
num_codebooks = self.decoder.num_codebooks
input = decoder_input_ids.reshape(-1, num_codebooks, decoder_input_ids.shape[-1])
inputs_embeds = sum(
[
self.decoder.model.decoder.embed_tokens[codebook](input[:, codebook])
for codebook in range(num_codebooks)
]
)
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
model_kwargs["inputs_embeds"] = inputs_embeds

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure to follow ! You prepend here, which makes total sense, but still leaves prompt_hidden_states in model_kwargs, so isn't it supposed to be used again in the forward pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prepend here to make sure that we have "inputs_embeds" in the model_kwargs when _get_initial_cache_position is called. I have not removed prompt_hidden_states from the model_kwargs here since I wanted to change as little as possible things from the current logic. Indeed, in the current implementation of prepare_inputs_for_generation, decoder_inputs_embeds (that corresponds to the inputs_embeds here created with prepending) is not handled and I have not changed that. Were we to remove promp_hidden_states from the model_kwargs at this stage, we need to then handle decoder_inputs_embeds in prepare_inputs_for_generation, which is doable, yet I felt the impact negligible and preferred to aim for a version that would change as little as possible things from the current logic.

@ylacombe
Copy link
Collaborator

ylacombe commented Aug 1, 2024

Another question on my side!
You said you need two generation in order to fully warmup the compilation, but in a single generation we already do multiple forward passes.
Because we only need 2 forward passes, isn't it enough to do one generation ?

@eustlb
Copy link
Contributor Author

eustlb commented Aug 5, 2024

I had the same intuition about it but it seems that a call to the full generate is necessary to warmup compilation, likely a specific torch compile inner working. Notice that the same warmup is done for other Transformers' models.

Copy link
Collaborator

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left some final remarks as to training compatibility, thanks for the great work!

For the final changes, it'd be great that you check if it doesn't break compile compatibility and run your standard tests to make sure everything's okay!

parler_tts/dac_wrapper/modeling_dac.py Outdated Show resolved Hide resolved
parler_tts/modeling_parler_tts.py Outdated Show resolved Hide resolved
@@ -244,7 +248,7 @@ def get_embedding(num_embeddings: int, embedding_dim: int):
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
bsz, seq_len, _ = input_ids.size()
# Create the position ids from the input token ids.
position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
position_ids = (torch.arange(seq_len, device=input_ids.device) + past_key_values_length).to(input_ids.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.arange(seq_len, device=input_ids.device) + past_key_values_length should be enough then, right ?

Comment on lines +515 to +519
if isinstance(past_key_value, StaticCache):
raise ValueError(
"The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
"Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's focus on the current behaviour, if we need to add it, we'll do it in a subsequent PR

parler_tts/modeling_parler_tts.py Show resolved Hide resolved
parler_tts/modeling_parler_tts.py Outdated Show resolved Hide resolved
@ylacombe
Copy link
Collaborator

ylacombe commented Aug 7, 2024

Thanks @eustlb, next steps:

  • Update the README (or make another .md ?) with torch.compile

Merging now ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants