Skip to content

Commit

Permalink
Fix/speecht5 bug (#28481)
Browse files Browse the repository at this point in the history
* Fix bug in SpeechT5 speech decoder prenet's forward method

- Removed redundant `repeat` operation on speaker_embeddings in the forward method. This line was erroneously duplicating the embeddings, leading to incorrect input size for concatenation and performance issues.
- Maintained original functionality of the method, ensuring the integrity of the speech decoder prenet's forward pass remains intact.
- This change resolves a critical bug affecting the model's performance in handling speaker embeddings.

* Refactor SpeechT5 text to speech integration tests

- Updated SpeechT5ForTextToSpeechIntegrationTests to accommodate the variability in sequence lengths due to dropout in the speech decoder pre-net. This change ensures that our tests are robust against random variations in generated speech, enhancing the reliability of our test suite.
- Removed hardcoded dimensions in test assertions. Replaced with dynamic checks based on model configuration and seed settings, ensuring tests remain valid across different runs and configurations.
- Added new test cases to thoroughly validate the shapes of generated spectrograms and waveforms. These tests leverage seed settings to ensure consistent and predictable behavior in testing, addressing potential issues in speech generation and vocoder processing.
- Fixed existing test cases where incorrect assumptions about output shapes led to potential errors.

* Fix bug in SpeechT5 speech decoder prenet's forward method

- Removed redundant `repeat` operation on speaker_embeddings in the forward method. This line was erroneously duplicating the embeddings, leading to incorrect input size for concatenation and performance issues.
- Maintained original functionality of the method, ensuring the integrity of the speech decoder prenet's forward pass remains intact.
- This change resolves a critical bug affecting the model's performance in handling speaker embeddings.

* Refactor SpeechT5 text to speech integration tests

- Updated SpeechT5ForTextToSpeechIntegrationTests to accommodate the variability in sequence lengths due to dropout in the speech decoder pre-net. This change ensures that our tests are robust against random variations in generated speech, enhancing the reliability of our test suite.
- Removed hardcoded dimensions in test assertions. Replaced with dynamic checks based on model configuration and seed settings, ensuring tests remain valid across different runs and configurations.
- Added new test cases to thoroughly validate the shapes of generated spectrograms and waveforms. These tests leverage seed settings to ensure consistent and predictable behavior in testing, addressing potential issues in speech generation and vocoder processing.
- Fixed existing test cases where incorrect assumptions about output shapes led to potential errors.

* Enhance handling of speaker embeddings in SpeechT5

- Refined the generate and generate_speech functions in the SpeechT5 class to robustly handle two scenarios for speaker embeddings: matching the batch size (one embedding per sample) and one-to-many (a single embedding for all samples in the batch).
- The update includes logic to repeat the speaker embedding when a single embedding is provided for multiple samples, and a ValueError is raised for any mismatched dimensions.
- Also added corresponding test cases to validate both scenarios, ensuring complete coverage and functionality for diverse speaker embedding situations.

* Improve Test Robustness with Randomized Speaker Embeddings
  • Loading branch information
NimaYaqmuri committed Jan 16, 2024
1 parent 66db33d commit 07ae53e
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 39 deletions.
26 changes: 21 additions & 5 deletions src/transformers/models/speecht5/modeling_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,13 +664,11 @@ def __init__(self, config):
)

self.final_layer = nn.Linear(config.speech_decoder_prenet_units, config.hidden_size)

self.encode_positions = SpeechT5ScaledPositionalEncoding(
config.positional_dropout,
config.hidden_size,
config.max_speech_positions,
)

self.speaker_embeds_layer = nn.Linear(config.speaker_embedding_dim + config.hidden_size, config.hidden_size)

def _consistent_dropout(self, inputs_embeds, p):
Expand All @@ -695,9 +693,7 @@ def forward(

if speaker_embeddings is not None:
speaker_embeddings = nn.functional.normalize(speaker_embeddings)
speaker_embeddings = speaker_embeddings.unsqueeze(1)
speaker_embeddings = speaker_embeddings.expand(-1, inputs_embeds.size(1), -1)
speaker_embeddings = speaker_embeddings.repeat(inputs_embeds.size(0), 1, 1)
speaker_embeddings = speaker_embeddings.unsqueeze(1).expand(-1, inputs_embeds.size(1), -1)
inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1)
inputs_embeds = nn.functional.relu(self.speaker_embeds_layer(inputs_embeds))

