Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: skip tests on unsupported models instead of passing #27265

Merged
merged 2 commits into from
Nov 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 24 additions & 38 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,8 +749,7 @@ def test_greedy_generate_dict_outputs_use_cache(self):
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()

if not hasattr(config, "use_cache"):
# only relevant if model has "use_cache"
return
self.skipTest("This model doesn't support caching")

config.use_cache = True
config.is_decoder = True
Expand Down Expand Up @@ -983,8 +982,7 @@ def test_beam_search_generate_dict_outputs_use_cache(self):
config.forced_eos_token_id = None

if not hasattr(config, "use_cache"):
# only relevant if model has "use_cache"
return
self.skipTest("This model doesn't support caching")

model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder:
Expand Down Expand Up @@ -1420,13 +1418,13 @@ def test_contrastive_generate(self):
for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return
self.skipTest("Won't fix: old model with different cache format")

config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()

# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
self.skipTest("This model doesn't support caching")
config.use_cache = True
config.is_decoder = True

Expand All @@ -1441,14 +1439,14 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return
self.skipTest("Won't fix: old model with different cache format")

# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()

# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
self.skipTest("This model doesn't support caching")
config.use_cache = True
config.is_decoder = True

Expand All @@ -1472,18 +1470,16 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
def test_contrastive_generate_low_memory(self):
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
# won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
if any(
model_name in model_class.__name__.lower()
for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"]
):
return
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
self.skipTest("Won't fix: old model with different cache format")
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
self.skipTest("TODO: fix me")

config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)

# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
self.skipTest("This model doesn't support caching")

config.use_cache = True
config.is_decoder = True
Expand All @@ -1510,8 +1506,6 @@ def test_contrastive_generate_low_memory(self):
)
self.assertListEqual(low_output.tolist(), high_output.tolist())

return

@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
def test_assisted_decoding_matches_greedy_search(self):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
Expand All @@ -1522,15 +1516,13 @@ def test_assisted_decoding_matches_greedy_search(self):
# - assisted_decoding does not support `batch_size > 1`

for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
self.skipTest("Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
):
return
self.skipTest("May fix in the future: need model-specific fixes")

# This for loop is a naive and temporary effort to make the test less flaky.
failed = 0
Expand All @@ -1540,7 +1532,7 @@ def test_assisted_decoding_matches_greedy_search(self):

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
self.skipTest("This model doesn't support caching")

config.use_cache = True
config.is_decoder = True
Expand Down Expand Up @@ -1587,24 +1579,21 @@ def test_assisted_decoding_matches_greedy_search(self):
def test_assisted_decoding_sample(self):
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).

for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
self.skipTest("Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"]
):
return
self.skipTest("May fix in the future: need model-specific fixes")

# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
self.skipTest("This model doesn't support caching")

config.use_cache = True
config.is_decoder = True
Expand Down Expand Up @@ -1716,7 +1705,7 @@ def test_past_key_values_format(self):

# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"):
return
self.skipTest("This model doesn't support caching")

model = model_class(config).to(torch_device)
if "use_cache" not in inputs:
Expand All @@ -1725,7 +1714,7 @@ def test_past_key_values_format(self):

# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
if "past_key_values" not in outputs:
return
self.skipTest("This model doesn't return `past_key_values`")

num_hidden_layers = (
getattr(config, "decoder_layers", None)
Expand Down Expand Up @@ -1832,18 +1821,15 @@ def test_generate_from_inputs_embeds_decoder_only(self):
def test_generate_continue_from_past_key_values(self):
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
# won't fix: old models with unique inputs/caches/others
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
return
# may fix in the future: needs modeling or test input preparation fixes for compatibility
self.skipTest("Won't fix: old model with unique inputs/caches/other")
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
return
self.skipTest("TODO: needs modeling or test input preparation fixes for compatibility")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is better, but as we are in for model_class in self.all_generative_model_classes, it probably means we skip many model_class unnecessary if a previous one is skipped here.

Good to merge. I can do a follow up PR to use subTest (as an exercise for me)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a second look, subTest doesn't make much sense here.


config, inputs = self.model_tester.prepare_config_and_inputs_for_common()

# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"):
return
self.skipTest("This model doesn't support caching")

# Let's make it always:
# 1. use cache (for obvious reasons)
Expand All @@ -1862,10 +1848,10 @@ def test_generate_continue_from_past_key_values(self):
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
model.generation_config.forced_eos_token_id = None

# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
outputs = model(**inputs)
if "past_key_values" not in outputs:
return
self.skipTest("This model doesn't return `past_key_values`")

# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
Expand Down
Loading