From 4cfdabcac76460e032f4ebe5d21674e3c1b1332b Mon Sep 17 00:00:00 2001 From: maxardito Date: Fri, 29 Mar 2024 14:33:56 -0400 Subject: [PATCH 1/3] removed most of diffusion --- .gitignore | 1 + audiocraft/__init__.py | 2 - .../grids/diffusion/4_bands_base_32khz.py | 27 -- audiocraft/grids/diffusion/__init__.py | 6 - audiocraft/grids/diffusion/_explorers.py | 66 ---- audiocraft/models/__init__.py | 2 - audiocraft/models/multibanddiffusion.py | 191 ----------- audiocraft/models/unet.py | 132 ++++++-- audiocraft/modules/diffusion_schedule.py | 272 ---------------- audiocraft/solvers/__init__.py | 1 - audiocraft/solvers/builders.py | 2 - audiocraft/solvers/diffusion.py | 279 ---------------- demos/musicgen_app.py | 300 ++++++++++-------- tests/models/test_multibanddiffusion.py | 53 ---- utils/export.py | 3 +- 15 files changed, 270 insertions(+), 1067 deletions(-) delete mode 100644 audiocraft/grids/diffusion/4_bands_base_32khz.py delete mode 100644 audiocraft/grids/diffusion/__init__.py delete mode 100644 audiocraft/grids/diffusion/_explorers.py delete mode 100644 audiocraft/models/multibanddiffusion.py delete mode 100644 audiocraft/modules/diffusion_schedule.py delete mode 100644 audiocraft/solvers/diffusion.py delete mode 100644 tests/models/test_multibanddiffusion.py diff --git a/.gitignore b/.gitignore index 581c6caa..fa469b38 100644 --- a/.gitignore +++ b/.gitignore @@ -53,6 +53,7 @@ dataset/* musicgen-deployment.txt manifests export +sessions # personal notebooks & scripts */local_scripts diff --git a/audiocraft/__init__.py b/audiocraft/__init__.py index 1226aa02..385d9a00 100644 --- a/audiocraft/__init__.py +++ b/audiocraft/__init__.py @@ -14,8 +14,6 @@ - [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity neural audio codec which provides an excellent tokenizer for autoregressive language models. See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`. -- [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that - improves the perceived quality and reduces the artifacts coming from adversarial decoders. """ # flake8: noqa diff --git a/audiocraft/grids/diffusion/4_bands_base_32khz.py b/audiocraft/grids/diffusion/4_bands_base_32khz.py deleted file mode 100644 index f7e67bcc..00000000 --- a/audiocraft/grids/diffusion/4_bands_base_32khz.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Training of the 4 diffusion models described in -"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion" -(paper link). -""" - -from ._explorers import DiffusionExplorer - - -@DiffusionExplorer -def explorer(launcher): - launcher.slurm_(gpus=4, partition='learnfair') - - launcher.bind_({'solver': 'diffusion/default', - 'dset': 'internal/music_10k_32khz'}) - - with launcher.job_array(): - launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4}) - launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4}) - launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4}) - launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75}) diff --git a/audiocraft/grids/diffusion/__init__.py b/audiocraft/grids/diffusion/__init__.py deleted file mode 100644 index e5737294..00000000 --- a/audiocraft/grids/diffusion/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Diffusion grids.""" diff --git a/audiocraft/grids/diffusion/_explorers.py b/audiocraft/grids/diffusion/_explorers.py deleted file mode 100644 index 0bf4ca57..00000000 --- a/audiocraft/grids/diffusion/_explorers.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import treetable as tt - -from .._base_explorers import BaseExplorer - - -class DiffusionExplorer(BaseExplorer): - eval_metrics = ["sisnr", "visqol"] - - def stages(self): - return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"] - - def get_grid_meta(self): - """Returns the list of Meta information to display for each XP/job. - """ - return [ - tt.leaf("index", align=">"), - tt.leaf("name", wrap=140), - tt.leaf("state"), - tt.leaf("sig", align=">"), - ] - - def get_grid_metrics(self): - """Return the metrics that should be displayed in the tracking table. - """ - return [ - tt.group( - "train", - [ - tt.leaf("epoch"), - tt.leaf("loss", ".3%"), - ], - align=">", - ), - tt.group( - "valid", - [ - tt.leaf("loss", ".3%"), - # tt.leaf("loss_0", ".3%"), - ], - align=">", - ), - tt.group( - "valid_ema", - [ - tt.leaf("loss", ".3%"), - # tt.leaf("loss_0", ".3%"), - ], - align=">", - ), - tt.group( - "evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"), - tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"), - tt.leaf("rvm_3", ".4f"), ], align=">" - ), - tt.group( - "evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"), - tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"), - tt.leaf("rvm_3", ".4f")], align=">" - ), - ] diff --git a/audiocraft/models/__init__.py b/audiocraft/models/__init__.py index 0628e05a..90008169 100644 --- a/audiocraft/models/__init__.py +++ b/audiocraft/models/__init__.py @@ -11,6 +11,4 @@ from .encodec import (CompressionModel, EncodecModel, DAC, HFEncodecModel, HFEncodecCompressionModel) from .lm import LMModel -from .multibanddiffusion import MultiBandDiffusion from .musicgen import MusicGen -from .unet import DiffusionUnet diff --git a/audiocraft/models/multibanddiffusion.py b/audiocraft/models/multibanddiffusion.py deleted file mode 100644 index 451b5862..00000000 --- a/audiocraft/models/multibanddiffusion.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Multi Band Diffusion models as described in -"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion" -(paper link). -""" - -import typing as tp - -import torch -import julius - -from .unet import DiffusionUnet -from ..modules.diffusion_schedule import NoiseSchedule -from .encodec import CompressionModel -from ..solvers.compression import CompressionSolver -from .loaders import load_compression_model, load_diffusion_models - - -class DiffusionProcess: - """Sampling for a diffusion Model. - - Args: - model (DiffusionUnet): Diffusion U-Net model. - noise_schedule (NoiseSchedule): Noise schedule for diffusion process. - """ - def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None: - self.model = model - self.schedule = noise_schedule - - def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor, - step_list: tp.Optional[tp.List[int]] = None): - """Perform one diffusion process to generate one of the bands. - - Args: - condition (torch.Tensor): The embeddings from the compression model. - initial_noise (torch.Tensor): The initial noise to start the process. - """ - return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list, - condition=condition) - - -class MultiBandDiffusion: - """Sample from multiple diffusion models. - - Args: - DPs (list of DiffusionProcess): Diffusion processes. - codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens. - """ - def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None: - self.DPs = DPs - self.codec_model = codec_model - self.device = next(self.codec_model.parameters()).device - - @property - def sample_rate(self) -> int: - return self.codec_model.sample_rate - - @staticmethod - def get_mbd_musicgen(device=None): - """Load our diffusion models trained for MusicGen.""" - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - path = 'facebook/multiband-diffusion' - filename = 'mbd_musicgen_32khz.th' - name = 'facebook/musicgen-small' - codec_model = load_compression_model(name, device=device) - models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device) - DPs = [] - for i in range(len(models)): - schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) - DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) - return MultiBandDiffusion(DPs=DPs, codec_model=codec_model) - - @staticmethod - def get_mbd_24khz(bw: float = 3.0, - device: tp.Optional[tp.Union[torch.device, str]] = None, - n_q: tp.Optional[int] = None): - """Get the pretrained Models for MultibandDiffusion. - - Args: - bw (float): Bandwidth of the compression model. - device (torch.device or str, optional): Device on which the models are loaded. - n_q (int, optional): Number of quantizers to use within the compression model. - """ - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available" - if n_q is not None: - assert n_q in [2, 4, 8] - assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \ - f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}" - n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw] - codec_model = CompressionSolver.model_from_checkpoint( - '//pretrained/facebook/encodec_24khz', device=device) - codec_model.set_num_codebooks(n_q) - codec_model = codec_model.to(device) - path = 'facebook/multiband-diffusion' - filename = f'mbd_comp_{n_q}.pt' - models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device) - DPs = [] - for i in range(len(models)): - schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) - DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) - return MultiBandDiffusion(DPs=DPs, codec_model=codec_model) - - @torch.no_grad() - def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: - """Get the conditioning (i.e. latent representations of the compression model) from a waveform. - Args: - wav (torch.Tensor): The audio that we want to extract the conditioning from. - sample_rate (int): Sample rate of the audio.""" - if sample_rate != self.sample_rate: - wav = julius.resample_frac(wav, sample_rate, self.sample_rate) - codes, scale = self.codec_model.encode(wav) - assert scale is None, "Scaled compression models not supported." - emb = self.get_emb(codes) - return emb - - @torch.no_grad() - def get_emb(self, codes: torch.Tensor): - """Get latent representation from the discrete codes. - Args: - codes (torch.Tensor): Discrete tokens.""" - emb = self.codec_model.decode_latent(codes) - return emb - - def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None, - step_list: tp.Optional[tp.List[int]] = None): - """Generate waveform audio from the latent embeddings of the compression model. - Args: - emb (torch.Tensor): Conditioning embeddings - size (None, torch.Size): Size of the output - if None this is computed from the typical upsampling of the model. - step_list (list[int], optional): list of Markov chain steps, defaults to 50 linearly spaced step. - """ - if size is None: - upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate) - size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling]) - assert size[0] == emb.size(0) - out = torch.zeros(size).to(self.device) - for DP in self.DPs: - out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out)) - return out - - def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1): - """Match the eq to the encodec output by matching the standard deviation of some frequency bands. - Args: - wav (torch.Tensor): Audio to equalize. - ref (torch.Tensor): Reference audio from which we match the spectrogram. - n_bands (int): Number of bands of the eq. - strictness (float): How strict the matching. 0 is no matching, 1 is exact matching. - """ - split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device) - bands = split(wav) - bands_ref = split(ref) - out = torch.zeros_like(ref) - for i in range(n_bands): - out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness - return out - - def regenerate(self, wav: torch.Tensor, sample_rate: int): - """Regenerate a waveform through compression and diffusion regeneration. - Args: - wav (torch.Tensor): Original 'ground truth' audio. - sample_rate (int): Sample rate of the input (and output) wav. - """ - if sample_rate != self.codec_model.sample_rate: - wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate) - emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate) - size = wav.size() - out = self.generate(emb, size=size) - if sample_rate != self.codec_model.sample_rate: - out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate) - return out - - def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32): - """Generate Waveform audio with diffusion from the discrete codes. - Args: - tokens (torch.Tensor): Discrete codes. - n_bands (int): Bands for the eq matching. - """ - wav_encodec = self.codec_model.decode(tokens) - condition = self.get_emb(tokens) - wav_diffusion = self.generate(emb=condition, size=wav_encodec.size()) - return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands) diff --git a/audiocraft/models/unet.py b/audiocraft/models/unet.py index db4a6df8..0061fafa 100644 --- a/audiocraft/models/unet.py +++ b/audiocraft/models/unet.py @@ -3,7 +3,6 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - """ Pytorch Unet Module used for diffusion. """ @@ -24,15 +23,21 @@ class Output: def get_model(cfg, channels: int, side: int, num_steps: int): if cfg.model == 'unet': - return DiffusionUnet( - chin=channels, num_steps=num_steps, **cfg.diffusion_unet) + return DiffusionUnet(chin=channels, + num_steps=num_steps, + **cfg.diffusion_unet) else: raise RuntimeError('Not Implemented') class ResBlock(nn.Module): - def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4, - dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, + + def __init__(self, + channels: int, + kernel: int = 3, + norm_groups: int = 4, + dilation: int = 1, + activation: tp.Type[nn.Module] = nn.ReLU, dropout: float = 0.): super().__init__() stride = 1 @@ -40,12 +45,22 @@ def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4, Conv = nn.Conv1d Drop = nn.Dropout1d self.norm1 = nn.GroupNorm(norm_groups, channels) - self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) + self.conv1 = Conv(channels, + channels, + kernel, + 1, + padding, + dilation=dilation) self.activation1 = activation() self.dropout1 = Drop(dropout) self.norm2 = nn.GroupNorm(norm_groups, channels) - self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) + self.conv2 = Conv(channels, + channels, + kernel, + 1, + padding, + dilation=dilation) self.activation2 = activation() self.dropout2 = Drop(dropout) @@ -56,14 +71,24 @@ def forward(self, x): class DecoderLayer(nn.Module): - def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, - norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, + + def __init__(self, + chin: int, + chout: int, + kernel: int = 4, + stride: int = 2, + norm_groups: int = 4, + res_blocks: int = 1, + activation: tp.Type[nn.Module] = nn.ReLU, dropout: float = 0.): super().__init__() padding = (kernel - stride) // 2 - self.res_blocks = nn.Sequential( - *[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) - for idx in range(res_blocks)]) + self.res_blocks = nn.Sequential(*[ + ResBlock(chin, + norm_groups=norm_groups, + dilation=2**idx, + dropout=dropout) for idx in range(res_blocks) + ]) self.norm = nn.GroupNorm(norm_groups, chin) ConvTr = nn.ConvTranspose1d self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False) @@ -78,8 +103,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class EncoderLayer(nn.Module): - def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, - norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, + + def __init__(self, + chin: int, + chout: int, + kernel: int = 4, + stride: int = 2, + norm_groups: int = 4, + res_blocks: int = 1, + activation: tp.Type[nn.Module] = nn.ReLU, dropout: float = 0.): super().__init__() padding = (kernel - stride) // 2 @@ -87,9 +119,12 @@ def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, self.conv = Conv(chin, chout, kernel, stride, padding, bias=False) self.norm = nn.GroupNorm(norm_groups, chout) self.activation = activation() - self.res_blocks = nn.Sequential( - *[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) - for idx in range(res_blocks)]) + self.res_blocks = nn.Sequential(*[ + ResBlock(chout, + norm_groups=norm_groups, + dilation=2**idx, + dropout=dropout) for idx in range(res_blocks) + ]) def forward(self, x: torch.Tensor) -> torch.Tensor: B, C, T = x.shape @@ -107,9 +142,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class BLSTM(nn.Module): """BiLSTM with same hidden units as input dim. """ + def __init__(self, dim, layers=2): super().__init__() - self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) + self.lstm = nn.LSTM(bidirectional=True, + num_layers=layers, + hidden_size=dim, + input_size=dim) self.linear = nn.Linear(2 * dim, dim) def forward(self, x): @@ -121,10 +160,20 @@ def forward(self, x): class DiffusionUnet(nn.Module): - def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2., - max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False, - bilstm: bool = False, transformer: bool = False, - codec_dim: tp.Optional[int] = None, **kwargs): + + def __init__(self, + chin: int = 3, + hidden: int = 24, + depth: int = 3, + growth: float = 2., + max_channels: int = 10_000, + num_steps: int = 1000, + emb_all_layers=False, + cross_attention: bool = False, + bilstm: bool = False, + transformer: bool = False, + codec_dim: tp.Optional[int] = None, + **kwargs): super().__init__() self.encoders = nn.ModuleList() self.decoders = nn.ModuleList() @@ -152,15 +201,23 @@ def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: floa self.cross_attention = False if transformer: self.cross_attention = cross_attention - self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False, - cross_attention=cross_attention) + self.transformer = StreamingTransformer( + chin, + 8, + 6, + bias_ff=False, + bias_attn=False, + cross_attention=cross_attention) self.use_codec = False if codec_dim is not None: self.conv_codec = nn.Conv1d(codec_dim, chin, 1) self.use_codec = True - def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None): + def forward(self, + x: torch.Tensor, + step: tp.Union[int, torch.Tensor], + condition: tp.Optional[torch.Tensor] = None): skips = [] bs = x.size(0) z = x @@ -168,25 +225,31 @@ def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: if type(step) is torch.Tensor: step_tensor = step else: - step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs) + step_tensor = torch.tensor([step], + device=x.device, + dtype=torch.long).expand(bs) for idx, encoder in enumerate(self.encoders): z = encoder(z) if idx == 0: - z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z) + z = z + self.embedding(step_tensor).view( + bs, -1, *view_args).expand_as(z) elif self.embeddings is not None: - z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z) + z = z + self.embeddings[idx - 1](step_tensor).view( + bs, -1, *view_args).expand_as(z) skips.append(z) if self.use_codec: # insert condition in the bottleneck assert condition is not None, "Model defined for conditionnal generation" - condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim + condition_emb = self.conv_codec( + condition) # reshape to the bottleneck dim assert condition_emb.size(-1) <= 2 * z.size(-1), \ f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}" if not self.cross_attention: - condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1)) + condition_emb = torch.nn.functional.interpolate( + condition_emb, z.size(-1)) assert z.size() == condition_emb.size() z += condition_emb cross_attention_src = None @@ -194,10 +257,15 @@ def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C B, T, C = cross_attention_src.shape positions = torch.arange(T, device=x.device).view(1, -1, 1) - pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype) + pos_emb = create_sin_embedding(positions, + C, + max_period=10_000, + dtype=cross_attention_src.dtype) cross_attention_src = cross_attention_src + pos_emb if self.use_transformer: - z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1) + z = self.transformer( + z.permute(0, 2, 1), + cross_attention_src=cross_attention_src).permute(0, 2, 1) else: if self.bilstm is None: z = torch.zeros_like(z) diff --git a/audiocraft/modules/diffusion_schedule.py b/audiocraft/modules/diffusion_schedule.py deleted file mode 100644 index 74ca6e3f..00000000 --- a/audiocraft/modules/diffusion_schedule.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Functions for Noise Schedule, defines diffusion process, reverse process and data processor. -""" - -from collections import namedtuple -import random -import typing as tp -import julius -import torch - -TrainingItem = namedtuple("TrainingItem", "noisy noise step") - - -def betas_from_alpha_bar(alpha_bar): - alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]]) - return 1 - alphas - - -class SampleProcessor(torch.nn.Module): - def project_sample(self, x: torch.Tensor): - """Project the original sample to the 'space' where the diffusion will happen.""" - return x - - def return_sample(self, z: torch.Tensor): - """Project back from diffusion space to the actual sample space.""" - return z - - -class MultiBandProcessor(SampleProcessor): - """ - MultiBand sample processor. The input audio is splitted across - frequency bands evenly distributed in mel-scale. - - Each band will be rescaled to match the power distribution - of Gaussian noise in that band, using online metrics - computed on the first few samples. - - Args: - n_bands (int): Number of mel-bands to split the signal over. - sample_rate (int): Sample rate of the audio. - num_samples (int): Number of samples to use to fit the rescaling - for each band. The processor won't be stable - until it has seen that many samples. - power_std (float or list/tensor): The rescaling factor computed to match the - power of Gaussian noise in each band is taken to - that power, i.e. `1.` means full correction of the energy - in each band, and values less than `1` means only partial - correction. Can be used to balance the relative importance - of low vs. high freq in typical audio signals. - """ - def __init__(self, n_bands: int = 8, sample_rate: float = 24_000, - num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.): - super().__init__() - self.n_bands = n_bands - self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands) - self.num_samples = num_samples - self.power_std = power_std - if isinstance(power_std, list): - assert len(power_std) == n_bands - power_std = torch.tensor(power_std) - self.register_buffer('counts', torch.zeros(1)) - self.register_buffer('sum_x', torch.zeros(n_bands)) - self.register_buffer('sum_x2', torch.zeros(n_bands)) - self.register_buffer('sum_target_x2', torch.zeros(n_bands)) - self.counts: torch.Tensor - self.sum_x: torch.Tensor - self.sum_x2: torch.Tensor - self.sum_target_x2: torch.Tensor - - @property - def mean(self): - mean = self.sum_x / self.counts - return mean - - @property - def std(self): - std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() - return std - - @property - def target_std(self): - target_std = self.sum_target_x2 / self.counts - return target_std - - def project_sample(self, x: torch.Tensor): - assert x.dim() == 3 - bands = self.split_bands(x) - if self.counts.item() < self.num_samples: - ref_bands = self.split_bands(torch.randn_like(x)) - self.counts += len(x) - self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1) - self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1) - self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1) - rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size - bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1) - return bands.sum(dim=0) - - def return_sample(self, x: torch.Tensor): - assert x.dim() == 3 - bands = self.split_bands(x) - rescale = (self.std / self.target_std) ** self.power_std - bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1) - return bands.sum(dim=0) - - -class NoiseSchedule: - """Noise schedule for diffusion. - - Args: - beta_t0 (float): Variance of the first diffusion step. - beta_t1 (float): Variance of the last diffusion step. - beta_exp (float): Power schedule exponent - num_steps (int): Number of diffusion step. - variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde" - clip (float): clipping value for the denoising steps - rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1) - repartition (str): shape of the schedule only power schedule is supported - sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution - noise_scale (float): Scaling factor for the noise - """ - def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta', - clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1, - repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None, - sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs): - - self.beta_t0 = beta_t0 - self.beta_t1 = beta_t1 - self.variance = variance - self.num_steps = num_steps - self.clip = clip - self.sample_processor = sample_processor - self.rescale = rescale - self.n_bands = n_bands - self.noise_scale = noise_scale - assert n_bands is None - if repartition == "power": - self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps, - device=device, dtype=torch.float) ** beta_exp - else: - raise RuntimeError('Not implemented') - self.rng = random.Random(1234) - - def get_beta(self, step: tp.Union[int, torch.Tensor]): - if self.n_bands is None: - return self.betas[step] - else: - return self.betas[:, step] # [n_bands, len(step)] - - def get_initial_noise(self, x: torch.Tensor): - if self.n_bands is None: - return torch.randn_like(x) - return torch.randn((x.size(0), self.n_bands, x.size(2))) - - def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor: - """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step.""" - if step is None: - return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands - if type(step) is int: - return (1 - self.betas[:step + 1]).prod() - else: - return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1) - - def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem: - """Create a noisy data item for diffusion model training: - - Args: - x (torch.Tensor): clean audio data torch.tensor(bs, 1, T) - tensor_step (bool): If tensor_step = false, only one step t is sample, - the whole batch is diffused to the same step and t is int. - If tensor_step = true, t is a tensor of size (x.size(0),) - every element of the batch is diffused to a independently sampled. - """ - step: tp.Union[int, torch.Tensor] - if tensor_step: - bs = x.size(0) - step = torch.randint(0, self.num_steps, size=(bs,), device=x.device) - else: - step = self.rng.randrange(self.num_steps) - alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1] - - x = self.sample_processor.project_sample(x) - noise = torch.randn_like(x) - noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale - return TrainingItem(noisy, noise, step) - - def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None, - condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): - """Full ddpm reverse process. - - Args: - model (nn.Module): Diffusion model. - initial (tensor): Initial Noise. - condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation). - return_list (bool): Whether to return the whole process or only the sampled point. - """ - alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) - current = initial - iterates = [initial] - for step in range(self.num_steps)[::-1]: - with torch.no_grad(): - estimate = model(current, step, condition=condition).sample - alpha = 1 - self.betas[step] - previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() - previous_alpha_bar = self.get_alpha_bar(step=step - 1) - if step == 0: - sigma2 = 0 - elif self.variance == 'beta': - sigma2 = 1 - alpha - elif self.variance == 'beta_tilde': - sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) - elif self.variance == 'none': - sigma2 = 0 - else: - raise ValueError(f'Invalid variance type {self.variance}') - - if sigma2 > 0: - previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale - if self.clip: - previous = previous.clamp(-self.clip, self.clip) - current = previous - alpha_bar = previous_alpha_bar - if step == 0: - previous *= self.rescale - if return_list: - iterates.append(previous.cpu()) - - if return_list: - return iterates - else: - return self.sample_processor.return_sample(previous) - - def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None, - condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): - """Reverse process that only goes through Markov chain states in step_list.""" - if step_list is None: - step_list = list(range(1000))[::-50] + [0] - alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) - alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu() - betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled) - current = initial * self.noise_scale - iterates = [current] - for idx, step in enumerate(step_list[:-1]): - with torch.no_grad(): - estimate = model(current, step, condition=condition).sample * self.noise_scale - alpha = 1 - betas_subsampled[-1 - idx] - previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() - previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1]) - if step == step_list[-2]: - sigma2 = 0 - previous_alpha_bar = torch.tensor(1.0) - else: - sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) - if sigma2 > 0: - previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale - if self.clip: - previous = previous.clamp(-self.clip, self.clip) - current = previous - alpha_bar = previous_alpha_bar - if step == 0: - previous *= self.rescale - if return_list: - iterates.append(previous.cpu()) - if return_list: - return iterates - else: - return self.sample_processor.return_sample(previous) diff --git a/audiocraft/solvers/__init__.py b/audiocraft/solvers/__init__.py index 3e4f3ceb..37bd789f 100644 --- a/audiocraft/solvers/__init__.py +++ b/audiocraft/solvers/__init__.py @@ -13,4 +13,3 @@ from .base import StandardSolver from .compression import CompressionSolver from .musicgen import MusicGenSolver -from .diffusion import DiffusionSolver diff --git a/audiocraft/solvers/builders.py b/audiocraft/solvers/builders.py index 07dbb33b..bbfe16d1 100644 --- a/audiocraft/solvers/builders.py +++ b/audiocraft/solvers/builders.py @@ -41,12 +41,10 @@ def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver: """Instantiate solver from config.""" from .compression import CompressionSolver from .musicgen import MusicGenSolver - from .diffusion import DiffusionSolver klass = { 'compression': CompressionSolver, 'musicgen': MusicGenSolver, 'lm': MusicGenSolver, # backward compatibility - 'diffusion': DiffusionSolver, }[cfg.solver] return klass(cfg) # type: ignore diff --git a/audiocraft/solvers/diffusion.py b/audiocraft/solvers/diffusion.py deleted file mode 100644 index 93dea252..00000000 --- a/audiocraft/solvers/diffusion.py +++ /dev/null @@ -1,279 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import typing as tp - -import flashy -import julius -import omegaconf -import torch -import torch.nn.functional as F - -from . import builders -from . import base -from .. import models -from ..modules.diffusion_schedule import NoiseSchedule -from ..metrics import RelativeVolumeMel -from ..models.builders import get_processor -from ..utils.samples.manager import SampleManager -from ..solvers.compression import CompressionSolver - - -class PerStageMetrics: - """Handle prompting the metrics per stage. - It outputs the metrics per range of diffusion states. - e.g. avg loss when t in [250, 500] - """ - def __init__(self, num_steps: int, num_stages: int = 4): - self.num_steps = num_steps - self.num_stages = num_stages - - def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]): - if type(step) is int: - stage = int((step / self.num_steps) * self.num_stages) - return {f"{name}_{stage}": loss for name, loss in losses.items()} - elif type(step) is torch.Tensor: - stage_tensor = ((step / self.num_steps) * self.num_stages).long() - out: tp.Dict[str, float] = {} - for stage_idx in range(self.num_stages): - mask = (stage_tensor == stage_idx) - N = mask.sum() - stage_out = {} - if N > 0: # pass if no elements in the stage - for name, loss in losses.items(): - stage_loss = (mask * loss).sum() / N - stage_out[f"{name}_{stage_idx}"] = stage_loss - out = {**out, **stage_out} - return out - - -class DataProcess: - """Apply filtering or resampling. - - Args: - initial_sr (int): Initial sample rate. - target_sr (int): Target sample rate. - use_resampling: Whether to use resampling or not. - use_filter (bool): - n_bands (int): Number of bands to consider. - idx_band (int): - device (torch.device or str): - cutoffs (): - boost (bool): - """ - def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False, - use_filter: bool = False, n_bands: int = 4, - idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False): - """Apply filtering or resampling - Args: - initial_sr (int): sample rate of the dataset - target_sr (int): sample rate after resampling - use_resampling (bool): whether or not performs resampling - use_filter (bool): when True filter the data to keep only one frequency band - n_bands (int): Number of bands used - cuts (none or list): The cutoff frequencies of the band filtering - if None then we use mel scale bands. - idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs - boost (bool): make the data scale match our music dataset. - """ - assert idx_band < n_bands - self.idx_band = idx_band - if use_filter: - if cutoffs is not None: - self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device) - else: - self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device) - self.use_filter = use_filter - self.use_resampling = use_resampling - self.target_sr = target_sr - self.initial_sr = initial_sr - self.boost = boost - - def process_data(self, x, metric=False): - if x is None: - return None - if self.boost: - x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4) - x * 0.22 - if self.use_filter and not metric: - x = self.filter(x)[self.idx_band] - if self.use_resampling: - x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr) - return x - - def inverse_process(self, x): - """Upsampling only.""" - if self.use_resampling: - x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr) - return x - - -class DiffusionSolver(base.StandardSolver): - """Solver for compression task. - - The diffusion task allows for MultiBand diffusion model training. - - Args: - cfg (DictConfig): Configuration. - """ - def __init__(self, cfg: omegaconf.DictConfig): - super().__init__(cfg) - self.cfg = cfg - self.device = cfg.device - self.sample_rate: int = self.cfg.sample_rate - self.codec_model = CompressionSolver.model_from_checkpoint( - cfg.compression_model_checkpoint, device=self.device) - - self.codec_model.set_num_codebooks(cfg.n_q) - assert self.codec_model.sample_rate == self.cfg.sample_rate, ( - f"Codec model sample rate is {self.codec_model.sample_rate} but " - f"Solver sample rate is {self.cfg.sample_rate}." - ) - assert self.codec_model.sample_rate == self.sample_rate, \ - f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \ - "don't match." - - self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate) - self.register_stateful('sample_processor') - self.sample_processor.to(self.device) - - self.schedule = NoiseSchedule( - **cfg.schedule, device=self.device, sample_processor=self.sample_processor) - - self.eval_metric: tp.Optional[torch.nn.Module] = None - - self.rvm = RelativeVolumeMel() - self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr, - use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs, - use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands, - idx_band=cfg.filter.idx_band, device=self.device) - - @property - def best_metric_name(self) -> tp.Optional[str]: - if self._current_stage == "evaluate": - return 'rvm' - else: - return 'loss' - - @torch.no_grad() - def get_condition(self, wav: torch.Tensor) -> torch.Tensor: - codes, scale = self.codec_model.encode(wav) - assert scale is None, "Scaled compression models not supported." - emb = self.codec_model.decode_latent(codes) - return emb - - def build_model(self): - """Build model and optimizer as well as optional Exponential Moving Average of the model. - """ - # Model and optimizer - self.model = models.builders.get_diffusion_model(self.cfg).to(self.device) - self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) - self.register_stateful('model', 'optimizer') - self.register_best_state('model') - self.register_ema('model') - - def build_dataloaders(self): - """Build audio dataloaders for each stage.""" - self.dataloaders = builders.get_audio_datasets(self.cfg) - - def show(self): - # TODO - raise NotImplementedError() - - def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): - """Perform one training or valid step on a given batch.""" - x = batch.to(self.device) - loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss - - condition = self.get_condition(x) # [bs, 128, T/hop, n_emb] - sample = self.data_processor.process_data(x) - - input_, target, step = self.schedule.get_training_item(sample, - tensor_step=self.cfg.schedule.variable_step_batch) - out = self.model(input_, step, condition=condition).sample - - base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2)) - reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2)) - loss = base_loss / reference_loss ** self.cfg.loss.norm_power - - if self.is_training: - loss.mean().backward() - flashy.distrib.sync_model(self.model) - self.optimizer.step() - self.optimizer.zero_grad() - metrics = { - 'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(), - } - metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step)) - metrics.update({ - 'std_in': input_.std(), 'std_out': out.std()}) - return metrics - - def run_epoch(self): - # reset random seed at the beginning of the epoch - self.rng = torch.Generator() - self.rng.manual_seed(1234 + self.epoch) - self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage) - # run epoch - super().run_epoch() - - def evaluate(self): - """Evaluate stage. - Runs audio reconstruction evaluation. - """ - self.model.eval() - evaluate_stage_name = f'{self.current_stage}' - loader = self.dataloaders['evaluate'] - updates = len(loader) - lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates) - - metrics = {} - n = 1 - for idx, batch in enumerate(lp): - x = batch.to(self.device) - with torch.no_grad(): - y_pred = self.regenerate(x) - - y_pred = y_pred.cpu() - y = batch.cpu() # should already be on CPU but just in case - rvm = self.rvm(y_pred, y) - lp.update(**rvm) - if len(metrics) == 0: - metrics = rvm - else: - for key in rvm.keys(): - metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1) - metrics = flashy.distrib.average_metrics(metrics) - return metrics - - @torch.no_grad() - def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None): - """Regenerate the given waveform.""" - condition = self.get_condition(wav) - initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav)) # sampling rate changes. - result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition, - step_list=step_list) - result = self.data_processor.inverse_process(result) - return result - - def generate(self): - """Generate stage.""" - sample_manager = SampleManager(self.xp) - self.model.eval() - generate_stage_name = f'{self.current_stage}' - - loader = self.dataloaders['generate'] - updates = len(loader) - lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) - - for batch in lp: - reference, _ = batch - reference = reference.to(self.device) - estimate = self.regenerate(reference) - reference = reference.cpu() - estimate = estimate.cpu() - sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) - flashy.distrib.barrier() diff --git a/demos/musicgen_app.py b/demos/musicgen_app.py index 88cd27dc..a283d05b 100644 --- a/demos/musicgen_app.py +++ b/demos/musicgen_app.py @@ -28,7 +28,6 @@ from audiocraft.models.encodec import InterleaveStereoCompressionModel from audiocraft.models import MusicGen, MultiBandDiffusion - MODEL = None # Last used model SPACE_ID = os.environ.get('SPACE_ID', '') IS_BATCHED = "facebook/MusicGen" in SPACE_ID or 'musicgen-internal/musicgen_dev' in SPACE_ID @@ -60,6 +59,7 @@ def interrupt(): class FileCleaner: + def __init__(self, file_lifetime: float = 3600): self.file_lifetime = file_lifetime self.files = [] @@ -77,7 +77,8 @@ def _cleanup(self): self.files.pop(0) else: break - + + file_cleaner = FileCleaner() @@ -102,16 +103,15 @@ def load_model(version='facebook/musicgen-melody'): MODEL = MusicGen.get_pretrained(version) -def load_diffusion(): - global MBD - if MBD is None: - print("loading MBD") - MBD = MultiBandDiffusion.get_mbd_musicgen() - - -def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=None, **gen_kwargs): +def _do_predictions(texts, + melodies, + duration, + progress=False, + gradio_progress=None, + **gen_kwargs): MODEL.set_generation_params(duration=duration, **gen_kwargs) - print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies]) + print("new batch", len(texts), texts, + [None if m is None else (m[0], m[1].shape) for m in melodies]) be = time.time() processed_melodies = [] target_sr = 32000 @@ -120,7 +120,8 @@ def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=N if melody is None: processed_melodies.append(None) else: - sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t() + sr, melody = melody[0], torch.from_numpy(melody[1]).to( + MODEL.device).float().t() if melody.dim() == 1: melody = melody[None] melody = melody[..., :int(sr * duration)] @@ -134,32 +135,25 @@ def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=N melody_wavs=processed_melodies, melody_sample_rate=target_sr, progress=progress, - return_tokens=USE_DIFFUSION - ) + return_tokens=False) else: - outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION) + outputs = MODEL.generate(texts, + progress=progress, + return_tokens=False) except RuntimeError as e: raise gr.Error("Error while generating " + e.args[0]) - if USE_DIFFUSION: - if gradio_progress is not None: - gradio_progress(1, desc='Running MultiBandDiffusion...') - tokens = outputs[1] - if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel): - left, right = MODEL.compression_model.get_left_right_codes(tokens) - tokens = torch.cat([left, right]) - outputs_diffusion = MBD.tokens_to_wav(tokens) - if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel): - assert outputs_diffusion.shape[1] == 1 # output is mono - outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2) - outputs = torch.cat([outputs[0], outputs_diffusion], dim=0) outputs = outputs.detach().cpu().float() pending_videos = [] out_wavs = [] for output in outputs: with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: - audio_write( - file.name, output, MODEL.sample_rate, strategy="loudness", - loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) + audio_write(file.name, + output, + MODEL.sample_rate, + strategy="loudness", + loudness_headroom_db=16, + loudness_compressor=True, + add_suffix=False) pending_videos.append(pool.submit(make_waveform, file.name)) out_wavs.append(file.name) file_cleaner.add(file.name) @@ -179,9 +173,18 @@ def predict_batched(texts, melodies): return res -def predict_full(model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()): +def predict_full(model, + model_path, + decoder, + text, + melody, + duration, + topk, + topp, + temperature, + cfg_coef, + progress=gr.Progress()): global INTERRUPTING - global USE_DIFFUSION INTERRUPTING = False progress(0, desc="Loading model...") model_path = model_path.strip() @@ -189,8 +192,9 @@ def predict_full(model, model_path, decoder, text, melody, duration, topk, topp, if not Path(model_path).exists(): raise gr.Error(f"Model path {model_path} doesn't exist.") if not Path(model_path).is_dir(): - raise gr.Error(f"Model path {model_path} must be a folder containing " - "state_dict.bin and compression_state_dict_.bin.") + raise gr.Error( + f"Model path {model_path} must be a folder containing " + "state_dict.bin and compression_state_dict_.bin.") model = model_path if temperature < 0: raise gr.Error("Temperature must be >= 0.") @@ -200,12 +204,6 @@ def predict_full(model, model_path, decoder, text, melody, duration, topk, topp, raise gr.Error("Topp must be non-negative.") topk = int(topk) - if decoder == "MultiBand_Diffusion": - USE_DIFFUSION = True - progress(0, desc="Loading diffusion model...") - load_diffusion() - else: - USE_DIFFUSION = False load_model(model) max_generated = 0 @@ -216,14 +214,17 @@ def _progress(generated, to_generate): progress((min(max_generated, to_generate), to_generate)) if INTERRUPTING: raise gr.Error("Interrupted.") + MODEL.set_custom_progress_callback(_progress) - videos, wavs = _do_predictions( - [text], [melody], duration, progress=True, - top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef, - gradio_progress=progress) - if USE_DIFFUSION: - return videos[0], wavs[0], videos[1], wavs[1] + videos, wavs = _do_predictions([text], [melody], + duration, + progress=True, + top_k=topk, + top_p=topp, + temperature=temperature, + cfg_coef=cfg_coef, + gradio_progress=progress) return videos[0], wavs[0], None, None @@ -234,110 +235,132 @@ def toggle_audio_src(choice): return gr.update(source="upload", value=None, label="File") -def toggle_diffusion(choice): - if choice == "MultiBand_Diffusion": - return [gr.update(visible=True)] * 2 - else: - return [gr.update(visible=False)] * 2 - - def ui_full(launch_kwargs): with gr.Blocks() as interface: - gr.Markdown( - """ + gr.Markdown(""" # MusicGen This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284) - """ - ) + """) with gr.Row(): with gr.Column(): with gr.Row(): text = gr.Text(label="Input Text", interactive=True) with gr.Column(): - radio = gr.Radio(["file", "mic"], value="file", - label="Condition on a melody (optional) File or Mic") - melody = gr.Audio(sources=["upload"], type="numpy", label="File", - interactive=True, elem_id="melody-input") + radio = gr.Radio( + ["file", "mic"], + value="file", + label="Condition on a melody (optional) File or Mic" + ) + melody = gr.Audio(sources=["upload"], + type="numpy", + label="File", + interactive=True, + elem_id="melody-input") with gr.Row(): submit = gr.Button("Submit") # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license. _ = gr.Button("Interrupt").click(fn=interrupt, queue=False) with gr.Row(): - model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small", - "facebook/musicgen-large", "facebook/musicgen-melody-large", - "facebook/musicgen-stereo-small", "facebook/musicgen-stereo-medium", - "facebook/musicgen-stereo-melody", "facebook/musicgen-stereo-large", - "facebook/musicgen-stereo-melody-large"], - label="Model", value="facebook/musicgen-stereo-melody", interactive=True) + model = gr.Radio([ + "facebook/musicgen-melody", "facebook/musicgen-medium", + "facebook/musicgen-small", "facebook/musicgen-large", + "facebook/musicgen-melody-large", + "facebook/musicgen-stereo-small", + "facebook/musicgen-stereo-medium", + "facebook/musicgen-stereo-melody", + "facebook/musicgen-stereo-large", + "facebook/musicgen-stereo-melody-large" + ], + label="Model", + value="facebook/musicgen-stereo-melody", + interactive=True) model_path = gr.Text(label="Model Path (custom models)") with gr.Row(): decoder = gr.Radio(["Default", "MultiBand_Diffusion"], - label="Decoder", value="Default", interactive=True) + label="Decoder", + value="Default", + interactive=True) with gr.Row(): - duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True) + duration = gr.Slider(minimum=1, + maximum=120, + value=10, + label="Duration", + interactive=True) with gr.Row(): - topk = gr.Number(label="Top-k", value=250, interactive=True) + topk = gr.Number(label="Top-k", + value=250, + interactive=True) topp = gr.Number(label="Top-p", value=0, interactive=True) - temperature = gr.Number(label="Temperature", value=1.0, interactive=True) - cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True) + temperature = gr.Number(label="Temperature", + value=1.0, + interactive=True) + cfg_coef = gr.Number(label="Classifier Free Guidance", + value=3.0, + interactive=True) with gr.Column(): output = gr.Video(label="Generated Music") - audio_output = gr.Audio(label="Generated Music (wav)", type='filepath') - diffusion_output = gr.Video(label="MultiBand Diffusion Decoder") - audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath') - submit.click(toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False, - show_progress=False).then(predict_full, inputs=[model, model_path, decoder, text, melody, duration, topk, topp, - temperature, cfg_coef], - outputs=[output, audio_output, diffusion_output, audio_diffusion]) - radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False) + audio_output = gr.Audio(label="Generated Music (wav)", + type='filepath') + diffusion_output = gr.Video( + label="MultiBand Diffusion Decoder") + audio_diffusion = gr.Audio( + label="MultiBand Diffusion Decoder (wav)", type='filepath') + submit.click(toggle_diffusion, + decoder, [diffusion_output, audio_diffusion], + queue=False, + show_progress=False).then(predict_full, + inputs=[ + model, model_path, decoder, + text, melody, duration, + topk, topp, temperature, + cfg_coef + ], + outputs=[ + output, audio_output, + diffusion_output, + audio_diffusion + ]) + radio.change(toggle_audio_src, + radio, [melody], + queue=False, + show_progress=False) gr.Examples( fn=predict_full, examples=[ [ "An 80s driving pop song with heavy drums and synth pads in the background", - "./assets/bach.mp3", - "facebook/musicgen-stereo-melody", + "./assets/bach.mp3", "facebook/musicgen-stereo-melody", "Default" ], [ "A cheerful country song with acoustic guitars", "./assets/bolero_ravel.mp3", - "facebook/musicgen-stereo-melody", - "Default" + "facebook/musicgen-stereo-melody", "Default" ], [ - "90s rock song with electric guitar and heavy drums", - None, - "facebook/musicgen-stereo-medium", - "Default" + "90s rock song with electric guitar and heavy drums", None, + "facebook/musicgen-stereo-medium", "Default" ], [ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions", - "./assets/bach.mp3", - "facebook/musicgen-stereo-melody", + "./assets/bach.mp3", "facebook/musicgen-stereo-melody", "Default" ], [ - "lofi slow bpm electro chill with organic samples", - None, - "facebook/musicgen-stereo-medium", - "Default" + "lofi slow bpm electro chill with organic samples", None, + "facebook/musicgen-stereo-medium", "Default" ], [ - "Punk rock with loud drum and power guitar", - None, - "facebook/musicgen-stereo-medium", - "MultiBand_Diffusion" + "Punk rock with loud drum and power guitar", None, + "facebook/musicgen-stereo-medium", "MultiBand_Diffusion" ], ], inputs=[text, melody, model, decoder], - outputs=[output] - ) - gr.Markdown( - """ + outputs=[output]) + gr.Markdown(""" ### More details The model will generate a short music extract based on the description you provided. @@ -378,16 +401,14 @@ def ui_full(launch_kwargs): See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md) for more details. - """ - ) + """) interface.queue().launch(**launch_kwargs) def ui_batched(launch_kwargs): with gr.Blocks() as demo: - gr.Markdown( - """ + gr.Markdown(""" # MusicGen This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md), @@ -399,25 +420,39 @@ def ui_batched(launch_kwargs): Duplicate Space for longer sequences, more control and no queue.

- """ - ) + """) with gr.Row(): with gr.Column(): with gr.Row(): - text = gr.Text(label="Describe your music", lines=2, interactive=True) + text = gr.Text(label="Describe your music", + lines=2, + interactive=True) with gr.Column(): - radio = gr.Radio(["file", "mic"], value="file", - label="Condition on a melody (optional) File or Mic") - melody = gr.Audio(source="upload", type="numpy", label="File", - interactive=True, elem_id="melody-input") + radio = gr.Radio( + ["file", "mic"], + value="file", + label="Condition on a melody (optional) File or Mic" + ) + melody = gr.Audio(source="upload", + type="numpy", + label="File", + interactive=True, + elem_id="melody-input") with gr.Row(): submit = gr.Button("Generate") with gr.Column(): output = gr.Video(label="Generated Music") - audio_output = gr.Audio(label="Generated Music (wav)", type='filepath') - submit.click(predict_batched, inputs=[text, melody], - outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE) - radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False) + audio_output = gr.Audio(label="Generated Music (wav)", + type='filepath') + submit.click(predict_batched, + inputs=[text, melody], + outputs=[output, audio_output], + batch=True, + max_batch_size=MAX_BATCH_SIZE) + radio.change(toggle_audio_src, + radio, [melody], + queue=False, + show_progress=False) gr.Examples( fn=predict_batched, examples=[ @@ -443,8 +478,7 @@ def ui_batched(launch_kwargs): ], ], inputs=[text, melody], - outputs=[output] - ) + outputs=[output]) gr.Markdown(""" ### More details @@ -482,24 +516,26 @@ def ui_batched(launch_kwargs): default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1', help='IP to listen on for connections to Gradio', ) - parser.add_argument( - '--username', type=str, default='', help='Username for authentication' - ) - parser.add_argument( - '--password', type=str, default='', help='Password for authentication' - ) + parser.add_argument('--username', + type=str, + default='', + help='Username for authentication') + parser.add_argument('--password', + type=str, + default='', + help='Password for authentication') parser.add_argument( '--server_port', type=int, default=0, help='Port to run the server listener on', ) - parser.add_argument( - '--inbrowser', action='store_true', help='Open in browser' - ) - parser.add_argument( - '--share', action='store_true', help='Share the gradio UI' - ) + parser.add_argument('--inbrowser', + action='store_true', + help='Open in browser') + parser.add_argument('--share', + action='store_true', + help='Share the gradio UI') args = parser.parse_args() @@ -519,8 +555,6 @@ def ui_batched(launch_kwargs): # Show the interface if IS_BATCHED: - global USE_DIFFUSION - USE_DIFFUSION = False ui_batched(launch_kwargs) else: ui_full(launch_kwargs) diff --git a/tests/models/test_multibanddiffusion.py b/tests/models/test_multibanddiffusion.py deleted file mode 100644 index 2702a3cb..00000000 --- a/tests/models/test_multibanddiffusion.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import random - -import numpy as np -import torch -from audiocraft.models.multibanddiffusion import MultiBandDiffusion, DiffusionProcess -from audiocraft.models import EncodecModel, DiffusionUnet -from audiocraft.modules import SEANetEncoder, SEANetDecoder -from audiocraft.modules.diffusion_schedule import NoiseSchedule -from audiocraft.quantization import DummyQuantizer - - -class TestMBD: - - def _create_mbd(self, - sample_rate: int, - channels: int, - n_filters: int = 3, - n_residual_layers: int = 1, - ratios: list = [5, 4, 3, 2], - num_steps: int = 1000, - codec_dim: int = 128, - **kwargs): - frame_rate = np.prod(ratios) - encoder = SEANetEncoder(channels=channels, dimension=codec_dim, n_filters=n_filters, - n_residual_layers=n_residual_layers, ratios=ratios) - decoder = SEANetDecoder(channels=channels, dimension=codec_dim, n_filters=n_filters, - n_residual_layers=n_residual_layers, ratios=ratios) - quantizer = DummyQuantizer() - compression_model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate, - sample_rate=sample_rate, channels=channels, **kwargs) - diffusion_model = DiffusionUnet(chin=channels, num_steps=num_steps, codec_dim=codec_dim) - schedule = NoiseSchedule(device='cpu', num_steps=num_steps) - DP = DiffusionProcess(model=diffusion_model, noise_schedule=schedule) - mbd = MultiBandDiffusion(DPs=[DP], codec_model=compression_model) - return mbd - - def test_model(self): - random.seed(1234) - sample_rate = 24_000 - channels = 1 - codec_dim = 128 - mbd = self._create_mbd(sample_rate=sample_rate, channels=channels, codec_dim=codec_dim) - for _ in range(10): - length = random.randrange(1, 10_000) - x = torch.randn(2, channels, length) - res = mbd.regenerate(x, sample_rate) - assert res.shape == x.shape diff --git a/utils/export.py b/utils/export.py index e437a507..1161dabc 100644 --- a/utils/export.py +++ b/utils/export.py @@ -6,4 +6,5 @@ export.export_lm(xp.folder / 'checkpoint.th', '../export/state_dict.bin') # Export the pretrained EnCodec model -export.export_pretrained_compression_model('facebook/encodec_32khz', '/home/maxwell/grimes_keys_model/compression_state_dict.bin') +export.export_pretrained_compression_model( + 'facebook/encodec_32khz', '../export/compression_state_dict.bin') From b170ee0fe9e9c7aa0914f25d5e491dd38532a516 Mon Sep 17 00:00:00 2001 From: maxardito Date: Tue, 9 Apr 2024 13:35:53 -0400 Subject: [PATCH 2/3] Removed diffusion processor --- audiocraft/models/builders.py | 144 ++++++++++++++++++++-------------- 1 file changed, 83 insertions(+), 61 deletions(-) diff --git a/audiocraft/models/builders.py b/audiocraft/models/builders.py index 46a4d96f..ae286c98 100644 --- a/audiocraft/models/builders.py +++ b/audiocraft/models/builders.py @@ -3,7 +3,6 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - """ All the functions to build the relevant models and modules from the Hydra config. @@ -37,10 +36,10 @@ from .unet import DiffusionUnet from .. import quantization as qt from ..utils.utils import dict_from_config -from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor -def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer: +def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, + dimension: int) -> qt.BaseQuantizer: klass = { 'no_quant': qt.DummyQuantizer, 'rvq': qt.ResidualVectorQuantizer @@ -77,8 +76,12 @@ def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: renormalize = kwargs.pop('renormalize', False) # deprecated params kwargs.pop('renorm', None) - return EncodecModel(encoder, decoder, quantizer, - frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) + return EncodecModel(encoder, + decoder, + quantizer, + frame_rate=frame_rate, + renormalize=renormalize, + **kwargs).to(cfg.device) else: raise KeyError(f"Unexpected compression model {cfg.compression_model}") @@ -91,37 +94,44 @@ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: q_modeling = kwargs.pop('q_modeling', None) codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern') attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout')) - cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance')) - cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef'] + cls_free_guidance = dict_from_config( + getattr(cfg, 'classifier_free_guidance')) + cfg_prob, cfg_coef = cls_free_guidance[ + 'training_dropout'], cls_free_guidance['inference_coef'] fuser = get_condition_fuser(cfg) - condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device) - if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically + condition_provider = get_conditioner_provider(kwargs["dim"], + cfg).to(cfg.device) + if len(fuser.fuse2cond['cross'] + ) > 0: # enforce cross-att programmatically kwargs['cross_attention'] = True if codebooks_pattern_cfg.modeling is None: assert q_modeling is not None, \ "LM model should either have a codebook pattern defined or transformer_lm.q_modeling" - codebooks_pattern_cfg = omegaconf.OmegaConf.create( - {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}} - ) + codebooks_pattern_cfg = omegaconf.OmegaConf.create({ + 'modeling': q_modeling, + 'delay': { + 'delays': list(range(n_q)) + } + }) - pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg) + pattern_provider = get_codebooks_pattern_provider( + n_q, codebooks_pattern_cfg) lm_class = LMModel - return lm_class( - pattern_provider=pattern_provider, - condition_provider=condition_provider, - fuser=fuser, - cfg_dropout=cfg_prob, - cfg_coef=cfg_coef, - attribute_dropout=attribute_dropout, - dtype=getattr(torch, cfg.dtype), - device=cfg.device, - **kwargs - ).to(cfg.device) + return lm_class(pattern_provider=pattern_provider, + condition_provider=condition_provider, + fuser=fuser, + cfg_dropout=cfg_prob, + cfg_coef=cfg_coef, + attribute_dropout=attribute_dropout, + dtype=getattr(torch, cfg.dtype), + device=cfg.device, + **kwargs).to(cfg.device) else: raise KeyError(f"Unexpected LM model {cfg.lm_model}") -def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider: +def get_conditioner_provider( + output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider: """Instantiate a conditioning model.""" device = cfg.device duration = cfg.dataset.segment_duration @@ -136,25 +146,26 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond model_type = cond_cfg['model'] model_args = cond_cfg[model_type] if model_type == 't5': - conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args) + conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, + device=device, + **model_args) elif model_type == 'lut': - conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args) + conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, + **model_args) elif model_type == 'chroma_stem': conditioners[str(cond)] = ChromaStemConditioner( output_dim=output_dim, duration=duration, device=device, - **model_args - ) + **model_args) elif model_type == 'clap': conditioners[str(cond)] = CLAPEmbeddingConditioner( - output_dim=output_dim, - device=device, - **model_args - ) + output_dim=output_dim, device=device, **model_args) else: raise ValueError(f"Unrecognized conditioning model: {model_type}") - conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args) + conditioner = ConditioningProvider(conditioners, + device=device, + **condition_provider_args) return conditioner @@ -168,7 +179,8 @@ def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: return fuser -def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider: +def get_codebooks_pattern_provider( + n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider: """Instantiate a codebooks pattern provider object.""" pattern_providers = { 'parallel': ParallelPatternProvider, @@ -185,7 +197,9 @@ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> Codeb def get_debug_compression_model(device='cpu', sample_rate: int = 32000): """Instantiate a debug compression model to be used for unit tests.""" - assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model" + assert sample_rate in [ + 16000, 32000 + ], "unsupported sample rate for debug compression model" model_ratios = { 16000: [10, 8, 8], # 25 Hz at 16kHz 32000: [10, 8, 16] # 25 Hz at 32kHz @@ -203,9 +217,12 @@ def get_debug_compression_model(device='cpu', sample_rate: int = 32000): quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4) init_x = torch.randn(8, 32, 128) quantizer(init_x, 1) # initialize kmeans etc. - compression_model = EncodecModel( - encoder, decoder, quantizer, - frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device) + compression_model = EncodecModel(encoder, + decoder, + quantizer, + frame_rate=frame_rate, + sample_rate=sample_rate, + channels=1).to(device) return compression_model.eval() @@ -213,19 +230,9 @@ def get_diffusion_model(cfg: omegaconf.DictConfig): # TODO Find a way to infer the channels from dset channels = cfg.channels num_steps = cfg.schedule.num_steps - return DiffusionUnet( - chin=channels, num_steps=num_steps, **cfg.diffusion_unet) - - -def get_processor(cfg, sample_rate: int = 24000): - sample_processor = SampleProcessor() - if cfg.use: - kw = dict(cfg) - kw.pop('use') - kw.pop('name') - if cfg.name == "multi_band_processor": - sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw) - return sample_processor + return DiffusionUnet(chin=channels, + num_steps=num_steps, + **cfg.diffusion_unet) def get_debug_lm_model(device='cpu'): @@ -233,16 +240,30 @@ def get_debug_lm_model(device='cpu'): pattern = DelayedPatternProvider(n_q=4) dim = 16 providers = { - 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"), + 'description': + LUTConditioner(n_bins=128, + dim=dim, + output_dim=dim, + tokenizer="whitespace"), } condition_provider = ConditioningProvider(providers) - fuser = ConditionFuser( - {'cross': ['description'], 'prepend': [], - 'sum': [], 'input_interpolate': []}) - lm = LMModel( - pattern, condition_provider, fuser, - n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2, - cross_attention=True, causal=True) + fuser = ConditionFuser({ + 'cross': ['description'], + 'prepend': [], + 'sum': [], + 'input_interpolate': [] + }) + lm = LMModel(pattern, + condition_provider, + fuser, + n_q=4, + card=400, + dim=dim, + num_heads=4, + custom=True, + num_layers=2, + cross_attention=True, + causal=True) return lm.to(device).eval() @@ -253,7 +274,8 @@ def get_wrapped_compression_model( if cfg.interleave_stereo_codebooks.use: kwargs = dict_from_config(cfg.interleave_stereo_codebooks) kwargs.pop('use') - compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs) + compression_model = InterleaveStereoCompressionModel( + compression_model, **kwargs) if hasattr(cfg, 'compression_model_n_q'): if cfg.compression_model_n_q is not None: compression_model.set_num_codebooks(cfg.compression_model_n_q) From 28db0cb250f683ae013f66a2cf730ae20f239b08 Mon Sep 17 00:00:00 2001 From: maxardito Date: Tue, 9 Apr 2024 13:44:01 -0400 Subject: [PATCH 3/3] Removed UNet --- audiocraft/models/builders.py | 10 -- audiocraft/models/unet.py | 282 ---------------------------------- config/model/score/basic.yaml | 17 -- 3 files changed, 309 deletions(-) delete mode 100644 audiocraft/models/unet.py delete mode 100644 config/model/score/basic.yaml diff --git a/audiocraft/models/builders.py b/audiocraft/models/builders.py index ae286c98..f8da9b72 100644 --- a/audiocraft/models/builders.py +++ b/audiocraft/models/builders.py @@ -33,7 +33,6 @@ LUTConditioner, T5Conditioner, ) -from .unet import DiffusionUnet from .. import quantization as qt from ..utils.utils import dict_from_config @@ -226,15 +225,6 @@ def get_debug_compression_model(device='cpu', sample_rate: int = 32000): return compression_model.eval() -def get_diffusion_model(cfg: omegaconf.DictConfig): - # TODO Find a way to infer the channels from dset - channels = cfg.channels - num_steps = cfg.schedule.num_steps - return DiffusionUnet(chin=channels, - num_steps=num_steps, - **cfg.diffusion_unet) - - def get_debug_lm_model(device='cpu'): """Instantiate a debug LM to be used for unit tests.""" pattern = DelayedPatternProvider(n_q=4) diff --git a/audiocraft/models/unet.py b/audiocraft/models/unet.py deleted file mode 100644 index 0061fafa..00000000 --- a/audiocraft/models/unet.py +++ /dev/null @@ -1,282 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -""" -Pytorch Unet Module used for diffusion. -""" - -from dataclasses import dataclass -import typing as tp - -import torch -from torch import nn -from torch.nn import functional as F -from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding - - -@dataclass -class Output: - sample: torch.Tensor - - -def get_model(cfg, channels: int, side: int, num_steps: int): - if cfg.model == 'unet': - return DiffusionUnet(chin=channels, - num_steps=num_steps, - **cfg.diffusion_unet) - else: - raise RuntimeError('Not Implemented') - - -class ResBlock(nn.Module): - - def __init__(self, - channels: int, - kernel: int = 3, - norm_groups: int = 4, - dilation: int = 1, - activation: tp.Type[nn.Module] = nn.ReLU, - dropout: float = 0.): - super().__init__() - stride = 1 - padding = dilation * (kernel - stride) // 2 - Conv = nn.Conv1d - Drop = nn.Dropout1d - self.norm1 = nn.GroupNorm(norm_groups, channels) - self.conv1 = Conv(channels, - channels, - kernel, - 1, - padding, - dilation=dilation) - self.activation1 = activation() - self.dropout1 = Drop(dropout) - - self.norm2 = nn.GroupNorm(norm_groups, channels) - self.conv2 = Conv(channels, - channels, - kernel, - 1, - padding, - dilation=dilation) - self.activation2 = activation() - self.dropout2 = Drop(dropout) - - def forward(self, x): - h = self.dropout1(self.conv1(self.activation1(self.norm1(x)))) - h = self.dropout2(self.conv2(self.activation2(self.norm2(h)))) - return x + h - - -class DecoderLayer(nn.Module): - - def __init__(self, - chin: int, - chout: int, - kernel: int = 4, - stride: int = 2, - norm_groups: int = 4, - res_blocks: int = 1, - activation: tp.Type[nn.Module] = nn.ReLU, - dropout: float = 0.): - super().__init__() - padding = (kernel - stride) // 2 - self.res_blocks = nn.Sequential(*[ - ResBlock(chin, - norm_groups=norm_groups, - dilation=2**idx, - dropout=dropout) for idx in range(res_blocks) - ]) - self.norm = nn.GroupNorm(norm_groups, chin) - ConvTr = nn.ConvTranspose1d - self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False) - self.activation = activation() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.res_blocks(x) - x = self.norm(x) - x = self.activation(x) - x = self.convtr(x) - return x - - -class EncoderLayer(nn.Module): - - def __init__(self, - chin: int, - chout: int, - kernel: int = 4, - stride: int = 2, - norm_groups: int = 4, - res_blocks: int = 1, - activation: tp.Type[nn.Module] = nn.ReLU, - dropout: float = 0.): - super().__init__() - padding = (kernel - stride) // 2 - Conv = nn.Conv1d - self.conv = Conv(chin, chout, kernel, stride, padding, bias=False) - self.norm = nn.GroupNorm(norm_groups, chout) - self.activation = activation() - self.res_blocks = nn.Sequential(*[ - ResBlock(chout, - norm_groups=norm_groups, - dilation=2**idx, - dropout=dropout) for idx in range(res_blocks) - ]) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - B, C, T = x.shape - stride, = self.conv.stride - pad = (stride - (T % stride)) % stride - x = F.pad(x, (0, pad)) - - x = self.conv(x) - x = self.norm(x) - x = self.activation(x) - x = self.res_blocks(x) - return x - - -class BLSTM(nn.Module): - """BiLSTM with same hidden units as input dim. - """ - - def __init__(self, dim, layers=2): - super().__init__() - self.lstm = nn.LSTM(bidirectional=True, - num_layers=layers, - hidden_size=dim, - input_size=dim) - self.linear = nn.Linear(2 * dim, dim) - - def forward(self, x): - x = x.permute(2, 0, 1) - x = self.lstm(x)[0] - x = self.linear(x) - x = x.permute(1, 2, 0) - return x - - -class DiffusionUnet(nn.Module): - - def __init__(self, - chin: int = 3, - hidden: int = 24, - depth: int = 3, - growth: float = 2., - max_channels: int = 10_000, - num_steps: int = 1000, - emb_all_layers=False, - cross_attention: bool = False, - bilstm: bool = False, - transformer: bool = False, - codec_dim: tp.Optional[int] = None, - **kwargs): - super().__init__() - self.encoders = nn.ModuleList() - self.decoders = nn.ModuleList() - self.embeddings: tp.Optional[nn.ModuleList] = None - self.embedding = nn.Embedding(num_steps, hidden) - if emb_all_layers: - self.embeddings = nn.ModuleList() - self.condition_embedding: tp.Optional[nn.Module] = None - for d in range(depth): - encoder = EncoderLayer(chin, hidden, **kwargs) - decoder = DecoderLayer(hidden, chin, **kwargs) - self.encoders.append(encoder) - self.decoders.insert(0, decoder) - if emb_all_layers and d > 0: - assert self.embeddings is not None - self.embeddings.append(nn.Embedding(num_steps, hidden)) - chin = hidden - hidden = min(int(chin * growth), max_channels) - self.bilstm: tp.Optional[nn.Module] - if bilstm: - self.bilstm = BLSTM(chin) - else: - self.bilstm = None - self.use_transformer = transformer - self.cross_attention = False - if transformer: - self.cross_attention = cross_attention - self.transformer = StreamingTransformer( - chin, - 8, - 6, - bias_ff=False, - bias_attn=False, - cross_attention=cross_attention) - - self.use_codec = False - if codec_dim is not None: - self.conv_codec = nn.Conv1d(codec_dim, chin, 1) - self.use_codec = True - - def forward(self, - x: torch.Tensor, - step: tp.Union[int, torch.Tensor], - condition: tp.Optional[torch.Tensor] = None): - skips = [] - bs = x.size(0) - z = x - view_args = [1] - if type(step) is torch.Tensor: - step_tensor = step - else: - step_tensor = torch.tensor([step], - device=x.device, - dtype=torch.long).expand(bs) - - for idx, encoder in enumerate(self.encoders): - z = encoder(z) - if idx == 0: - z = z + self.embedding(step_tensor).view( - bs, -1, *view_args).expand_as(z) - elif self.embeddings is not None: - z = z + self.embeddings[idx - 1](step_tensor).view( - bs, -1, *view_args).expand_as(z) - - skips.append(z) - - if self.use_codec: # insert condition in the bottleneck - assert condition is not None, "Model defined for conditionnal generation" - condition_emb = self.conv_codec( - condition) # reshape to the bottleneck dim - assert condition_emb.size(-1) <= 2 * z.size(-1), \ - f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}" - if not self.cross_attention: - - condition_emb = torch.nn.functional.interpolate( - condition_emb, z.size(-1)) - assert z.size() == condition_emb.size() - z += condition_emb - cross_attention_src = None - else: - cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C - B, T, C = cross_attention_src.shape - positions = torch.arange(T, device=x.device).view(1, -1, 1) - pos_emb = create_sin_embedding(positions, - C, - max_period=10_000, - dtype=cross_attention_src.dtype) - cross_attention_src = cross_attention_src + pos_emb - if self.use_transformer: - z = self.transformer( - z.permute(0, 2, 1), - cross_attention_src=cross_attention_src).permute(0, 2, 1) - else: - if self.bilstm is None: - z = torch.zeros_like(z) - else: - z = self.bilstm(z) - - for decoder in self.decoders: - s = skips.pop(-1) - z = z[:, :, :s.shape[2]] - z = z + s - z = decoder(z) - - z = z[:, :, :x.shape[2]] - return Output(z) diff --git a/config/model/score/basic.yaml b/config/model/score/basic.yaml deleted file mode 100644 index 75fbc378..00000000 --- a/config/model/score/basic.yaml +++ /dev/null @@ -1,17 +0,0 @@ -# @package _global_ - -diffusion_unet: - hidden: 48 - depth: 4 - res_blocks: 1 - norm_groups: 4 - kernel: 8 - stride: 4 - growth: 4 - max_channels: 10_000 - dropout: 0. - emb_all_layers: true - bilstm: false - codec_dim: null - transformer: false - cross_attention: false \ No newline at end of file