Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add encoder decoder model #851

Merged
merged 16 commits into from
Sep 1, 2023

Conversation

mht-sharma
Copy link
Contributor

@mht-sharma mht-sharma commented Mar 3, 2023

What does this PR do?

Support encoder-decoder export and inference in ORT

Fixes #367

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 3, 2023

The documentation is not available anymore as the PR was closed or merged.

@mht-sharma mht-sharma requested review from fxmarty and michaelbenayoun and removed request for fxmarty March 6, 2023 11:49
optimum/exporters/onnx/config.py Show resolved Hide resolved
@@ -304,7 +304,7 @@ def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str,
return reference_model_inputs


class EncoderDecoderOnnxConfig(OnnxSeq2SeqConfigWithPast):
class DummyEncoderDecoderOnnxConfig(OnnxSeq2SeqConfigWithPast):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "Dummy"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want this as the base class all encoder-decoder type models inherit from. do you have other naming suggestions. Maybe EncoderDecoderBaseOnnxConfig

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

tests/exporters/onnx/test_exporters_onnx_cli.py Outdated Show resolved Hide resolved
tests/exporters/onnx/test_onnx_export.py Show resolved Hide resolved
@hljjjmssyh
Copy link

This feature is very useful.
I'm looking forward to this feature can be merged into master as soon as possible.

Comment on lines 57 to 62
if model_type == "encoder-decoder" and task == "seq2seq-lm-with-past":
# The model uses bert as decoder and does not support past key values
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand this. This means that pkv is not tested, but may be supported depending on which arch is used as decoder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Comment on lines 85 to 128
"vision-encoder-decoder",
"encoder-decoder",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment for those as well (like above for segformer, etc.)?

tests/onnxruntime/test_modeling.py Outdated Show resolved Hide resolved
Comment on lines 3130 to 3242
if model_arch == "encoder-decoder" and use_cache is True:
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

tests/onnxruntime/test_modeling.py Outdated Show resolved Hide resolved
tests/onnxruntime/utils_onnxruntime_tests.py Outdated Show resolved Hide resolved
Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments, LGTM otherwise!

optimum/exporters/onnx/config.py Show resolved Hide resolved
@@ -3172,6 +3175,8 @@ def test_merge_from_onnx_and_save(self, model_arch):

@parameterized.expand(grid_parameters(FULL_GRID))
def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool):
if model_arch == "encoder-decoder" and use_cache is True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if model_arch == "encoder-decoder" and use_cache is True:
if model_arch == "encoder-decoder" and use_cache:

@@ -3232,6 +3238,9 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach

@parameterized.expand(grid_parameters(FULL_GRID))
def test_pipeline_text_generation(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool):
if model_arch == "encoder-decoder" and use_cache is True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if model_arch == "encoder-decoder" and use_cache is True:
if model_arch == "encoder-decoder" and use_cache:

@mht-sharma mht-sharma merged commit a39b1f5 into huggingface:main Sep 1, 2023
64 of 68 checks passed
Comment on lines +358 to +359
# TODO: validate the axis name for attention_mask
# common_inputs["attention_mask"][1] = "past_encoder_sequence_length + sequence_length"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was copied from class TextSeq2SeqOnnxConfig(OnnxSeq2SeqConfigWithPast): so if in future change is made it is done in both places.L167

Comment on lines +3126 to +3127
if model_arch == "encoder-decoder":
self.skipTest("encoder-decoder model type with use_merged=True is not supported for bert as a decoder")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that encoder-decoder is not tested for merged onnx?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test uses bert-bert model for testing, so only use_cache=False is used, which cannot work with use_merged=True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My question remains. There are some encoder-decoder that support past KV. I am wondering if this is tested anywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok got it. There was no suitable model to add in testing for such a model type. I will create a custom model and add a new pr now

Comment on lines +3101 to +3102
if model_arch == "encoder-decoder" and use_cache is True:
self.skipTest("encoder-decoder model type with use_cache=True is not supported for bert as a decoder")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question with cache

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, check this comment discussion_r1162759972

Comment on lines +3525 to +3526
if model_arch == "encoder-decoder":
use_cache = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should rather be a skipTest for the use_cache case.

Copy link
Contributor Author

@mht-sharma mht-sharma Sep 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this particular test,use_cache=True ids only tested so the model was never going to be tested. Hence, for the particular model I changed it to False

    @parameterized.expand(
        grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for encoder-decoder models
6 participants