-
Notifications
You must be signed in to change notification settings - Fork 455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add encoder decoder model #851
Conversation
The documentation is not available anymore as the PR was closed or merged. |
optimum/exporters/onnx/config.py
Outdated
@@ -304,7 +304,7 @@ def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, | |||
return reference_model_inputs | |||
|
|||
|
|||
class EncoderDecoderOnnxConfig(OnnxSeq2SeqConfigWithPast): | |||
class DummyEncoderDecoderOnnxConfig(OnnxSeq2SeqConfigWithPast): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why "Dummy"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want this as the base class all encoder-decoder type models inherit from. do you have other naming suggestions. Maybe EncoderDecoderBaseOnnxConfig
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
This feature is very useful. |
4792a4f
to
5ff651d
Compare
if model_type == "encoder-decoder" and task == "seq2seq-lm-with-past": | ||
# The model uses bert as decoder and does not support past key values | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand this. This means that pkv is not tested, but may be supported depending on which arch is used as decoder?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
"vision-encoder-decoder", | ||
"encoder-decoder", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment for those as well (like above for segformer, etc.)?
tests/onnxruntime/test_modeling.py
Outdated
if model_arch == "encoder-decoder" and use_cache is True: | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
0c09fd4
to
1101f56
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments, LGTM otherwise!
@@ -3172,6 +3175,8 @@ def test_merge_from_onnx_and_save(self, model_arch): | |||
|
|||
@parameterized.expand(grid_parameters(FULL_GRID)) | |||
def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): | |||
if model_arch == "encoder-decoder" and use_cache is True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if model_arch == "encoder-decoder" and use_cache is True: | |
if model_arch == "encoder-decoder" and use_cache: |
@@ -3232,6 +3238,9 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach | |||
|
|||
@parameterized.expand(grid_parameters(FULL_GRID)) | |||
def test_pipeline_text_generation(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): | |||
if model_arch == "encoder-decoder" and use_cache is True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if model_arch == "encoder-decoder" and use_cache is True: | |
if model_arch == "encoder-decoder" and use_cache: |
863447c
to
22f9c45
Compare
Co-authored-by: fxmarty <[email protected]>
Co-authored-by: fxmarty <[email protected]>
b635be4
to
6efa5d2
Compare
# TODO: validate the axis name for attention_mask | ||
# common_inputs["attention_mask"][1] = "past_encoder_sequence_length + sequence_length" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was copied from class TextSeq2SeqOnnxConfig(OnnxSeq2SeqConfigWithPast):
so if in future change is made it is done in both places.L167
if model_arch == "encoder-decoder": | ||
self.skipTest("encoder-decoder model type with use_merged=True is not supported for bert as a decoder") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean that encoder-decoder is not tested for merged onnx?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test uses bert-bert
model for testing, so only use_cache=False
is used, which cannot work with use_merged=True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My question remains. There are some encoder-decoder that support past KV. I am wondering if this is tested anywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok got it. There was no suitable model to add in testing for such a model type. I will create a custom model and add a new pr now
if model_arch == "encoder-decoder" and use_cache is True: | ||
self.skipTest("encoder-decoder model type with use_cache=True is not supported for bert as a decoder") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same question with cache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above, check this comment discussion_r1162759972
if model_arch == "encoder-decoder": | ||
use_cache = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should rather be a skipTest for the use_cache case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this particular test,use_cache=True
ids only tested so the model was never going to be tested. Hence, for the particular model I changed it to False
@parameterized.expand(
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
)
What does this PR do?
Support encoder-decoder export and inference in ORT
Fixes #367
Before submitting