Skip to content

Commit

Permalink
Add static cache (#89)
Browse files Browse the repository at this point in the history
* add rope

* don't include padding in rope

* possibly use cross-attn for prompt

* fix rope

* fix cross-attn

* fix self-attn

* fix dummy model

* clean-up rope

* first gqa implementation

* fix wer eval

* feat: add flash attention and spda

* chore: add README for flash attention

* chore: add benchmark script

* chore: add benchmark attention approach

* multi node and fix wer and fix compile

* Update modeling_parler_tts.py

* fix FA2, SDPA and add cross-attn MHA and attention type forcing

* better cross_attention key values number of heads default + add training arguments for attn implementation

* fix audio padding when torch compile or pad_to_max_length=True

* correct multi node

* make rope faster

* fix encoder sdpa

* fix training with cross attention + with FAZ

* use fp32 as default model dtype + fix generation when using FA2 with autocast

* remove redundant passes in generate + clean and fix attentions

* fix edge case in WER evaluation when longform generation

* better multi-node mapping and saving / add eval dataloader num workers

* remove old benchmarks

* faster audio encoding + checkpointing + fix generation step

* unpin trfms

* remove CFG

* imports and constants

Co-Authored-By:  sang-nguyen-ts <[email protected]>

* attention modifications to handle static cach

Co-Authored-By:  sang-nguyen-ts <[email protected]>

* decoder layer modification to handle static cache

Co-Authored-By:  sang-nguyen-ts <[email protected]>

* ParlerTTSPreTrainedModel modifs to handle static cache

Co-Authored-By:  sang-nguyen-ts <[email protected]>

* ParlerTTSDecoder modifs to handle static cache

Co-Authored-By:  sang-nguyen-ts <[email protected]>

* ParlerTTSModel + ParlerTTSForCausalLM modfis

Co-Authored-By:  sang-nguyen-ts <[email protected]>

* ParlerTTSForConditionalGeneration modifs

Co-Authored-By:  sang-nguyen-ts <[email protected]>

* decoder_attention_mask for static cache

Co-Authored-By:  sang-nguyen-ts <[email protected]>

* create inputs_embeds early to have a good cache initialization

Co-Authored-By:  sang-nguyen-ts <[email protected]>

* _get_cache method

Co-Authored-By:  sang-nguyen-ts <[email protected]>

* init the cache

Co-Authored-By:  sang-nguyen-ts <[email protected]>

* ensure good device

Co-Authored-By:  sang-nguyen-ts <[email protected]>

* pin tfrms version

Co-Authored-By:  sang-nguyen-ts <[email protected]>

* fix attention_mask FA2

* remove unnecessary method

* Update parler_tts/modeling_parler_tts.py

Co-authored-by: Sanchit Gandhi <[email protected]>

* Update parler_tts/modeling_parler_tts.py

Co-authored-by: Sanchit Gandhi <[email protected]>

* remove unnecessary imports

* replace the hardcoded cache_position with a more elegant approach

* make style

* unpin transformers

* pin transformers

* pin torch

* refactor + unpin torch

* Update parler_tts/modeling_parler_tts.py

Co-authored-by: Yoach Lacombe <[email protected]>

* update training script to match 11b209e

* Update parler_tts/modeling_parler_tts.py

Co-authored-by: Yoach Lacombe <[email protected]>

* ensure compatibility with trfms 4.43.3, changes taken from #31980 on trfms

* fix input_ids_length

* warning full attention mask creation

* changes for training compatibility

---------

Co-authored-by: sanchit-gandhi <[email protected]>
Co-authored-by: Yoach Lacombe <[email protected]>
Co-authored-by: Yoach Lacombe <[email protected]>
Co-authored-by: sang-nguyen-ts <[email protected]>
Co-authored-by: [email protected] <Yoach Lacombe>
Co-authored-by: sang-nguyen-ts <[email protected]>
Co-authored-by: Sanchit Gandhi <[email protected]>
  • Loading branch information
7 people authored Aug 7, 2024
1 parent 11b209e commit 862f841
Show file tree
Hide file tree
Showing 13 changed files with 576 additions and 346 deletions.
5 changes: 3 additions & 2 deletions helpers/gradio_demo/app.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoTokenizer, set_seed

from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed


device = "cuda:0" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -57,7 +58,7 @@ def gen_tts(text, description):
background-color: #000000;
justify-content: center;
align-items: center;
border-radius: 9999px !important;
border-radius: 9999px !important;
width: 13rem;
margin-top: 10px;
margin-left: auto;
Expand Down
8 changes: 5 additions & 3 deletions helpers/model_init_scripts/init_dummy_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
import os

from transformers import AutoConfig

from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration


if __name__ == "__main__":
Expand Down
9 changes: 6 additions & 3 deletions helpers/model_init_scripts/init_dummy_model_with_encodec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
import os

from transformers import AutoConfig

from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration


if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down
8 changes: 5 additions & 3 deletions helpers/model_init_scripts/init_model_600M.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
from transformers import AutoConfig
import os
import argparse
import os

from transformers import AutoConfig

from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions helpers/push_to_hub_scripts/push_dac_to_hub.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dac
from transformers import AutoConfig, AutoModel, EncodecFeatureExtractor

from parler_tts import DACConfig, DACModel
from transformers import AutoConfig, AutoModel
from transformers import EncodecFeatureExtractor
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from transformers import AutoFeatureExtractor, AutoTokenizer

from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor


path = "TODO"
repo_id = "parler_tts_600M"
Expand Down
5 changes: 3 additions & 2 deletions parler_tts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
__version__ = "0.1"


from transformers import AutoConfig, AutoModel

from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .dac_wrapper import DACConfig, DACModel
from .modeling_parler_tts import (
ParlerTTSForCausalLM,
ParlerTTSForConditionalGeneration,
apply_delay_pattern_mask,
build_delay_pattern_mask,
)

from .dac_wrapper import DACConfig, DACModel
from transformers import AutoConfig, AutoModel

AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
2 changes: 1 addition & 1 deletion parler_tts/dac_wrapper/configuration_dac.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from transformers import PretrainedConfig
from typing import List


class DACConfig(PretrainedConfig):
Expand Down
9 changes: 4 additions & 5 deletions parler_tts/dac_wrapper/modeling_dac.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import torch

from dac.model import DAC
from transformers import PreTrainedModel
from transformers.models.encodec.modeling_encodec import EncodecEncoderOutput, EncodecDecoderOutput
from .configuration_dac import DACConfig
from transformers.models.encodec.modeling_encodec import EncodecDecoderOutput, EncodecEncoderOutput

from dac.model import DAC
from .configuration_dac import DACConfig


# model doesn't support batching yet
Expand Down Expand Up @@ -134,4 +133,4 @@ def decode(
return EncodecDecoderOutput(audio_values)

def forward(self, tensor):
raise ValueError(f"`DACModel.forward` not implemented yet")
raise ValueError("`DACModel.forward` not implemented yet")
Loading

0 comments on commit 862f841

Please sign in to comment.