Skip to content

Commit

Permalink
New model support RTDETR (huggingface#29077)
Browse files Browse the repository at this point in the history
* fill out docs string in configuration
https://github.com/huggingface/transformers/pull/29077/files/75dcd3a0e82cca36f12178b65bbd071ab7b25088#r1506391856

* reduce the input image size for the tests

* remove the unappropriate tests

* only 5 failes exists

* make style

* fill up missed architecture for object detection in docs

* fix auto modeling

* simple fix in missing import

* major change including backbone refactor and objectdetectionoutput refactor

* minor fix only 4 fails left

* intermediate fix

* revert __init__.py

* revert __init__.py

* make style

* fixes in pr_docs

* intermediate fix

* make style

* two fixes

* pass doctest

* only one fix left

* intermediate commit

* all fixed

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/rt_detr/convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* Update tests/models/rt_detr/test_modeling_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* function class above the model definition in dice_loss

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* simple fix

* layernorm add config.layer_norm_eps

* fix inputs_docstring

* make style

* simple fix

* add custom coco loading test in image_processor

* fix error in BaseModelOutput
huggingface#29077 (comment)

* simple typo

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* intermediate fix

* fix with load_backbone format

* remove unused configuration

* 3 fix test left

* make style

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: Sounak Dey <[email protected]>

* change last_hidden_state to first index

* all pass fix
TO DO: minor update in comments

* make fix-copies

* remove deepcopy

* pr_document fix

* revert deepcopy due to the issue of unexpceted behavior in decoderlayer

* add atol in final

* add no_split_module

* _no_split_modules = None

* device transfer for model parallelism

* minor fix

* make fix-copies

* fix typo

* add test_image_processor with post_processing

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* add config in RTDETRPredictionHead

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* set lru_cache with max_size 32

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* add lru_cache import and configuration change

* change the order of definition

* make fix-copies

* add docs and change config error

* revert strange make-fix

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* test pass

* fix get_clones related and remove deepcopy

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* nit for paper section

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* rename denoising related parameters

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* check the image transformation logic

* make style

* make style

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* pe_encoding -> positional_encoding_temperature

* remove TODO

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* remove eval_idx since transformer DETR is giving all decoder output

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* change variable name

* make style and docs import update

* Revert "Update src/transformers/models/rt_detr/image_processing_rt_detr.py"

This reverts commit 74aa3e1.

* fix typo

* add postprocessing in docs

* move import scipy to top

* change varaible name

* make fix-copies

* remove eval_idx in test

* move to after first sentence

* update image_processor since box loss requires normalized one

* change appropriate name to auxiliary_outputs

* Update src/transformers/models/rt_detr/__init__.py

Co-authored-by: NielsRogge <[email protected]>

* Update src/transformers/models/rt_detr/__init__.py

Co-authored-by: NielsRogge <[email protected]>

* Update docs/source/en/model_doc/rt_detr.md

Co-authored-by: NielsRogge <[email protected]>

* Update docs/source/en/model_doc/rt_detr.md

Co-authored-by: NielsRogge <[email protected]>

* make style

* remove panoptic related comments

* make style

* revert valid_processor_keys

* fix aux related test

* make style

* change origination from config to backbone API

* enable the dn_loss

* fix test and conversion

* renewal weight initialization

* change initializer_range

* make fix-up

* fix the loss issue in the auxiliary output and denoising part

* change weight loss to original RTDETR

* fix in initialization

* sync shape format of dn and aux

* make style

* stable fine-tuning and compatible conversion for resnet101

* make style

* skip input_embed

* change encoder related variable

* enable converting rtdetr_r101

* add r101 related conversion code

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* Update docs/source/en/model_doc/rt_detr.md

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/__init__.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/__init__.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* change name _shape to _reshape

* Update src/transformers/__init__.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/__init__.py

Co-authored-by: amyeroberts <[email protected]>

* maket style

* make fix-copies

* remove deprecated import

* more fix

* remove last_hidden_state for task-specific model

* Revert "remove last_hidden_state for task-specific model"

This reverts commit ccb7a34.

* minore change in convert

* remove print

* make style and fix-copies

* add custom rtdetr backbone for r18, r34

* remove print

* change copied

* add pad_size

* make style

* change layertype to optional to pass the CI

* make style

* add test in modeling_resnet_rt_detr

* make fix-copies

* skip tmp file test

* fix comment

* add docs

* change to modeling_resnet file format

* enabling resnet50 above

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: Jason Wu <[email protected]>

* enable all the rtdetr model :)

