diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 1e5704e893..d6e9f69fbd 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -152,6 +152,9 @@ class OnnxConfig(ExportConfig, ABC): "image-classification": OrderedDict({"logits": {0: "batch_size"}}), "image-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}), "image-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), + "image-to-image": OrderedDict( + {"reconstruction": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}} + ), "mask-generation": OrderedDict({"logits": {0: "batch_size"}}), "masked-im": OrderedDict( {"reconstruction" if check_if_transformers_greater("4.29.0") else "logits": {0: "batch_size"}} diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index e1461c2a0c..04e4c0b11f 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -770,6 +770,10 @@ class SwinOnnxConfig(ViTOnnxConfig): pass +class Swin2srOnnxConfig(SwinOnnxConfig): + pass + + class PoolFormerOnnxConfig(ViTOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig ATOL_FOR_VALIDATION = 2e-3 diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 2a0f9076ce..c26dc98da7 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -168,6 +168,7 @@ class TasksManager: "fill-mask": "AutoModelForMaskedLM", "image-classification": "AutoModelForImageClassification", "image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"), + "image-to-image": "AutoModelForImageToImage", "image-to-text": "AutoModelForVision2Seq", "mask-generation": "AutoModel", "masked-im": "AutoModelForMaskedImageModeling", @@ -884,6 +885,11 @@ class TasksManager: "masked-im", onnx="SwinOnnxConfig", ), + "swin2sr": supported_tasks_mapping( + "feature-extraction", + "image-to-image", + onnx="Swin2srOnnxConfig", + ), "t5": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 105a5a7d77..e0640b7657 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -120,6 +120,7 @@ "splinter": "hf-internal-testing/tiny-random-SplinterModel", "squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel", "swin": "hf-internal-testing/tiny-random-SwinModel", + "swin2sr": "hf-internal-testing/tiny-random-Swin2SRModel", "t5": "hf-internal-testing/tiny-random-t5", "vit": "hf-internal-testing/tiny-random-vit", "yolos": "hf-internal-testing/tiny-random-YolosModel",