From c66415ab848527d5332f1f084aeb3da33fc65c7c Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Wed, 9 Aug 2023 18:28:02 +0200 Subject: [PATCH] Update Bark generation configs and tests (#25409) * update bark generation configs for more coherent parameter * make style * update bark hub repo --- .../bark/generation_configuration_bark.py | 27 ++++++++++++------- src/transformers/models/bark/modeling_bark.py | 6 ++--- tests/models/bark/test_modeling_bark.py | 24 ++++++++++++----- tests/models/bark/test_processor_bark.py | 2 +- 4 files changed, 39 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/bark/generation_configuration_bark.py b/src/transformers/models/bark/generation_configuration_bark.py index 9a280b99898da7..ea00d6f0516ab1 100644 --- a/src/transformers/models/bark/generation_configuration_bark.py +++ b/src/transformers/models/bark/generation_configuration_bark.py @@ -36,8 +36,8 @@ def __init__( return_dict_in_generate=False, output_hidden_states=False, output_attentions=False, - temperature=0.7, - do_sample=True, + temperature=1.0, + do_sample=False, text_encoding_offset=10_048, text_pad_token=129_595, semantic_infer_token=129_599, @@ -70,9 +70,9 @@ def __init__( output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. - temperature (`float`, *optional*, defaults to 0.7): + temperature (`float`, *optional*, defaults to 1.0): The value used to modulate the next token probabilities. - do_sample (`bool`, *optional*, defaults to `True`): + do_sample (`bool`, *optional*, defaults to `False`): Whether or not to use sampling ; use greedy decoding otherwise. text_encoding_offset (`int`, *optional*, defaults to 10_048): Text encoding offset. @@ -119,8 +119,8 @@ def __init__( return_dict_in_generate=False, output_hidden_states=False, output_attentions=False, - temperature=0.7, - do_sample=True, + temperature=1.0, + do_sample=False, coarse_semantic_pad_token=12_048, coarse_rate_hz=75, n_coarse_codebooks=2, @@ -150,9 +150,9 @@ def __init__( output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. - temperature (`float`, *optional*, defaults to 0.7): + temperature (`float`, *optional*, defaults to 1.0): The value used to modulate the next token probabilities. - do_sample (`bool`, *optional*, defaults to `True`): + do_sample (`bool`, *optional*, defaults to `False`): Whether or not to use sampling ; use greedy decoding otherwise. coarse_semantic_pad_token (`int`, *optional*, defaults to 12_048): Coarse semantic pad token. @@ -194,7 +194,7 @@ class BarkFineGenerationConfig(GenerationConfig): def __init__( self, - temperature=0.5, + temperature=1.0, max_fine_history_length=512, max_fine_input_length=1024, n_fine_codebooks=8, @@ -209,7 +209,7 @@ def __init__( documentation from [`GenerationConfig`] for more information. Args: - temperature (`float`, *optional*, defaults to 0.5): + temperature (`float`, *optional*): The value used to modulate the next token probabilities. max_fine_history_length (`int`, *optional*, defaults to 512): Max length of the fine history vector. @@ -224,6 +224,13 @@ def __init__( self.max_fine_input_length = max_fine_input_length self.n_fine_codebooks = n_fine_codebooks + def validate(self): + """ + Overrides GenerationConfig.validate because BarkFineGenerationConfig don't use any parameters outside + temperature. + """ + pass + class BarkGenerationConfig(GenerationConfig): model_type = "bark" diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 1cc42bc811dee4..368c0b5e01e22c 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -1336,7 +1336,7 @@ def generate( input_buffer = fine_input[:, start_idx : start_idx + max_fine_input_length, :] for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks): logits = self.forward(n_inner, input_buffer).logits - if temperature is None: + if temperature is None or temperature == 1.0: relevant_logits = logits[:, rel_start_fill_idx:, :codebook_size] codebook_preds = torch.argmax(relevant_logits, -1) else: @@ -1499,8 +1499,8 @@ def generate( ```python >>> from transformers import AutoProcessor, BarkModel - >>> processor = AutoProcessor.from_pretrained("ylacombe/bark-small") - >>> model = BarkModel.from_pretrained("ylacombe/bark-small") + >>> processor = AutoProcessor.from_pretrained("suno/bark-small") + >>> model = BarkModel.from_pretrained("suno/bark-small") >>> # To add a voice preset, you can pass `voice_preset` to `BarkProcessor.__call__(...)` >>> voice_preset = "v2/en_speaker_6" diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 82a902ded44aa9..6fc4cb58a63936 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -894,11 +894,11 @@ def test_resize_embeddings_untied(self): class BarkModelIntegrationTests(unittest.TestCase): @cached_property def model(self): - return BarkModel.from_pretrained("ylacombe/bark-large").to(torch_device) + return BarkModel.from_pretrained("suno/bark").to(torch_device) @cached_property def processor(self): - return BarkProcessor.from_pretrained("ylacombe/bark-large") + return BarkProcessor.from_pretrained("suno/bark") @cached_property def inputs(self): @@ -937,6 +937,7 @@ def test_generate_semantic(self): output_ids = self.model.semantic.generate( **input_ids, do_sample=False, + temperature=1.0, semantic_generation_config=self.semantic_generation_config, ) @@ -957,6 +958,7 @@ def test_generate_coarse(self): output_ids = self.model.semantic.generate( **input_ids, do_sample=False, + temperature=1.0, semantic_generation_config=self.semantic_generation_config, ) @@ -964,6 +966,7 @@ def test_generate_coarse(self): output_ids, history_prompt=history_prompt, do_sample=False, + temperature=1.0, semantic_generation_config=self.semantic_generation_config, coarse_generation_config=self.coarse_generation_config, codebook_size=self.model.generation_config.codebook_size, @@ -994,6 +997,7 @@ def test_generate_fine(self): output_ids = self.model.semantic.generate( **input_ids, do_sample=False, + temperature=1.0, semantic_generation_config=self.semantic_generation_config, ) @@ -1001,6 +1005,7 @@ def test_generate_fine(self): output_ids, history_prompt=history_prompt, do_sample=False, + temperature=1.0, semantic_generation_config=self.semantic_generation_config, coarse_generation_config=self.coarse_generation_config, codebook_size=self.model.generation_config.codebook_size, @@ -1040,9 +1045,16 @@ def test_generate_end_to_end_with_sub_models_args(self): input_ids = self.inputs with torch.no_grad(): - self.model.generate(**input_ids, do_sample=False, coarse_do_sample=True, coarse_temperature=0.7) self.model.generate( - **input_ids, do_sample=False, coarse_do_sample=True, coarse_temperature=0.7, fine_temperature=0.3 + **input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7 + ) + self.model.generate( + **input_ids, + do_sample=False, + temperature=1.0, + coarse_do_sample=True, + coarse_temperature=0.7, + fine_temperature=0.3, ) self.model.generate( **input_ids, @@ -1061,7 +1073,7 @@ def test_generate_end_to_end_with_offload(self): with torch.no_grad(): # standard generation - output_with_no_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None) + output_with_no_offload = self.model.generate(**input_ids, do_sample=False, temperature=1.0) torch.cuda.empty_cache() @@ -1088,7 +1100,7 @@ def test_generate_end_to_end_with_offload(self): self.assertTrue(hasattr(self.model.semantic, "_hf_hook")) # output with cpu offload - output_with_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None) + output_with_offload = self.model.generate(**input_ids, do_sample=False, temperature=1.0) # checks if same output self.assertListEqual(output_with_no_offload.tolist(), output_with_offload.tolist()) diff --git a/tests/models/bark/test_processor_bark.py b/tests/models/bark/test_processor_bark.py index aa25951b5c4106..15b0871d81448d 100644 --- a/tests/models/bark/test_processor_bark.py +++ b/tests/models/bark/test_processor_bark.py @@ -26,7 +26,7 @@ @require_torch class BarkProcessorTest(unittest.TestCase): def setUp(self): - self.checkpoint = "ylacombe/bark-small" + self.checkpoint = "suno/bark-small" self.tmpdirname = tempfile.mkdtemp() self.voice_preset = "en_speaker_1" self.input_string = "This is a test string"