Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do I use the weights trained on mello TTS with the converter? #310

Open
kaushal-gawri9899 opened this issue Sep 18, 2024 · 1 comment
Labels
question Further information is requested

Comments

@kaushal-gawri9899
Copy link

Hey, I've fine tuned mello tts for indian accent and a few indian languages. I wanted to use the weights in the tone converter but realized voice_conversion expects the averaged tensor values for source embeddings. How should I post process my weights to get the average? I've tried to use the extractor to get the weights instead of using MelloTTS, although tone color is varied but it's not cloned.

@dosubot dosubot bot added the question Further information is requested label Sep 18, 2024
Copy link

dosubot bot commented Sep 18, 2024

Hey @kaushal-gawri9899! I'm here to help you with any bugs, questions, or contributions you have. Let's tackle this together!

To use the weights trained on Mello TTS with the tone converter and post-process your weights to get the averaged tensor values for source embeddings, follow these steps:

  1. Initialize the ToneColorConverter and Load Checkpoints:

    import os
    import torch
    from openvoice.api import ToneColorConverter
    
    ckpt_converter = 'checkpoints/converter'
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    output_dir = 'outputs'
    
    tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
    tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
    
    os.makedirs(output_dir, exist_ok=True)
  2. Obtain Tone Color Embedding for the Target Speaker:

    from openvoice import se_extractor
    
    reference_speaker = 'resources/example_reference.mp3'  # This is the voice you want to clone
    target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, vad=True)
  3. Use MelloTTS as Base Speakers and Generate Speech:

    from openvoice.api import BaseSpeakerTTS
    
    ckpt_base = 'checkpoints/base_speakers/EN'
    base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base}/config.json', device=device)
    base_speaker_tts.load_ckpt(f'{ckpt_base}/checkpoint.pth')
    
    source_se = torch.load(f'{ckpt_base}/en_default_se.pth').to(device)
    
    texts = [
        "MyShell is a decentralized and comprehensive platform for discovering, creating, and staking AI-native apps.",
        "MyShell es una plataforma descentralizada y completa para descubrir, crear y apostar por aplicaciones nativas de IA.",
        "MyShell est une plateforme décentralisée et complète pour découvrir, créer et miser sur des applications natives d'IA.",
        "MyShell ist eine dezentralisierte und umfassende Plattform zum Entdecken, Erstellen und Staken von KI-nativen Apps.",
        "MyShell è una piattaforma decentralizzata e completa per scoprire, creare e scommettere su app native di intelligenza artificiale.",
        "MyShellは、AIネイティブアプリの発見、作成、およびステーキングのための分散型かつ包括的なプラットフォームです。",
        "MyShell — это децентрализованная и всеобъемлющая платформа для обнаружения, создания и стейкинга AI-ориентированных приложений.",
        "MyShell هي منصة لامركزية وشاملة لاكتشاف وإنشاء ورهان تطبيقات الذكاء الاصطناعي الأصلية.",
        "MyShell是一个去中心化且全面的平台,用于发现、创建和投资AI原生应用程序。",
        "MyShell एक विकेंद्रीकृत और व्यापक मंच है, जो AI-मूल ऐप्स की खोज, सृजन और स्टेकिंग के लिए है।",
        "MyShell é uma plataforma descentralizada e abrangente para descobrir, criar e apostar em aplicativos nativos de IA."
    ]
    
    src_path = f'{output_dir}/tmp.wav'
    
    for i, t in enumerate(texts):
        base_speaker_tts.tts(t, src_path, speaker='default', language='English', speed=1.0)
    
        save_path = f'{output_dir}/output_crosslingual_{i}.wav'
    
        # Run the tone color converter
        encode_message = "@MyShell"
        tone_color_converter.convert(
            audio_src_path=src_path, 
            src_se=source_se, 
            tgt_se=target_se, 
            output_path=save_path,
            message=encode_message)
  4. Post-Process Weights to Get Averaged Tensor Values for Source Embeddings:

    To average the tensor values for source embeddings, you can use the following code snippet:

    import torch
    
    def average_tensors(tensor_list):
        return sum(tensor_list) / len(tensor_list)
    
    source_embeddings_dir = 'checkpoints/base_speakers/ses'
    source_embeddings = []
    
    for file_name in os.listdir(source_embeddings_dir):
        if file_name.endswith('.pth'):
            tensor = torch.load(os.path.join(source_embeddings_dir, file_name), map_location=device)
            source_embeddings.append(tensor)
    
    averaged_tensor = average_tensors(source_embeddings)
    torch.save(averaged_tensor, 'checkpoints/averaged_source_embedding.pth')

This code will load all the source embeddings, compute their average, and save the averaged tensor [1][2][3].

To continue talking to Dosu, mention @dosu.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

1 participant