From f2618c9de7b4a1b16f06c174704aa5c9f69e3fe2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 3 Nov 2023 10:50:43 +0000 Subject: [PATCH 1/2] skip instead of return --- tests/generation/test_utils.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 7e2f242c6fd66c..5186816c6b597e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1832,18 +1832,15 @@ def test_generate_from_inputs_embeds_decoder_only(self): def test_generate_continue_from_past_key_values(self): # Tests that we can continue generating from past key values, returned from a previous `generate` call for model_class in self.all_generative_model_classes: - # won't fix: old models with unique inputs/caches/others if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]): - return - # may fix in the future: needs modeling or test input preparation fixes for compatibility + self.skipTest("Won't fix: old model with unique inputs/caches/other") if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): - return + self.skipTest("TODO: needs modeling or test input preparation fixes for compatibility") config, inputs = self.model_tester.prepare_config_and_inputs_for_common() - # If it doesn't support cache, pass the test if not hasattr(config, "use_cache"): - return + self.skipTest("This model doesn't support caching") # Let's make it always: # 1. use cache (for obvious reasons) @@ -1862,10 +1859,10 @@ def test_generate_continue_from_past_key_values(self): model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 model.generation_config.forced_eos_token_id = None - # If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format) + # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) outputs = model(**inputs) if "past_key_values" not in outputs: - return + self.skipTest("This model doesn't return `past_key_values`") # Traditional way of generating text, with `return_dict_in_generate` to return the past key values outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True) From 97bc0588f546ee3e90197602d3e73f415dbd4e3f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 7 Nov 2023 11:41:38 +0000 Subject: [PATCH 2/2] skip more tests (instead of returning) --- tests/generation/test_utils.py | 49 +++++++++++++--------------------- 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 5186816c6b597e..7531502be28922 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -749,8 +749,7 @@ def test_greedy_generate_dict_outputs_use_cache(self): config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() if not hasattr(config, "use_cache"): - # only relevant if model has "use_cache" - return + self.skipTest("This model doesn't support caching") config.use_cache = True config.is_decoder = True @@ -983,8 +982,7 @@ def test_beam_search_generate_dict_outputs_use_cache(self): config.forced_eos_token_id = None if not hasattr(config, "use_cache"): - # only relevant if model has "use_cache" - return + self.skipTest("This model doesn't support caching") model = model_class(config).to(torch_device).eval() if model.config.is_encoder_decoder: @@ -1420,13 +1418,13 @@ def test_contrastive_generate(self): for model_class in self.all_generative_model_classes: # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - return + self.skipTest("Won't fix: old model with different cache format") config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): - return + self.skipTest("This model doesn't support caching") config.use_cache = True config.is_decoder = True @@ -1441,14 +1439,14 @@ def test_contrastive_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - return + self.skipTest("Won't fix: old model with different cache format") # enable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): - return + self.skipTest("This model doesn't support caching") config.use_cache = True config.is_decoder = True @@ -1472,18 +1470,16 @@ def test_contrastive_generate_dict_outputs_use_cache(self): def test_contrastive_generate_low_memory(self): # Check that choosing 'low_memory' does not change the model output for model_class in self.all_generative_model_classes: - # won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format). - if any( - model_name in model_class.__name__.lower() - for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"] - ): - return + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]): + self.skipTest("Won't fix: old model with different cache format") + if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]): + self.skipTest("TODO: fix me") config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): - return + self.skipTest("This model doesn't support caching") config.use_cache = True config.is_decoder = True @@ -1510,8 +1506,6 @@ def test_contrastive_generate_low_memory(self): ) self.assertListEqual(low_output.tolist(), high_output.tolist()) - return - @slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%. def test_assisted_decoding_matches_greedy_search(self): # This test ensures that the assisted generation does not introduce output changes over greedy search. @@ -1522,15 +1516,13 @@ def test_assisted_decoding_matches_greedy_search(self): # - assisted_decoding does not support `batch_size > 1` for model_class in self.all_generative_model_classes: - # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - return - # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes + self.skipTest("Won't fix: old model with different cache format") if any( model_name in model_class.__name__.lower() for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] ): - return + self.skipTest("May fix in the future: need model-specific fixes") # This for loop is a naive and temporary effort to make the test less flaky. failed = 0 @@ -1540,7 +1532,7 @@ def test_assisted_decoding_matches_greedy_search(self): # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): - return + self.skipTest("This model doesn't support caching") config.use_cache = True config.is_decoder = True @@ -1587,24 +1579,21 @@ def test_assisted_decoding_matches_greedy_search(self): def test_assisted_decoding_sample(self): # Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the # exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking). - for model_class in self.all_generative_model_classes: - # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - return - # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes + self.skipTest("Won't fix: old model with different cache format") if any( model_name in model_class.__name__.lower() for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"] ): - return + self.skipTest("May fix in the future: need model-specific fixes") # enable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): - return + self.skipTest("This model doesn't support caching") config.use_cache = True config.is_decoder = True @@ -1716,7 +1705,7 @@ def test_past_key_values_format(self): # If it doesn't support cache, pass the test if not hasattr(config, "use_cache"): - return + self.skipTest("This model doesn't support caching") model = model_class(config).to(torch_device) if "use_cache" not in inputs: @@ -1725,7 +1714,7 @@ def test_past_key_values_format(self): # If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format) if "past_key_values" not in outputs: - return + self.skipTest("This model doesn't return `past_key_values`") num_hidden_layers = ( getattr(config, "decoder_layers", None)