* finish except CI

* add RTDetrResNetBackbone

* make fix-copies

* fix
TO DO: CI enable

* make style

* rename test

* add docs

* add special fix

* revert resnet

* Update src/transformers/models/rt_detr/modeling_rt_detr_resnet.py

Co-authored-by: NielsRogge <[email protected]>

* add more comment

* remove swin comment

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: NielsRogge <[email protected]>

* rename convert and add verify backbone

* Update docs/source/en/_toctree.yml

Co-authored-by: NielsRogge <[email protected]>

* Update docs/source/en/model_doc/rt_detr.md

Co-authored-by: NielsRogge <[email protected]>

* Update docs/source/en/model_doc/rt_detr.md

Co-authored-by: NielsRogge <[email protected]>

* make style

* requests for docs

* more general test docs

* general script docs

* make fix-copies

* final commit

* Revert "Update src/transformers/models/rt_detr/configuration_rt_detr.py"

This reverts commit d136225.

* skip test_model_get_set_embeddings

* remove target

* add changes

* make fix-copies

* remove decoder_attention_mask

* add load_backbone function for auto_backbone

* remove comment

* fix repo name

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: amyeroberts <[email protected]>

* final commit

* remove unused downsample_in_bottleneck

* new test for autobackbone

* change to appropriate indices

* test fix

* fix dict in test_image_processor

* fix test

* [run-slow] rt_detr, rt_detr_resnet

* change the slow test

* [run-slow] rt_detr

* [run-slow] rt_detr, rt_detr_resnet

* make in to same cuda in CSPRepLayer

* [run-slow] rt_detr, rt_detr_resnet

---------

Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: Sounak Dey <[email protected]>
Co-authored-by: NielsRogge <[email protected]>
Co-authored-by: Jason Wu <[email protected]>
Co-authored-by: ChoiSangBum <[email protected]>
  • Loading branch information
6 people authored Jun 21, 2024
1 parent 8b7cd40 commit 74a2074
Show file tree
Hide file tree
Showing 24 changed files with 6,892 additions and 9 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,8 @@
title: RegNet
- local: model_doc/resnet
title: ResNet
- local: model_doc/rt_detr
title: RT-DETR
- local: model_doc/segformer
title: SegFormer
- local: model_doc/seggpt
Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ Flax), PyTorch, and/or TensorFlow.
| [RoBERTa-PreLayerNorm](model_doc/roberta-prelayernorm) ||||
| [RoCBert](model_doc/roc_bert) ||||
| [RoFormer](model_doc/roformer) ||||
| [RT-DETR](model_doc/rt_detr) ||||
| [RT-DETR-ResNet](model_doc/rt_detr_resnet) ||||
| [RWKV](model_doc/rwkv) ||||
| [SAM](model_doc/sam) ||||
| [SeamlessM4T](model_doc/seamless_m4t) ||||
Expand Down
85 changes: 85 additions & 0 deletions docs/source/en/model_doc/rt_detr.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# RT-DETR

## Overview


