Skip to content

Commit

Permalink
add MultiBandDiffusion (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsxdalv authored Aug 4, 2023
1 parent 1e01bc3 commit f54576e
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 14 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ Google Colab demo: [![Open In Colab](https://colab.research.google.com/assets/co

## Videos

| **Refining Bark TTS vocals using Demucs & Vocos** | **Demo - How to use RVC with Tortoise** | **How To Get More Voices for Bark TTS** |
| **TTS Generation WebUI - A Tool for Text to Speech and Voice Cloning** | **Text to speech and voice cloning - TTS Generation WebUI** | **AudioGen Unveils New Text-to-Audio Capabilities** |
| :------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------: |
| [![Watch the video](https://img.youtube.com/vi/jCb-8JE7pk8/sddefault.jpg)](https://youtu.be/jCb-8JE7pk8) | [![Watch the video](https://img.youtube.com/vi/mhp_e8WSpxA/sddefault.jpg)](https://youtu.be/mhp_e8WSpxA) | [![Watch the video](https://img.youtube.com/vi/yeC5vJoavOE/sddefault.jpg)](https://youtu.be/yeC5vJoavOE) |
| [![Watch the video](https://img.youtube.com/vi/JXojhFjZ39k/sddefault.jpg)](https://youtu.be/JXojhFjZ39k) | [![Watch the video](https://img.youtube.com/vi/ScN2ypewABc/sddefault.jpg)](https://youtu.be/ScN2ypewABc) | [![Watch the video](https://img.youtube.com/vi/fDqyw9JG6PY/sddefault.jpg)](https://youtu.be/fDqyw9JG6PY) |

## Screenshots

Expand All @@ -34,6 +34,10 @@ Google Colab demo: [![Open In Colab](https://colab.research.google.com/assets/co
https://rsxdalv.github.io/bark-speaker-directory/

## Changelog
Aug 4:
* Add MultiBandDiffusion option to MusicGen https://github.com/rsxdalv/tts-generation-webui/pull/109
* MusicGen/AudioGen save tokens on generation as .npz files.

Aug 3:
* Add AudioGen https://github.com/rsxdalv/tts-generation-webui/pull/105

Expand Down
18 changes: 15 additions & 3 deletions src/bark/npz_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from src.bark.FullGeneration import FullGeneration
import json
import torch


def compress_history(full_generation: FullGeneration):
Expand All @@ -13,10 +14,11 @@ def compress_history(full_generation: FullGeneration):
}


def save_npz(filename: str, full_generation: FullGeneration, metadata: dict[str, Any]):
def pack_metadata(metadata: dict[str, Any]):
return list(json.dumps(metadata))
def pack_metadata(metadata: dict[str, Any]):
return list(json.dumps(metadata))


def save_npz(filename: str, full_generation: FullGeneration, metadata: dict[str, Any]):
np.savez(
filename,
**{
Expand All @@ -26,6 +28,16 @@ def pack_metadata(metadata: dict[str, Any]):
)


def save_npz_musicgen(filename: str, tokens: torch.Tensor, metadata: dict[str, Any]):
np.savez(
filename,
**{
"tokens": tokens.cpu().numpy(),
"metadata": pack_metadata(metadata),
},
)


def load_npz(filename):
def unpack_metadata(metadata: np.ndarray):
def join_list(x: list | np.ndarray):
Expand Down
60 changes: 51 additions & 9 deletions src/musicgen/musicgen_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional, Tuple, TypedDict
import numpy as np
import os
from src.bark.npz_tools import save_npz_musicgen
from src.musicgen.setup_seed_ui_musicgen import setup_seed_ui_musicgen
from src.bark.parse_or_set_seed import parse_or_set_seed
from src.musicgen.audio_array_to_sha256 import audio_array_to_sha256
Expand Down Expand Up @@ -37,6 +38,7 @@ class MusicGenGeneration(TypedDict):
temperature: float
cfg_coef: float
seed: int
use_multi_band_diffusion: bool


def melody_to_sha256(melody: Optional[Tuple[int, np.ndarray]]) -> Optional[str]:
Expand Down Expand Up @@ -76,13 +78,14 @@ def save_generation(
audio_array: np.ndarray,
SAMPLE_RATE: int,
params: MusicGenGeneration,
tokens: torch.Tensor,
):
prompt = params["text"]
date = get_date_string()
title = prompt[:20].replace(" ", "_")
base_filename = create_base_filename(title, "outputs", model="musicgen", date=date)

filename, filename_png, filename_json, _ = get_filenames(base_filename)
filename, filename_png, filename_json, filename_npz = get_filenames(base_filename)
write_wav(filename, SAMPLE_RATE, audio_array)
plot = save_waveform_plot(audio_array, filename_png)

Expand All @@ -93,6 +96,7 @@ def save_generation(
params=params,
audio_array=audio_array,
)
save_npz_musicgen(filename_npz, tokens, metadata)

filename_ogg = filename.replace(".wav", ".ogg")
ext_callback_save_generation_musicgen(
Expand Down Expand Up @@ -165,17 +169,19 @@ def generate(params: MusicGenGeneration, melody_in: Optional[Tuple[int, np.ndarr
if melody.dim() == 2:
melody = melody[None]
melody = melody[..., : int(sr * MODEL.lm.cfg.dataset.segment_duration)] # type: ignore
output = MODEL.generate_with_chroma(
output, tokens = MODEL.generate_with_chroma(
descriptions=[text],
melody_wavs=melody,
melody_sample_rate=sr,
progress=False,
return_tokens=True,
# generator=generator,
)
else:
output = MODEL.generate(
output, tokens = MODEL.generate(
descriptions=[text],
progress=True,
return_tokens=True,
# generator=generator,
)
set_seed(-1)
Expand All @@ -184,12 +190,19 @@ def generate(params: MusicGenGeneration, melody_in: Optional[Tuple[int, np.ndarr
# print time taken
print("Generated in", "{:.3f}".format(elapsed), "seconds")

output = output.detach().cpu().numpy().squeeze()
if params["use_multi_band_diffusion"]:
from audiocraft.models.multibanddiffusion import MultiBandDiffusion
mbd = MultiBandDiffusion.get_mbd_musicgen()
wav_diffusion = mbd.tokens_to_wav(tokens)
output = wav_diffusion.detach().cpu().numpy().squeeze()
else:
output = output.detach().cpu().numpy().squeeze()

filename, plot, _metadata = save_generation(
audio_array=output,
SAMPLE_RATE=MODEL.sample_rate,
params=params,
tokens=tokens,
)

return [
Expand All @@ -215,19 +228,31 @@ def generation_tab_musicgen():
"temperature": 1.0,
"cfg_coef": 3.0,
"seed": -1,
"use_multi_band_diffusion": False,
},
)
# musicgen_atom.render()
gr.Markdown(f"""Audiocraft version: {AUDIOCRAFT_VERSION}""")
with gr.Row():
with gr.Row(equal_height=False):
with gr.Column():
text = gr.Textbox(
label="Prompt", lines=3, placeholder="Enter text here..."
)
model = gr.Radio(
["melody", "medium", "small", "large", "facebook/audiogen-medium"],
[
"facebook/musicgen-melody",
# "musicgen-melody",
"facebook/musicgen-medium",
# "musicgen-medium",
"facebook/musicgen-small",
# "musicgen-small",
"facebook/musicgen-large",
# "musicgen-large",
"facebook/audiogen-medium",
# "audiogen-medium",
],
label="Model",
value="melody",
value="facebook/musicgen-small",
)
melody = gr.Audio(
source="upload",
Expand Down Expand Up @@ -269,6 +294,10 @@ def generation_tab_musicgen():
interactive=True,
step=0.1,
)
use_multi_band_diffusion = gr.Checkbox(
label="Use Multi-Band Diffusion",
value=False,
)
seed, set_old_seed_button, _ = setup_seed_ui_musicgen()

with gr.Column():
Expand All @@ -295,7 +324,18 @@ def generation_tab_musicgen():
outputs=[melody],
)

inputs = [text, melody, model, duration, topk, topp, temperature, cfg_coef, seed]
inputs = [
text,
melody,
model,
duration,
topk,
topp,
temperature,
cfg_coef,
seed,
use_multi_band_diffusion,
]

def update_components(x):
return {
Expand All @@ -308,6 +348,7 @@ def update_components(x):
temperature: x["temperature"],
cfg_coef: x["cfg_coef"],
seed: x["seed"],
use_multi_band_diffusion: x["use_multi_band_diffusion"],
}

musicgen_atom.change(
Expand All @@ -317,7 +358,7 @@ def update_components(x):
)

def update_json(
text, _melody, model, duration, topk, topp, temperature, cfg_coef, seed
text, _melody, model, duration, topk, topp, temperature, cfg_coef, seed, use_multi_band_diffusion
):
return {
"text": text,
Expand All @@ -329,6 +370,7 @@ def update_json(
"temperature": float(temperature),
"cfg_coef": float(cfg_coef),
"seed": int(seed),
"use_multi_band_diffusion": bool(use_multi_band_diffusion),
}

seed_cache = gr.State() # type: ignore
Expand Down

0 comments on commit f54576e

Please sign in to comment.