Expand Down Expand Up @@ -2825,6 +2821,16 @@ def generate(
`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
"""
if speaker_embeddings is not None:
batch_size = input_ids.size(0)
if speaker_embeddings.size(0) != batch_size:
if speaker_embeddings.size(0) == 1:
speaker_embeddings = speaker_embeddings.repeat(batch_size, 1)
else:
raise ValueError(
"The first dimension of speaker_embeddings must be either 1 or the same as batch_size."
)

return _generate_speech(
self,
input_ids,
Expand Down Expand Up @@ -2911,6 +2917,16 @@ def generate_speech(
`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
"""
if speaker_embeddings is not None:
batch_size = input_ids.size(0)
if speaker_embeddings.size(0) != batch_size:
if speaker_embeddings.size(0) == 1:
speaker_embeddings = speaker_embeddings.repeat(batch_size, 1)
else:
raise ValueError(
"The first dimension of speaker_embeddings must be either 1 or the same as batch size."
)

return _generate_speech(
self,
input_ids,
Expand Down
228 changes: 194 additions & 34 deletions tests/models/speecht5/test_modeling_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,107 +1029,267 @@ def _mock_init_weights(self, module):
class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
@cached_property
def default_model(self):
return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(torch_device)

@cached_property
def default_processor(self):
return SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")

@cached_property
def default_vocoder(self):
return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(torch_device)

def test_generation(self):
model = self.default_model
model.to(torch_device)
processor = self.default_processor

set_seed(555) # make deterministic

speaker_embeddings = torch.zeros((1, 512)).to(torch_device)

input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
input_text = "Mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
speaker_embeddings = torch.zeros((1, 512), device=torch_device)

# Generate speech and validate output dimensions
set_seed(555) # Ensure deterministic behavior
generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings)
self.assertEqual(generated_speech.shape, (230, model.config.num_mel_bins))

set_seed(555) # make deterministic
num_mel_bins = model.config.num_mel_bins
self.assertEqual(
generated_speech.shape[1], num_mel_bins, "Generated speech output has an unexpected number of mel bins."
)

# test model.generate, same method than generate_speech but with additional kwargs to absorb kwargs such as attention_mask
# Validate generation with additional kwargs using model.generate;
# same method than generate_speech
set_seed(555) # Reset seed for consistent results
generated_speech_with_generate = model.generate(
input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings
)
self.assertEqual(generated_speech_with_generate.shape, (230, model.config.num_mel_bins))
self.assertEqual(
generated_speech_with_generate.shape,
generated_speech.shape,
"Shape mismatch between generate_speech and generate methods.",
)

def test_batch_generation(self):
def test_one_to_many_generation(self):
model = self.default_model
model.to(torch_device)
processor = self.default_processor
vocoder = self.default_vocoder
set_seed(555) # make deterministic

input_text = [
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
"nor is mister quilter's manner less interesting than his matter",
"he tells us that at this festive season of the year with christmas and rosebeaf looming before us",
]
inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)

speaker_embeddings = torch.zeros((1, 512), device=torch_device)

# Generate spectrograms
set_seed(555) # Ensure deterministic behavior
spectrograms, spectrogram_lengths = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
attention_mask=inputs["attention_mask"],
return_output_lengths=True,
)
self.assertEqual(spectrograms.shape, (3, 262, model.config.num_mel_bins))

# Validate generated spectrogram dimensions
expected_batch_size = len(input_text)
num_mel_bins = model.config.num_mel_bins
actual_batch_size, _, actual_num_mel_bins = spectrograms.shape
self.assertEqual(actual_batch_size, expected_batch_size, "Batch size of generated spectrograms is incorrect.")
self.assertEqual(
actual_num_mel_bins, num_mel_bins, "Number of mel bins in batch generated spectrograms is incorrect."
)

# Generate waveforms using the vocoder
waveforms = vocoder(spectrograms)
waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]

# Check waveform results are the same with or without using vocder
set_seed(555)
# Validate generation with integrated vocoder
set_seed(555) # Reset seed for consistent results
waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
attention_mask=inputs["attention_mask"],
vocoder=vocoder,
return_output_lengths=True,
)
self.assertTrue(torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8))
self.assertEqual(waveform_lengths, waveform_lengths_with_vocoder)

# Check waveform results are the same with return_concrete_lengths=True/False
set_seed(555)
# Check consistency between waveforms generated with and without standalone vocoder
self.assertTrue(
torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8),
"Mismatch in waveforms generated with and without the standalone vocoder.",
)
self.assertEqual(
waveform_lengths,
waveform_lengths_with_vocoder,
"Waveform lengths differ between standalone and integrated vocoder generation.",
)

# Test generation consistency without returning lengths
set_seed(555) # Reset seed for consistent results
waveforms_with_vocoder_no_lengths = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
attention_mask=inputs["attention_mask"],
vocoder=vocoder,
return_output_lengths=False,
)
self.assertTrue(torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8))

# Check results when batching are consistent with results without batching
# Validate waveform consistency without length information
self.assertTrue(
torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8),
"Waveforms differ when generated with and without length information.",
)

# Validate batch vs. single instance generation consistency
for i, text in enumerate(input_text):
set_seed(555) # make deterministic
inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
set_seed(555) # Reset seed for consistent results
spectrogram = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
)
self.assertEqual(spectrogram.shape, spectrograms[i][: spectrogram_lengths[i]].shape)
self.assertTrue(torch.allclose(spectrogram, spectrograms[i][: spectrogram_lengths[i]], atol=5e-3))

# Check spectrogram shape consistency
self.assertEqual(
spectrogram.shape,
spectrograms[i][: spectrogram_lengths[i]].shape,
"Mismatch in spectrogram shape between batch and single instance generation.",
)

# Generate and validate waveform for single instance
waveform = vocoder(spectrogram)
self.assertEqual(waveform.shape, waveforms[i][: waveform_lengths[i]].shape)
# Check whether waveforms are the same with/without passing vocoder
set_seed(555)
waveform_with_vocoder = model.generate_speech(
self.assertEqual(
waveform.shape,
waveforms[i][: waveform_lengths[i]].shape,
"Mismatch in waveform shape between batch and single instance generation.",
)

# Check waveform consistency with integrated vocoder
set_seed(555) # Reset seed for consistent results
waveform_with_integrated_vocoder = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
vocoder=vocoder,
)
self.assertTrue(torch.allclose(waveform, waveform_with_vocoder, atol=1e-8))
self.assertTrue(
torch.allclose(waveform, waveform_with_integrated_vocoder, atol=1e-8),
"Mismatch in waveform between standalone and integrated vocoder for single instance generation.",
)

def test_batch_generation(self):
model = self.default_model
processor = self.default_processor
vocoder = self.default_vocoder

input_text = [
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
"nor is mister quilter's manner less interesting than his matter",
"he tells us that at this festive season of the year with christmas and rosebeaf looming before us",
]
inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
set_seed(555) # Ensure deterministic behavior
speaker_embeddings = torch.randn((len(input_text), 512), device=torch_device)

# Generate spectrograms
set_seed(555) # Reset seed for consistent results
spectrograms, spectrogram_lengths = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
attention_mask=inputs["attention_mask"],
return_output_lengths=True,
)

# Validate generated spectrogram dimensions
expected_batch_size = len(input_text)
num_mel_bins = model.config.num_mel_bins
actual_batch_size, _, actual_num_mel_bins = spectrograms.shape
self.assertEqual(
actual_batch_size,
expected_batch_size,
"Batch size of generated spectrograms is incorrect.",
)
self.assertEqual(
actual_num_mel_bins,
num_mel_bins,
"Number of mel bins in batch generated spectrograms is incorrect.",
)

# Generate waveforms using the vocoder
waveforms = vocoder(spectrograms)
waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]

# Validate generation with integrated vocoder
set_seed(555) # Reset seed for consistent results
waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
attention_mask=inputs["attention_mask"],
vocoder=vocoder,
return_output_lengths=True,
)

# Check consistency between waveforms generated with and without standalone vocoder
self.assertTrue(
torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8),
"Mismatch in waveforms generated with and without the standalone vocoder.",
)
self.assertEqual(
waveform_lengths,
waveform_lengths_with_vocoder,
"Waveform lengths differ between standalone and integrated vocoder generation.",
)

# Test generation consistency without returning lengths
set_seed(555) # Reset seed for consistent results
waveforms_with_vocoder_no_lengths = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
attention_mask=inputs["attention_mask"],
vocoder=vocoder,
return_output_lengths=False,
)

# Validate waveform consistency without length information
self.assertTrue(
torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8),
"Waveforms differ when generated with and without length information.",
)

# Validate batch vs. single instance generation consistency
for i, text in enumerate(input_text):
inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
current_speaker_embedding = speaker_embeddings[i].unsqueeze(0)
set_seed(555) # Reset seed for consistent results
spectrogram = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=current_speaker_embedding,
)

# Check spectrogram shape consistency
self.assertEqual(
spectrogram.shape,
spectrograms[i][: spectrogram_lengths[i]].shape,
"Mismatch in spectrogram shape between batch and single instance generation.",
)

# Generate and validate waveform for single instance
waveform = vocoder(spectrogram)
self.assertEqual(
waveform.shape,
waveforms[i][: waveform_lengths[i]].shape,
"Mismatch in waveform shape between batch and single instance generation.",
)

# Check waveform consistency with integrated vocoder
set_seed(555) # Reset seed for consistent results
waveform_with_integrated_vocoder = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=current_speaker_embedding,
vocoder=vocoder,
)
self.assertTrue(
torch.allclose(waveform, waveform_with_integrated_vocoder, atol=1e-8),
"Mismatch in waveform between standalone and integrated vocoder for single instance generation.",
)


@require_torch
Expand Down

0 comments on commit 07ae53e

Please sign in to comment.