From 4cfdabcac76460e032f4ebe5d21674e3c1b1332b Mon Sep 17 00:00:00 2001
From: maxardito
Date: Fri, 29 Mar 2024 14:33:56 -0400
Subject: [PATCH 1/3] removed most of diffusion
---
.gitignore | 1 +
audiocraft/__init__.py | 2 -
.../grids/diffusion/4_bands_base_32khz.py | 27 --
audiocraft/grids/diffusion/__init__.py | 6 -
audiocraft/grids/diffusion/_explorers.py | 66 ----
audiocraft/models/__init__.py | 2 -
audiocraft/models/multibanddiffusion.py | 191 -----------
audiocraft/models/unet.py | 132 ++++++--
audiocraft/modules/diffusion_schedule.py | 272 ----------------
audiocraft/solvers/__init__.py | 1 -
audiocraft/solvers/builders.py | 2 -
audiocraft/solvers/diffusion.py | 279 ----------------
demos/musicgen_app.py | 300 ++++++++++--------
tests/models/test_multibanddiffusion.py | 53 ----
utils/export.py | 3 +-
15 files changed, 270 insertions(+), 1067 deletions(-)
delete mode 100644 audiocraft/grids/diffusion/4_bands_base_32khz.py
delete mode 100644 audiocraft/grids/diffusion/__init__.py
delete mode 100644 audiocraft/grids/diffusion/_explorers.py
delete mode 100644 audiocraft/models/multibanddiffusion.py
delete mode 100644 audiocraft/modules/diffusion_schedule.py
delete mode 100644 audiocraft/solvers/diffusion.py
delete mode 100644 tests/models/test_multibanddiffusion.py
diff --git a/.gitignore b/.gitignore
index 581c6caa..fa469b38 100644
--- a/.gitignore
+++ b/.gitignore
@@ -53,6 +53,7 @@ dataset/*
musicgen-deployment.txt
manifests
export
+sessions
# personal notebooks & scripts
*/local_scripts
diff --git a/audiocraft/__init__.py b/audiocraft/__init__.py
index 1226aa02..385d9a00 100644
--- a/audiocraft/__init__.py
+++ b/audiocraft/__init__.py
@@ -14,8 +14,6 @@
- [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
neural audio codec which provides an excellent tokenizer for autoregressive language models.
See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
-- [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
- improves the perceived quality and reduces the artifacts coming from adversarial decoders.
"""
# flake8: noqa
diff --git a/audiocraft/grids/diffusion/4_bands_base_32khz.py b/audiocraft/grids/diffusion/4_bands_base_32khz.py
deleted file mode 100644
index f7e67bcc..00000000
--- a/audiocraft/grids/diffusion/4_bands_base_32khz.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""
-Training of the 4 diffusion models described in
-"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
-(paper link).
-"""
-
-from ._explorers import DiffusionExplorer
-
-
-@DiffusionExplorer
-def explorer(launcher):
- launcher.slurm_(gpus=4, partition='learnfair')
-
- launcher.bind_({'solver': 'diffusion/default',
- 'dset': 'internal/music_10k_32khz'})
-
- with launcher.job_array():
- launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4})
- launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4})
- launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4})
- launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75})
diff --git a/audiocraft/grids/diffusion/__init__.py b/audiocraft/grids/diffusion/__init__.py
deleted file mode 100644
index e5737294..00000000
--- a/audiocraft/grids/diffusion/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-"""Diffusion grids."""
diff --git a/audiocraft/grids/diffusion/_explorers.py b/audiocraft/grids/diffusion/_explorers.py
deleted file mode 100644
index 0bf4ca57..00000000
--- a/audiocraft/grids/diffusion/_explorers.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import treetable as tt
-
-from .._base_explorers import BaseExplorer
-
-
-class DiffusionExplorer(BaseExplorer):
- eval_metrics = ["sisnr", "visqol"]
-
- def stages(self):
- return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"]
-
- def get_grid_meta(self):
- """Returns the list of Meta information to display for each XP/job.
- """
- return [
- tt.leaf("index", align=">"),
- tt.leaf("name", wrap=140),
- tt.leaf("state"),
- tt.leaf("sig", align=">"),
- ]
-
- def get_grid_metrics(self):
- """Return the metrics that should be displayed in the tracking table.
- """
- return [
- tt.group(
- "train",
- [
- tt.leaf("epoch"),
- tt.leaf("loss", ".3%"),
- ],
- align=">",
- ),
- tt.group(
- "valid",
- [
- tt.leaf("loss", ".3%"),
- # tt.leaf("loss_0", ".3%"),
- ],
- align=">",
- ),
- tt.group(
- "valid_ema",
- [
- tt.leaf("loss", ".3%"),
- # tt.leaf("loss_0", ".3%"),
- ],
- align=">",
- ),
- tt.group(
- "evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
- tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
- tt.leaf("rvm_3", ".4f"), ], align=">"
- ),
- tt.group(
- "evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
- tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
- tt.leaf("rvm_3", ".4f")], align=">"
- ),
- ]
diff --git a/audiocraft/models/__init__.py b/audiocraft/models/__init__.py
index 0628e05a..90008169 100644
--- a/audiocraft/models/__init__.py
+++ b/audiocraft/models/__init__.py
@@ -11,6 +11,4 @@
from .encodec import (CompressionModel, EncodecModel, DAC, HFEncodecModel,
HFEncodecCompressionModel)
from .lm import LMModel
-from .multibanddiffusion import MultiBandDiffusion
from .musicgen import MusicGen
-from .unet import DiffusionUnet
diff --git a/audiocraft/models/multibanddiffusion.py b/audiocraft/models/multibanddiffusion.py
deleted file mode 100644
index 451b5862..00000000
--- a/audiocraft/models/multibanddiffusion.py
+++ /dev/null
@@ -1,191 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""
-Multi Band Diffusion models as described in
-"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
-(paper link).
-"""
-
-import typing as tp
-
-import torch
-import julius
-
-from .unet import DiffusionUnet
-from ..modules.diffusion_schedule import NoiseSchedule
-from .encodec import CompressionModel
-from ..solvers.compression import CompressionSolver
-from .loaders import load_compression_model, load_diffusion_models
-
-
-class DiffusionProcess:
- """Sampling for a diffusion Model.
-
- Args:
- model (DiffusionUnet): Diffusion U-Net model.
- noise_schedule (NoiseSchedule): Noise schedule for diffusion process.
- """
- def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None:
- self.model = model
- self.schedule = noise_schedule
-
- def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
- step_list: tp.Optional[tp.List[int]] = None):
- """Perform one diffusion process to generate one of the bands.
-
- Args:
- condition (torch.Tensor): The embeddings from the compression model.
- initial_noise (torch.Tensor): The initial noise to start the process.
- """
- return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
- condition=condition)
-
-
-class MultiBandDiffusion:
- """Sample from multiple diffusion models.
-
- Args:
- DPs (list of DiffusionProcess): Diffusion processes.
- codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens.
- """
- def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None:
- self.DPs = DPs
- self.codec_model = codec_model
- self.device = next(self.codec_model.parameters()).device
-
- @property
- def sample_rate(self) -> int:
- return self.codec_model.sample_rate
-
- @staticmethod
- def get_mbd_musicgen(device=None):
- """Load our diffusion models trained for MusicGen."""
- if device is None:
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
- path = 'facebook/multiband-diffusion'
- filename = 'mbd_musicgen_32khz.th'
- name = 'facebook/musicgen-small'
- codec_model = load_compression_model(name, device=device)
- models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
- DPs = []
- for i in range(len(models)):
- schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
- DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
- return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
-
- @staticmethod
- def get_mbd_24khz(bw: float = 3.0,
- device: tp.Optional[tp.Union[torch.device, str]] = None,
- n_q: tp.Optional[int] = None):
- """Get the pretrained Models for MultibandDiffusion.
-
- Args:
- bw (float): Bandwidth of the compression model.
- device (torch.device or str, optional): Device on which the models are loaded.
- n_q (int, optional): Number of quantizers to use within the compression model.
- """
- if device is None:
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
- assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available"
- if n_q is not None:
- assert n_q in [2, 4, 8]
- assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \
- f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}"
- n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw]
- codec_model = CompressionSolver.model_from_checkpoint(
- '//pretrained/facebook/encodec_24khz', device=device)
- codec_model.set_num_codebooks(n_q)
- codec_model = codec_model.to(device)
- path = 'facebook/multiband-diffusion'
- filename = f'mbd_comp_{n_q}.pt'
- models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
- DPs = []
- for i in range(len(models)):
- schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
- DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
- return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
-
- @torch.no_grad()
- def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
- """Get the conditioning (i.e. latent representations of the compression model) from a waveform.
- Args:
- wav (torch.Tensor): The audio that we want to extract the conditioning from.
- sample_rate (int): Sample rate of the audio."""
- if sample_rate != self.sample_rate:
- wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
- codes, scale = self.codec_model.encode(wav)
- assert scale is None, "Scaled compression models not supported."
- emb = self.get_emb(codes)
- return emb
-
- @torch.no_grad()
- def get_emb(self, codes: torch.Tensor):
- """Get latent representation from the discrete codes.
- Args:
- codes (torch.Tensor): Discrete tokens."""
- emb = self.codec_model.decode_latent(codes)
- return emb
-
- def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
- step_list: tp.Optional[tp.List[int]] = None):
- """Generate waveform audio from the latent embeddings of the compression model.
- Args:
- emb (torch.Tensor): Conditioning embeddings
- size (None, torch.Size): Size of the output
- if None this is computed from the typical upsampling of the model.
- step_list (list[int], optional): list of Markov chain steps, defaults to 50 linearly spaced step.
- """
- if size is None:
- upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
- size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling])
- assert size[0] == emb.size(0)
- out = torch.zeros(size).to(self.device)
- for DP in self.DPs:
- out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out))
- return out
-
- def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
- """Match the eq to the encodec output by matching the standard deviation of some frequency bands.
- Args:
- wav (torch.Tensor): Audio to equalize.
- ref (torch.Tensor): Reference audio from which we match the spectrogram.
- n_bands (int): Number of bands of the eq.
- strictness (float): How strict the matching. 0 is no matching, 1 is exact matching.
- """
- split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
- bands = split(wav)
- bands_ref = split(ref)
- out = torch.zeros_like(ref)
- for i in range(n_bands):
- out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness
- return out
-
- def regenerate(self, wav: torch.Tensor, sample_rate: int):
- """Regenerate a waveform through compression and diffusion regeneration.
- Args:
- wav (torch.Tensor): Original 'ground truth' audio.
- sample_rate (int): Sample rate of the input (and output) wav.
- """
- if sample_rate != self.codec_model.sample_rate:
- wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
- emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate)
- size = wav.size()
- out = self.generate(emb, size=size)
- if sample_rate != self.codec_model.sample_rate:
- out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate)
- return out
-
- def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
- """Generate Waveform audio with diffusion from the discrete codes.
- Args:
- tokens (torch.Tensor): Discrete codes.
- n_bands (int): Bands for the eq matching.
- """
- wav_encodec = self.codec_model.decode(tokens)
- condition = self.get_emb(tokens)
- wav_diffusion = self.generate(emb=condition, size=wav_encodec.size())
- return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands)
diff --git a/audiocraft/models/unet.py b/audiocraft/models/unet.py
index db4a6df8..0061fafa 100644
--- a/audiocraft/models/unet.py
+++ b/audiocraft/models/unet.py
@@ -3,7 +3,6 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
-
"""
Pytorch Unet Module used for diffusion.
"""
@@ -24,15 +23,21 @@ class Output:
def get_model(cfg, channels: int, side: int, num_steps: int):
if cfg.model == 'unet':
- return DiffusionUnet(
- chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
+ return DiffusionUnet(chin=channels,
+ num_steps=num_steps,
+ **cfg.diffusion_unet)
else:
raise RuntimeError('Not Implemented')
class ResBlock(nn.Module):
- def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
- dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
+
+ def __init__(self,
+ channels: int,
+ kernel: int = 3,
+ norm_groups: int = 4,
+ dilation: int = 1,
+ activation: tp.Type[nn.Module] = nn.ReLU,
dropout: float = 0.):
super().__init__()
stride = 1
@@ -40,12 +45,22 @@ def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
Conv = nn.Conv1d
Drop = nn.Dropout1d
self.norm1 = nn.GroupNorm(norm_groups, channels)
- self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
+ self.conv1 = Conv(channels,
+ channels,
+ kernel,
+ 1,
+ padding,
+ dilation=dilation)
self.activation1 = activation()
self.dropout1 = Drop(dropout)
self.norm2 = nn.GroupNorm(norm_groups, channels)
- self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
+ self.conv2 = Conv(channels,
+ channels,
+ kernel,
+ 1,
+ padding,
+ dilation=dilation)
self.activation2 = activation()
self.dropout2 = Drop(dropout)
@@ -56,14 +71,24 @@ def forward(self, x):
class DecoderLayer(nn.Module):
- def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
- norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
+
+ def __init__(self,
+ chin: int,
+ chout: int,
+ kernel: int = 4,
+ stride: int = 2,
+ norm_groups: int = 4,
+ res_blocks: int = 1,
+ activation: tp.Type[nn.Module] = nn.ReLU,
dropout: float = 0.):
super().__init__()
padding = (kernel - stride) // 2
- self.res_blocks = nn.Sequential(
- *[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
- for idx in range(res_blocks)])
+ self.res_blocks = nn.Sequential(*[
+ ResBlock(chin,
+ norm_groups=norm_groups,
+ dilation=2**idx,
+ dropout=dropout) for idx in range(res_blocks)
+ ])
self.norm = nn.GroupNorm(norm_groups, chin)
ConvTr = nn.ConvTranspose1d
self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
@@ -78,8 +103,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class EncoderLayer(nn.Module):
- def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
- norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
+
+ def __init__(self,
+ chin: int,
+ chout: int,
+ kernel: int = 4,
+ stride: int = 2,
+ norm_groups: int = 4,
+ res_blocks: int = 1,
+ activation: tp.Type[nn.Module] = nn.ReLU,
dropout: float = 0.):
super().__init__()
padding = (kernel - stride) // 2
@@ -87,9 +119,12 @@ def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
self.norm = nn.GroupNorm(norm_groups, chout)
self.activation = activation()
- self.res_blocks = nn.Sequential(
- *[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
- for idx in range(res_blocks)])
+ self.res_blocks = nn.Sequential(*[
+ ResBlock(chout,
+ norm_groups=norm_groups,
+ dilation=2**idx,
+ dropout=dropout) for idx in range(res_blocks)
+ ])
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, T = x.shape
@@ -107,9 +142,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class BLSTM(nn.Module):
"""BiLSTM with same hidden units as input dim.
"""
+
def __init__(self, dim, layers=2):
super().__init__()
- self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
+ self.lstm = nn.LSTM(bidirectional=True,
+ num_layers=layers,
+ hidden_size=dim,
+ input_size=dim)
self.linear = nn.Linear(2 * dim, dim)
def forward(self, x):
@@ -121,10 +160,20 @@ def forward(self, x):
class DiffusionUnet(nn.Module):
- def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.,
- max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False,
- bilstm: bool = False, transformer: bool = False,
- codec_dim: tp.Optional[int] = None, **kwargs):
+
+ def __init__(self,
+ chin: int = 3,
+ hidden: int = 24,
+ depth: int = 3,
+ growth: float = 2.,
+ max_channels: int = 10_000,
+ num_steps: int = 1000,
+ emb_all_layers=False,
+ cross_attention: bool = False,
+ bilstm: bool = False,
+ transformer: bool = False,
+ codec_dim: tp.Optional[int] = None,
+ **kwargs):
super().__init__()
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
@@ -152,15 +201,23 @@ def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: floa
self.cross_attention = False
if transformer:
self.cross_attention = cross_attention
- self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False,
- cross_attention=cross_attention)
+ self.transformer = StreamingTransformer(
+ chin,
+ 8,
+ 6,
+ bias_ff=False,
+ bias_attn=False,
+ cross_attention=cross_attention)
self.use_codec = False
if codec_dim is not None:
self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
self.use_codec = True
- def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
+ def forward(self,
+ x: torch.Tensor,
+ step: tp.Union[int, torch.Tensor],
+ condition: tp.Optional[torch.Tensor] = None):
skips = []
bs = x.size(0)
z = x
@@ -168,25 +225,31 @@ def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition:
if type(step) is torch.Tensor:
step_tensor = step
else:
- step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
+ step_tensor = torch.tensor([step],
+ device=x.device,
+ dtype=torch.long).expand(bs)
for idx, encoder in enumerate(self.encoders):
z = encoder(z)
if idx == 0:
- z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
+ z = z + self.embedding(step_tensor).view(
+ bs, -1, *view_args).expand_as(z)
elif self.embeddings is not None:
- z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
+ z = z + self.embeddings[idx - 1](step_tensor).view(
+ bs, -1, *view_args).expand_as(z)
skips.append(z)
if self.use_codec: # insert condition in the bottleneck
assert condition is not None, "Model defined for conditionnal generation"
- condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim
+ condition_emb = self.conv_codec(
+ condition) # reshape to the bottleneck dim
assert condition_emb.size(-1) <= 2 * z.size(-1), \
f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
if not self.cross_attention:
- condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
+ condition_emb = torch.nn.functional.interpolate(
+ condition_emb, z.size(-1))
assert z.size() == condition_emb.size()
z += condition_emb
cross_attention_src = None
@@ -194,10 +257,15 @@ def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition:
cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C
B, T, C = cross_attention_src.shape
positions = torch.arange(T, device=x.device).view(1, -1, 1)
- pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
+ pos_emb = create_sin_embedding(positions,
+ C,
+ max_period=10_000,
+ dtype=cross_attention_src.dtype)
cross_attention_src = cross_attention_src + pos_emb
if self.use_transformer:
- z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
+ z = self.transformer(
+ z.permute(0, 2, 1),
+ cross_attention_src=cross_attention_src).permute(0, 2, 1)
else:
if self.bilstm is None:
z = torch.zeros_like(z)
diff --git a/audiocraft/modules/diffusion_schedule.py b/audiocraft/modules/diffusion_schedule.py
deleted file mode 100644
index 74ca6e3f..00000000
--- a/audiocraft/modules/diffusion_schedule.py
+++ /dev/null
@@ -1,272 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""
-Functions for Noise Schedule, defines diffusion process, reverse process and data processor.
-"""
-
-from collections import namedtuple
-import random
-import typing as tp
-import julius
-import torch
-
-TrainingItem = namedtuple("TrainingItem", "noisy noise step")
-
-
-def betas_from_alpha_bar(alpha_bar):
- alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
- return 1 - alphas
-
-
-class SampleProcessor(torch.nn.Module):
- def project_sample(self, x: torch.Tensor):
- """Project the original sample to the 'space' where the diffusion will happen."""
- return x
-
- def return_sample(self, z: torch.Tensor):
- """Project back from diffusion space to the actual sample space."""
- return z
-
-
-class MultiBandProcessor(SampleProcessor):
- """
- MultiBand sample processor. The input audio is splitted across
- frequency bands evenly distributed in mel-scale.
-
- Each band will be rescaled to match the power distribution
- of Gaussian noise in that band, using online metrics
- computed on the first few samples.
-
- Args:
- n_bands (int): Number of mel-bands to split the signal over.
- sample_rate (int): Sample rate of the audio.
- num_samples (int): Number of samples to use to fit the rescaling
- for each band. The processor won't be stable
- until it has seen that many samples.
- power_std (float or list/tensor): The rescaling factor computed to match the
- power of Gaussian noise in each band is taken to
- that power, i.e. `1.` means full correction of the energy
- in each band, and values less than `1` means only partial
- correction. Can be used to balance the relative importance
- of low vs. high freq in typical audio signals.
- """
- def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
- num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
- super().__init__()
- self.n_bands = n_bands
- self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
- self.num_samples = num_samples
- self.power_std = power_std
- if isinstance(power_std, list):
- assert len(power_std) == n_bands
- power_std = torch.tensor(power_std)
- self.register_buffer('counts', torch.zeros(1))
- self.register_buffer('sum_x', torch.zeros(n_bands))
- self.register_buffer('sum_x2', torch.zeros(n_bands))
- self.register_buffer('sum_target_x2', torch.zeros(n_bands))
- self.counts: torch.Tensor
- self.sum_x: torch.Tensor
- self.sum_x2: torch.Tensor
- self.sum_target_x2: torch.Tensor
-
- @property
- def mean(self):
- mean = self.sum_x / self.counts
- return mean
-
- @property
- def std(self):
- std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
- return std
-
- @property
- def target_std(self):
- target_std = self.sum_target_x2 / self.counts
- return target_std
-
- def project_sample(self, x: torch.Tensor):
- assert x.dim() == 3
- bands = self.split_bands(x)
- if self.counts.item() < self.num_samples:
- ref_bands = self.split_bands(torch.randn_like(x))
- self.counts += len(x)
- self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
- self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
- self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
- rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
- bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
- return bands.sum(dim=0)
-
- def return_sample(self, x: torch.Tensor):
- assert x.dim() == 3
- bands = self.split_bands(x)
- rescale = (self.std / self.target_std) ** self.power_std
- bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
- return bands.sum(dim=0)
-
-
-class NoiseSchedule:
- """Noise schedule for diffusion.
-
- Args:
- beta_t0 (float): Variance of the first diffusion step.
- beta_t1 (float): Variance of the last diffusion step.
- beta_exp (float): Power schedule exponent
- num_steps (int): Number of diffusion step.
- variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
- clip (float): clipping value for the denoising steps
- rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
- repartition (str): shape of the schedule only power schedule is supported
- sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
- noise_scale (float): Scaling factor for the noise
- """
- def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
- clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
- repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
- sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):
-
- self.beta_t0 = beta_t0
- self.beta_t1 = beta_t1
- self.variance = variance
- self.num_steps = num_steps
- self.clip = clip
- self.sample_processor = sample_processor
- self.rescale = rescale
- self.n_bands = n_bands
- self.noise_scale = noise_scale
- assert n_bands is None
- if repartition == "power":
- self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
- device=device, dtype=torch.float) ** beta_exp
- else:
- raise RuntimeError('Not implemented')
- self.rng = random.Random(1234)
-
- def get_beta(self, step: tp.Union[int, torch.Tensor]):
- if self.n_bands is None:
- return self.betas[step]
- else:
- return self.betas[:, step] # [n_bands, len(step)]
-
- def get_initial_noise(self, x: torch.Tensor):
- if self.n_bands is None:
- return torch.randn_like(x)
- return torch.randn((x.size(0), self.n_bands, x.size(2)))
-
- def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
- """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
- if step is None:
- return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands
- if type(step) is int:
- return (1 - self.betas[:step + 1]).prod()
- else:
- return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
-
- def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
- """Create a noisy data item for diffusion model training:
-
- Args:
- x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
- tensor_step (bool): If tensor_step = false, only one step t is sample,
- the whole batch is diffused to the same step and t is int.
- If tensor_step = true, t is a tensor of size (x.size(0),)
- every element of the batch is diffused to a independently sampled.
- """
- step: tp.Union[int, torch.Tensor]
- if tensor_step:
- bs = x.size(0)
- step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
- else:
- step = self.rng.randrange(self.num_steps)
- alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1]
-
- x = self.sample_processor.project_sample(x)
- noise = torch.randn_like(x)
- noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
- return TrainingItem(noisy, noise, step)
-
- def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
- condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
- """Full ddpm reverse process.
-
- Args:
- model (nn.Module): Diffusion model.
- initial (tensor): Initial Noise.
- condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
- return_list (bool): Whether to return the whole process or only the sampled point.
- """
- alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
- current = initial
- iterates = [initial]
- for step in range(self.num_steps)[::-1]:
- with torch.no_grad():
- estimate = model(current, step, condition=condition).sample
- alpha = 1 - self.betas[step]
- previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
- previous_alpha_bar = self.get_alpha_bar(step=step - 1)
- if step == 0:
- sigma2 = 0
- elif self.variance == 'beta':
- sigma2 = 1 - alpha
- elif self.variance == 'beta_tilde':
- sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
- elif self.variance == 'none':
- sigma2 = 0
- else:
- raise ValueError(f'Invalid variance type {self.variance}')
-
- if sigma2 > 0:
- previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
- if self.clip:
- previous = previous.clamp(-self.clip, self.clip)
- current = previous
- alpha_bar = previous_alpha_bar
- if step == 0:
- previous *= self.rescale
- if return_list:
- iterates.append(previous.cpu())
-
- if return_list:
- return iterates
- else:
- return self.sample_processor.return_sample(previous)
-
- def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
- condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
- """Reverse process that only goes through Markov chain states in step_list."""
- if step_list is None:
- step_list = list(range(1000))[::-50] + [0]
- alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
- alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
- betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
- current = initial * self.noise_scale
- iterates = [current]
- for idx, step in enumerate(step_list[:-1]):
- with torch.no_grad():
- estimate = model(current, step, condition=condition).sample * self.noise_scale
- alpha = 1 - betas_subsampled[-1 - idx]
- previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
- previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
- if step == step_list[-2]:
- sigma2 = 0
- previous_alpha_bar = torch.tensor(1.0)
- else:
- sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
- if sigma2 > 0:
- previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
- if self.clip:
- previous = previous.clamp(-self.clip, self.clip)
- current = previous
- alpha_bar = previous_alpha_bar
- if step == 0:
- previous *= self.rescale
- if return_list:
- iterates.append(previous.cpu())
- if return_list:
- return iterates
- else:
- return self.sample_processor.return_sample(previous)
diff --git a/audiocraft/solvers/__init__.py b/audiocraft/solvers/__init__.py
index 3e4f3ceb..37bd789f 100644
--- a/audiocraft/solvers/__init__.py
+++ b/audiocraft/solvers/__init__.py
@@ -13,4 +13,3 @@
from .base import StandardSolver
from .compression import CompressionSolver
from .musicgen import MusicGenSolver
-from .diffusion import DiffusionSolver
diff --git a/audiocraft/solvers/builders.py b/audiocraft/solvers/builders.py
index 07dbb33b..bbfe16d1 100644
--- a/audiocraft/solvers/builders.py
+++ b/audiocraft/solvers/builders.py
@@ -41,12 +41,10 @@ def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver:
"""Instantiate solver from config."""
from .compression import CompressionSolver
from .musicgen import MusicGenSolver
- from .diffusion import DiffusionSolver
klass = {
'compression': CompressionSolver,
'musicgen': MusicGenSolver,
'lm': MusicGenSolver, # backward compatibility
- 'diffusion': DiffusionSolver,
}[cfg.solver]
return klass(cfg) # type: ignore
diff --git a/audiocraft/solvers/diffusion.py b/audiocraft/solvers/diffusion.py
deleted file mode 100644
index 93dea252..00000000
--- a/audiocraft/solvers/diffusion.py
+++ /dev/null
@@ -1,279 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import typing as tp
-
-import flashy
-import julius
-import omegaconf
-import torch
-import torch.nn.functional as F
-
-from . import builders
-from . import base
-from .. import models
-from ..modules.diffusion_schedule import NoiseSchedule
-from ..metrics import RelativeVolumeMel
-from ..models.builders import get_processor
-from ..utils.samples.manager import SampleManager
-from ..solvers.compression import CompressionSolver
-
-
-class PerStageMetrics:
- """Handle prompting the metrics per stage.
- It outputs the metrics per range of diffusion states.
- e.g. avg loss when t in [250, 500]
- """
- def __init__(self, num_steps: int, num_stages: int = 4):
- self.num_steps = num_steps
- self.num_stages = num_stages
-
- def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]):
- if type(step) is int:
- stage = int((step / self.num_steps) * self.num_stages)
- return {f"{name}_{stage}": loss for name, loss in losses.items()}
- elif type(step) is torch.Tensor:
- stage_tensor = ((step / self.num_steps) * self.num_stages).long()
- out: tp.Dict[str, float] = {}
- for stage_idx in range(self.num_stages):
- mask = (stage_tensor == stage_idx)
- N = mask.sum()
- stage_out = {}
- if N > 0: # pass if no elements in the stage
- for name, loss in losses.items():
- stage_loss = (mask * loss).sum() / N
- stage_out[f"{name}_{stage_idx}"] = stage_loss
- out = {**out, **stage_out}
- return out
-
-
-class DataProcess:
- """Apply filtering or resampling.
-
- Args:
- initial_sr (int): Initial sample rate.
- target_sr (int): Target sample rate.
- use_resampling: Whether to use resampling or not.
- use_filter (bool):
- n_bands (int): Number of bands to consider.
- idx_band (int):
- device (torch.device or str):
- cutoffs ():
- boost (bool):
- """
- def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False,
- use_filter: bool = False, n_bands: int = 4,
- idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False):
- """Apply filtering or resampling
- Args:
- initial_sr (int): sample rate of the dataset
- target_sr (int): sample rate after resampling
- use_resampling (bool): whether or not performs resampling
- use_filter (bool): when True filter the data to keep only one frequency band
- n_bands (int): Number of bands used
- cuts (none or list): The cutoff frequencies of the band filtering
- if None then we use mel scale bands.
- idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs
- boost (bool): make the data scale match our music dataset.
- """
- assert idx_band < n_bands
- self.idx_band = idx_band
- if use_filter:
- if cutoffs is not None:
- self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device)
- else:
- self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device)
- self.use_filter = use_filter
- self.use_resampling = use_resampling
- self.target_sr = target_sr
- self.initial_sr = initial_sr
- self.boost = boost
-
- def process_data(self, x, metric=False):
- if x is None:
- return None
- if self.boost:
- x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4)
- x * 0.22
- if self.use_filter and not metric:
- x = self.filter(x)[self.idx_band]
- if self.use_resampling:
- x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr)
- return x
-
- def inverse_process(self, x):
- """Upsampling only."""
- if self.use_resampling:
- x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr)
- return x
-
-
-class DiffusionSolver(base.StandardSolver):
- """Solver for compression task.
-
- The diffusion task allows for MultiBand diffusion model training.
-
- Args:
- cfg (DictConfig): Configuration.
- """
- def __init__(self, cfg: omegaconf.DictConfig):
- super().__init__(cfg)
- self.cfg = cfg
- self.device = cfg.device
- self.sample_rate: int = self.cfg.sample_rate
- self.codec_model = CompressionSolver.model_from_checkpoint(
- cfg.compression_model_checkpoint, device=self.device)
-
- self.codec_model.set_num_codebooks(cfg.n_q)
- assert self.codec_model.sample_rate == self.cfg.sample_rate, (
- f"Codec model sample rate is {self.codec_model.sample_rate} but "
- f"Solver sample rate is {self.cfg.sample_rate}."
- )
- assert self.codec_model.sample_rate == self.sample_rate, \
- f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \
- "don't match."
-
- self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate)
- self.register_stateful('sample_processor')
- self.sample_processor.to(self.device)
-
- self.schedule = NoiseSchedule(
- **cfg.schedule, device=self.device, sample_processor=self.sample_processor)
-
- self.eval_metric: tp.Optional[torch.nn.Module] = None
-
- self.rvm = RelativeVolumeMel()
- self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr,
- use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs,
- use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands,
- idx_band=cfg.filter.idx_band, device=self.device)
-
- @property
- def best_metric_name(self) -> tp.Optional[str]:
- if self._current_stage == "evaluate":
- return 'rvm'
- else:
- return 'loss'
-
- @torch.no_grad()
- def get_condition(self, wav: torch.Tensor) -> torch.Tensor:
- codes, scale = self.codec_model.encode(wav)
- assert scale is None, "Scaled compression models not supported."
- emb = self.codec_model.decode_latent(codes)
- return emb
-
- def build_model(self):
- """Build model and optimizer as well as optional Exponential Moving Average of the model.
- """
- # Model and optimizer
- self.model = models.builders.get_diffusion_model(self.cfg).to(self.device)
- self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
- self.register_stateful('model', 'optimizer')
- self.register_best_state('model')
- self.register_ema('model')
-
- def build_dataloaders(self):
- """Build audio dataloaders for each stage."""
- self.dataloaders = builders.get_audio_datasets(self.cfg)
-
- def show(self):
- # TODO
- raise NotImplementedError()
-
- def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
- """Perform one training or valid step on a given batch."""
- x = batch.to(self.device)
- loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss
-
- condition = self.get_condition(x) # [bs, 128, T/hop, n_emb]
- sample = self.data_processor.process_data(x)
-
- input_, target, step = self.schedule.get_training_item(sample,
- tensor_step=self.cfg.schedule.variable_step_batch)
- out = self.model(input_, step, condition=condition).sample
-
- base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2))
- reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2))
- loss = base_loss / reference_loss ** self.cfg.loss.norm_power
-
- if self.is_training:
- loss.mean().backward()
- flashy.distrib.sync_model(self.model)
- self.optimizer.step()
- self.optimizer.zero_grad()
- metrics = {
- 'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(),
- }
- metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step))
- metrics.update({
- 'std_in': input_.std(), 'std_out': out.std()})
- return metrics
-
- def run_epoch(self):
- # reset random seed at the beginning of the epoch
- self.rng = torch.Generator()
- self.rng.manual_seed(1234 + self.epoch)
- self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage)
- # run epoch
- super().run_epoch()
-
- def evaluate(self):
- """Evaluate stage.
- Runs audio reconstruction evaluation.
- """
- self.model.eval()
- evaluate_stage_name = f'{self.current_stage}'
- loader = self.dataloaders['evaluate']
- updates = len(loader)
- lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates)
-
- metrics = {}
- n = 1
- for idx, batch in enumerate(lp):
- x = batch.to(self.device)
- with torch.no_grad():
- y_pred = self.regenerate(x)
-
- y_pred = y_pred.cpu()
- y = batch.cpu() # should already be on CPU but just in case
- rvm = self.rvm(y_pred, y)
- lp.update(**rvm)
- if len(metrics) == 0:
- metrics = rvm
- else:
- for key in rvm.keys():
- metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1)
- metrics = flashy.distrib.average_metrics(metrics)
- return metrics
-
- @torch.no_grad()
- def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None):
- """Regenerate the given waveform."""
- condition = self.get_condition(wav)
- initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav)) # sampling rate changes.
- result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition,
- step_list=step_list)
- result = self.data_processor.inverse_process(result)
- return result
-
- def generate(self):
- """Generate stage."""
- sample_manager = SampleManager(self.xp)
- self.model.eval()
- generate_stage_name = f'{self.current_stage}'
-
- loader = self.dataloaders['generate']
- updates = len(loader)
- lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
-
- for batch in lp:
- reference, _ = batch
- reference = reference.to(self.device)
- estimate = self.regenerate(reference)
- reference = reference.cpu()
- estimate = estimate.cpu()
- sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
- flashy.distrib.barrier()
diff --git a/demos/musicgen_app.py b/demos/musicgen_app.py
index 88cd27dc..a283d05b 100644
--- a/demos/musicgen_app.py
+++ b/demos/musicgen_app.py
@@ -28,7 +28,6 @@
from audiocraft.models.encodec import InterleaveStereoCompressionModel
from audiocraft.models import MusicGen, MultiBandDiffusion
-
MODEL = None # Last used model
SPACE_ID = os.environ.get('SPACE_ID', '')
IS_BATCHED = "facebook/MusicGen" in SPACE_ID or 'musicgen-internal/musicgen_dev' in SPACE_ID
@@ -60,6 +59,7 @@ def interrupt():
class FileCleaner:
+
def __init__(self, file_lifetime: float = 3600):
self.file_lifetime = file_lifetime
self.files = []
@@ -77,7 +77,8 @@ def _cleanup(self):
self.files.pop(0)
else:
break
-
+
+
file_cleaner = FileCleaner()
@@ -102,16 +103,15 @@ def load_model(version='facebook/musicgen-melody'):
MODEL = MusicGen.get_pretrained(version)
-def load_diffusion():
- global MBD
- if MBD is None:
- print("loading MBD")
- MBD = MultiBandDiffusion.get_mbd_musicgen()
-
-
-def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=None, **gen_kwargs):
+def _do_predictions(texts,
+ melodies,
+ duration,
+ progress=False,
+ gradio_progress=None,
+ **gen_kwargs):
MODEL.set_generation_params(duration=duration, **gen_kwargs)
- print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
+ print("new batch", len(texts), texts,
+ [None if m is None else (m[0], m[1].shape) for m in melodies])
be = time.time()
processed_melodies = []
target_sr = 32000
@@ -120,7 +120,8 @@ def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=N
if melody is None:
processed_melodies.append(None)
else:
- sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
+ sr, melody = melody[0], torch.from_numpy(melody[1]).to(
+ MODEL.device).float().t()
if melody.dim() == 1:
melody = melody[None]
melody = melody[..., :int(sr * duration)]
@@ -134,32 +135,25 @@ def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=N
melody_wavs=processed_melodies,
melody_sample_rate=target_sr,
progress=progress,
- return_tokens=USE_DIFFUSION
- )
+ return_tokens=False)
else:
- outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION)
+ outputs = MODEL.generate(texts,
+ progress=progress,
+ return_tokens=False)
except RuntimeError as e:
raise gr.Error("Error while generating " + e.args[0])
- if USE_DIFFUSION:
- if gradio_progress is not None:
- gradio_progress(1, desc='Running MultiBandDiffusion...')
- tokens = outputs[1]
- if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
- left, right = MODEL.compression_model.get_left_right_codes(tokens)
- tokens = torch.cat([left, right])
- outputs_diffusion = MBD.tokens_to_wav(tokens)
- if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
- assert outputs_diffusion.shape[1] == 1 # output is mono
- outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
- outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
outputs = outputs.detach().cpu().float()
pending_videos = []
out_wavs = []
for output in outputs:
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
- audio_write(
- file.name, output, MODEL.sample_rate, strategy="loudness",
- loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
+ audio_write(file.name,
+ output,
+ MODEL.sample_rate,
+ strategy="loudness",
+ loudness_headroom_db=16,
+ loudness_compressor=True,
+ add_suffix=False)
pending_videos.append(pool.submit(make_waveform, file.name))
out_wavs.append(file.name)
file_cleaner.add(file.name)
@@ -179,9 +173,18 @@ def predict_batched(texts, melodies):
return res
-def predict_full(model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
+def predict_full(model,
+ model_path,
+ decoder,
+ text,
+ melody,
+ duration,
+ topk,
+ topp,
+ temperature,
+ cfg_coef,
+ progress=gr.Progress()):
global INTERRUPTING
- global USE_DIFFUSION
INTERRUPTING = False
progress(0, desc="Loading model...")
model_path = model_path.strip()
@@ -189,8 +192,9 @@ def predict_full(model, model_path, decoder, text, melody, duration, topk, topp,
if not Path(model_path).exists():
raise gr.Error(f"Model path {model_path} doesn't exist.")
if not Path(model_path).is_dir():
- raise gr.Error(f"Model path {model_path} must be a folder containing "
- "state_dict.bin and compression_state_dict_.bin.")
+ raise gr.Error(
+ f"Model path {model_path} must be a folder containing "
+ "state_dict.bin and compression_state_dict_.bin.")
model = model_path
if temperature < 0:
raise gr.Error("Temperature must be >= 0.")
@@ -200,12 +204,6 @@ def predict_full(model, model_path, decoder, text, melody, duration, topk, topp,
raise gr.Error("Topp must be non-negative.")
topk = int(topk)
- if decoder == "MultiBand_Diffusion":
- USE_DIFFUSION = True
- progress(0, desc="Loading diffusion model...")
- load_diffusion()
- else:
- USE_DIFFUSION = False
load_model(model)
max_generated = 0
@@ -216,14 +214,17 @@ def _progress(generated, to_generate):
progress((min(max_generated, to_generate), to_generate))
if INTERRUPTING:
raise gr.Error("Interrupted.")
+
MODEL.set_custom_progress_callback(_progress)
- videos, wavs = _do_predictions(
- [text], [melody], duration, progress=True,
- top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef,
- gradio_progress=progress)
- if USE_DIFFUSION:
- return videos[0], wavs[0], videos[1], wavs[1]
+ videos, wavs = _do_predictions([text], [melody],
+ duration,
+ progress=True,
+ top_k=topk,
+ top_p=topp,
+ temperature=temperature,
+ cfg_coef=cfg_coef,
+ gradio_progress=progress)
return videos[0], wavs[0], None, None
@@ -234,110 +235,132 @@ def toggle_audio_src(choice):
return gr.update(source="upload", value=None, label="File")
-def toggle_diffusion(choice):
- if choice == "MultiBand_Diffusion":
- return [gr.update(visible=True)] * 2
- else:
- return [gr.update(visible=False)] * 2
-
-
def ui_full(launch_kwargs):
with gr.Blocks() as interface:
- gr.Markdown(
- """
+ gr.Markdown("""
# MusicGen
This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
a simple and controllable model for music generation
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
- """
- )
+ """)
with gr.Row():
with gr.Column():
with gr.Row():
text = gr.Text(label="Input Text", interactive=True)
with gr.Column():
- radio = gr.Radio(["file", "mic"], value="file",
- label="Condition on a melody (optional) File or Mic")
- melody = gr.Audio(sources=["upload"], type="numpy", label="File",
- interactive=True, elem_id="melody-input")
+ radio = gr.Radio(
+ ["file", "mic"],
+ value="file",
+ label="Condition on a melody (optional) File or Mic"
+ )
+ melody = gr.Audio(sources=["upload"],
+ type="numpy",
+ label="File",
+ interactive=True,
+ elem_id="melody-input")
with gr.Row():
submit = gr.Button("Submit")
# Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
_ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
with gr.Row():
- model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small",
- "facebook/musicgen-large", "facebook/musicgen-melody-large",
- "facebook/musicgen-stereo-small", "facebook/musicgen-stereo-medium",
- "facebook/musicgen-stereo-melody", "facebook/musicgen-stereo-large",
- "facebook/musicgen-stereo-melody-large"],
- label="Model", value="facebook/musicgen-stereo-melody", interactive=True)
+ model = gr.Radio([
+ "facebook/musicgen-melody", "facebook/musicgen-medium",
+ "facebook/musicgen-small", "facebook/musicgen-large",
+ "facebook/musicgen-melody-large",
+ "facebook/musicgen-stereo-small",
+ "facebook/musicgen-stereo-medium",
+ "facebook/musicgen-stereo-melody",
+ "facebook/musicgen-stereo-large",
+ "facebook/musicgen-stereo-melody-large"
+ ],
+ label="Model",
+ value="facebook/musicgen-stereo-melody",
+ interactive=True)
model_path = gr.Text(label="Model Path (custom models)")
with gr.Row():
decoder = gr.Radio(["Default", "MultiBand_Diffusion"],
- label="Decoder", value="Default", interactive=True)
+ label="Decoder",
+ value="Default",
+ interactive=True)
with gr.Row():
- duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
+ duration = gr.Slider(minimum=1,
+ maximum=120,
+ value=10,
+ label="Duration",
+ interactive=True)
with gr.Row():
- topk = gr.Number(label="Top-k", value=250, interactive=True)
+ topk = gr.Number(label="Top-k",
+ value=250,
+ interactive=True)
topp = gr.Number(label="Top-p", value=0, interactive=True)
- temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
- cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
+ temperature = gr.Number(label="Temperature",
+ value=1.0,
+ interactive=True)
+ cfg_coef = gr.Number(label="Classifier Free Guidance",
+ value=3.0,
+ interactive=True)
with gr.Column():
output = gr.Video(label="Generated Music")
- audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
- diffusion_output = gr.Video(label="MultiBand Diffusion Decoder")
- audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath')
- submit.click(toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False,
- show_progress=False).then(predict_full, inputs=[model, model_path, decoder, text, melody, duration, topk, topp,
- temperature, cfg_coef],
- outputs=[output, audio_output, diffusion_output, audio_diffusion])
- radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
+ audio_output = gr.Audio(label="Generated Music (wav)",
+ type='filepath')
+ diffusion_output = gr.Video(
+ label="MultiBand Diffusion Decoder")
+ audio_diffusion = gr.Audio(
+ label="MultiBand Diffusion Decoder (wav)", type='filepath')
+ submit.click(toggle_diffusion,
+ decoder, [diffusion_output, audio_diffusion],
+ queue=False,
+ show_progress=False).then(predict_full,
+ inputs=[
+ model, model_path, decoder,
+ text, melody, duration,
+ topk, topp, temperature,
+ cfg_coef
+ ],
+ outputs=[
+ output, audio_output,
+ diffusion_output,
+ audio_diffusion
+ ])
+ radio.change(toggle_audio_src,
+ radio, [melody],
+ queue=False,
+ show_progress=False)
gr.Examples(
fn=predict_full,
examples=[
[
"An 80s driving pop song with heavy drums and synth pads in the background",
- "./assets/bach.mp3",
- "facebook/musicgen-stereo-melody",
+ "./assets/bach.mp3", "facebook/musicgen-stereo-melody",
"Default"
],
[
"A cheerful country song with acoustic guitars",
"./assets/bolero_ravel.mp3",
- "facebook/musicgen-stereo-melody",
- "Default"
+ "facebook/musicgen-stereo-melody", "Default"
],
[
- "90s rock song with electric guitar and heavy drums",
- None,
- "facebook/musicgen-stereo-medium",
- "Default"
+ "90s rock song with electric guitar and heavy drums", None,
+ "facebook/musicgen-stereo-medium", "Default"
],
[
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
- "./assets/bach.mp3",
- "facebook/musicgen-stereo-melody",
+ "./assets/bach.mp3", "facebook/musicgen-stereo-melody",
"Default"
],
[
- "lofi slow bpm electro chill with organic samples",
- None,
- "facebook/musicgen-stereo-medium",
- "Default"
+ "lofi slow bpm electro chill with organic samples", None,
+ "facebook/musicgen-stereo-medium", "Default"
],
[
- "Punk rock with loud drum and power guitar",
- None,
- "facebook/musicgen-stereo-medium",
- "MultiBand_Diffusion"
+ "Punk rock with loud drum and power guitar", None,
+ "facebook/musicgen-stereo-medium", "MultiBand_Diffusion"
],
],
inputs=[text, melody, model, decoder],
- outputs=[output]
- )
- gr.Markdown(
- """
+ outputs=[output])
+ gr.Markdown("""
### More details
The model will generate a short music extract based on the description you provided.
@@ -378,16 +401,14 @@ def ui_full(launch_kwargs):
See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md)
for more details.
- """
- )
+ """)
interface.queue().launch(**launch_kwargs)
def ui_batched(launch_kwargs):
with gr.Blocks() as demo:
- gr.Markdown(
- """
+ gr.Markdown("""
# MusicGen
This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md),
@@ -399,25 +420,39 @@ def ui_batched(launch_kwargs):
for longer sequences, more control and no queue.
- """
- )
+ """)
with gr.Row():
with gr.Column():
with gr.Row():
- text = gr.Text(label="Describe your music", lines=2, interactive=True)
+ text = gr.Text(label="Describe your music",
+ lines=2,
+ interactive=True)
with gr.Column():
- radio = gr.Radio(["file", "mic"], value="file",
- label="Condition on a melody (optional) File or Mic")
- melody = gr.Audio(source="upload", type="numpy", label="File",
- interactive=True, elem_id="melody-input")
+ radio = gr.Radio(
+ ["file", "mic"],
+ value="file",
+ label="Condition on a melody (optional) File or Mic"
+ )
+ melody = gr.Audio(source="upload",
+ type="numpy",
+ label="File",
+ interactive=True,
+ elem_id="melody-input")
with gr.Row():
submit = gr.Button("Generate")
with gr.Column():
output = gr.Video(label="Generated Music")
- audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
- submit.click(predict_batched, inputs=[text, melody],
- outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE)
- radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
+ audio_output = gr.Audio(label="Generated Music (wav)",
+ type='filepath')
+ submit.click(predict_batched,
+ inputs=[text, melody],
+ outputs=[output, audio_output],
+ batch=True,
+ max_batch_size=MAX_BATCH_SIZE)
+ radio.change(toggle_audio_src,
+ radio, [melody],
+ queue=False,
+ show_progress=False)
gr.Examples(
fn=predict_batched,
examples=[
@@ -443,8 +478,7 @@ def ui_batched(launch_kwargs):
],
],
inputs=[text, melody],
- outputs=[output]
- )
+ outputs=[output])
gr.Markdown("""
### More details
@@ -482,24 +516,26 @@ def ui_batched(launch_kwargs):
default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
help='IP to listen on for connections to Gradio',
)
- parser.add_argument(
- '--username', type=str, default='', help='Username for authentication'
- )
- parser.add_argument(
- '--password', type=str, default='', help='Password for authentication'
- )
+ parser.add_argument('--username',
+ type=str,
+ default='',
+ help='Username for authentication')
+ parser.add_argument('--password',
+ type=str,
+ default='',
+ help='Password for authentication')
parser.add_argument(
'--server_port',
type=int,
default=0,
help='Port to run the server listener on',
)
- parser.add_argument(
- '--inbrowser', action='store_true', help='Open in browser'
- )
- parser.add_argument(
- '--share', action='store_true', help='Share the gradio UI'
- )
+ parser.add_argument('--inbrowser',
+ action='store_true',
+ help='Open in browser')
+ parser.add_argument('--share',
+ action='store_true',
+ help='Share the gradio UI')
args = parser.parse_args()
@@ -519,8 +555,6 @@ def ui_batched(launch_kwargs):
# Show the interface
if IS_BATCHED:
- global USE_DIFFUSION
- USE_DIFFUSION = False
ui_batched(launch_kwargs)
else:
ui_full(launch_kwargs)
diff --git a/tests/models/test_multibanddiffusion.py b/tests/models/test_multibanddiffusion.py
deleted file mode 100644
index 2702a3cb..00000000
--- a/tests/models/test_multibanddiffusion.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import random
-
-import numpy as np
-import torch
-from audiocraft.models.multibanddiffusion import MultiBandDiffusion, DiffusionProcess
-from audiocraft.models import EncodecModel, DiffusionUnet
-from audiocraft.modules import SEANetEncoder, SEANetDecoder
-from audiocraft.modules.diffusion_schedule import NoiseSchedule
-from audiocraft.quantization import DummyQuantizer
-
-
-class TestMBD:
-
- def _create_mbd(self,
- sample_rate: int,
- channels: int,
- n_filters: int = 3,
- n_residual_layers: int = 1,
- ratios: list = [5, 4, 3, 2],
- num_steps: int = 1000,
- codec_dim: int = 128,
- **kwargs):
- frame_rate = np.prod(ratios)
- encoder = SEANetEncoder(channels=channels, dimension=codec_dim, n_filters=n_filters,
- n_residual_layers=n_residual_layers, ratios=ratios)
- decoder = SEANetDecoder(channels=channels, dimension=codec_dim, n_filters=n_filters,
- n_residual_layers=n_residual_layers, ratios=ratios)
- quantizer = DummyQuantizer()
- compression_model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate,
- sample_rate=sample_rate, channels=channels, **kwargs)
- diffusion_model = DiffusionUnet(chin=channels, num_steps=num_steps, codec_dim=codec_dim)
- schedule = NoiseSchedule(device='cpu', num_steps=num_steps)
- DP = DiffusionProcess(model=diffusion_model, noise_schedule=schedule)
- mbd = MultiBandDiffusion(DPs=[DP], codec_model=compression_model)
- return mbd
-
- def test_model(self):
- random.seed(1234)
- sample_rate = 24_000
- channels = 1
- codec_dim = 128
- mbd = self._create_mbd(sample_rate=sample_rate, channels=channels, codec_dim=codec_dim)
- for _ in range(10):
- length = random.randrange(1, 10_000)
- x = torch.randn(2, channels, length)
- res = mbd.regenerate(x, sample_rate)
- assert res.shape == x.shape
diff --git a/utils/export.py b/utils/export.py
index e437a507..1161dabc 100644
--- a/utils/export.py
+++ b/utils/export.py
@@ -6,4 +6,5 @@
export.export_lm(xp.folder / 'checkpoint.th', '../export/state_dict.bin')
# Export the pretrained EnCodec model
-export.export_pretrained_compression_model('facebook/encodec_32khz', '/home/maxwell/grimes_keys_model/compression_state_dict.bin')
+export.export_pretrained_compression_model(
+ 'facebook/encodec_32khz', '../export/compression_state_dict.bin')
From b170ee0fe9e9c7aa0914f25d5e491dd38532a516 Mon Sep 17 00:00:00 2001
From: maxardito
Date: Tue, 9 Apr 2024 13:35:53 -0400
Subject: [PATCH 2/3] Removed diffusion processor
---
audiocraft/models/builders.py | 144 ++++++++++++++++++++--------------
1 file changed, 83 insertions(+), 61 deletions(-)
diff --git a/audiocraft/models/builders.py b/audiocraft/models/builders.py
index 46a4d96f..ae286c98 100644
--- a/audiocraft/models/builders.py
+++ b/audiocraft/models/builders.py
@@ -3,7 +3,6 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
-
"""
All the functions to build the relevant models and modules
from the Hydra config.
@@ -37,10 +36,10 @@
from .unet import DiffusionUnet
from .. import quantization as qt
from ..utils.utils import dict_from_config
-from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
-def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
+def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig,
+ dimension: int) -> qt.BaseQuantizer:
klass = {
'no_quant': qt.DummyQuantizer,
'rvq': qt.ResidualVectorQuantizer
@@ -77,8 +76,12 @@ def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
renormalize = kwargs.pop('renormalize', False)
# deprecated params
kwargs.pop('renorm', None)
- return EncodecModel(encoder, decoder, quantizer,
- frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
+ return EncodecModel(encoder,
+ decoder,
+ quantizer,
+ frame_rate=frame_rate,
+ renormalize=renormalize,
+ **kwargs).to(cfg.device)
else:
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
@@ -91,37 +94,44 @@ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
q_modeling = kwargs.pop('q_modeling', None)
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
- cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
- cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
+ cls_free_guidance = dict_from_config(
+ getattr(cfg, 'classifier_free_guidance'))
+ cfg_prob, cfg_coef = cls_free_guidance[
+ 'training_dropout'], cls_free_guidance['inference_coef']
fuser = get_condition_fuser(cfg)
- condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
- if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
+ condition_provider = get_conditioner_provider(kwargs["dim"],
+ cfg).to(cfg.device)
+ if len(fuser.fuse2cond['cross']
+ ) > 0: # enforce cross-att programmatically
kwargs['cross_attention'] = True
if codebooks_pattern_cfg.modeling is None:
assert q_modeling is not None, \
"LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
- codebooks_pattern_cfg = omegaconf.OmegaConf.create(
- {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
- )
+ codebooks_pattern_cfg = omegaconf.OmegaConf.create({
+ 'modeling': q_modeling,
+ 'delay': {
+ 'delays': list(range(n_q))
+ }
+ })
- pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
+ pattern_provider = get_codebooks_pattern_provider(
+ n_q, codebooks_pattern_cfg)
lm_class = LMModel
- return lm_class(
- pattern_provider=pattern_provider,
- condition_provider=condition_provider,
- fuser=fuser,
- cfg_dropout=cfg_prob,
- cfg_coef=cfg_coef,
- attribute_dropout=attribute_dropout,
- dtype=getattr(torch, cfg.dtype),
- device=cfg.device,
- **kwargs
- ).to(cfg.device)
+ return lm_class(pattern_provider=pattern_provider,
+ condition_provider=condition_provider,
+ fuser=fuser,
+ cfg_dropout=cfg_prob,
+ cfg_coef=cfg_coef,
+ attribute_dropout=attribute_dropout,
+ dtype=getattr(torch, cfg.dtype),
+ device=cfg.device,
+ **kwargs).to(cfg.device)
else:
raise KeyError(f"Unexpected LM model {cfg.lm_model}")
-def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
+def get_conditioner_provider(
+ output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
"""Instantiate a conditioning model."""
device = cfg.device
duration = cfg.dataset.segment_duration
@@ -136,25 +146,26 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond
model_type = cond_cfg['model']
model_args = cond_cfg[model_type]
if model_type == 't5':
- conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
+ conditioners[str(cond)] = T5Conditioner(output_dim=output_dim,
+ device=device,
+ **model_args)
elif model_type == 'lut':
- conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
+ conditioners[str(cond)] = LUTConditioner(output_dim=output_dim,
+ **model_args)
elif model_type == 'chroma_stem':
conditioners[str(cond)] = ChromaStemConditioner(
output_dim=output_dim,
duration=duration,
device=device,
- **model_args
- )
+ **model_args)
elif model_type == 'clap':
conditioners[str(cond)] = CLAPEmbeddingConditioner(
- output_dim=output_dim,
- device=device,
- **model_args
- )
+ output_dim=output_dim, device=device, **model_args)
else:
raise ValueError(f"Unrecognized conditioning model: {model_type}")
- conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
+ conditioner = ConditioningProvider(conditioners,
+ device=device,
+ **condition_provider_args)
return conditioner
@@ -168,7 +179,8 @@ def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
return fuser
-def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
+def get_codebooks_pattern_provider(
+ n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
"""Instantiate a codebooks pattern provider object."""
pattern_providers = {
'parallel': ParallelPatternProvider,
@@ -185,7 +197,9 @@ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> Codeb
def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
"""Instantiate a debug compression model to be used for unit tests."""
- assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model"
+ assert sample_rate in [
+ 16000, 32000
+ ], "unsupported sample rate for debug compression model"
model_ratios = {
16000: [10, 8, 8], # 25 Hz at 16kHz
32000: [10, 8, 16] # 25 Hz at 32kHz
@@ -203,9 +217,12 @@ def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
init_x = torch.randn(8, 32, 128)
quantizer(init_x, 1) # initialize kmeans etc.
- compression_model = EncodecModel(
- encoder, decoder, quantizer,
- frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device)
+ compression_model = EncodecModel(encoder,
+ decoder,
+ quantizer,
+ frame_rate=frame_rate,
+ sample_rate=sample_rate,
+ channels=1).to(device)
return compression_model.eval()
@@ -213,19 +230,9 @@ def get_diffusion_model(cfg: omegaconf.DictConfig):
# TODO Find a way to infer the channels from dset
channels = cfg.channels
num_steps = cfg.schedule.num_steps
- return DiffusionUnet(
- chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
-
-
-def get_processor(cfg, sample_rate: int = 24000):
- sample_processor = SampleProcessor()
- if cfg.use:
- kw = dict(cfg)
- kw.pop('use')
- kw.pop('name')
- if cfg.name == "multi_band_processor":
- sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
- return sample_processor
+ return DiffusionUnet(chin=channels,
+ num_steps=num_steps,
+ **cfg.diffusion_unet)
def get_debug_lm_model(device='cpu'):
@@ -233,16 +240,30 @@ def get_debug_lm_model(device='cpu'):
pattern = DelayedPatternProvider(n_q=4)
dim = 16
providers = {
- 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
+ 'description':
+ LUTConditioner(n_bins=128,
+ dim=dim,
+ output_dim=dim,
+ tokenizer="whitespace"),
}
condition_provider = ConditioningProvider(providers)
- fuser = ConditionFuser(
- {'cross': ['description'], 'prepend': [],
- 'sum': [], 'input_interpolate': []})
- lm = LMModel(
- pattern, condition_provider, fuser,
- n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
- cross_attention=True, causal=True)
+ fuser = ConditionFuser({
+ 'cross': ['description'],
+ 'prepend': [],
+ 'sum': [],
+ 'input_interpolate': []
+ })
+ lm = LMModel(pattern,
+ condition_provider,
+ fuser,
+ n_q=4,
+ card=400,
+ dim=dim,
+ num_heads=4,
+ custom=True,
+ num_layers=2,
+ cross_attention=True,
+ causal=True)
return lm.to(device).eval()
@@ -253,7 +274,8 @@ def get_wrapped_compression_model(
if cfg.interleave_stereo_codebooks.use:
kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
kwargs.pop('use')
- compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs)
+ compression_model = InterleaveStereoCompressionModel(
+ compression_model, **kwargs)
if hasattr(cfg, 'compression_model_n_q'):
if cfg.compression_model_n_q is not None:
compression_model.set_num_codebooks(cfg.compression_model_n_q)
From 28db0cb250f683ae013f66a2cf730ae20f239b08 Mon Sep 17 00:00:00 2001
From: maxardito
Date: Tue, 9 Apr 2024 13:44:01 -0400
Subject: [PATCH 3/3] Removed UNet
---
audiocraft/models/builders.py | 10 --
audiocraft/models/unet.py | 282 ----------------------------------
config/model/score/basic.yaml | 17 --
3 files changed, 309 deletions(-)
delete mode 100644 audiocraft/models/unet.py
delete mode 100644 config/model/score/basic.yaml
diff --git a/audiocraft/models/builders.py b/audiocraft/models/builders.py
index ae286c98..f8da9b72 100644
--- a/audiocraft/models/builders.py
+++ b/audiocraft/models/builders.py
@@ -33,7 +33,6 @@
LUTConditioner,
T5Conditioner,
)
-from .unet import DiffusionUnet
from .. import quantization as qt
from ..utils.utils import dict_from_config
@@ -226,15 +225,6 @@ def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
return compression_model.eval()
-def get_diffusion_model(cfg: omegaconf.DictConfig):
- # TODO Find a way to infer the channels from dset
- channels = cfg.channels
- num_steps = cfg.schedule.num_steps
- return DiffusionUnet(chin=channels,
- num_steps=num_steps,
- **cfg.diffusion_unet)
-
-
def get_debug_lm_model(device='cpu'):
"""Instantiate a debug LM to be used for unit tests."""
pattern = DelayedPatternProvider(n_q=4)
diff --git a/audiocraft/models/unet.py b/audiocraft/models/unet.py
deleted file mode 100644
index 0061fafa..00000000
--- a/audiocraft/models/unet.py
+++ /dev/null
@@ -1,282 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-"""
-Pytorch Unet Module used for diffusion.
-"""
-
-from dataclasses import dataclass
-import typing as tp
-
-import torch
-from torch import nn
-from torch.nn import functional as F
-from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding
-
-
-@dataclass
-class Output:
- sample: torch.Tensor
-
-
-def get_model(cfg, channels: int, side: int, num_steps: int):
- if cfg.model == 'unet':
- return DiffusionUnet(chin=channels,
- num_steps=num_steps,
- **cfg.diffusion_unet)
- else:
- raise RuntimeError('Not Implemented')
-
-
-class ResBlock(nn.Module):
-
- def __init__(self,
- channels: int,
- kernel: int = 3,
- norm_groups: int = 4,
- dilation: int = 1,
- activation: tp.Type[nn.Module] = nn.ReLU,
- dropout: float = 0.):
- super().__init__()
- stride = 1
- padding = dilation * (kernel - stride) // 2
- Conv = nn.Conv1d
- Drop = nn.Dropout1d
- self.norm1 = nn.GroupNorm(norm_groups, channels)
- self.conv1 = Conv(channels,
- channels,
- kernel,
- 1,
- padding,
- dilation=dilation)
- self.activation1 = activation()
- self.dropout1 = Drop(dropout)
-
- self.norm2 = nn.GroupNorm(norm_groups, channels)
- self.conv2 = Conv(channels,
- channels,
- kernel,
- 1,
- padding,
- dilation=dilation)
- self.activation2 = activation()
- self.dropout2 = Drop(dropout)
-
- def forward(self, x):
- h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
- h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
- return x + h
-
-
-class DecoderLayer(nn.Module):
-
- def __init__(self,
- chin: int,
- chout: int,
- kernel: int = 4,
- stride: int = 2,
- norm_groups: int = 4,
- res_blocks: int = 1,
- activation: tp.Type[nn.Module] = nn.ReLU,
- dropout: float = 0.):
- super().__init__()
- padding = (kernel - stride) // 2
- self.res_blocks = nn.Sequential(*[
- ResBlock(chin,
- norm_groups=norm_groups,
- dilation=2**idx,
- dropout=dropout) for idx in range(res_blocks)
- ])
- self.norm = nn.GroupNorm(norm_groups, chin)
- ConvTr = nn.ConvTranspose1d
- self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
- self.activation = activation()
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.res_blocks(x)
- x = self.norm(x)
- x = self.activation(x)
- x = self.convtr(x)
- return x
-
-
-class EncoderLayer(nn.Module):
-
- def __init__(self,
- chin: int,
- chout: int,
- kernel: int = 4,
- stride: int = 2,
- norm_groups: int = 4,
- res_blocks: int = 1,
- activation: tp.Type[nn.Module] = nn.ReLU,
- dropout: float = 0.):
- super().__init__()
- padding = (kernel - stride) // 2
- Conv = nn.Conv1d
- self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
- self.norm = nn.GroupNorm(norm_groups, chout)
- self.activation = activation()
- self.res_blocks = nn.Sequential(*[
- ResBlock(chout,
- norm_groups=norm_groups,
- dilation=2**idx,
- dropout=dropout) for idx in range(res_blocks)
- ])
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- B, C, T = x.shape
- stride, = self.conv.stride
- pad = (stride - (T % stride)) % stride
- x = F.pad(x, (0, pad))
-
- x = self.conv(x)
- x = self.norm(x)
- x = self.activation(x)
- x = self.res_blocks(x)
- return x
-
-
-class BLSTM(nn.Module):
- """BiLSTM with same hidden units as input dim.
- """
-
- def __init__(self, dim, layers=2):
- super().__init__()
- self.lstm = nn.LSTM(bidirectional=True,
- num_layers=layers,
- hidden_size=dim,
- input_size=dim)
- self.linear = nn.Linear(2 * dim, dim)
-
- def forward(self, x):
- x = x.permute(2, 0, 1)
- x = self.lstm(x)[0]
- x = self.linear(x)
- x = x.permute(1, 2, 0)
- return x
-
-
-class DiffusionUnet(nn.Module):
-
- def __init__(self,
- chin: int = 3,
- hidden: int = 24,
- depth: int = 3,
- growth: float = 2.,
- max_channels: int = 10_000,
- num_steps: int = 1000,
- emb_all_layers=False,
- cross_attention: bool = False,
- bilstm: bool = False,
- transformer: bool = False,
- codec_dim: tp.Optional[int] = None,
- **kwargs):
- super().__init__()
- self.encoders = nn.ModuleList()
- self.decoders = nn.ModuleList()
- self.embeddings: tp.Optional[nn.ModuleList] = None
- self.embedding = nn.Embedding(num_steps, hidden)
- if emb_all_layers:
- self.embeddings = nn.ModuleList()
- self.condition_embedding: tp.Optional[nn.Module] = None
- for d in range(depth):
- encoder = EncoderLayer(chin, hidden, **kwargs)
- decoder = DecoderLayer(hidden, chin, **kwargs)
- self.encoders.append(encoder)
- self.decoders.insert(0, decoder)
- if emb_all_layers and d > 0:
- assert self.embeddings is not None
- self.embeddings.append(nn.Embedding(num_steps, hidden))
- chin = hidden
- hidden = min(int(chin * growth), max_channels)
- self.bilstm: tp.Optional[nn.Module]
- if bilstm:
- self.bilstm = BLSTM(chin)
- else:
- self.bilstm = None
- self.use_transformer = transformer
- self.cross_attention = False
- if transformer:
- self.cross_attention = cross_attention
- self.transformer = StreamingTransformer(
- chin,
- 8,
- 6,
- bias_ff=False,
- bias_attn=False,
- cross_attention=cross_attention)
-
- self.use_codec = False
- if codec_dim is not None:
- self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
- self.use_codec = True
-
- def forward(self,
- x: torch.Tensor,
- step: tp.Union[int, torch.Tensor],
- condition: tp.Optional[torch.Tensor] = None):
- skips = []
- bs = x.size(0)
- z = x
- view_args = [1]
- if type(step) is torch.Tensor:
- step_tensor = step
- else:
- step_tensor = torch.tensor([step],
- device=x.device,
- dtype=torch.long).expand(bs)
-
- for idx, encoder in enumerate(self.encoders):
- z = encoder(z)
- if idx == 0:
- z = z + self.embedding(step_tensor).view(
- bs, -1, *view_args).expand_as(z)
- elif self.embeddings is not None:
- z = z + self.embeddings[idx - 1](step_tensor).view(
- bs, -1, *view_args).expand_as(z)
-
- skips.append(z)
-
- if self.use_codec: # insert condition in the bottleneck
- assert condition is not None, "Model defined for conditionnal generation"
- condition_emb = self.conv_codec(
- condition) # reshape to the bottleneck dim
- assert condition_emb.size(-1) <= 2 * z.size(-1), \
- f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
- if not self.cross_attention:
-
- condition_emb = torch.nn.functional.interpolate(
- condition_emb, z.size(-1))
- assert z.size() == condition_emb.size()
- z += condition_emb
- cross_attention_src = None
- else:
- cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C
- B, T, C = cross_attention_src.shape
- positions = torch.arange(T, device=x.device).view(1, -1, 1)
- pos_emb = create_sin_embedding(positions,
- C,
- max_period=10_000,
- dtype=cross_attention_src.dtype)
- cross_attention_src = cross_attention_src + pos_emb
- if self.use_transformer:
- z = self.transformer(
- z.permute(0, 2, 1),
- cross_attention_src=cross_attention_src).permute(0, 2, 1)
- else:
- if self.bilstm is None:
- z = torch.zeros_like(z)
- else:
- z = self.bilstm(z)
-
- for decoder in self.decoders:
- s = skips.pop(-1)
- z = z[:, :, :s.shape[2]]
- z = z + s
- z = decoder(z)
-
- z = z[:, :, :x.shape[2]]
- return Output(z)
diff --git a/config/model/score/basic.yaml b/config/model/score/basic.yaml
deleted file mode 100644
index 75fbc378..00000000
--- a/config/model/score/basic.yaml
+++ /dev/null
@@ -1,17 +0,0 @@
-# @package _global_
-
-diffusion_unet:
- hidden: 48
- depth: 4
- res_blocks: 1
- norm_groups: 4
- kernel: 8
- stride: 4
- growth: 4
- max_channels: 10_000
- dropout: 0.
- emb_all_layers: true
- bilstm: false
- codec_dim: null
- transformer: false
- cross_attention: false
\ No newline at end of file