Spear-TTS - Pytorch

Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch

The text-to-semantic module built here will be used for SoundStorm for conditioning.


  • Stability for their generous sponsorships to work on and open source cutting edge artificial intelligence research

  • Lucas Newman for completing the backtranslation portion, as well as beam search decoding!

  • Lucas Newman for completing the final text to semantic transformer training code!


$ pip install spear-tts-pytorch


import torch

from audiolm_pytorch import HubertWithKmeans

from spear_tts_pytorch import (

wav2vec = HubertWithKmeans(
    checkpoint_path = './',
    kmeans_path = './hubert_base_ls960_L9_km500.bin'

model = TextToSemantic(
    wav2vec = wav2vec,
    dim = 512,
    num_text_token_ids = 256,
    heads = 8,
    target_kv_heads = 2, # grouped query attention, for memory efficient decoding
    source_depth = 1,
    target_depth = 1

ds = MockDataset(10)

dataset_generator = SemanticToTextDatasetGenerator(
    model = model,
    dataset = ds,
    folder = './output_folder'

dataset_generator(max_length = 2)

generated_dataset = GeneratedAudioTextDataset(
    folder = './output_folder'

assert len(generated_dataset) == 10


  • add eos logic + generate, and hook up end-to-end generation in soundstorm

  • add first pretraining speech-to-speech with the reconstruction of 60% deleted tokens

  • add dropouts for this project, as low-resource

  • add total flexiblity of which layers of encoder / decoder to freeze during training

  • add step for training on small speech -> text corpus and generating pseudo-labelled dataset + finetuning (thanks to @lucasnewman)

  • add final step of finetuning on text -> speech + pseudolabelled dataset

  • figure out the best way to store and manage the pseudo-labelled generated dataset

  • batched beam search decoding

  • allow for using rotary positions in decoder + flash attention, give Tri another citation

  • integrate speculative decoding with some improvisation - done in same model using early exit strategy

  • add cached key / values for starter + single / grouped key values, make sure flash attention can support specialized causal mask before flash attention 2 is in pytorch core

  • polish the audio-text generation workflow

  • concatting the real audio-text dataset with the generated one -> or being able to convert real audio-text dataset to generated