The RT-DETR model was proposed in [DETRs Beat YOLOs on Real-time Object Detection](https://arxiv.org/abs/2304.08069) by Wenyu Lv, Yian Zhao, Shangliang Xu, Jinman Wei, Guanzhong Wang, Cheng Cui, Yuning Du, Qingqing Dang, Yi Liu.

RT-DETR is an object detection model that stands for "Real-Time DEtection Transformer." This model is designed to perform object detection tasks with a focus on achieving real-time performance while maintaining high accuracy. Leveraging the transformer architecture, which has gained significant popularity in various fields of deep learning, RT-DETR processes images to identify and locate multiple objects within them.

The abstract from the paper is the following:

*Recently, end-to-end transformer-based detectors (DETRs) have achieved remarkable performance. However, the issue of the high computational cost of DETRs has not been effectively addressed, limiting their practical application and preventing them from fully exploiting the benefits of no post-processing, such as non-maximum suppression (NMS). In this paper, we first analyze the influence of NMS in modern real-time object detectors on inference speed, and establish an end-to-end speed benchmark. To avoid the inference delay caused by NMS, we propose a Real-Time DEtection TRansformer (RT-DETR), the first real-time end-to-end object detector to our best knowledge. Specifically, we design an efficient hybrid encoder to efficiently process multi-scale features by decoupling the intra-scale interaction and cross-scale fusion, and propose IoU-aware query selection to improve the initialization of object queries. In addition, our proposed detector supports flexibly adjustment of the inference speed by using different decoder layers without the need for retraining, which facilitates the practical application of real-time object detectors. Our RT-DETR-L achieves 53.0% AP on COCO val2017 and 114 FPS on T4 GPU, while RT-DETR-X achieves 54.8% AP and 74 FPS, outperforming all YOLO detectors of the same scale in both speed and accuracy. Furthermore, our RT-DETR-R50 achieves 53.1% AP and 108 FPS, outperforming DINO-Deformable-DETR-R50 by 2.2% AP in accuracy and by about 21 times in FPS.*

The model version was contributed by [rafaelpadilla](https://huggingface.co/rafaelpadilla) and [sangbumchoi](https://github.com/SangbumChoi). The original code can be found [here](https://github.com/lyuwenyu/RT-DETR/).


## Usage tips

Initially, an image is processed using a pre-trained convolutional neural network, specifically a Resnet-D variant as referenced in the original code. This network extracts features from the final three layers of the architecture. Following this, a hybrid encoder is employed to convert the multi-scale features into a sequential array of image features. Then, a decoder, equipped with auxiliary prediction heads is used to refine the object queries. This process facilitates the direct generation of bounding boxes, eliminating the need for any additional post-processing to acquire the logits and coordinates for the bounding boxes.

```py
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
from PIL import Image
import json
import torch
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")

inputs = image_processor(images=image, return_tensors="pt")

with torch.no_grad():
outputs = model(**inputs)

results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]), threshold=0.3)
```

## RTDetrConfig

[[autodoc]] RTDetrConfig

## RTDetrResNetConfig

[[autodoc]] RTDetrResNetConfig

## RTDetrImageProcessor

[[autodoc]] RTDetrImageProcessor
- preprocess
- post_process_object_detection

## RTDetrModel

[[autodoc]] RTDetrModel
- forward

## RTDetrForObjectDetection

[[autodoc]] RTDetrForObjectDetection
- forward

## RTDetrResNetBackbone

[[autodoc]] RTDetrResNetBackbone
- forward
23 changes: 23 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@
"RoFormerConfig",
"RoFormerTokenizer",
],
"models.rt_detr": ["RTDetrConfig", "RTDetrResNetConfig"],
"models.rwkv": ["RwkvConfig"],
"models.sam": [
"SamConfig",
Expand Down Expand Up @@ -1153,6 +1154,7 @@
_import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"])
_import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
_import_structure["models.pvt"].extend(["PvtImageProcessor"])
_import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"])
_import_structure["models.sam"].extend(["SamImageProcessor"])
_import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"])
_import_structure["models.seggpt"].extend(["SegGptImageProcessor"])
Expand Down Expand Up @@ -3004,6 +3006,15 @@
"load_tf_weights_in_roformer",
]
)
_import_structure["models.rt_detr"].extend(
[
"RTDetrForObjectDetection",
"RTDetrModel",
"RTDetrPreTrainedModel",
"RTDetrResNetBackbone",
"RTDetrResNetPreTrainedModel",
]
)
_import_structure["models.rwkv"].extend(
[
"RwkvForCausalLM",
Expand Down Expand Up @@ -5270,6 +5281,10 @@
RoFormerConfig,
RoFormerTokenizer,
)
from .models.rt_detr import (
RTDetrConfig,
RTDetrResNetConfig,
)
from .models.rwkv import RwkvConfig
from .models.sam import (
SamConfig,
Expand Down Expand Up @@ -5792,6 +5807,7 @@
PoolFormerImageProcessor,
)
from .models.pvt import PvtImageProcessor
from .models.rt_detr import RTDetrImageProcessor
from .models.sam import SamImageProcessor
from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor
from .models.seggpt import SegGptImageProcessor
Expand Down Expand Up @@ -7295,6 +7311,13 @@
RoFormerPreTrainedModel,
load_tf_weights_in_roformer,
)
from .models.rt_detr import (
RTDetrForObjectDetection,
RTDetrModel,
RTDetrPreTrainedModel,
RTDetrResNetBackbone,
RTDetrResNetPreTrainedModel,
)
from .models.rwkv import (
RwkvForCausalLM,
RwkvModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@
roberta_prelayernorm,
roc_bert,
roformer,
rt_detr,
rwkv,
sam,
seamless_m4t,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@
("roberta-prelayernorm", "RobertaPreLayerNormConfig"),
("roc_bert", "RoCBertConfig"),
("roformer", "RoFormerConfig"),
("rt_detr", "RTDetrConfig"),
("rt_detr_resnet", "RTDetrResNetConfig"),
("rwkv", "RwkvConfig"),
("sam", "SamConfig"),
("seamless_m4t", "SeamlessM4TConfig"),
Expand Down Expand Up @@ -499,6 +501,8 @@
("roberta-prelayernorm", "RoBERTa-PreLayerNorm"),
("roc_bert", "RoCBert"),
("roformer", "RoFormer"),
("rt_detr", "RT-DETR"),
("rt_detr_resnet", "RT-DETR-ResNet"),
("rwkv", "RWKV"),
("sam", "SAM"),
("seamless_m4t", "SeamlessM4T"),
Expand Down Expand Up @@ -623,6 +627,7 @@
("clip_vision_model", "clip"),
("siglip_vision_model", "siglip"),
("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"),
]
)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
("pvt_v2", ("PvtImageProcessor",)),
("regnet", ("ConvNextImageProcessor",)),
("resnet", ("ConvNextImageProcessor",)),
("rt_detr", "RTDetrImageProcessor"),
("sam", ("SamImageProcessor",)),
("segformer", ("SegformerImageProcessor",)),
("seggpt", ("SegGptImageProcessor",)),
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@
("roberta-prelayernorm", "RobertaPreLayerNormModel"),
("roc_bert", "RoCBertModel"),
("roformer", "RoFormerModel"),
("rt_detr", "RTDetrModel"),
("rwkv", "RwkvModel"),
("sam", "SamModel"),
("seamless_m4t", "SeamlessM4TModel"),
Expand Down Expand Up @@ -765,6 +766,7 @@
("deformable_detr", "DeformableDetrForObjectDetection"),
("deta", "DetaForObjectDetection"),
("detr", "DetrForObjectDetection"),
("rt_detr", "RTDetrForObjectDetection"),
("table-transformer", "TableTransformerForObjectDetection"),
("yolos", "YolosForObjectDetection"),
]
Expand Down Expand Up @@ -1252,6 +1254,7 @@
("nat", "NatBackbone"),
("pvt_v2", "PvtV2Backbone"),
("resnet", "ResNetBackbone"),
("rt_detr_resnet", "RTDetrResNetBackbone"),
("swin", "SwinBackbone"),
("swinv2", "Swinv2Backbone"),
("timm_backbone", "TimmBackbone"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,24 @@
from torch.autograd.function import once_differentiable

from ...activations import ACT2FN
from ...file_utils import (
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_ninja_available,
is_scipy_available,
is_timm_available,
is_torch_cuda_available,
is_vision_available,
logging,
replace_return_docstrings,
requires_backends,
)
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import is_accelerate_available, is_ninja_available, logging
from ...utils.backbone_utils import load_backbone
from .configuration_deformable_detr import DeformableDetrConfig

Expand Down
78 changes: 78 additions & 0 deletions src/transformers/models/rt_detr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available


_import_structure = {"configuration_rt_detr": ["RTDetrConfig"], "configuration_rt_detr_resnet": ["RTDetrResNetConfig"]}

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_rt_detr"] = ["RTDetrImageProcessor"]

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_rt_detr"] = [
"RTDetrForObjectDetection",
"RTDetrModel",
"RTDetrPreTrainedModel",
]
_import_structure["modeling_rt_detr_resnet"] = [
"RTDetrResNetBackbone",
"RTDetrResNetPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_rt_detr import RTDetrConfig
from .configuration_rt_detr_resnet import RTDetrResNetConfig

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_rt_detr import RTDetrImageProcessor

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_rt_detr import (
RTDetrForObjectDetection,
RTDetrModel,
RTDetrPreTrainedModel,
)
from .modeling_rt_detr_resnet import (
RTDetrResNetBackbone,
RTDetrResNetPreTrainedModel,
)

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Loading

0 comments on commit 74a2074

Please sign in to comment.