Skip to content

Commit

Permalink
Merge pull request #1 from NVIDIA/huggingface_hub_integration
Browse files Browse the repository at this point in the history
HuggingFace Hub integration with refactor
  • Loading branch information
L0SG authored Jul 16, 2024
2 parents 2d44823 + a34f475 commit a21751e
Show file tree
Hide file tree
Showing 12 changed files with 1,208 additions and 1,011 deletions.
91 changes: 69 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,23 @@

<center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>

**Paper**: https://arxiv.org/abs/2206.04658

**Code**: https://github.com/NVIDIA/BigVGAN

**Project page**: https://research.nvidia.com/labs/adlr/projects/bigvgan/

**Audio Demo**: https://bigvgan-demo.github.io/

**🤗 Hugging Face Spaces Demo**: https://huggingface.co/spaces/nvidia/BigVGAN

**🤗 Hugging Face Model Collection**: https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a

### [Paper](https://arxiv.org/abs/2206.04658) &emsp; [Project page](https://research.nvidia.com/labs/adlr/projects/bigvgan/) &emsp; [Audio demo](https://bigvgan-demo.github.io/)

## News
[Jul 2024] We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
[Jul 2024 (v2.1)] BigVGAN is now integrated with 🤗 Hugging Face Hub with easy access to inference using pretrained checkpoints. We also provide an interactive demo on Hugging Face Spaces.

[Jul 2024 (v2)] We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
* Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.
* Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546).
* Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
Expand All @@ -27,8 +39,42 @@ cd BigVGAN
pip install -r requirements.txt
```

## Inference Quickstart using 🤗 Hugging Face Hub

Below example describes how you can use BigVGAN: load the pretrained BigVGAN generator from Hugging Face Hub, compute mel spectrogram from input waveform, and generate synthesized waveform using the mel spectrogram as the model's input.

```python
device = 'cuda'

import torch
import bigvgan
import librosa
from meldataset import get_mel_spectrogram

# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference.
model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False)

# remove weight norm in the model and set to eval mode
model.remove_weight_norm()
model = model.eval().to(device)

# load wav file and compute mel spectrogram
wav, sr = librosa.load('/path/to/your/audio.wav', sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]

# compute mel spectrogram from the ground truth audio
mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame]

# generate waveform from mel
with torch.inference_mode():
wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time]

# you can convert the generated waveform to 16 bit linear PCM
wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
```

## Training
Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
``` shell
cd LibriTTS && \
Expand All @@ -42,7 +88,6 @@ ln -s /path/to/your/LibriTTS/test-other test-other && \
cd ..
```

## Training
Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
```shell
python train.py \
Expand All @@ -61,7 +106,7 @@ Synthesize from BigVGAN model. Below is an example command for generating audio
It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
```shell
python inference.py \
--checkpoint_file exp/bigvgan_v2_24khz_100band_256x/g_03000000 \
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
--input_wavs_dir /path/to/your/input_wav \
--output_dir /path/to/your/output_wav
```
Expand All @@ -72,7 +117,7 @@ It loads mel spectrograms from `--input_mels_dir` and saves the generated audio
Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
```shell
python inference_e2e.py \
--checkpoint_file exp/bigvgan_v2_24khz_100band_256x/g_03000000 \
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
--input_mels_dir /path/to/your/input_mel \
--output_dir /path/to/your/output_wav
```
Expand All @@ -94,7 +139,7 @@ We recommend running `test_cuda_vs_torch_model.py` first to build and check the

```python
python test_cuda_vs_torch_model.py \
--checkpoint_file /path/to/your/bigvgan/g_03000000
--checkpoint_file /path/to/your/bigvgan_generator.pt
```

