diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 94334e76ef4b17..bbdaaec473fa78 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -664,13 +664,11 @@ def __init__(self, config): ) self.final_layer = nn.Linear(config.speech_decoder_prenet_units, config.hidden_size) - self.encode_positions = SpeechT5ScaledPositionalEncoding( config.positional_dropout, config.hidden_size, config.max_speech_positions, ) - self.speaker_embeds_layer = nn.Linear(config.speaker_embedding_dim + config.hidden_size, config.hidden_size) def _consistent_dropout(self, inputs_embeds, p): @@ -695,9 +693,7 @@ def forward( if speaker_embeddings is not None: speaker_embeddings = nn.functional.normalize(speaker_embeddings) - speaker_embeddings = speaker_embeddings.unsqueeze(1) - speaker_embeddings = speaker_embeddings.expand(-1, inputs_embeds.size(1), -1) - speaker_embeddings = speaker_embeddings.repeat(inputs_embeds.size(0), 1, 1) + speaker_embeddings = speaker_embeddings.unsqueeze(1).expand(-1, inputs_embeds.size(1), -1) inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1) inputs_embeds = nn.functional.relu(self.speaker_embeds_layer(inputs_embeds)) @@ -2825,6 +2821,16 @@ def generate( `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. """ + if speaker_embeddings is not None: + batch_size = input_ids.size(0) + if speaker_embeddings.size(0) != batch_size: + if speaker_embeddings.size(0) == 1: + speaker_embeddings = speaker_embeddings.repeat(batch_size, 1) + else: + raise ValueError( + "The first dimension of speaker_embeddings must be either 1 or the same as batch_size." + ) + return _generate_speech( self, input_ids, @@ -2911,6 +2917,16 @@ def generate_speech( `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. """ + if speaker_embeddings is not None: + batch_size = input_ids.size(0) + if speaker_embeddings.size(0) != batch_size: + if speaker_embeddings.size(0) == 1: + speaker_embeddings = speaker_embeddings.repeat(batch_size, 1) + else: + raise ValueError( + "The first dimension of speaker_embeddings must be either 1 or the same as batch size." + ) + return _generate_speech( self, input_ids, diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index c6b4b24873a2fe..7849b59d2935a7 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -1029,7 +1029,7 @@ def _mock_init_weights(self, module): class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase): @cached_property def default_model(self): - return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") + return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(torch_device) @cached_property def default_processor(self): @@ -1037,37 +1037,40 @@ def default_processor(self): @cached_property def default_vocoder(self): - return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") + return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(torch_device) def test_generation(self): model = self.default_model - model.to(torch_device) processor = self.default_processor - set_seed(555) # make deterministic - - speaker_embeddings = torch.zeros((1, 512)).to(torch_device) - - input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel" + input_text = "Mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device) + speaker_embeddings = torch.zeros((1, 512), device=torch_device) + # Generate speech and validate output dimensions + set_seed(555) # Ensure deterministic behavior generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings) - self.assertEqual(generated_speech.shape, (230, model.config.num_mel_bins)) - - set_seed(555) # make deterministic + num_mel_bins = model.config.num_mel_bins + self.assertEqual( + generated_speech.shape[1], num_mel_bins, "Generated speech output has an unexpected number of mel bins." + ) - # test model.generate, same method than generate_speech but with additional kwargs to absorb kwargs such as attention_mask + # Validate generation with additional kwargs using model.generate; + # same method than generate_speech + set_seed(555) # Reset seed for consistent results generated_speech_with_generate = model.generate( input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings ) - self.assertEqual(generated_speech_with_generate.shape, (230, model.config.num_mel_bins)) + self.assertEqual( + generated_speech_with_generate.shape, + generated_speech.shape, + "Shape mismatch between generate_speech and generate methods.", + ) - def test_batch_generation(self): + def test_one_to_many_generation(self): model = self.default_model - model.to(torch_device) processor = self.default_processor vocoder = self.default_vocoder - set_seed(555) # make deterministic input_text = [ "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel", @@ -1075,20 +1078,32 @@ def test_batch_generation(self): "he tells us that at this festive season of the year with christmas and rosebeaf looming before us", ] inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device) - speaker_embeddings = torch.zeros((1, 512), device=torch_device) + + # Generate spectrograms + set_seed(555) # Ensure deterministic behavior spectrograms, spectrogram_lengths = model.generate_speech( input_ids=inputs["input_ids"], speaker_embeddings=speaker_embeddings, attention_mask=inputs["attention_mask"], return_output_lengths=True, ) - self.assertEqual(spectrograms.shape, (3, 262, model.config.num_mel_bins)) + + # Validate generated spectrogram dimensions + expected_batch_size = len(input_text) + num_mel_bins = model.config.num_mel_bins + actual_batch_size, _, actual_num_mel_bins = spectrograms.shape + self.assertEqual(actual_batch_size, expected_batch_size, "Batch size of generated spectrograms is incorrect.") + self.assertEqual( + actual_num_mel_bins, num_mel_bins, "Number of mel bins in batch generated spectrograms is incorrect." + ) + + # Generate waveforms using the vocoder waveforms = vocoder(spectrograms) waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths] - # Check waveform results are the same with or without using vocder - set_seed(555) + # Validate generation with integrated vocoder + set_seed(555) # Reset seed for consistent results waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech( input_ids=inputs["input_ids"], speaker_embeddings=speaker_embeddings, @@ -1096,11 +1111,20 @@ def test_batch_generation(self): vocoder=vocoder, return_output_lengths=True, ) - self.assertTrue(torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8)) - self.assertEqual(waveform_lengths, waveform_lengths_with_vocoder) - # Check waveform results are the same with return_concrete_lengths=True/False - set_seed(555) + # Check consistency between waveforms generated with and without standalone vocoder + self.assertTrue( + torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8), + "Mismatch in waveforms generated with and without the standalone vocoder.", + ) + self.assertEqual( + waveform_lengths, + waveform_lengths_with_vocoder, + "Waveform lengths differ between standalone and integrated vocoder generation.", + ) + + # Test generation consistency without returning lengths + set_seed(555) # Reset seed for consistent results waveforms_with_vocoder_no_lengths = model.generate_speech( input_ids=inputs["input_ids"], speaker_embeddings=speaker_embeddings, @@ -1108,28 +1132,164 @@ def test_batch_generation(self): vocoder=vocoder, return_output_lengths=False, ) - self.assertTrue(torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8)) - # Check results when batching are consistent with results without batching + # Validate waveform consistency without length information + self.assertTrue( + torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8), + "Waveforms differ when generated with and without length information.", + ) + + # Validate batch vs. single instance generation consistency for i, text in enumerate(input_text): - set_seed(555) # make deterministic inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device) + set_seed(555) # Reset seed for consistent results spectrogram = model.generate_speech( input_ids=inputs["input_ids"], speaker_embeddings=speaker_embeddings, ) - self.assertEqual(spectrogram.shape, spectrograms[i][: spectrogram_lengths[i]].shape) - self.assertTrue(torch.allclose(spectrogram, spectrograms[i][: spectrogram_lengths[i]], atol=5e-3)) + + # Check spectrogram shape consistency + self.assertEqual( + spectrogram.shape, + spectrograms[i][: spectrogram_lengths[i]].shape, + "Mismatch in spectrogram shape between batch and single instance generation.", + ) + + # Generate and validate waveform for single instance waveform = vocoder(spectrogram) - self.assertEqual(waveform.shape, waveforms[i][: waveform_lengths[i]].shape) - # Check whether waveforms are the same with/without passing vocoder - set_seed(555) - waveform_with_vocoder = model.generate_speech( + self.assertEqual( + waveform.shape, + waveforms[i][: waveform_lengths[i]].shape, + "Mismatch in waveform shape between batch and single instance generation.", + ) + + # Check waveform consistency with integrated vocoder + set_seed(555) # Reset seed for consistent results + waveform_with_integrated_vocoder = model.generate_speech( input_ids=inputs["input_ids"], speaker_embeddings=speaker_embeddings, vocoder=vocoder, ) - self.assertTrue(torch.allclose(waveform, waveform_with_vocoder, atol=1e-8)) + self.assertTrue( + torch.allclose(waveform, waveform_with_integrated_vocoder, atol=1e-8), + "Mismatch in waveform between standalone and integrated vocoder for single instance generation.", + ) + + def test_batch_generation(self): + model = self.default_model + processor = self.default_processor + vocoder = self.default_vocoder + + input_text = [ + "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel", + "nor is mister quilter's manner less interesting than his matter", + "he tells us that at this festive season of the year with christmas and rosebeaf looming before us", + ] + inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device) + set_seed(555) # Ensure deterministic behavior + speaker_embeddings = torch.randn((len(input_text), 512), device=torch_device) + + # Generate spectrograms + set_seed(555) # Reset seed for consistent results + spectrograms, spectrogram_lengths = model.generate_speech( + input_ids=inputs["input_ids"], + speaker_embeddings=speaker_embeddings, + attention_mask=inputs["attention_mask"], + return_output_lengths=True, + ) + + # Validate generated spectrogram dimensions + expected_batch_size = len(input_text) + num_mel_bins = model.config.num_mel_bins + actual_batch_size, _, actual_num_mel_bins = spectrograms.shape + self.assertEqual( + actual_batch_size, + expected_batch_size, + "Batch size of generated spectrograms is incorrect.", + ) + self.assertEqual( + actual_num_mel_bins, + num_mel_bins, + "Number of mel bins in batch generated spectrograms is incorrect.", + ) + + # Generate waveforms using the vocoder + waveforms = vocoder(spectrograms) + waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths] + + # Validate generation with integrated vocoder + set_seed(555) # Reset seed for consistent results + waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech( + input_ids=inputs["input_ids"], + speaker_embeddings=speaker_embeddings, + attention_mask=inputs["attention_mask"], + vocoder=vocoder, + return_output_lengths=True, + ) + + # Check consistency between waveforms generated with and without standalone vocoder + self.assertTrue( + torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8), + "Mismatch in waveforms generated with and without the standalone vocoder.", + ) + self.assertEqual( + waveform_lengths, + waveform_lengths_with_vocoder, + "Waveform lengths differ between standalone and integrated vocoder generation.", + ) + + # Test generation consistency without returning lengths + set_seed(555) # Reset seed for consistent results + waveforms_with_vocoder_no_lengths = model.generate_speech( + input_ids=inputs["input_ids"], + speaker_embeddings=speaker_embeddings, + attention_mask=inputs["attention_mask"], + vocoder=vocoder, + return_output_lengths=False, + ) + + # Validate waveform consistency without length information + self.assertTrue( + torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8), + "Waveforms differ when generated with and without length information.", + ) + + # Validate batch vs. single instance generation consistency + for i, text in enumerate(input_text): + inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device) + current_speaker_embedding = speaker_embeddings[i].unsqueeze(0) + set_seed(555) # Reset seed for consistent results + spectrogram = model.generate_speech( + input_ids=inputs["input_ids"], + speaker_embeddings=current_speaker_embedding, + ) + + # Check spectrogram shape consistency + self.assertEqual( + spectrogram.shape, + spectrograms[i][: spectrogram_lengths[i]].shape, + "Mismatch in spectrogram shape between batch and single instance generation.", + ) + + # Generate and validate waveform for single instance + waveform = vocoder(spectrogram) + self.assertEqual( + waveform.shape, + waveforms[i][: waveform_lengths[i]].shape, + "Mismatch in waveform shape between batch and single instance generation.", + ) + + # Check waveform consistency with integrated vocoder + set_seed(555) # Reset seed for consistent results + waveform_with_integrated_vocoder = model.generate_speech( + input_ids=inputs["input_ids"], + speaker_embeddings=current_speaker_embedding, + vocoder=vocoder, + ) + self.assertTrue( + torch.allclose(waveform, waveform_with_integrated_vocoder, atol=1e-8), + "Mismatch in waveform between standalone and integrated vocoder for single instance generation.", + ) @require_torch