diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index e70e8afa84..9852ec162c 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -40,6 +40,7 @@ Supported architectures: - DistilBert - Donut-Swin - Electra +- Encoder Decoder - Flaubert - GPT-2 - GPT-BigCode @@ -88,6 +89,7 @@ Supported architectures: - TROCR - UniSpeech - UniSpeech SAT +- Vision Encoder Decoder - Vit - Wav2Vec2 - Wav2Vec2 Conformer diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 28d32a55fb..2db092ead1 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -267,7 +267,7 @@ def torch_to_onnx_input_map(self) -> Dict[str, str]: return {} -class EncoderDecoderOnnxConfig(OnnxSeq2SeqConfigWithPast): +class EncoderDecoderBaseOnnxConfig(OnnxSeq2SeqConfigWithPast): DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,) def __init__( @@ -341,6 +341,34 @@ def __init__( self.DUMMY_INPUT_GENERATOR_CLASSES += self._past_key_values_generator + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + common_inputs = {} + if self._behavior is not ConfigBehavior.DECODER: + common_inputs["input_ids"] = {0: "batch_size", 1: "encoder_sequence_length"} + + common_inputs["attention_mask"] = {0: "batch_size", 1: "encoder_sequence_length"} + + if self._behavior is not ConfigBehavior.ENCODER: + # TODO: it is likely this pop() is unwanted as we then always hit + # https://github.com/huggingface/transformers/blob/v4.26.0/src/transformers/models/t5/modeling_t5.py#L965-L969 + common_inputs.pop("attention_mask") + + if self.use_past_in_inputs: + # TODO: validate the axis name for attention_mask + # common_inputs["attention_mask"][1] = "past_encoder_sequence_length + sequence_length" + common_inputs["decoder_input_ids"] = {0: "batch_size"} + else: + common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} + + if self.use_past_in_inputs: + self.add_past_key_values(common_inputs, direction="inputs") + + if self._behavior is ConfigBehavior.DECODER: + common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"} + + return common_inputs + @property def torch_to_onnx_input_map(self) -> Dict[str, str]: if self._behavior is ConfigBehavior.DECODER: diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index db0256e4d0..d20b668884 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -47,7 +47,7 @@ from .config import ( AudioOnnxConfig, AudioToTextOnnxConfig, - EncoderDecoderOnnxConfig, + EncoderDecoderBaseOnnxConfig, TextAndVisionOnnxConfig, TextDecoderOnnxConfig, TextEncoderOnnxConfig, @@ -1168,7 +1168,7 @@ class TrOCROnnxConfig(TextSeq2SeqOnnxConfig): ) -class VisionEncoderDecoderOnnxConfig(EncoderDecoderOnnxConfig): +class VisionEncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig ATOL_FOR_VALIDATION = 1e-3 @@ -1439,3 +1439,7 @@ def overwrite_shape_and_generate_input( dummy_input = dummy_input_gen.generate(input_name, framework=framework) return dummy_input + + +class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 0ebcfc2759..f4908dcb35 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -497,6 +497,11 @@ class TasksManager: onnx="ElectraOnnxConfig", tflite="ElectraTFLiteConfig", ), + "encoder-decoder": supported_tasks_mapping( + "text2text-generation", + "text2text-generation-with-past", + onnx="EncoderDecoderOnnxConfig", + ), "flaubert": supported_tasks_mapping( "feature-extraction", "fill-mask", diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index c436a900cb..42952a2581 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -1092,6 +1092,43 @@ class ORTModelForSeq2SeqLM(ORTModelForConditionalGeneration, GenerationMixin): auto_model_class = AutoModelForSeq2SeqLM main_input_name = "input_ids" + def __init__( + self, + encoder_session: ort.InferenceSession, + decoder_session: ort.InferenceSession, + config: "PretrainedConfig", + onnx_paths: List[str], + decoder_with_past_session: Optional[ort.InferenceSession] = None, + use_cache: bool = True, + use_io_binding: Optional[bool] = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + preprocessors: Optional[List] = None, + generation_config: Optional[GenerationConfig] = None, + **kwargs, + ): + super().__init__( + encoder_session, + decoder_session, + config, + onnx_paths, + decoder_with_past_session, + use_cache, + use_io_binding, + model_save_dir, + preprocessors, + generation_config, + **kwargs, + ) + + if config.model_type == "encoder-decoder": + self.encoder.normalized_config = NormalizedConfigManager.get_normalized_config_class( + config.encoder.model_type + )(config.encoder) + + self.decoder.normalized_config = NormalizedConfigManager.get_normalized_config_class( + config.decoder.model_type + )(config.decoder) + def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder: return ORTEncoder(session, self) @@ -1153,6 +1190,7 @@ def prepare_inputs_for_generation( input_ids, past_key_values=None, attention_mask=None, + token_type_ids=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index c5f3d5ce4c..e65c3c42d6 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -220,6 +220,7 @@ class NormalizedConfigManager: "distilbert": NormalizedTextConfig.with_args(num_attention_heads="n_heads", hidden_size="dim"), "donut-swin": NormalizedVisionConfig, "electra": NormalizedTextConfig, + "encoder-decoder": NormalizedEncoderDecoderConfig, "gpt2": GPT2LikeNormalizedTextConfig, "gpt-bigcode": GPT2LikeNormalizedTextConfig, "gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"), diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 7a20fa4528..53d08f58af 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -42,6 +42,7 @@ "camembert": "hf-internal-testing/tiny-random-camembert", "clip": "hf-internal-testing/tiny-random-CLIPModel", "convbert": "hf-internal-testing/tiny-random-ConvBertModel", + "convnext": "hf-internal-testing/tiny-random-convnext", "codegen": "hf-internal-testing/tiny-random-CodeGenModel", "cvt": "hf-internal-testing/tiny-random-CvTModel", "data2vec-text": "hf-internal-testing/tiny-random-Data2VecTextModel", @@ -51,10 +52,10 @@ "deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model", "deit": "hf-internal-testing/tiny-random-DeiTModel", "donut-swin": "hf-internal-testing/tiny-random-DonutSwinModel", - "convnext": "hf-internal-testing/tiny-random-convnext", "detr": "hf-internal-testing/tiny-random-DetrModel", # hf-internal-testing/tiny-random-detr is larger "distilbert": "hf-internal-testing/tiny-random-DistilBertModel", "electra": "hf-internal-testing/tiny-random-ElectraModel", + "encoder-decoder": "hf-internal-testing/tiny-random-EncoderDecoderModel-bert-bert", "flaubert": "hf-internal-testing/tiny-random-flaubert", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt-bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", @@ -161,6 +162,7 @@ "camembert": "camembert-base", "clip": "openai/clip-vit-base-patch32", "convbert": "YituTech/conv-bert-base", + "convnext": "facebook/convnext-tiny-224", "codegen": "hf-internal-testing/tiny-random-CodeGenModel", # Not using Salesforce/codegen-350M-multi because it takes too much time for testing. "data2vec-text": "facebook/data2vec-text-base", "data2vec-vision": "facebook/data2vec-vision-base", @@ -168,10 +170,10 @@ "deberta": "hf-internal-testing/tiny-random-DebertaModel", # Not using microsoft/deberta-base because it takes too much time for testing. "deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model", # Not using microsoft/deberta-v2-xlarge because it takes too much time for testing. "deit": "facebook/deit-small-patch16-224", - "convnext": "facebook/convnext-tiny-224", "detr": "hf-internal-testing/tiny-random-detr", # Not using facebook/detr-resnet-50 because it takes too much time for testing. "distilbert": "distilbert-base-cased", "electra": "google/electra-base-generator", + "encoder-decoder": "patrickvonplaten/bert2bert_cnn_daily_mail", "flaubert": "hf-internal-testing/tiny-random-flaubert", # TODO "gpt2": "gpt2", "gpt-neo": "EleutherAI/gpt-neo-125M", diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index 1d25240c18..b9291fa407 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -57,6 +57,9 @@ def _get_models_to_test(export_models_dict: Dict): for model_name, tasks in model_tasks.items(): for task in tasks: + if model_type == "encoder-decoder" and task == "text2text-generation-with-past": + # The model uses bert as decoder and does not support past key values + continue onnx_config_class = TasksManager.get_exporter_config_constructor( "onnx", task=task, model_type=model_type ) @@ -117,7 +120,13 @@ def _get_models_to_test(export_models_dict: Dict): # TODO: segformer task can not be automatically inferred # TODO: xlm-roberta model auto-infers text-generation, but we don't support it # TODO: perceiver auto-infers default, but we don't support it (why?) - if model_type not in ["segformer", "xlm-roberta", "perceiver", "vision-encoder-decoder"]: + # TODO: encoder-decoder auto-infers text3text-generation, but it uses bert as decoder and does not support past key values + if model_type not in [ + "segformer", + "xlm-roberta", + "perceiver", + "encoder-decoder", + ]: models_to_test.append( (f"{model_type}_no_task", model_type, model_name, "auto", "default", False, False) ) diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index 7e172452cd..fc2d143aec 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -161,6 +161,10 @@ def _get_models_to_test(export_models_dict: Dict): for model_name, tasks in model_tasks.items(): for task in tasks: + if model_type == "encoder-decoder" and task == "seq2seq-lm-with-past": + # The model uses bert as decoder and does not support past key values + continue + onnx_config_constructor = TasksManager.get_exporter_config_constructor( model_type=model_type, exporter="onnx", task=task, model_name=model_name ) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index f28a3676be..c8aff67cde 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -3059,6 +3059,7 @@ class ORTModelForSeq2SeqLMIntegrationTest(ORTModelTestMixin): # "bigbird_pegasus", "blenderbot", "blenderbot_small", + "encoder-decoder", "longt5", "m2m_100", "marian", @@ -3097,11 +3098,13 @@ def test_load_vanilla_transformers_which_is_not_supported(self): @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]})) def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str): + if model_arch == "encoder-decoder" and use_cache is True: + self.skipTest("encoder-decoder model type with use_cache=True is not supported for bert as a decoder") model_args = {"test_name": test_name, "model_arch": model_arch, "use_cache": use_cache} self._setup(model_args) model_id = MODEL_NAMES[model_arch] - model = ORTModelForSeq2SeqLM.from_pretrained(self.onnx_model_dirs[test_name]) + model = ORTModelForSeq2SeqLM.from_pretrained(self.onnx_model_dirs[test_name], use_cache=use_cache) tokenizer = get_preprocessor(model_id) text = "This is a sample output" tokens = tokenizer(text, return_tensors="pt") @@ -3120,6 +3123,9 @@ def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_transformers_and_save(self, model_arch): + if model_arch == "encoder-decoder": + self.skipTest("encoder-decoder model type with use_merged=True is not supported for bert as a decoder") + if "text2text-generation-with-past" not in TasksManager.get_supported_tasks_for_model_type( model_arch.replace("_", "-"), exporter="onnx" ): @@ -3139,6 +3145,9 @@ def test_merge_from_transformers_and_save(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_onnx_and_save(self, model_arch): + if model_arch == "encoder-decoder": + self.skipTest("encoder-decoder model type with use_merged=True is not supported for bert as a decoder") + model_id = MODEL_NAMES[model_arch] task = "text2text-generation-with-past" @@ -3164,6 +3173,9 @@ def test_merge_from_onnx_and_save(self, model_arch): @parameterized.expand(grid_parameters(FULL_GRID)) def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): + if model_arch == "encoder-decoder" and use_cache is True: + self.skipTest("encoder-decoder model type with use_cache=True is not supported for bert as a decoder") + if use_cache is False and use_merged is True: self.skipTest("use_cache=False, use_merged=True are uncompatible") @@ -3173,6 +3185,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach "use_cache": use_cache, "use_merged": use_merged, } + self._setup(model_args) model_id = MODEL_NAMES[model_arch] @@ -3201,6 +3214,9 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach tokenizer = get_preprocessor(model_id) tokens = tokenizer("This is a sample output", return_tensors="pt") decoder_start_token_id = transformers_model.config.decoder_start_token_id if model_arch != "mbart" else 2 + if model_arch == "encoder-decoder": + decoder_start_token_id = tokenizer.cls_token_id + decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id} with torch.no_grad(): @@ -3224,6 +3240,9 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach @parameterized.expand(grid_parameters(FULL_GRID)) def test_pipeline_text_generation(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): + if model_arch == "encoder-decoder" and use_cache is True: + self.skipTest("encoder-decoder model type with use_cache=True is not supported for bert as a decoder") + if use_cache is False and use_merged is True: self.skipTest("use_cache=False, use_merged=True are uncompatible") @@ -3233,30 +3252,35 @@ def test_pipeline_text_generation(self, test_name: str, model_arch: str, use_cac "use_cache": use_cache, "use_merged": use_merged, } + self._setup(model_args) model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForSeq2SeqLM.from_pretrained(self.onnx_model_dirs[test_name], use_cache=use_cache) tokenizer = get_preprocessor(model_id) + decoder_start_token_id = onnx_model.config.decoder_start_token_id if model_arch != "mbart" else 2 + if model_arch == "encoder-decoder": + decoder_start_token_id = tokenizer.cls_token_id + # Text2Text generation pipe = pipeline("text2text-generation", model=onnx_model, tokenizer=tokenizer) text = "This is a test" - outputs = pipe(text) + outputs = pipe(text, decoder_start_token_id=decoder_start_token_id) self.assertEqual(pipe.device, onnx_model.device) self.assertIsInstance(outputs[0]["generated_text"], str) # Summarization pipe = pipeline("summarization", model=onnx_model, tokenizer=tokenizer) text = "This is a test" - outputs = pipe(text) + outputs = pipe(text, decoder_start_token_id=decoder_start_token_id) self.assertEqual(pipe.device, onnx_model.device) self.assertIsInstance(outputs[0]["summary_text"], str) # Translation pipe = pipeline("translation_en_to_de", model=onnx_model, tokenizer=tokenizer) text = "This is a test" - outputs = pipe(text) + outputs = pipe(text, decoder_start_token_id=decoder_start_token_id) self.assertEqual(pipe.device, onnx_model.device) self.assertIsInstance(outputs[0]["translation_text"], str) @@ -3287,6 +3311,8 @@ def test_pipeline_model_is_none(self): @require_torch_gpu @pytest.mark.gpu_test def test_pipeline_on_gpu(self, test_name: str, model_arch: str, use_cache: bool): + if model_arch == "encoder-decoder": + use_cache = False model_args = {"test_name": test_name, "model_arch": model_arch, "use_cache": use_cache} self._setup(model_args) @@ -3358,8 +3384,8 @@ def test_pipeline_on_trt_execution_provider(self, test_name: str, model_arch: st @parameterized.expand(SUPPORTED_ARCHITECTURES) @pytest.mark.gpu_test # mark as GPU test as well to run the without/with cache timing test on the slow tests def test_compare_with_and_without_past_key_values(self, model_arch: str): - if model_arch == "m2m_100": - return # TODO: this test is failing for m2m_100 + if model_arch == "m2m_100" or model_arch == "encoder-decoder": + self.skipTest("m2m_100 and encoder-decoder comparison with/without pkv fail or is not supported") model_args = {"test_name": model_arch + "_False", "model_arch": model_arch, "use_cache": False} self._setup(model_args) model_args = {"test_name": model_arch + "_True", "model_arch": model_arch, "use_cache": True} @@ -3401,6 +3427,8 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str): @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]})) def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, model_arch: str, use_cache: bool): + if model_arch == "encoder-decoder" and use_cache is True: + self.skipTest("encoder-decoder model type with use_cache=True is not supported for bert as a decoder") model_args = { "test_name": test_name + "_True", "model_arch": model_arch, @@ -3446,6 +3474,8 @@ def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, mode @require_torch_gpu @pytest.mark.gpu_test def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): + if model_arch == "encoder-decoder": + use_cache = False if use_cache is False and use_merged is True: self.skipTest("use_cache=False, use_merged=True are uncompatible") @@ -3455,15 +3485,16 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: "use_cache": use_cache, "use_merged": use_merged, } + self._setup(model_args) model_id = MODEL_NAMES[model_arch] - onnx_model = ORTModelForSeq2SeqLM.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=False).to( - "cuda" - ) - io_model = ORTModelForSeq2SeqLM.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True).to( - "cuda" - ) + onnx_model = ORTModelForSeq2SeqLM.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=False, use_cache=use_cache + ).to("cuda") + io_model = ORTModelForSeq2SeqLM.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=True, use_cache=use_cache + ).to("cuda") self.assertFalse(onnx_model.use_io_binding) self.assertTrue(io_model.use_io_binding) @@ -3491,6 +3522,8 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: def test_compare_generation_to_io_binding( self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool ): + if model_arch == "encoder-decoder": + use_cache = False if use_cache is False and use_merged is True: self.skipTest("use_cache=False, use_merged=True are uncompatible") @@ -3500,15 +3533,16 @@ def test_compare_generation_to_io_binding( "use_cache": use_cache, "use_merged": use_merged, } + self._setup(model_args) model_id = MODEL_NAMES[model_arch] - onnx_model = ORTModelForSeq2SeqLM.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=False).to( - "cuda" - ) - io_model = ORTModelForSeq2SeqLM.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True).to( - "cuda" - ) + onnx_model = ORTModelForSeq2SeqLM.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=False, use_cache=use_cache + ).to("cuda") + io_model = ORTModelForSeq2SeqLM.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=True, use_cache=use_cache + ).to("cuda") tokenizer = get_preprocessor(model_id) tokens = tokenizer("This is a sample output", return_tensors="pt").to("cuda") diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index be0f3d0c31..cf776f11ed 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -39,6 +39,7 @@ "camembert": "hf-internal-testing/tiny-random-camembert", "clip": "hf-internal-testing/tiny-random-CLIPModel", "convbert": "hf-internal-testing/tiny-random-ConvBertModel", + "convnext": "hf-internal-testing/tiny-random-convnext", "codegen": "hf-internal-testing/tiny-random-CodeGenModel", "data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel", "data2vec_vision": "hf-internal-testing/tiny-random-Data2VecVisionModel", @@ -46,10 +47,10 @@ "deberta": "hf-internal-testing/tiny-random-DebertaModel", "deberta_v2": "hf-internal-testing/tiny-random-DebertaV2Model", "deit": "hf-internal-testing/tiny-random-DeiTModel", - "convnext": "hf-internal-testing/tiny-random-convnext", "detr": "hf-internal-testing/tiny-random-detr", "distilbert": "hf-internal-testing/tiny-random-DistilBertModel", "electra": "hf-internal-testing/tiny-random-ElectraModel", + "encoder-decoder": "hf-internal-testing/tiny-random-EncoderDecoderModel-bert-bert", "flaubert": "hf-internal-testing/tiny-random-flaubert", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",