Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add encoder decoder model #851

Merged
merged 16 commits into from
Sep 1, 2023
2 changes: 2 additions & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Supported architectures:
- DistilBert
- Donut-Swin
- Electra
- Encoder Decoder
- Flaubert
- GPT-2
- GPT-BigCode
Expand Down Expand Up @@ -88,6 +89,7 @@ Supported architectures:
- TROCR
- UniSpeech
- UniSpeech SAT
- Vision Encoder Decoder
- Vit
- Wav2Vec2
- Wav2Vec2 Conformer
Expand Down
30 changes: 29 additions & 1 deletion optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def torch_to_onnx_input_map(self) -> Dict[str, str]:
return {}


class EncoderDecoderOnnxConfig(OnnxSeq2SeqConfigWithPast):
class EncoderDecoderBaseOnnxConfig(OnnxSeq2SeqConfigWithPast):
mht-sharma marked this conversation as resolved.
Show resolved Hide resolved
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,)

def __init__(
Expand Down Expand Up @@ -341,6 +341,34 @@ def __init__(

self.DUMMY_INPUT_GENERATOR_CLASSES += self._past_key_values_generator

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {}
if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_ids"] = {0: "batch_size", 1: "encoder_sequence_length"}

common_inputs["attention_mask"] = {0: "batch_size", 1: "encoder_sequence_length"}

if self._behavior is not ConfigBehavior.ENCODER:
# TODO: it is likely this pop() is unwanted as we then always hit
# https://github.com/huggingface/transformers/blob/v4.26.0/src/transformers/models/t5/modeling_t5.py#L965-L969
common_inputs.pop("attention_mask")
mht-sharma marked this conversation as resolved.
Show resolved Hide resolved

if self.use_past_in_inputs:
# TODO: validate the axis name for attention_mask
# common_inputs["attention_mask"][1] = "past_encoder_sequence_length + sequence_length"
Comment on lines +358 to +359
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was copied from class TextSeq2SeqOnnxConfig(OnnxSeq2SeqConfigWithPast): so if in future change is made it is done in both places.L167

common_inputs["decoder_input_ids"] = {0: "batch_size"}
else:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}

if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")

if self._behavior is ConfigBehavior.DECODER:
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}

return common_inputs

@property
def torch_to_onnx_input_map(self) -> Dict[str, str]:
if self._behavior is ConfigBehavior.DECODER:
Expand Down
8 changes: 6 additions & 2 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from .config import (
AudioOnnxConfig,
AudioToTextOnnxConfig,
EncoderDecoderOnnxConfig,
EncoderDecoderBaseOnnxConfig,
TextAndVisionOnnxConfig,
TextDecoderOnnxConfig,
TextEncoderOnnxConfig,
Expand Down Expand Up @@ -1168,7 +1168,7 @@ class TrOCROnnxConfig(TextSeq2SeqOnnxConfig):
)


class VisionEncoderDecoderOnnxConfig(EncoderDecoderOnnxConfig):
class VisionEncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig
ATOL_FOR_VALIDATION = 1e-3

Expand Down Expand Up @@ -1439,3 +1439,7 @@ def overwrite_shape_and_generate_input(
dummy_input = dummy_input_gen.generate(input_name, framework=framework)

return dummy_input


class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig
5 changes: 5 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,11 @@ class TasksManager:
onnx="ElectraOnnxConfig",
tflite="ElectraTFLiteConfig",
),
"encoder-decoder": supported_tasks_mapping(
"text2text-generation",
"text2text-generation-with-past",
onnx="EncoderDecoderOnnxConfig",
),
"flaubert": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
Expand Down
38 changes: 38 additions & 0 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,43 @@ class ORTModelForSeq2SeqLM(ORTModelForConditionalGeneration, GenerationMixin):
auto_model_class = AutoModelForSeq2SeqLM
main_input_name = "input_ids"

def __init__(
self,
encoder_session: ort.InferenceSession,
decoder_session: ort.InferenceSession,
config: "PretrainedConfig",
onnx_paths: List[str],
decoder_with_past_session: Optional[ort.InferenceSession] = None,
use_cache: bool = True,
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
preprocessors: Optional[List] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
):
super().__init__(
encoder_session,
decoder_session,
config,
onnx_paths,
decoder_with_past_session,
use_cache,
use_io_binding,
model_save_dir,
preprocessors,
generation_config,
**kwargs,
)

if config.model_type == "encoder-decoder":
self.encoder.normalized_config = NormalizedConfigManager.get_normalized_config_class(
config.encoder.model_type
)(config.encoder)

self.decoder.normalized_config = NormalizedConfigManager.get_normalized_config_class(
config.decoder.model_type
)(config.decoder)

def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder:
return ORTEncoder(session, self)

