Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with Matcha TTS #39

Open
mush42 opened this issue Dec 16, 2023 · 7 comments · May be fixed by #47
Open

Compatibility with Matcha TTS #39

mush42 opened this issue Dec 16, 2023 · 7 comments · May be fixed by #47

Comments

@mush42
Copy link

mush42 commented Dec 16, 2023

Hi

The issue

I trained a model based on Matcha TTS, and I tried to use Vocos with it. Unfortunately, vocoding using a checkpoint trained with the default config of Vocos gives a robotic output with very low volume.

The only config values I changed are sample_rate (=22050) and n_mels (=80).

I asumed that there is a mismatch between Matcha TTS-generated melspectrogram and Vocos expected melspectrogram in terms of parameters.

A new feature extractor

I wrote a feature extractor class to generate melspectogram using same parameters of Matcha TTS. Most of the code is copied directly from Matcha's source code.

Click to expand: MatchaMelSpectrogramFeatures
import numpy as np
import torch
from librosa.filters import mel as librosa_mel_fn

from vocos.feature_extractors import FeatureExtractor


class MatchaMelSpectrogramFeatures(FeatureExtractor):
    """
    Generate MelSpectrogram from audio using same params
    as Matcha TTS (https://github.com/shivammehta25/Matcha-TTS)
    This is also useful with tacatron, waveglow..etc.
    """

    def __init__(
        self,
        *,
        mel_mean,
        mel_std,
        sample_rate=22050,
        n_fft=1024,
        win_length=1024,
        n_mels=80,
        hop_length=256,
        center=False,
        f_min=0,
        f_max=8000,
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.win_length = win_length
        self.hop_length = hop_length
        self.center = center
        self.f_min = f_min
        self.f_max = f_max
        # Data-dependent
        self.mel_mean = mel_mean
        self.mel_std = mel_std
        # Cache
        self._mel_basis = {}
        self._hann_window = {}

    def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
        mel = self.mel_spectrogram(audio).squeeze()
        mel = normalize(mel, self.mel_mean, self.mel_std)
        return mel.unsqueeze(0)

    def mel_spectrogram(self, y):
        mel_basis_key = str(self.f_max) + "_" + str(y.device)
        han_window_key = str(y.device)
        if mel_basis_key not in self._mel_basis:
            mel = librosa_mel_fn(
                sr=self.sample_rate,
                n_fft=self.n_fft,
                n_mels=self.n_mels,
                fmin=self.f_min,
                fmax=self.f_max
            )
            self._mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
            self._hann_window[han_window_key] = torch.hann_window(self.win_length).to(y.device)
        pad_vals = (
            (self.n_fft - self.hop_length) // 2,
            (self.n_fft - self.hop_length) // 2,
        )
        y = torch.nn.functional.pad(
            y.unsqueeze(1),
            pad_vals,
            mode="reflect"
        )
        y = y.squeeze(1)
        spec = torch.stft(
            y,
            self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=self._hann_window[han_window_key],
            center=self.center,
            pad_mode="reflect",
            normalized=False,
            onesided=True,
            return_complex=True,
        )
        spec = torch.view_as_real(spec)
        spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
        spec = torch.matmul(self._mel_basis[mel_basis_key], spec)
        spec = spectral_normalize_torch(spec)
        return spec


def spectral_normalize_torch(magnitudes):
    output = dynamic_range_compression_torch(magnitudes)
    return output


def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
    return torch.log(torch.clamp(x, min=clip_val) * C)

def normalize(data, mu, std):
    if not isinstance(mu, (float, int)):
        if isinstance(mu, list):
            mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
        elif isinstance(mu, torch.Tensor):
            mu = mu.to(data.device)
        elif isinstance(mu, np.ndarray):
            mu = torch.from_numpy(mu).to(data.device)
        mu = mu.unsqueeze(-1)

    if not isinstance(std, (float, int)):
        if isinstance(std, list):
            std = torch.tensor(std, dtype=data.dtype, device=data.device)
        elif isinstance(std, torch.Tensor):
            std = std.to(data.device)
        elif isinstance(std, np.ndarray):
            std = torch.from_numpy(std).to(data.device)
        std = std.unsqueeze(-1)

    return (data - mu) / std

And I used it with the following config:

Click to expand config: vocos-matcha.yaml
# pytorch_lightning==1.8.6
seed_everything: 4444

data:
  class_path: vocos.dataset.VocosDataModule
  init_args:
    train_params:
      filelist_path: ./datasets/train.txt
      sampling_rate: 22050
      num_samples: 16384
      batch_size: 16
      num_workers: 4

    val_params:
      filelist_path: ./datasets/val.txt
      sampling_rate: 22050
      num_samples: 48384
      batch_size: 16
      num_workers: 4

model:
  class_path: vocos.experiment.VocosExp
  init_args:
    sample_rate: 22050
    initial_learning_rate: 5e-4
    mel_loss_coeff: 45
    mrd_loss_coeff: 0.1
    num_warmup_steps: 0 # Optimizers warmup steps
    pretrain_mel_steps: 0  # 0 means GAN objective from the first iteration
    # automatic evaluation
    evaluate_utmos: true
    evaluate_pesq: true
    evaluate_periodicty: true

    feature_extractor:
      class_path: matcha_feature_extractor.MatchaMelSpectrogramFeatures
      init_args:
        sample_rate: 22050
        n_fft: 1024
        n_mels: 80
        hop_length: 256
        win_length: 1024
        f_min: 0
        f_max: 8000
        center: False
        mel_mean: -6.38385
        mel_std: 2.541796

    backbone:
      class_path: vocos.models.VocosBackbone
      init_args:
        input_channels: 80
        dim: 512
        intermediate_dim: 1536
        num_layers: 8

    head:
      class_path: vocos.heads.ISTFTHead
      init_args:
        dim: 512
        n_fft: 1024
        hop_length: 256
        padding: same

trainer:
  logger:
    class_path: pytorch_lightning.loggers.TensorBoardLogger
    init_args:
      save_dir: /content/drive/MyDrive/vocos/logs
  callbacks:
    - class_path: pytorch_lightning.callbacks.LearningRateMonitor
    - class_path: pytorch_lightning.callbacks.ModelSummary
      init_args:
        max_depth: 2
    - class_path: pytorch_lightning.callbacks.ModelCheckpoint
      init_args:
        monitor: val_loss
        filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f}
        save_top_k: 2
        save_last: true
    - class_path: vocos.helpers.GradNormCallback

  # Lightning calculates max_steps across all optimizer steps (rather than number of batches)
  # This equals to 1M steps per generator and 1M per discriminator
  max_steps: 2000000
  # You might want to limit val batches when evaluating all the metrics, as they are time-consuming
  limit_val_batches: 128
  accelerator: gpu
  strategy: ddp
  devices: [0]
  log_every_n_steps: 100

