Skip to content

Commit

Permalink
Swin2sr onnx (#1492)
Browse files Browse the repository at this point in the history
* Add ONNX export support for swin2SR models

* Add feature extraction task

* Add testing model

* Add Swin2srOnnxConfig class

* Update optimum/exporters/tasks.py

---------

Co-authored-by: fxmarty <[email protected]>
  • Loading branch information
baskrahmer and fxmarty authored Nov 9, 2023
1 parent 18ab883 commit 41347fc
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 0 deletions.
3 changes: 3 additions & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ class OnnxConfig(ExportConfig, ABC):
"image-classification": OrderedDict({"logits": {0: "batch_size"}}),
"image-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}),
"image-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"image-to-image": OrderedDict(
{"reconstruction": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}}
),
"mask-generation": OrderedDict({"logits": {0: "batch_size"}}),
"masked-im": OrderedDict(
{"reconstruction" if check_if_transformers_greater("4.29.0") else "logits": {0: "batch_size"}}
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,10 @@ class SwinOnnxConfig(ViTOnnxConfig):
pass


class Swin2srOnnxConfig(SwinOnnxConfig):
pass


class PoolFormerOnnxConfig(ViTOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
ATOL_FOR_VALIDATION = 2e-3
Expand Down
6 changes: 6 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class TasksManager:
"fill-mask": "AutoModelForMaskedLM",
"image-classification": "AutoModelForImageClassification",
"image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"),
"image-to-image": "AutoModelForImageToImage",
"image-to-text": "AutoModelForVision2Seq",
"mask-generation": "AutoModel",
"masked-im": "AutoModelForMaskedImageModeling",
Expand Down Expand Up @@ -884,6 +885,11 @@ class TasksManager:
"masked-im",
onnx="SwinOnnxConfig",
),
"swin2sr": supported_tasks_mapping(
"feature-extraction",
"image-to-image",
onnx="Swin2srOnnxConfig",
),
"t5": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
1 change: 1 addition & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
"splinter": "hf-internal-testing/tiny-random-SplinterModel",
"squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel",
"swin": "hf-internal-testing/tiny-random-SwinModel",
"swin2sr": "hf-internal-testing/tiny-random-Swin2SRModel",
"t5": "hf-internal-testing/tiny-random-t5",
"vit": "hf-internal-testing/tiny-random-vit",
"yolos": "hf-internal-testing/tiny-random-YolosModel",
Expand Down

0 comments on commit 41347fc

Please sign in to comment.