Expand Down Expand Up @@ -1153,6 +1190,7 @@ def prepare_inputs_for_generation(
input_ids,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class NormalizedConfigManager:
"distilbert": NormalizedTextConfig.with_args(num_attention_heads="n_heads", hidden_size="dim"),
"donut-swin": NormalizedVisionConfig,
"electra": NormalizedTextConfig,
"encoder-decoder": NormalizedEncoderDecoderConfig,
"gpt2": GPT2LikeNormalizedTextConfig,
"gpt-bigcode": GPT2LikeNormalizedTextConfig,
"gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
Expand Down
6 changes: 4 additions & 2 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"camembert": "hf-internal-testing/tiny-random-camembert",
"clip": "hf-internal-testing/tiny-random-CLIPModel",
"convbert": "hf-internal-testing/tiny-random-ConvBertModel",
"convnext": "hf-internal-testing/tiny-random-convnext",
"codegen": "hf-internal-testing/tiny-random-CodeGenModel",
"cvt": "hf-internal-testing/tiny-random-CvTModel",
"data2vec-text": "hf-internal-testing/tiny-random-Data2VecTextModel",
Expand All @@ -51,10 +52,10 @@
"deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model",
"deit": "hf-internal-testing/tiny-random-DeiTModel",
"donut-swin": "hf-internal-testing/tiny-random-DonutSwinModel",
"convnext": "hf-internal-testing/tiny-random-convnext",
"detr": "hf-internal-testing/tiny-random-DetrModel", # hf-internal-testing/tiny-random-detr is larger
"distilbert": "hf-internal-testing/tiny-random-DistilBertModel",
"electra": "hf-internal-testing/tiny-random-ElectraModel",
"encoder-decoder": "hf-internal-testing/tiny-random-EncoderDecoderModel-bert-bert",
"flaubert": "hf-internal-testing/tiny-random-flaubert",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt-bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
Expand Down Expand Up @@ -161,17 +162,18 @@
"camembert": "camembert-base",
"clip": "openai/clip-vit-base-patch32",
"convbert": "YituTech/conv-bert-base",
"convnext": "facebook/convnext-tiny-224",
"codegen": "hf-internal-testing/tiny-random-CodeGenModel", # Not using Salesforce/codegen-350M-multi because it takes too much time for testing.
"data2vec-text": "facebook/data2vec-text-base",
"data2vec-vision": "facebook/data2vec-vision-base",
"data2vec-audio": "facebook/data2vec-audio-base",
"deberta": "hf-internal-testing/tiny-random-DebertaModel", # Not using microsoft/deberta-base because it takes too much time for testing.
"deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model", # Not using microsoft/deberta-v2-xlarge because it takes too much time for testing.
"deit": "facebook/deit-small-patch16-224",
"convnext": "facebook/convnext-tiny-224",
"detr": "hf-internal-testing/tiny-random-detr", # Not using facebook/detr-resnet-50 because it takes too much time for testing.
"distilbert": "distilbert-base-cased",
"electra": "google/electra-base-generator",
"encoder-decoder": "patrickvonplaten/bert2bert_cnn_daily_mail",
"flaubert": "hf-internal-testing/tiny-random-flaubert", # TODO
"gpt2": "gpt2",
"gpt-neo": "EleutherAI/gpt-neo-125M",
Expand Down
11 changes: 10 additions & 1 deletion tests/exporters/onnx/test_exporters_onnx_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def _get_models_to_test(export_models_dict: Dict):

for model_name, tasks in model_tasks.items():
for task in tasks:
if model_type == "encoder-decoder" and task == "text2text-generation-with-past":
# The model uses bert as decoder and does not support past key values
continue
onnx_config_class = TasksManager.get_exporter_config_constructor(
"onnx", task=task, model_type=model_type
)
Expand Down Expand Up @@ -117,7 +120,13 @@ def _get_models_to_test(export_models_dict: Dict):
# TODO: segformer task can not be automatically inferred
# TODO: xlm-roberta model auto-infers text-generation, but we don't support it
# TODO: perceiver auto-infers default, but we don't support it (why?)
if model_type not in ["segformer", "xlm-roberta", "perceiver", "vision-encoder-decoder"]:
# TODO: encoder-decoder auto-infers text3text-generation, but it uses bert as decoder and does not support past key values
if model_type not in [
"segformer",
"xlm-roberta",
"perceiver",
"encoder-decoder",
]:
models_to_test.append(
(f"{model_type}_no_task", model_type, model_name, "auto", "default", False, False)
)
Expand Down
4 changes: 4 additions & 0 deletions tests/exporters/onnx/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def _get_models_to_test(export_models_dict: Dict):

for model_name, tasks in model_tasks.items():
for task in tasks:
if model_type == "encoder-decoder" and task == "seq2seq-lm-with-past":
mht-sharma marked this conversation as resolved.
Show resolved Hide resolved
# The model uses bert as decoder and does not support past key values
continue

onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model_type=model_type, exporter="onnx", task=task, model_name=model_name
)
Expand Down
Loading
Loading