Results

I trained Vocos using the above feature extractor and config, but this also fails with even worse vocoding quality and even lower volume.

Questions

  • Did I miss something in above feature extractor?
  • Does the default Vocos head expects melspectograms generated using certain parameters?
  • Any suggestions to resolve this?

Additional notes

I believe many open-source TTS models use the same code to extract melspectogram. So resolving this will help with training Vocos for use with these TTS models.

Best

@mush42 mush42 changed the title Compatibility with Matcha TTS generated Melspectogram Compatibility with Matcha TTS Dec 16, 2023
@mush42
Copy link
Author

mush42 commented Dec 17, 2023

@hubertsiuzdak @alealv

would you mind helping with this? I don't know where to start.

@artificalaudio
Copy link

artificalaudio commented Jan 15, 2024

@mush42 Hey, I'm pissing around with the same thing currently.

Tried synthesising using vocos as a head to MatchaTTS. Vocos seems to want 100 mel bins? Matcha currently outputs specs with 80bins. I'm not sure the best way to go, either retrain Matcha on 100bins, or see if zero padding could work. I tried earlier, just zero padding from 80 to 100mel bins, and synthesising through vocos mel head, quality wasn't that great

@egorsmkv
Copy link

You can check out my fork with config for 22050 vocos - https://github.com/egorsmkv/vocos

@bharathraj-v
Copy link

bharathraj-v commented Feb 15, 2024

@egorsmkv, even after training the model with vocos.yaml config from your repo, the issue seems to persist, the output is still robotic and in low-volume

@hubertsiuzdak @alealv, Any help or guidance regarding this would be really helpful!

@egorsmkv
Copy link

@egorsmkv, even after training the model with vocos.yaml config from your repo, the issue seems to persist, the output is still robotic and in low-volume

@hubertsiuzdak @alealv, Any help or guidance regarding this would be really helpful!

How many steps did you train?

@bharathraj-v
Copy link

bharathraj-v commented Feb 15, 2024

How many steps did you train?

15k steps

@wetdog
Copy link

wetdog commented Feb 17, 2024

@mush42 I took a different approach. I searched for the appropriate parameters of torchaudio.transforms.MelSpectrogram to ensure that the features of Vocos match those of Matcha, which are also the same as those of HiFi-GAN

The difference was on the frequency limits and the mel scaling that uses by default torchaudio.

mel_spec_transform_mod = torchaudio.transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=padding == "center", power=1, f_min=0, f_max=8000, norm="slaney", mel_scale="slaney", )

image

I also updated the feature extraction in the reconstruction loss. You can check the changes in this fork https://github.com/wetdog/vocos/tree/matcha

The results sound good after 20 epochs with Libritts, We'll publish the checkpoints once the training finishes.

@wetdog wetdog linked a pull request Feb 19, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants