Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add rope * don't include padding in rope * possibly use cross-attn for prompt * fix rope * fix cross-attn * fix self-attn * fix dummy model * clean-up rope * first gqa implementation * fix wer eval * feat: add flash attention and spda * chore: add README for flash attention * chore: add benchmark script * chore: add benchmark attention approach * multi node and fix wer and fix compile * Update modeling_parler_tts.py * fix FA2, SDPA and add cross-attn MHA and attention type forcing * better cross_attention key values number of heads default + add training arguments for attn implementation * fix audio padding when torch compile or pad_to_max_length=True * correct multi node * make rope faster * fix encoder sdpa * fix training with cross attention + with FAZ * use fp32 as default model dtype + fix generation when using FA2 with autocast * remove redundant passes in generate + clean and fix attentions * fix edge case in WER evaluation when longform generation * better multi-node mapping and saving / add eval dataloader num workers * remove old benchmarks * faster audio encoding + checkpointing + fix generation step * unpin trfms * remove CFG * imports and constants Co-Authored-By: sang-nguyen-ts <[email protected]> * attention modifications to handle static cach Co-Authored-By: sang-nguyen-ts <[email protected]> * decoder layer modification to handle static cache Co-Authored-By: sang-nguyen-ts <[email protected]> * ParlerTTSPreTrainedModel modifs to handle static cache Co-Authored-By: sang-nguyen-ts <[email protected]> * ParlerTTSDecoder modifs to handle static cache Co-Authored-By: sang-nguyen-ts <[email protected]> * ParlerTTSModel + ParlerTTSForCausalLM modfis Co-Authored-By: sang-nguyen-ts <[email protected]> * ParlerTTSForConditionalGeneration modifs Co-Authored-By: sang-nguyen-ts <[email protected]> * decoder_attention_mask for static cache Co-Authored-By: sang-nguyen-ts <[email protected]> * create inputs_embeds early to have a good cache initialization Co-Authored-By: sang-nguyen-ts <[email protected]> * _get_cache method Co-Authored-By: sang-nguyen-ts <[email protected]> * init the cache Co-Authored-By: sang-nguyen-ts <[email protected]> * ensure good device Co-Authored-By: sang-nguyen-ts <[email protected]> * pin tfrms version Co-Authored-By: sang-nguyen-ts <[email protected]> * fix attention_mask FA2 * remove unnecessary method * Update parler_tts/modeling_parler_tts.py Co-authored-by: Sanchit Gandhi <[email protected]> * Update parler_tts/modeling_parler_tts.py Co-authored-by: Sanchit Gandhi <[email protected]> * remove unnecessary imports * replace the hardcoded cache_position with a more elegant approach * make style * unpin transformers * pin transformers * pin torch * refactor + unpin torch * Update parler_tts/modeling_parler_tts.py Co-authored-by: Yoach Lacombe <[email protected]> * update training script to match 11b209e * Update parler_tts/modeling_parler_tts.py Co-authored-by: Yoach Lacombe <[email protected]> * ensure compatibility with trfms 4.43.3, changes taken from #31980 on trfms * fix input_ids_length * warning full attention mask creation * changes for training compatibility --------- Co-authored-by: sanchit-gandhi <[email protected]> Co-authored-by: Yoach Lacombe <[email protected]> Co-authored-by: Yoach Lacombe <[email protected]> Co-authored-by: sang-nguyen-ts <[email protected]> Co-authored-by: [email protected] <Yoach Lacombe> Co-authored-by: sang-nguyen-ts <[email protected]> Co-authored-by: Sanchit Gandhi <[email protected]>
- Loading branch information