Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: unset GenerationConfig parameters do not raise warning #29119

Merged
merged 4 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ class GenerationConfig(PushToHubMixin):

def __init__(self, **kwargs):
# Parameters that control the length of the output
# if the default `max_length` is updated here, make sure to update the `generate` tests following https://github.com/huggingface/transformers/pull/25030
self.max_length = kwargs.pop("max_length", 20)
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
self.min_length = kwargs.pop("min_length", 0)
Expand Down Expand Up @@ -403,32 +402,34 @@ def validate(self, is_init=False):
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
+ fix_location
)
if self.temperature != 1.0:
if self.temperature is not None and self.temperature != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
UserWarning,
)
if self.top_p != 1.0:
if self.top_p is not None and self.top_p != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
UserWarning,
)
if self.typical_p != 1.0:
if self.typical_p is not None and self.typical_p != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
UserWarning,
)
if self.top_k != 50 and self.penalty_alpha is None: # contrastive search uses top_k
if (
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
): # contrastive search uses top_k
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
UserWarning,
)
if self.epsilon_cutoff != 0.0:
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
UserWarning,
)
if self.eta_cutoff != 0.0:
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
UserWarning,
Expand All @@ -449,21 +450,21 @@ def validate(self, is_init=False):
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
UserWarning,
)
if self.num_beam_groups != 1:
if self.num_beam_groups is not None and self.num_beam_groups != 1:
warnings.warn(
single_beam_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
),
UserWarning,
)
if self.diversity_penalty != 0.0:
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
warnings.warn(
single_beam_wrong_parameter_msg.format(
flag_name="diversity_penalty", flag_value=self.diversity_penalty
),
UserWarning,
)
if self.length_penalty != 1.0:
if self.length_penalty is not None and self.length_penalty != 1.0:
warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
UserWarning,
Expand All @@ -487,7 +488,7 @@ def validate(self, is_init=False):
raise ValueError(
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
)
if self.num_beam_groups != 1:
if self.num_beam_groups is not None and self.num_beam_groups != 1:
raise ValueError(
constrained_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
Expand Down Expand Up @@ -996,6 +997,9 @@ def update(self, **kwargs):
setattr(self, key, value)
to_remove.append(key)

# remove all the attributes that were updated, without modifying the input dict
# Confirm that the updated instance is still valid
self.validate()

# Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs
1 change: 0 additions & 1 deletion src/transformers/generation/flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ def generate(

generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why get rid of these validations here (and in the other utils files)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I've added the .validate() call to .update(), which is the line that precedes all these removals :) It makes more sense to validate at update time, rather than relying on the user to manually validate.

self._validate_model_kwargs(model_kwargs.copy())

logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList()
Expand Down
1 change: 0 additions & 1 deletion src/transformers/generation/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,6 @@ def generate(

generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())

# 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
Expand Down
1 change: 0 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,6 @@ def generate(

generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())

# 2. Set generation parameters if not already defined
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def to_json_string(self, use_diff: bool = True) -> str:
config_dict = self.to_dict()
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

# Copied from transformers.generation.configuration_utils.GenerationConfig.update
def update(self, **kwargs):
"""
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
Expand All @@ -171,7 +170,7 @@ def update(self, **kwargs):
setattr(self, key, value)
to_remove.append(key)

# remove all the attributes that were updated, without modifying the input dict
# Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs

Expand Down
32 changes: 25 additions & 7 deletions tests/generation/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,26 +124,44 @@ def test_validate(self):
"""
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
"""
# Case 1: A correct configuration will not throw any warning
# A correct configuration will not throw any warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig()
self.assertEqual(len(captured_warnings), 0)

# Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# parameters with `do_sample=False`). May be escalated to an error in the future.
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(temperature=0.5)
GenerationConfig(do_sample=False, temperature=0.5)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit I've made a few implicit parameters explicit, to make reading the test easier :)

self.assertEqual(len(captured_warnings), 1)

# Case 3: Impossible sets of contraints/parameters will raise an exception
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
# that is done by unsetting the parameter (i.e. setting it to None)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
# BAD - 0.9 means it is still set, we should warn
generation_config_bad_temperature.update(temperature=0.9)
self.assertEqual(len(captured_warnings), 1)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
generation_config_bad_temperature.update(temperature=1.0)
self.assertEqual(len(captured_warnings), 0)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
# OK - None means it is unset, nothing to warn about
generation_config_bad_temperature.update(temperature=None)
self.assertEqual(len(captured_warnings), 0)

# Impossible sets of contraints/parameters will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(num_return_sequences=2)
GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2)

# Case 4: Passing `generate()`-only flags to `validate` will raise an exception
# Passing `generate()`-only flags to `validate` will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(logits_processor="foo")

# Case 5: Model-specific parameters will NOT raise an exception or a warning
# Model-specific parameters will NOT raise an exception or a warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(foo="bar")
self.assertEqual(len(captured_warnings), 0)
Expand Down
Loading