Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diffusion Removal #4

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ dataset/*
musicgen-deployment.txt
manifests
export
sessions

# personal notebooks & scripts
*/local_scripts
Expand Down
2 changes: 0 additions & 2 deletions audiocraft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
- [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
neural audio codec which provides an excellent tokenizer for autoregressive language models.
See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
- [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
improves the perceived quality and reduces the artifacts coming from adversarial decoders.
"""

# flake8: noqa
Expand Down
27 changes: 0 additions & 27 deletions audiocraft/grids/diffusion/4_bands_base_32khz.py

This file was deleted.

6 changes: 0 additions & 6 deletions audiocraft/grids/diffusion/__init__.py

This file was deleted.

66 changes: 0 additions & 66 deletions audiocraft/grids/diffusion/_explorers.py

This file was deleted.

2 changes: 0 additions & 2 deletions audiocraft/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,4 @@
from .encodec import (CompressionModel, EncodecModel, DAC, HFEncodecModel,
HFEncodecCompressionModel)
from .lm import LMModel
from .multibanddiffusion import MultiBandDiffusion
from .musicgen import MusicGen
from .unet import DiffusionUnet
148 changes: 80 additions & 68 deletions audiocraft/models/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
All the functions to build the relevant models and modules
from the Hydra config.
Expand Down Expand Up @@ -34,13 +33,12 @@
LUTConditioner,
T5Conditioner,
)
from .unet import DiffusionUnet
from .. import quantization as qt
from ..utils.utils import dict_from_config
from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor


def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig,
dimension: int) -> qt.BaseQuantizer:
klass = {
'no_quant': qt.DummyQuantizer,
'rvq': qt.ResidualVectorQuantizer
Expand Down Expand Up @@ -77,8 +75,12 @@ def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
renormalize = kwargs.pop('renormalize', False)
# deprecated params
kwargs.pop('renorm', None)
return EncodecModel(encoder, decoder, quantizer,
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
return EncodecModel(encoder,
decoder,
quantizer,
frame_rate=frame_rate,
renormalize=renormalize,
**kwargs).to(cfg.device)
else:
raise KeyError(f"Unexpected compression model {cfg.compression_model}")

Expand All @@ -91,37 +93,44 @@ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
q_modeling = kwargs.pop('q_modeling', None)
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
cls_free_guidance = dict_from_config(
getattr(cfg, 'classifier_free_guidance'))
cfg_prob, cfg_coef = cls_free_guidance[
'training_dropout'], cls_free_guidance['inference_coef']
fuser = get_condition_fuser(cfg)
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
condition_provider = get_conditioner_provider(kwargs["dim"],
cfg).to(cfg.device)
if len(fuser.fuse2cond['cross']
) > 0: # enforce cross-att programmatically
kwargs['cross_attention'] = True
if codebooks_pattern_cfg.modeling is None:
assert q_modeling is not None, \
"LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
{'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
)
codebooks_pattern_cfg = omegaconf.OmegaConf.create({
'modeling': q_modeling,
'delay': {
'delays': list(range(n_q))
}
})

pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
pattern_provider = get_codebooks_pattern_provider(
n_q, codebooks_pattern_cfg)
lm_class = LMModel
return lm_class(
pattern_provider=pattern_provider,
condition_provider=condition_provider,
fuser=fuser,
cfg_dropout=cfg_prob,
cfg_coef=cfg_coef,
attribute_dropout=attribute_dropout,
dtype=getattr(torch, cfg.dtype),
device=cfg.device,
**kwargs
).to(cfg.device)
return lm_class(pattern_provider=pattern_provider,
condition_provider=condition_provider,
fuser=fuser,
cfg_dropout=cfg_prob,
cfg_coef=cfg_coef,
attribute_dropout=attribute_dropout,
dtype=getattr(torch, cfg.dtype),
device=cfg.device,
**kwargs).to(cfg.device)
else:
raise KeyError(f"Unexpected LM model {cfg.lm_model}")


def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
def get_conditioner_provider(
output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
"""Instantiate a conditioning model."""
device = cfg.device
duration = cfg.dataset.segment_duration
Expand All @@ -136,25 +145,26 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond
model_type = cond_cfg['model']
model_args = cond_cfg[model_type]
if model_type == 't5':
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim,
device=device,
**model_args)
elif model_type == 'lut':
conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
conditioners[str(cond)] = LUTConditioner(output_dim=output_dim,
**model_args)
elif model_type == 'chroma_stem':
conditioners[str(cond)] = ChromaStemConditioner(
output_dim=output_dim,
duration=duration,
device=device,
**model_args
)
**model_args)
elif model_type == 'clap':
conditioners[str(cond)] = CLAPEmbeddingConditioner(
output_dim=output_dim,
device=device,
**model_args
)
output_dim=output_dim, device=device, **model_args)
else:
raise ValueError(f"Unrecognized conditioning model: {model_type}")
conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
conditioner = ConditioningProvider(conditioners,
device=device,
**condition_provider_args)
return conditioner


Expand All @@ -168,7 +178,8 @@ def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
return fuser


def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
def get_codebooks_pattern_provider(
n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
"""Instantiate a codebooks pattern provider object."""
pattern_providers = {
'parallel': ParallelPatternProvider,
Expand All @@ -185,7 +196,9 @@ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> Codeb

def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
"""Instantiate a debug compression model to be used for unit tests."""
assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model"
assert sample_rate in [
16000, 32000
], "unsupported sample rate for debug compression model"
model_ratios = {
16000: [10, 8, 8], # 25 Hz at 16kHz
32000: [10, 8, 16] # 25 Hz at 32kHz
Expand All @@ -203,46 +216,44 @@ def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
init_x = torch.randn(8, 32, 128)
quantizer(init_x, 1) # initialize kmeans etc.
compression_model = EncodecModel(
encoder, decoder, quantizer,
frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device)
compression_model = EncodecModel(encoder,
decoder,
quantizer,
frame_rate=frame_rate,
sample_rate=sample_rate,
channels=1).to(device)
return compression_model.eval()


def get_diffusion_model(cfg: omegaconf.DictConfig):
# TODO Find a way to infer the channels from dset
channels = cfg.channels
num_steps = cfg.schedule.num_steps
return DiffusionUnet(
chin=channels, num_steps=num_steps, **cfg.diffusion_unet)


def get_processor(cfg, sample_rate: int = 24000):
sample_processor = SampleProcessor()
if cfg.use:
kw = dict(cfg)
kw.pop('use')
kw.pop('name')
if cfg.name == "multi_band_processor":
sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
return sample_processor


def get_debug_lm_model(device='cpu'):
"""Instantiate a debug LM to be used for unit tests."""
pattern = DelayedPatternProvider(n_q=4)
dim = 16
providers = {
'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
'description':
LUTConditioner(n_bins=128,
dim=dim,
output_dim=dim,
tokenizer="whitespace"),
}
condition_provider = ConditioningProvider(providers)
fuser = ConditionFuser(
{'cross': ['description'], 'prepend': [],
'sum': [], 'input_interpolate': []})
lm = LMModel(
pattern, condition_provider, fuser,
n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
cross_attention=True, causal=True)
fuser = ConditionFuser({
'cross': ['description'],
'prepend': [],
'sum': [],
'input_interpolate': []
})
lm = LMModel(pattern,
condition_provider,
fuser,
n_q=4,
card=400,
dim=dim,
num_heads=4,
custom=True,
num_layers=2,
cross_attention=True,
causal=True)
return lm.to(device).eval()


Expand All @@ -253,7 +264,8 @@ def get_wrapped_compression_model(
if cfg.interleave_stereo_codebooks.use:
kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
kwargs.pop('use')
compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs)
compression_model = InterleaveStereoCompressionModel(
compression_model, **kwargs)
if hasattr(cfg, 'compression_model_n_q'):
if cfg.compression_model_n_q is not None:
compression_model.set_num_codebooks(cfg.compression_model_n_q)
Expand Down
Loading