Skip to content

Commit

Permalink
Update Bark generation configs and tests (huggingface#25409)
Browse files Browse the repository at this point in the history
* update bark generation configs for more coherent parameter

* make style

* update bark hub repo
  • Loading branch information
ylacombe authored and EduardoPach committed Aug 9, 2023
1 parent 5c6de12 commit c66415a
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 20 deletions.
27 changes: 17 additions & 10 deletions src/transformers/models/bark/generation_configuration_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(
return_dict_in_generate=False,
output_hidden_states=False,
output_attentions=False,
temperature=0.7,
do_sample=True,
temperature=1.0,
do_sample=False,
text_encoding_offset=10_048,
text_pad_token=129_595,
semantic_infer_token=129_599,
Expand Down Expand Up @@ -70,9 +70,9 @@ def __init__(
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
temperature (`float`, *optional*, defaults to 0.7):
temperature (`float`, *optional*, defaults to 1.0):
The value used to modulate the next token probabilities.
do_sample (`bool`, *optional*, defaults to `True`):
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
text_encoding_offset (`int`, *optional*, defaults to 10_048):
Text encoding offset.
Expand Down Expand Up @@ -119,8 +119,8 @@ def __init__(
return_dict_in_generate=False,
output_hidden_states=False,
output_attentions=False,
temperature=0.7,
do_sample=True,
temperature=1.0,
do_sample=False,
coarse_semantic_pad_token=12_048,
coarse_rate_hz=75,
n_coarse_codebooks=2,
Expand Down Expand Up @@ -150,9 +150,9 @@ def __init__(
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
temperature (`float`, *optional*, defaults to 0.7):
temperature (`float`, *optional*, defaults to 1.0):
The value used to modulate the next token probabilities.
do_sample (`bool`, *optional*, defaults to `True`):
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
coarse_semantic_pad_token (`int`, *optional*, defaults to 12_048):
Coarse semantic pad token.
Expand Down Expand Up @@ -194,7 +194,7 @@ class BarkFineGenerationConfig(GenerationConfig):

def __init__(
self,
temperature=0.5,
temperature=1.0,
max_fine_history_length=512,
max_fine_input_length=1024,
n_fine_codebooks=8,
Expand All @@ -209,7 +209,7 @@ def __init__(
documentation from [`GenerationConfig`] for more information.
Args:
temperature (`float`, *optional*, defaults to 0.5):
temperature (`float`, *optional*):
The value used to modulate the next token probabilities.
max_fine_history_length (`int`, *optional*, defaults to 512):
Max length of the fine history vector.
Expand All @@ -224,6 +224,13 @@ def __init__(
self.max_fine_input_length = max_fine_input_length
self.n_fine_codebooks = n_fine_codebooks

def validate(self):
"""
Overrides GenerationConfig.validate because BarkFineGenerationConfig don't use any parameters outside
temperature.
"""
pass


class BarkGenerationConfig(GenerationConfig):
model_type = "bark"
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,7 +1336,7 @@ def generate(
input_buffer = fine_input[:, start_idx : start_idx + max_fine_input_length, :]
for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks):
logits = self.forward(n_inner, input_buffer).logits
if temperature is None:
if temperature is None or temperature == 1.0:
relevant_logits = logits[:, rel_start_fill_idx:, :codebook_size]
codebook_preds = torch.argmax(relevant_logits, -1)
else:
Expand Down Expand Up @@ -1499,8 +1499,8 @@ def generate(
```python
>>> from transformers import AutoProcessor, BarkModel
>>> processor = AutoProcessor.from_pretrained("ylacombe/bark-small")
>>> model = BarkModel.from_pretrained("ylacombe/bark-small")
>>> processor = AutoProcessor.from_pretrained("suno/bark-small")
>>> model = BarkModel.from_pretrained("suno/bark-small")
>>> # To add a voice preset, you can pass `voice_preset` to `BarkProcessor.__call__(...)`
>>> voice_preset = "v2/en_speaker_6"
Expand Down
24 changes: 18 additions & 6 deletions tests/models/bark/test_modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,11 +894,11 @@ def test_resize_embeddings_untied(self):
class BarkModelIntegrationTests(unittest.TestCase):
@cached_property
def model(self):
return BarkModel.from_pretrained("ylacombe/bark-large").to(torch_device)
return BarkModel.from_pretrained("suno/bark").to(torch_device)

@cached_property
def processor(self):
return BarkProcessor.from_pretrained("ylacombe/bark-large")
return BarkProcessor.from_pretrained("suno/bark")

@cached_property
def inputs(self):
Expand Down Expand Up @@ -937,6 +937,7 @@ def test_generate_semantic(self):
output_ids = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=1.0,
semantic_generation_config=self.semantic_generation_config,
)

Expand All @@ -957,13 +958,15 @@ def test_generate_coarse(self):
output_ids = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=1.0,
semantic_generation_config=self.semantic_generation_config,
)

output_ids = self.model.coarse_acoustics.generate(
output_ids,
history_prompt=history_prompt,
do_sample=False,
temperature=1.0,
semantic_generation_config=self.semantic_generation_config,
coarse_generation_config=self.coarse_generation_config,
codebook_size=self.model.generation_config.codebook_size,
Expand Down Expand Up @@ -994,13 +997,15 @@ def test_generate_fine(self):
output_ids = self.model.semantic.generate(
**input_ids,
do_sample=False,
temperature=1.0,
semantic_generation_config=self.semantic_generation_config,
)

output_ids = self.model.coarse_acoustics.generate(
output_ids,
history_prompt=history_prompt,
do_sample=False,
temperature=1.0,
semantic_generation_config=self.semantic_generation_config,
coarse_generation_config=self.coarse_generation_config,
codebook_size=self.model.generation_config.codebook_size,
Expand Down Expand Up @@ -1040,9 +1045,16 @@ def test_generate_end_to_end_with_sub_models_args(self):
input_ids = self.inputs

with torch.no_grad():
self.model.generate(**input_ids, do_sample=False, coarse_do_sample=True, coarse_temperature=0.7)
self.model.generate(
**input_ids, do_sample=False, coarse_do_sample=True, coarse_temperature=0.7, fine_temperature=0.3
**input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7
)
self.model.generate(
**input_ids,
do_sample=False,
temperature=1.0,
coarse_do_sample=True,
coarse_temperature=0.7,
fine_temperature=0.3,
)
self.model.generate(
**input_ids,
Expand All @@ -1061,7 +1073,7 @@ def test_generate_end_to_end_with_offload(self):

with torch.no_grad():
# standard generation
output_with_no_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None)
output_with_no_offload = self.model.generate(**input_ids, do_sample=False, temperature=1.0)

torch.cuda.empty_cache()

Expand All @@ -1088,7 +1100,7 @@ def test_generate_end_to_end_with_offload(self):
self.assertTrue(hasattr(self.model.semantic, "_hf_hook"))

# output with cpu offload
output_with_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None)
output_with_offload = self.model.generate(**input_ids, do_sample=False, temperature=1.0)

# checks if same output
self.assertListEqual(output_with_no_offload.tolist(), output_with_offload.tolist())
2 changes: 1 addition & 1 deletion tests/models/bark/test_processor_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
@require_torch
class BarkProcessorTest(unittest.TestCase):
def setUp(self):
self.checkpoint = "ylacombe/bark-small"
self.checkpoint = "suno/bark-small"
self.tmpdirname = tempfile.mkdtemp()
self.voice_preset = "en_speaker_1"
self.input_string = "This is a test string"
Expand Down

0 comments on commit c66415a

Please sign in to comment.