```shell
Expand All @@ -107,7 +152,7 @@ Building extension module anti_alias_activation_cuda...
...
Loading extension module anti_alias_activation_cuda...
...
Loading '/path/to/your/bigvgan/g_03000000'
Loading '/path/to/your/bigvgan_generator.pt'
...
[Success] test CUDA fused vs. plain torch BigVGAN inference
> mean_difference=0.0007238413265440613
Expand All @@ -118,26 +163,28 @@ If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means


## Pretrained Models
We provide the [pretrained models](https://drive.google.com/drive/folders/1L2RDeJMBE7QAI8qV51n0QAf4mkSgUUeE?usp=sharing).
One can download the checkpoints of the generator weight (e.g., `g_(training_steps)`) and its discriminator/optimizer states (e.g., `do_(training_steps)`) within the listed folders.

|Folder Name|Sampling Rate|Mel band|fmax|Upsampling Ratio|Params.|Dataset|Fine-Tuned|
|------|---|---|---|---|---|------|---|
|bigvgan_v2_44khz_128band_512x|44 kHz|128|22050|512|122M|Large-scale Compilation|No|
|bigvgan_v2_44khz_128band_256x|44 kHz|128|22050|256|112M|Large-scale Compilation|No|
|bigvgan_v2_24khz_100band_256x|24 kHz|100|12000|256|112M|Large-scale Compilation|No|
|bigvgan_v2_22khz_80band_256x|22 kHz|80|11025|256|112M|Large-scale Compilation|No|
|bigvgan_v2_22khz_80band_fmax8k_256x|22 kHz|80|8000|256|112M|Large-scale Compilation|No|
|bigvgan_24khz_100band|24 kHz|100|12000|256|112M|LibriTTS|No|
|bigvgan_base_24khz_100band|24 kHz|100|12000|256|14M|LibriTTS|No|
|bigvgan_22khz_80band|22 kHz|80|8000|256|112M|LibriTTS + VCTK + LJSpeech|No|
|bigvgan_base_22khz_80band|22 kHz|80|8000|256|14M|LibriTTS + VCTK + LJSpeech|No|
We provide the [pretrained models on Hugging Face Collections](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a).
One can download the checkpoints of the generator weight (named `bigvgan_generator.pt`) and its discriminator/optimizer states (named `bigvgan_discriminator_optimizer.pt`) within the listed model repositories.

|Model Name|Sampling Rate|Mel band|fmax|Upsampling Ratio|Params|Dataset|Steps|Fine-Tuned|
|------|---|---|---|---|---|------|---|---|
|[bigvgan_v2_44khz_128band_512x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_512x)|44 kHz|128|22050|512|122M|Large-scale Compilation|3M|No|
|[bigvgan_v2_44khz_128band_256x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_256x)|44 kHz|128|22050|256|112M|Large-scale Compilation|3M|No|
|[bigvgan_v2_24khz_100band_256x](https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x)|24 kHz|100|12000|256|112M|Large-scale Compilation|3M|No|
|[bigvgan_v2_22khz_80band_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x)|22 kHz|80|11025|256|112M|Large-scale Compilation|3M|No|
|[bigvgan_v2_22khz_80band_fmax8k_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x)|22 kHz|80|8000|256|112M|Large-scale Compilation|3M|No|
|[bigvgan_24khz_100band](https://huggingface.co/nvidia/bigvgan_24khz_100band)|24 kHz|100|12000|256|112M|LibriTTS|5M|No|
|[bigvgan_base_24khz_100band](https://huggingface.co/nvidia/bigvgan_base_24khz_100band)|24 kHz|100|12000|256|14M|LibriTTS|5M|No|
|[bigvgan_22khz_80band](https://huggingface.co/nvidia/bigvgan_22khz_80band)|22 kHz|80|8000|256|112M|LibriTTS + VCTK + LJSpeech|5M|No|
|[bigvgan_base_22khz_80band](https://huggingface.co/nvidia/bigvgan_base_22khz_80band)|22 kHz|80|8000|256|14M|LibriTTS + VCTK + LJSpeech|5M|No|

The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
Note that the checkpoints use ``snakebeta`` activation with log scale parameterization, which have the best overall quality.

You can fine-tune the models by downloading the checkpoints (both the generator weight and its discrimiantor/optimizer states) and resuming training using your audio dataset.
You can fine-tune the models by:
1. downloading the checkpoints (both the generator weight and its discriminator/optimizer states)
2. resuming training using your audio dataset by specifying `--checkpoint_path` that includes the checkpoints when launching `train.py`

## Training Details of BigVGAN-v2
Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
Expand Down
Loading

0 comments on commit a21751e

Please sign in to comment.