Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: GenerationConfig throws an exception when generate args are passed #27757

Merged
merged 1 commit into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,24 @@ def validate(self, is_init=False):
f"({self.num_beams})."
)

# 5. check common issue: passing `generate` arguments inside the generation config
generate_arguments = (
"logits_processor",
"stopping_criteria",
"prefix_allowed_tokens_fn",
"synced_gpus",
"assistant_model",
"streamer",
"negative_prompt_ids",
"negative_prompt_attention_mask",
)
for arg in generate_arguments:
if hasattr(self, arg):
raise ValueError(
f"Argument `{arg}` is not a valid argument of `GenerationConfig`. It should be passed to "
"`generate()` (or a pipeline) directly."
)

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
28 changes: 28 additions & 0 deletions tests/generation/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,34 @@ def test_kwarg_init(self):
self.assertEqual(loaded_config.do_sample, True)
self.assertEqual(loaded_config.num_beams, 1) # default value

def test_validate(self):
"""
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
"""
# Case 1: A correct configuration will not throw any warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig()
self.assertEqual(len(captured_warnings), 0)

# Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# parameters with `do_sample=False`). May be escalated to an error in the future.
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(temperature=0.5)
self.assertEqual(len(captured_warnings), 1)

# Case 3: Impossible sets of contraints/parameters will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(num_return_sequences=2)

# Case 4: Passing `generate()`-only flags to `validate` will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(logits_processor="foo")

# Case 5: Model-specific parameters will NOT raise an exception or a warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(foo="bar")
self.assertEqual(len(captured_warnings), 0)

def test_refuse_to_save(self):
"""Tests that we refuse to save a generation config that fails validation."""

Expand Down
Loading