From 41347fc1b4897045483d61ebf0c8192af1d92d97 Mon Sep 17 00:00:00 2001 From: Bas Krahmer Date: Thu, 9 Nov 2023 16:54:38 +0100 Subject: [PATCH] Swin2sr onnx (#1492) * Add ONNX export support for swin2SR models * Add feature extraction task * Add testing model * Add Swin2srOnnxConfig class * Update optimum/exporters/tasks.py --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> --- optimum/exporters/onnx/base.py | 3 +++ optimum/exporters/onnx/model_configs.py | 4 ++++ optimum/exporters/tasks.py | 6 ++++++ tests/exporters/exporters_utils.py | 1 + 4 files changed, 14 insertions(+) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 6765f3310c..2958d3d920 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -152,6 +152,9 @@ class OnnxConfig(ExportConfig, ABC): "image-classification": OrderedDict({"logits": {0: "batch_size"}}), "image-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}), "image-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), + "image-to-image": OrderedDict( + {"reconstruction": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}} + ), "mask-generation": OrderedDict({"logits": {0: "batch_size"}}), "masked-im": OrderedDict( {"reconstruction" if check_if_transformers_greater("4.29.0") else "logits": {0: "batch_size"}} diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index b5d67e5040..a330e23bfb 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -739,6 +739,10 @@ class SwinOnnxConfig(ViTOnnxConfig): pass +class Swin2srOnnxConfig(SwinOnnxConfig): + pass + + class PoolFormerOnnxConfig(ViTOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig ATOL_FOR_VALIDATION = 2e-3 diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 46ee59be34..2208a6fe26 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -168,6 +168,7 @@ class TasksManager: "fill-mask": "AutoModelForMaskedLM", "image-classification": "AutoModelForImageClassification", "image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"), + "image-to-image": "AutoModelForImageToImage", "image-to-text": "AutoModelForVision2Seq", "mask-generation": "AutoModel", "masked-im": "AutoModelForMaskedImageModeling", @@ -884,6 +885,11 @@ class TasksManager: "masked-im", onnx="SwinOnnxConfig", ), + "swin2sr": supported_tasks_mapping( + "feature-extraction", + "image-to-image", + onnx="Swin2srOnnxConfig", + ), "t5": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 105a5a7d77..e0640b7657 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -120,6 +120,7 @@ "splinter": "hf-internal-testing/tiny-random-SplinterModel", "squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel", "swin": "hf-internal-testing/tiny-random-SwinModel", + "swin2sr": "hf-internal-testing/tiny-random-Swin2SRModel", "t5": "hf-internal-testing/tiny-random-t5", "vit": "hf-internal-testing/tiny-random-vit", "yolos": "hf-internal-testing/tiny-random-YolosModel",