Skip to content

Commit

Permalink
MusicGen Update (huggingface#27084)
Browse files Browse the repository at this point in the history
* [MusicGen] Add stereo model

* safe serialization

* Update src/transformers/models/musicgen/modeling_musicgen.py

* split over 2 lines

* fix slow tests on cuda
  • Loading branch information
sanchit-gandhi authored and EduardoPach committed Nov 19, 2023
1 parent 90bbda1 commit 3325dfe
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 28 deletions.
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/musicgen.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ Generation is limited by the sinusoidal positional embeddings to 30 second input
than 30 seconds of audio (1503 tokens), and input audio passed by Audio-Prompted Generation contributes to this limit so,
given an input of 20 seconds of audio, MusicGen cannot generate more than 10 seconds of additional audio.

Transformers supports both mono (1-channel) and stereo (2-channel) variants of MusicGen. The mono channel versions
generate a single set of codebooks. The stereo versions generate 2 sets of codebooks, 1 for each channel (left/right),
and each set of codebooks is decoded independently through the audio compression model. The audio streams for each
channel are combined to give the final stereo output.

### Unconditional Generation

The inputs for unconditional (or 'null') generation can be obtained through the method
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/musicgen/configuration_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class MusicgenDecoderConfig(PretrainedConfig):
The number of parallel codebooks forwarded to the model.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether input and output word embeddings should be tied.
audio_channels (`int`, *optional*, defaults to 1
Number of channels in the audio data. Either 1 for mono or 2 for stereo. Stereo models generate a separate
audio stream for the left/right output channels. Mono models generate a single audio stream output.
"""
model_type = "musicgen_decoder"
keys_to_ignore_at_inference = ["past_key_values"]
Expand All @@ -96,6 +99,7 @@ def __init__(
initializer_factor=0.02,
scale_embedding=False,
num_codebooks=4,
audio_channels=1,
pad_token_id=2048,
bos_token_id=2048,
eos_token_id=None,
Expand All @@ -117,6 +121,11 @@ def __init__(
self.use_cache = use_cache
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.num_codebooks = num_codebooks

if audio_channels not in [1, 2]:
raise ValueError(f"Expected 1 (mono) or 2 (stereo) audio channels, got {audio_channels} channels.")
self.audio_channels = audio_channels

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
Expand Down
50 changes: 38 additions & 12 deletions src/transformers/models/musicgen/convert_musicgen_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,32 +88,48 @@ def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict,


def decoder_config_from_checkpoint(checkpoint: str) -> MusicgenDecoderConfig:
if checkpoint == "small":
if checkpoint == "small" or checkpoint == "facebook/musicgen-stereo-small":
# default config values
hidden_size = 1024
num_hidden_layers = 24
num_attention_heads = 16
elif checkpoint == "medium":
elif checkpoint == "medium" or checkpoint == "facebook/musicgen-stereo-medium":
hidden_size = 1536
num_hidden_layers = 48
num_attention_heads = 24
elif checkpoint == "large":
elif checkpoint == "large" or checkpoint == "facebook/musicgen-stereo-large":
hidden_size = 2048
num_hidden_layers = 48
num_attention_heads = 32
else:
raise ValueError(f"Checkpoint should be one of `['small', 'medium', 'large']`, got {checkpoint}.")
raise ValueError(
"Checkpoint should be one of `['small', 'medium', 'large']` for the mono checkpoints, "
"or `['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
f"for the stereo checkpoints, got {checkpoint}."
)

if "stereo" in checkpoint:
audio_channels = 2
num_codebooks = 8
else:
audio_channels = 1
num_codebooks = 4

config = MusicgenDecoderConfig(
hidden_size=hidden_size,
ffn_dim=hidden_size * 4,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_codebooks=num_codebooks,
audio_channels=audio_channels,
)
return config


@torch.no_grad()
def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu"):
def convert_musicgen_checkpoint(
checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu", safe_serialization=False
):
fairseq_model = MusicGen.get_pretrained(checkpoint, device=device)
decoder_config = decoder_config_from_checkpoint(checkpoint)

Expand Down Expand Up @@ -146,18 +162,20 @@ def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=No
model.enc_to_dec_proj.load_state_dict(enc_dec_proj_state_dict)

# check we can do a forward pass
input_ids = torch.arange(0, 8, dtype=torch.long).reshape(2, -1)
decoder_input_ids = input_ids.reshape(2 * 4, -1)
input_ids = torch.arange(0, 2 * decoder_config.num_codebooks, dtype=torch.long).reshape(2, -1)
decoder_input_ids = input_ids.reshape(2 * decoder_config.num_codebooks, -1)

with torch.no_grad():
logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits

if logits.shape != (8, 1, 2048):
if logits.shape != (2 * decoder_config.num_codebooks, 1, 2048):
raise ValueError("Incorrect shape for logits")

# now construct the processor
tokenizer = AutoTokenizer.from_pretrained("t5-base")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/encodec_32khz", padding_side="left")
feature_extractor = AutoFeatureExtractor.from_pretrained(
"facebook/encodec_32khz", padding_side="left", feature_size=decoder_config.audio_channels
)

processor = MusicgenProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)

Expand All @@ -173,12 +191,12 @@ def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=No
if pytorch_dump_folder is not None:
Path(pytorch_dump_folder).mkdir(exist_ok=True)
logger.info(f"Saving model {checkpoint} to {pytorch_dump_folder}")
model.save_pretrained(pytorch_dump_folder)
model.save_pretrained(pytorch_dump_folder, safe_serialization=safe_serialization)
processor.save_pretrained(pytorch_dump_folder)

if repo_id:
logger.info(f"Pushing model {checkpoint} to {repo_id}")
model.push_to_hub(repo_id)
model.push_to_hub(repo_id, safe_serialization=safe_serialization)
processor.push_to_hub(repo_id)


Expand All @@ -189,7 +207,10 @@ def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=No
"--checkpoint",
default="small",
type=str,
help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: `['small', 'medium', 'large']`.",
help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: "
"`['small', 'medium', 'large']` for the mono checkpoints, or "
"`['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
"for the stereo checkpoints.",
)
parser.add_argument(
"--pytorch_dump_folder",
Expand All @@ -204,6 +225,11 @@ def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=No
parser.add_argument(
"--device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
)
parser.add_argument(
"--safe_serialization",
action="store_true",
help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).",
)

args = parser.parse_args()
convert_musicgen_checkpoint(args.checkpoint, args.pytorch_dump_folder, args.push_to_hub)
88 changes: 72 additions & 16 deletions src/transformers/models/musicgen/modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,21 +1077,33 @@ def build_delay_pattern_mask(self, input_ids: torch.LongTensor, pad_token_id: in
torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
)

channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * num_codebooks - 1:
if max_length < 2 * channel_codebooks - 1:
return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)

# fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(num_codebooks):
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
for codebook in range(channel_codebooks):
if self.config.audio_channels == 1:
# mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
else:
# left/right channels are interleaved in the generated codebooks, so handle one then the other
input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook]
input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1]

# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
delay_pattern = torch.triu(
torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1
torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal=max_length - channel_codebooks + 1
)
# then fill the lower triangular part (the BOS padding)
delay_pattern = delay_pattern + torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool))
delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.bool))

if self.config.audio_channels == 2:
# for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion
delay_pattern = delay_pattern.repeat_interleave(2, dim=0)

mask = ~delay_pattern.to(input_ids.device)
input_ids = mask * input_ids_shifted + ~mask * pad_token_id

Expand Down Expand Up @@ -1856,6 +1868,11 @@ def forward(
f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
"disabled by setting `chunk_length=None` in the audio encoder."
)

if self.config.audio_channels == 2 and audio_codes.shape[2] == self.decoder.num_codebooks // 2:
# mono input through encodec that we convert to stereo
audio_codes = audio_codes.repeat_interleave(2, dim=2)

decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)

# Decode
Expand Down Expand Up @@ -2074,12 +2091,42 @@ def _prepare_audio_encoder_kwargs_for_generation(
# 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name
encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = input_values

audio_encoder_outputs = encoder.encode(**encoder_kwargs)
if self.decoder.config.audio_channels == 1:
encoder_kwargs[model_input_name] = input_values
audio_encoder_outputs = encoder.encode(**encoder_kwargs)
audio_codes = audio_encoder_outputs.audio_codes
audio_scales = audio_encoder_outputs.audio_scales

audio_codes = audio_encoder_outputs.audio_codes
frames, bsz, codebooks, seq_len = audio_codes.shape
frames, bsz, codebooks, seq_len = audio_codes.shape

else:
if input_values.shape[1] != 2:
raise ValueError(
f"Expected stereo audio (2-channels) but example has {input_values.shape[1]} channel."
)

encoder_kwargs[model_input_name] = input_values[:, :1, :]
audio_encoder_outputs_left = encoder.encode(**encoder_kwargs)
audio_codes_left = audio_encoder_outputs_left.audio_codes
audio_scales_left = audio_encoder_outputs_left.audio_scales

encoder_kwargs[model_input_name] = input_values[:, 1:, :]
audio_encoder_outputs_right = encoder.encode(**encoder_kwargs)
audio_codes_right = audio_encoder_outputs_right.audio_codes
audio_scales_right = audio_encoder_outputs_right.audio_scales

frames, bsz, codebooks, seq_len = audio_codes_left.shape
# copy alternating left/right channel codes into stereo codebook
audio_codes = audio_codes_left.new_ones((frames, bsz, 2 * codebooks, seq_len))

audio_codes[:, :, ::2, :] = audio_codes_left
audio_codes[:, :, 1::2, :] = audio_codes_right

if audio_scales_left != [None] or audio_scales_right != [None]:
audio_scales = torch.stack([audio_scales_left, audio_scales_right], dim=1)
else:
audio_scales = [None] * bsz

if frames != 1:
raise ValueError(
Expand All @@ -2090,7 +2137,7 @@ def _prepare_audio_encoder_kwargs_for_generation(
decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)

model_kwargs["decoder_input_ids"] = decoder_input_ids
model_kwargs["audio_scales"] = audio_encoder_outputs.audio_scales
model_kwargs["audio_scales"] = audio_scales
return model_kwargs

def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
Expand Down Expand Up @@ -2433,16 +2480,25 @@ def generate(
if audio_scales is None:
audio_scales = [None] * batch_size

output_values = self.audio_encoder.decode(
output_ids,
audio_scales=audio_scales,
)
if self.decoder.config.audio_channels == 1:
output_values = self.audio_encoder.decode(
output_ids,
audio_scales=audio_scales,
).audio_values
else:
codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales)
output_values_left = codec_outputs_left.audio_values

codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales)
output_values_right = codec_outputs_right.audio_values

output_values = torch.cat([output_values_left, output_values_right], dim=1)

if generation_config.return_dict_in_generate:
outputs.sequences = output_values.audio_values
outputs.sequences = output_values
return outputs
else:
return output_values.audio_values
return output_values

def get_unconditional_inputs(self, num_samples=1):
"""
Expand Down
Loading

0 comments on commit 3325dfe

Please sign in to comment.