From 74a207404e8d4524d1fdc4aa23789694f9eef347 Mon Sep 17 00:00:00 2001 From: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> Date: Sat, 22 Jun 2024 01:50:08 +0900 Subject: [PATCH] New model support RTDETR (#29077) * fill out docs string in configuration https://github.com/huggingface/transformers/pull/29077/files/75dcd3a0e82cca36f12178b65bbd071ab7b25088#r1506391856 * reduce the input image size for the tests * remove the unappropriate tests * only 5 failes exists * make style * fill up missed architecture for object detection in docs * fix auto modeling * simple fix in missing import * major change including backbone refactor and objectdetectionoutput refactor * minor fix only 4 fails left * intermediate fix * revert __init__.py * revert __init__.py * make style * fixes in pr_docs * intermediate fix * make style * two fixes * pass doctest * only one fix left * intermediate commit * all fixed * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/rt_detr/test_modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * function class above the model definition in dice_loss * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * simple fix * layernorm add config.layer_norm_eps * fix inputs_docstring * make style * simple fix * add custom coco loading test in image_processor * fix error in BaseModelOutput https://github.com/huggingface/transformers/pull/29077#discussion_r1516657790 * simple typo * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * intermediate fix * fix with load_backbone format * remove unused configuration * 3 fix test left * make style * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: Sounak Dey * change last_hidden_state to first index * all pass fix TO DO: minor update in comments * make fix-copies * remove deepcopy * pr_document fix * revert deepcopy due to the issue of unexpceted behavior in decoderlayer * add atol in final * add no_split_module * _no_split_modules = None * device transfer for model parallelism * minor fix * make fix-copies * fix typo * add test_image_processor with post_processing * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add config in RTDETRPredictionHead * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set lru_cache with max_size 32 * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add lru_cache import and configuration change * change the order of definition * make fix-copies * add docs and change config error * revert strange make-fix * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * test pass * fix get_clones related and remove deepcopy * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * nit for paper section * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * rename denoising related parameters * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * check the image transformation logic * make style * make style * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * pe_encoding -> positional_encoding_temperature * remove TODO * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * remove eval_idx since transformer DETR is giving all decoder output * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * change variable name * make style and docs import update * Revert "Update src/transformers/models/rt_detr/image_processing_rt_detr.py" This reverts commit 74aa3e1de0ca0cd3d354161d38ef28b4389c0eee. * fix typo * add postprocessing in docs * move import scipy to top * change varaible name * make fix-copies * remove eval_idx in test * move to after first sentence * update image_processor since box loss requires normalized one * change appropriate name to auxiliary_outputs * Update src/transformers/models/rt_detr/__init__.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/__init__.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/rt_detr.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/rt_detr.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * make style * remove panoptic related comments * make style * revert valid_processor_keys * fix aux related test * make style * change origination from config to backbone API * enable the dn_loss * fix test and conversion * renewal weight initialization * change initializer_range * make fix-up * fix the loss issue in the auxiliary output and denoising part * change weight loss to original RTDETR * fix in initialization * sync shape format of dn and aux * make style * stable fine-tuning and compatible conversion for resnet101 * make style * skip input_embed * change encoder related variable * enable converting rtdetr_r101 * add r101 related conversion code * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/rt_detr.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * change name _shape to _reshape * Update src/transformers/__init__.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * maket style * make fix-copies * remove deprecated import * more fix * remove last_hidden_state for task-specific model * Revert "remove last_hidden_state for task-specific model" This reverts commit ccb7a34051d69b9fc7aa17ed8644664d3fdbdaca. * minore change in convert * remove print * make style and fix-copies * add custom rtdetr backbone for r18, r34 * remove print * change copied * add pad_size * make style * change layertype to optional to pass the CI * make style * add test in modeling_resnet_rt_detr * make fix-copies * skip tmp file test * fix comment * add docs * change to modeling_resnet file format * enabling resnet50 above * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: Jason Wu * enable all the rtdetr model :) * finish except CI * add RTDetrResNetBackbone * make fix-copies * fix TO DO: CI enable * make style * rename test * add docs * add special fix * revert resnet * Update src/transformers/models/rt_detr/modeling_rt_detr_resnet.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * add more comment * remove swin comment * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * rename convert and add verify backbone * Update docs/source/en/_toctree.yml Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/rt_detr.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/rt_detr.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * make style * requests for docs * more general test docs * general script docs * make fix-copies * final commit * Revert "Update src/transformers/models/rt_detr/configuration_rt_detr.py" This reverts commit d136225cd3f64f510d303ce1d227698174f43fff. * skip test_model_get_set_embeddings * remove target * add changes * make fix-copies * remove decoder_attention_mask * add load_backbone function for auto_backbone * remove comment * fix repo name * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * final commit * remove unused downsample_in_bottleneck * new test for autobackbone * change to appropriate indices * test fix * fix dict in test_image_processor * fix test * [run-slow] rt_detr, rt_detr_resnet * change the slow test * [run-slow] rt_detr * [run-slow] rt_detr, rt_detr_resnet * make in to same cuda in CSPRepLayer * [run-slow] rt_detr, rt_detr_resnet --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sounak Dey Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Jason Wu Co-authored-by: ChoiSangBum --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 2 + docs/source/en/model_doc/rt_detr.md | 85 + src/transformers/__init__.py | 23 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 5 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 3 + .../modeling_deformable_detr.py | 14 +- src/transformers/models/rt_detr/__init__.py | 78 + .../models/rt_detr/configuration_rt_detr.py | 352 +++ .../rt_detr/configuration_rt_detr_resnet.py | 111 + ..._detr_original_pytorch_checkpoint_to_hf.py | 782 +++++ .../rt_detr/image_processing_rt_detr.py | 1120 +++++++ .../models/rt_detr/modeling_rt_detr.py | 2675 +++++++++++++++++ .../models/rt_detr/modeling_rt_detr_resnet.py | 426 +++ .../timm_backbone/modeling_timm_backbone.py | 4 +- src/transformers/utils/backbone_utils.py | 1 - src/transformers/utils/dummy_pt_objects.py | 35 + .../utils/dummy_vision_objects.py | 7 + tests/models/rt_detr/__init__.py | 0 .../rt_detr/test_image_processing_rt_detr.py | 364 +++ tests/models/rt_detr/test_modeling_rt_detr.py | 680 +++++ .../rt_detr/test_modeling_rt_detr_resnet.py | 130 + 24 files changed, 6892 insertions(+), 9 deletions(-) create mode 100644 docs/source/en/model_doc/rt_detr.md create mode 100644 src/transformers/models/rt_detr/__init__.py create mode 100644 src/transformers/models/rt_detr/configuration_rt_detr.py create mode 100644 src/transformers/models/rt_detr/configuration_rt_detr_resnet.py create mode 100644 src/transformers/models/rt_detr/convert_rt_detr_original_pytorch_checkpoint_to_hf.py create mode 100644 src/transformers/models/rt_detr/image_processing_rt_detr.py create mode 100644 src/transformers/models/rt_detr/modeling_rt_detr.py create mode 100644 src/transformers/models/rt_detr/modeling_rt_detr_resnet.py create mode 100644 tests/models/rt_detr/__init__.py create mode 100644 tests/models/rt_detr/test_image_processing_rt_detr.py create mode 100644 tests/models/rt_detr/test_modeling_rt_detr.py create mode 100644 tests/models/rt_detr/test_modeling_rt_detr_resnet.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index be3001dc761a90..88b82c890522f8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -627,6 +627,8 @@ title: RegNet - local: model_doc/resnet title: ResNet + - local: model_doc/rt_detr + title: RT-DETR - local: model_doc/segformer title: SegFormer - local: model_doc/seggpt diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 72237d13839569..09724c755689fe 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -262,6 +262,8 @@ Flax), PyTorch, and/or TensorFlow. | [RoBERTa-PreLayerNorm](model_doc/roberta-prelayernorm) | ✅ | ✅ | ✅ | | [RoCBert](model_doc/roc_bert) | ✅ | ❌ | ❌ | | [RoFormer](model_doc/roformer) | ✅ | ✅ | ✅ | +| [RT-DETR](model_doc/rt_detr) | ✅ | ❌ | ❌ | +| [RT-DETR-ResNet](model_doc/rt_detr_resnet) | ✅ | ❌ | ❌ | | [RWKV](model_doc/rwkv) | ✅ | ❌ | ❌ | | [SAM](model_doc/sam) | ✅ | ✅ | ❌ | | [SeamlessM4T](model_doc/seamless_m4t) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/rt_detr.md b/docs/source/en/model_doc/rt_detr.md new file mode 100644 index 00000000000000..11f1b795daa285 --- /dev/null +++ b/docs/source/en/model_doc/rt_detr.md @@ -0,0 +1,85 @@ + + +# RT-DETR + +## Overview + + +The RT-DETR model was proposed in [DETRs Beat YOLOs on Real-time Object Detection](https://arxiv.org/abs/2304.08069) by Wenyu Lv, Yian Zhao, Shangliang Xu, Jinman Wei, Guanzhong Wang, Cheng Cui, Yuning Du, Qingqing Dang, Yi Liu. + +RT-DETR is an object detection model that stands for "Real-Time DEtection Transformer." This model is designed to perform object detection tasks with a focus on achieving real-time performance while maintaining high accuracy. Leveraging the transformer architecture, which has gained significant popularity in various fields of deep learning, RT-DETR processes images to identify and locate multiple objects within them. + +The abstract from the paper is the following: + +*Recently, end-to-end transformer-based detectors (DETRs) have achieved remarkable performance. However, the issue of the high computational cost of DETRs has not been effectively addressed, limiting their practical application and preventing them from fully exploiting the benefits of no post-processing, such as non-maximum suppression (NMS). In this paper, we first analyze the influence of NMS in modern real-time object detectors on inference speed, and establish an end-to-end speed benchmark. To avoid the inference delay caused by NMS, we propose a Real-Time DEtection TRansformer (RT-DETR), the first real-time end-to-end object detector to our best knowledge. Specifically, we design an efficient hybrid encoder to efficiently process multi-scale features by decoupling the intra-scale interaction and cross-scale fusion, and propose IoU-aware query selection to improve the initialization of object queries. In addition, our proposed detector supports flexibly adjustment of the inference speed by using different decoder layers without the need for retraining, which facilitates the practical application of real-time object detectors. Our RT-DETR-L achieves 53.0% AP on COCO val2017 and 114 FPS on T4 GPU, while RT-DETR-X achieves 54.8% AP and 74 FPS, outperforming all YOLO detectors of the same scale in both speed and accuracy. Furthermore, our RT-DETR-R50 achieves 53.1% AP and 108 FPS, outperforming DINO-Deformable-DETR-R50 by 2.2% AP in accuracy and by about 21 times in FPS.* + +The model version was contributed by [rafaelpadilla](https://huggingface.co/rafaelpadilla) and [sangbumchoi](https://github.com/SangbumChoi). The original code can be found [here](https://github.com/lyuwenyu/RT-DETR/). + + +## Usage tips + +Initially, an image is processed using a pre-trained convolutional neural network, specifically a Resnet-D variant as referenced in the original code. This network extracts features from the final three layers of the architecture. Following this, a hybrid encoder is employed to convert the multi-scale features into a sequential array of image features. Then, a decoder, equipped with auxiliary prediction heads is used to refine the object queries. This process facilitates the direct generation of bounding boxes, eliminating the need for any additional post-processing to acquire the logits and coordinates for the bounding boxes. + +```py +from transformers import RTDetrForObjectDetection, RTDetrImageProcessor +from PIL import Image +import json +import torch +import requests + +url = 'http://images.cocodataset.org/val2017/000000039769.jpg' +image = Image.open(requests.get(url, stream=True).raw) + +image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd") +model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd") + +inputs = image_processor(images=image, return_tensors="pt") + +with torch.no_grad(): + outputs = model(**inputs) + +results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]), threshold=0.3) +``` + +## RTDetrConfig + +[[autodoc]] RTDetrConfig + +## RTDetrResNetConfig + +[[autodoc]] RTDetrResNetConfig + +## RTDetrImageProcessor + +[[autodoc]] RTDetrImageProcessor + - preprocess + - post_process_object_detection + +## RTDetrModel + +[[autodoc]] RTDetrModel + - forward + +## RTDetrForObjectDetection + +[[autodoc]] RTDetrForObjectDetection + - forward + +## RTDetrResNetBackbone + +[[autodoc]] RTDetrResNetBackbone + - forward \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4976a4a1b90e7e..fd7f3c1cf7901f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -654,6 +654,7 @@ "RoFormerConfig", "RoFormerTokenizer", ], + "models.rt_detr": ["RTDetrConfig", "RTDetrResNetConfig"], "models.rwkv": ["RwkvConfig"], "models.sam": [ "SamConfig", @@ -1153,6 +1154,7 @@ _import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"]) _import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"]) _import_structure["models.pvt"].extend(["PvtImageProcessor"]) + _import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"]) _import_structure["models.sam"].extend(["SamImageProcessor"]) _import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"]) _import_structure["models.seggpt"].extend(["SegGptImageProcessor"]) @@ -3004,6 +3006,15 @@ "load_tf_weights_in_roformer", ] ) + _import_structure["models.rt_detr"].extend( + [ + "RTDetrForObjectDetection", + "RTDetrModel", + "RTDetrPreTrainedModel", + "RTDetrResNetBackbone", + "RTDetrResNetPreTrainedModel", + ] + ) _import_structure["models.rwkv"].extend( [ "RwkvForCausalLM", @@ -5270,6 +5281,10 @@ RoFormerConfig, RoFormerTokenizer, ) + from .models.rt_detr import ( + RTDetrConfig, + RTDetrResNetConfig, + ) from .models.rwkv import RwkvConfig from .models.sam import ( SamConfig, @@ -5792,6 +5807,7 @@ PoolFormerImageProcessor, ) from .models.pvt import PvtImageProcessor + from .models.rt_detr import RTDetrImageProcessor from .models.sam import SamImageProcessor from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor from .models.seggpt import SegGptImageProcessor @@ -7295,6 +7311,13 @@ RoFormerPreTrainedModel, load_tf_weights_in_roformer, ) + from .models.rt_detr import ( + RTDetrForObjectDetection, + RTDetrModel, + RTDetrPreTrainedModel, + RTDetrResNetBackbone, + RTDetrResNetPreTrainedModel, + ) from .models.rwkv import ( RwkvForCausalLM, RwkvModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 24b602f18c8f38..82a26b0f3061ba 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -193,6 +193,7 @@ roberta_prelayernorm, roc_bert, roformer, + rt_detr, rwkv, sam, seamless_m4t, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 40e282166ef99e..8793df948531e6 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -214,6 +214,8 @@ ("roberta-prelayernorm", "RobertaPreLayerNormConfig"), ("roc_bert", "RoCBertConfig"), ("roformer", "RoFormerConfig"), + ("rt_detr", "RTDetrConfig"), + ("rt_detr_resnet", "RTDetrResNetConfig"), ("rwkv", "RwkvConfig"), ("sam", "SamConfig"), ("seamless_m4t", "SeamlessM4TConfig"), @@ -499,6 +501,8 @@ ("roberta-prelayernorm", "RoBERTa-PreLayerNorm"), ("roc_bert", "RoCBert"), ("roformer", "RoFormer"), + ("rt_detr", "RT-DETR"), + ("rt_detr_resnet", "RT-DETR-ResNet"), ("rwkv", "RWKV"), ("sam", "SAM"), ("seamless_m4t", "SeamlessM4T"), @@ -623,6 +627,7 @@ ("clip_vision_model", "clip"), ("siglip_vision_model", "siglip"), ("chinese_clip_vision_model", "chinese_clip"), + ("rt_detr_resnet", "rt_detr"), ] ) diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 055f2ca733ce99..a9df7adc396243 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -114,6 +114,7 @@ ("pvt_v2", ("PvtImageProcessor",)), ("regnet", ("ConvNextImageProcessor",)), ("resnet", ("ConvNextImageProcessor",)), + ("rt_detr", "RTDetrImageProcessor"), ("sam", ("SamImageProcessor",)), ("segformer", ("SegformerImageProcessor",)), ("seggpt", ("SegGptImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index adfcc7af9fbc88..d371183fc6c800 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -202,6 +202,7 @@ ("roberta-prelayernorm", "RobertaPreLayerNormModel"), ("roc_bert", "RoCBertModel"), ("roformer", "RoFormerModel"), + ("rt_detr", "RTDetrModel"), ("rwkv", "RwkvModel"), ("sam", "SamModel"), ("seamless_m4t", "SeamlessM4TModel"), @@ -765,6 +766,7 @@ ("deformable_detr", "DeformableDetrForObjectDetection"), ("deta", "DetaForObjectDetection"), ("detr", "DetrForObjectDetection"), + ("rt_detr", "RTDetrForObjectDetection"), ("table-transformer", "TableTransformerForObjectDetection"), ("yolos", "YolosForObjectDetection"), ] @@ -1252,6 +1254,7 @@ ("nat", "NatBackbone"), ("pvt_v2", "PvtV2Backbone"), ("resnet", "ResNetBackbone"), + ("rt_detr_resnet", "RTDetrResNetBackbone"), ("swin", "SwinBackbone"), ("swinv2", "Swinv2Backbone"), ("timm_backbone", "TimmBackbone"), diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 4920262443035d..cfa08e3974b78b 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -29,22 +29,24 @@ from torch.autograd.function import once_differentiable from ...activations import ACT2FN -from ...file_utils import ( +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import meshgrid +from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + is_accelerate_available, + is_ninja_available, is_scipy_available, is_timm_available, is_torch_cuda_available, is_vision_available, + logging, replace_return_docstrings, requires_backends, ) -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_outputs import BaseModelOutput -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import meshgrid -from ...utils import is_accelerate_available, is_ninja_available, logging from ...utils.backbone_utils import load_backbone from .configuration_deformable_detr import DeformableDetrConfig diff --git a/src/transformers/models/rt_detr/__init__.py b/src/transformers/models/rt_detr/__init__.py new file mode 100644 index 00000000000000..94a428c66685a6 --- /dev/null +++ b/src/transformers/models/rt_detr/__init__.py @@ -0,0 +1,78 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available + + +_import_structure = {"configuration_rt_detr": ["RTDetrConfig"], "configuration_rt_detr_resnet": ["RTDetrResNetConfig"]} + +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["image_processing_rt_detr"] = ["RTDetrImageProcessor"] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_rt_detr"] = [ + "RTDetrForObjectDetection", + "RTDetrModel", + "RTDetrPreTrainedModel", + ] + _import_structure["modeling_rt_detr_resnet"] = [ + "RTDetrResNetBackbone", + "RTDetrResNetPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_rt_detr import RTDetrConfig + from .configuration_rt_detr_resnet import RTDetrResNetConfig + + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .image_processing_rt_detr import RTDetrImageProcessor + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_rt_detr import ( + RTDetrForObjectDetection, + RTDetrModel, + RTDetrPreTrainedModel, + ) + from .modeling_rt_detr_resnet import ( + RTDetrResNetBackbone, + RTDetrResNetPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/rt_detr/configuration_rt_detr.py b/src/transformers/models/rt_detr/configuration_rt_detr.py new file mode 100644 index 00000000000000..a3d49fafeaedc7 --- /dev/null +++ b/src/transformers/models/rt_detr/configuration_rt_detr.py @@ -0,0 +1,352 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RT-DETR model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING +from .configuration_rt_detr_resnet import RTDetrResNetConfig + + +logger = logging.get_logger(__name__) + + +class RTDetrConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RTDetrModel`]. It is used to instantiate a + RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the RT-DETR + [checkpoing/todo](https://huggingface.co/checkpoing/todo) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch normalization layers. + backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`): + The configuration of the backbone model. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + encoder_hidden_dim (`int`, *optional*, defaults to 256): + Dimension of the layers in hybrid encoder. + encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`): + Multi level features input for encoder. + feat_strides (`List[int]`, *optional*, defaults to `[8, 16, 32]`): + Strides used in each feature map. + encoder_layers (`int`, *optional*, defaults to 1): + Total of layers to be used by the encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + dropout (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + encode_proj_layers (`List[int]`, *optional*, defaults to `[2]`): + Indexes of the projected layers to be used in the encoder. + positional_encoding_temperature (`int`, *optional*, defaults to 10000): + The temperature parameter used to create the positional encodings. + encoder_activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + activation_function (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the general layer. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + eval_size (`Tuple[int, int]`, *optional*): + Height and width used to computes the effective height and width of the position embeddings after taking + into account the stride. + normalize_before (`bool`, *optional*, defaults to `False`): + Determine whether to apply layer normalization in the transformer encoder layer before self-attention and + feed-forward modules. + hidden_expansion (`float`, *optional*, defaults to 1.0): + Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer. + d_model (`int`, *optional*, defaults to 256): + Dimension of the layers exclude hybrid encoder. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries. + decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`): + Multi level features dimension for decoder + decoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of input feature levels. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_activation_function (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_denoising (`int`, *optional*, defaults to 100): + The total number of denoising tasks or queries to be used for contrastive denoising. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + The fraction of denoising labels to which random noise should be added. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale or magnitude of noise to be added to the bounding boxes. + learn_initial_query (`bool`, *optional*, defaults to `False`): + Indicates whether the initial query embeddings for the decoder should be learned during training + anchor_image_size (`Tuple[int, int]`, *optional*, defaults to `[640, 640]`): + Height and width of the input image used during evaluation to generate the bounding box anchors. + disable_custom_kernels (`bool`, *optional*, defaults to `True`): + Whether to disable custom kernels. + with_box_refine (`bool`, *optional*, defaults to `True`): + Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes + based on the predictions from the previous layer. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the architecture has an encoder decoder structure. + matcher_alpha (`float`, *optional*, defaults to 0.25): + Parameter alpha used by the Hungarian Matcher. + matcher_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used by the Hungarian Matcher. + matcher_class_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the class loss used by the Hungarian Matcher. + matcher_bbox_cost (`float`, *optional*, defaults to 5.0): + The relative weight of the bounding box loss used by the Hungarian Matcher. + matcher_giou_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the giou loss of used by the Hungarian Matcher. + use_focal_loss (`bool`, *optional*, defaults to `True`): + Parameter informing if focal focal should be used. + auxiliary_loss (`bool`, *optional*, defaults to `True`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + focal_loss_alpha (`float`, *optional*, defaults to 0.75): + Parameter alpha used to compute the focal loss. + focal_loss_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used to compute the focal loss. + weight_loss_vfl (`float`, *optional*, defaults to 1.0): + Relative weight of the varifocal loss in the object detection loss. + weight_loss_bbox (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 bounding box loss in the object detection loss. + weight_loss_giou (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.0001): + Relative classification weight of the 'no-object' class in the object detection loss. + + Examples: + + ```python + >>> from transformers import RTDetrConfig, RTDetrModel + + >>> # Initializing a RT-DETR configuration + >>> configuration = RTDetrConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RTDetrModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "rt_detr" + layer_types = ["basic", "bottleneck"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + initializer_range=0.01, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + # backbone + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + backbone_kwargs=None, + # encoder HybridEncoder + encoder_hidden_dim=256, + encoder_in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=1024, + encoder_attention_heads=8, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + hidden_expansion=1.0, + # decoder RTDetrTransformer + d_model=256, + num_queries=300, + decoder_in_channels=[256, 256, 256], + decoder_ffn_dim=1024, + num_feature_levels=3, + decoder_n_points=4, + decoder_layers=6, + decoder_attention_heads=8, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=[640, 640], + disable_custom_kernels=True, + with_box_refine=True, + is_encoder_decoder=True, + # Loss + matcher_alpha=0.25, + matcher_gamma=2.0, + matcher_class_cost=2.0, + matcher_bbox_cost=5.0, + matcher_giou_cost=2.0, + use_focal_loss=True, + auxiliary_loss=True, + focal_loss_alpha=0.75, + focal_loss_gamma=2.0, + weight_loss_vfl=1.0, + weight_loss_bbox=5.0, + weight_loss_giou=2.0, + eos_coefficient=1e-4, + **kwargs, + ): + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + # backbone + if backbone_config is None and backbone is None: + logger.info( + "`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetr-ResNet` backbone." + ) + backbone_config = RTDetrResNetConfig( + num_channels=3, + embedding_size=64, + hidden_sizes=[256, 512, 1024, 2048], + depths=[3, 4, 6, 3], + layer_type="bottleneck", + hidden_act="relu", + downsample_in_first_stage=False, + downsample_in_bottleneck=False, + out_features=None, + out_indices=[2, 3, 4], + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.pop("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + # encoder + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.encoder_layers = encoder_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.eval_size = eval_size + self.normalize_before = normalize_before + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.hidden_expansion = hidden_expansion + # decoder + self.d_model = d_model + self.num_queries = num_queries + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_in_channels = decoder_in_channels + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.auxiliary_loss = auxiliary_loss + self.disable_custom_kernels = disable_custom_kernels + self.with_box_refine = with_box_refine + # Loss + self.matcher_alpha = matcher_alpha + self.matcher_gamma = matcher_gamma + self.matcher_class_cost = matcher_class_cost + self.matcher_bbox_cost = matcher_bbox_cost + self.matcher_giou_cost = matcher_giou_cost + self.use_focal_loss = use_focal_loss + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + self.weight_loss_vfl = weight_loss_vfl + self.weight_loss_bbox = weight_loss_bbox + self.weight_loss_giou = weight_loss_giou + self.eos_coefficient = eos_coefficient + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model + + @classmethod + def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs): + """Instantiate a [`RTDetrConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model + configuration. + + Args: + backbone_config ([`PretrainedConfig`]): + The backbone configuration. + + Returns: + [`RTDetrConfig`]: An instance of a configuration object + """ + return cls( + backbone_config=backbone_config, + **kwargs, + ) diff --git a/src/transformers/models/rt_detr/configuration_rt_detr_resnet.py b/src/transformers/models/rt_detr/configuration_rt_detr_resnet.py new file mode 100644 index 00000000000000..fb46086296a45b --- /dev/null +++ b/src/transformers/models/rt_detr/configuration_rt_detr_resnet.py @@ -0,0 +1,111 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RT-DETR ResNet model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class RTDetrResNetConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RTDetrResnetBackbone`]. It is used to instantiate an + ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the ResNet + [microsoft/resnet-50](https://huggingface.co/microsoft/resnet-50) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embedding_size (`int`, *optional*, defaults to 64): + Dimensionality (hidden size) for the embedding layer. + hidden_sizes (`List[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`): + Dimensionality (hidden size) at each stage. + depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 3]`): + Depth (number of layers) for each stage. + layer_type (`str`, *optional*, defaults to `"bottleneck"`): + The layer to use, it can be either `"basic"` (used for smaller models, like resnet-18 or resnet-34) or + `"bottleneck"` (used for larger models like resnet-50 and above). + hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` + are supported. + downsample_in_first_stage (`bool`, *optional*, defaults to `False`): + If `True`, the first stage will downsample the inputs using a `stride` of 2. + downsample_in_bottleneck (`bool`, *optional*, defaults to `False`): + If `True`, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a `stride` of 2. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + + Example: + ```python + >>> from transformers import RTDetrResNetConfig, RTDetrResnetBackbone + + >>> # Initializing a ResNet resnet-50 style configuration + >>> configuration = RTDetrResNetConfig() + + >>> # Initializing a model (with random weights) from the resnet-50 style configuration + >>> model = RTDetrResnetBackbone(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "rt_detr_resnet" + layer_types = ["basic", "bottleneck"] + + def __init__( + self, + num_channels=3, + embedding_size=64, + hidden_sizes=[256, 512, 1024, 2048], + depths=[3, 4, 6, 3], + layer_type="bottleneck", + hidden_act="relu", + downsample_in_first_stage=False, + downsample_in_bottleneck=False, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + if layer_type not in self.layer_types: + raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}") + self.num_channels = num_channels + self.embedding_size = embedding_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.layer_type = layer_type + self.hidden_act = hidden_act + self.downsample_in_first_stage = downsample_in_first_stage + self.downsample_in_bottleneck = downsample_in_bottleneck + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) diff --git a/src/transformers/models/rt_detr/convert_rt_detr_original_pytorch_checkpoint_to_hf.py b/src/transformers/models/rt_detr/convert_rt_detr_original_pytorch_checkpoint_to_hf.py new file mode 100644 index 00000000000000..9f2271930e1378 --- /dev/null +++ b/src/transformers/models/rt_detr/convert_rt_detr_original_pytorch_checkpoint_to_hf.py @@ -0,0 +1,782 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert RT Detr checkpoints with Timm backbone""" + +import argparse +import json +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from torchvision import transforms + +from transformers import RTDetrConfig, RTDetrForObjectDetection, RTDetrImageProcessor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_rt_detr_config(model_name: str) -> RTDetrConfig: + config = RTDetrConfig() + + config.num_labels = 80 + repo_id = "huggingface/label-files" + filename = "coco-detection-mmdet-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + if model_name == "rtdetr_r18vd": + config.backbone_config.hidden_sizes = [64, 128, 256, 512] + config.backbone_config.depths = [2, 2, 2, 2] + config.backbone_config.layer_type = "basic" + config.encoder_in_channels = [128, 256, 512] + config.hidden_expansion = 0.5 + config.decoder_layers = 3 + elif model_name == "rtdetr_r34vd": + config.backbone_config.hidden_sizes = [64, 128, 256, 512] + config.backbone_config.depths = [3, 4, 6, 3] + config.backbone_config.layer_type = "basic" + config.encoder_in_channels = [128, 256, 512] + config.hidden_expansion = 0.5 + config.decoder_layers = 4 + elif model_name == "rtdetr_r50vd_m": + pass + elif model_name == "rtdetr_r50vd": + pass + elif model_name == "rtdetr_r101vd": + config.backbone_config.depths = [3, 4, 23, 3] + config.encoder_ffn_dim = 2048 + config.encoder_hidden_dim = 384 + config.decoder_in_channels = [384, 384, 384] + elif model_name == "rtdetr_r18vd_coco_o365": + config.backbone_config.hidden_sizes = [64, 128, 256, 512] + config.backbone_config.depths = [2, 2, 2, 2] + config.backbone_config.layer_type = "basic" + config.encoder_in_channels = [128, 256, 512] + config.hidden_expansion = 0.5 + config.decoder_layers = 3 + elif model_name == "rtdetr_r50vd_coco_o365": + pass + elif model_name == "rtdetr_r101vd_coco_o365": + config.backbone_config.depths = [3, 4, 23, 3] + config.encoder_ffn_dim = 2048 + config.encoder_hidden_dim = 384 + config.decoder_in_channels = [384, 384, 384] + + return config + + +def create_rename_keys(config): + # here we list all keys to be renamed (original name on the left, our name on the right) + rename_keys = [] + + # stem + # fmt: off + last_key = ["weight", "bias", "running_mean", "running_var"] + + for level in range(3): + rename_keys.append((f"backbone.conv1.conv1_{level+1}.conv.weight", f"model.backbone.model.embedder.embedder.{level}.convolution.weight")) + for last in last_key: + rename_keys.append((f"backbone.conv1.conv1_{level+1}.norm.{last}", f"model.backbone.model.embedder.embedder.{level}.normalization.{last}")) + + for stage_idx in range(len(config.backbone_config.depths)): + for layer_idx in range(config.backbone_config.depths[stage_idx]): + # shortcut + if layer_idx == 0: + if stage_idx == 0: + rename_keys.append( + ( + f"backbone.res_layers.{stage_idx}.blocks.0.short.conv.weight", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.convolution.weight", + ) + ) + for last in last_key: + rename_keys.append( + ( + f"backbone.res_layers.{stage_idx}.blocks.0.short.norm.{last}", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.normalization.{last}", + ) + ) + else: + rename_keys.append( + ( + f"backbone.res_layers.{stage_idx}.blocks.0.short.conv.conv.weight", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.1.convolution.weight", + ) + ) + for last in last_key: + rename_keys.append( + ( + f"backbone.res_layers.{stage_idx}.blocks.0.short.conv.norm.{last}", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.1.normalization.{last}", + ) + ) + + rename_keys.append( + ( + f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2a.conv.weight", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.0.convolution.weight", + ) + ) + for last in last_key: + rename_keys.append(( + f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2a.norm.{last}", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.0.normalization.{last}", + )) + + rename_keys.append( + ( + f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2b.conv.weight", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.1.convolution.weight", + ) + ) + for last in last_key: + rename_keys.append(( + f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2b.norm.{last}", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.1.normalization.{last}", + )) + + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/nn/backbone/presnet.py#L171 + if config.backbone_config.layer_type != "basic": + rename_keys.append( + ( + f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2c.conv.weight", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.2.convolution.weight", + ) + ) + for last in last_key: + rename_keys.append(( + f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2c.norm.{last}", + f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.2.normalization.{last}", + )) + # fmt: on + + for i in range(config.encoder_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.self_attn.out_proj.weight", + f"model.encoder.encoder.{i}.layers.0.self_attn.out_proj.weight", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.self_attn.out_proj.bias", + f"model.encoder.encoder.{i}.layers.0.self_attn.out_proj.bias", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.linear1.weight", + f"model.encoder.encoder.{i}.layers.0.fc1.weight", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.linear1.bias", + f"model.encoder.encoder.{i}.layers.0.fc1.bias", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.linear2.weight", + f"model.encoder.encoder.{i}.layers.0.fc2.weight", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.linear2.bias", + f"model.encoder.encoder.{i}.layers.0.fc2.bias", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.norm1.weight", + f"model.encoder.encoder.{i}.layers.0.self_attn_layer_norm.weight", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.norm1.bias", + f"model.encoder.encoder.{i}.layers.0.self_attn_layer_norm.bias", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.norm2.weight", + f"model.encoder.encoder.{i}.layers.0.final_layer_norm.weight", + ) + ) + rename_keys.append( + ( + f"encoder.encoder.{i}.layers.0.norm2.bias", + f"model.encoder.encoder.{i}.layers.0.final_layer_norm.bias", + ) + ) + + for j in range(0, 3): + rename_keys.append((f"encoder.input_proj.{j}.0.weight", f"model.encoder_input_proj.{j}.0.weight")) + for last in last_key: + rename_keys.append((f"encoder.input_proj.{j}.1.{last}", f"model.encoder_input_proj.{j}.1.{last}")) + + block_levels = 3 if config.backbone_config.layer_type != "basic" else 4 + + for i in range(len(config.encoder_in_channels) - 1): + # encoder layers: hybridencoder parts + for j in range(1, block_levels): + rename_keys.append( + (f"encoder.fpn_blocks.{i}.conv{j}.conv.weight", f"model.encoder.fpn_blocks.{i}.conv{j}.conv.weight") + ) + for last in last_key: + rename_keys.append( + ( + f"encoder.fpn_blocks.{i}.conv{j}.norm.{last}", + f"model.encoder.fpn_blocks.{i}.conv{j}.norm.{last}", + ) + ) + + rename_keys.append((f"encoder.lateral_convs.{i}.conv.weight", f"model.encoder.lateral_convs.{i}.conv.weight")) + for last in last_key: + rename_keys.append( + (f"encoder.lateral_convs.{i}.norm.{last}", f"model.encoder.lateral_convs.{i}.norm.{last}") + ) + + for j in range(3): + for k in range(1, 3): + rename_keys.append( + ( + f"encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight", + f"model.encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight", + ) + ) + for last in last_key: + rename_keys.append( + ( + f"encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}", + f"model.encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}", + ) + ) + + for j in range(1, block_levels): + rename_keys.append( + (f"encoder.pan_blocks.{i}.conv{j}.conv.weight", f"model.encoder.pan_blocks.{i}.conv{j}.conv.weight") + ) + for last in last_key: + rename_keys.append( + ( + f"encoder.pan_blocks.{i}.conv{j}.norm.{last}", + f"model.encoder.pan_blocks.{i}.conv{j}.norm.{last}", + ) + ) + + for j in range(3): + for k in range(1, 3): + rename_keys.append( + ( + f"encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight", + f"model.encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight", + ) + ) + for last in last_key: + rename_keys.append( + ( + f"encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}", + f"model.encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}", + ) + ) + + rename_keys.append( + (f"encoder.downsample_convs.{i}.conv.weight", f"model.encoder.downsample_convs.{i}.conv.weight") + ) + for last in last_key: + rename_keys.append( + (f"encoder.downsample_convs.{i}.norm.{last}", f"model.encoder.downsample_convs.{i}.norm.{last}") + ) + + for i in range(config.decoder_layers): + # decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.self_attn.out_proj.weight", + f"model.decoder.layers.{i}.self_attn.out_proj.weight", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.self_attn.out_proj.bias", + f"model.decoder.layers.{i}.self_attn.out_proj.bias", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.sampling_offsets.weight", + f"model.decoder.layers.{i}.encoder_attn.sampling_offsets.weight", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.sampling_offsets.bias", + f"model.decoder.layers.{i}.encoder_attn.sampling_offsets.bias", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.attention_weights.weight", + f"model.decoder.layers.{i}.encoder_attn.attention_weights.weight", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.attention_weights.bias", + f"model.decoder.layers.{i}.encoder_attn.attention_weights.bias", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.value_proj.weight", + f"model.decoder.layers.{i}.encoder_attn.value_proj.weight", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.value_proj.bias", + f"model.decoder.layers.{i}.encoder_attn.value_proj.bias", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.output_proj.weight", + f"model.decoder.layers.{i}.encoder_attn.output_proj.weight", + ) + ) + rename_keys.append( + ( + f"decoder.decoder.layers.{i}.cross_attn.output_proj.bias", + f"model.decoder.layers.{i}.encoder_attn.output_proj.bias", + ) + ) + rename_keys.append( + (f"decoder.decoder.layers.{i}.norm1.weight", f"model.decoder.layers.{i}.self_attn_layer_norm.weight") + ) + rename_keys.append( + (f"decoder.decoder.layers.{i}.norm1.bias", f"model.decoder.layers.{i}.self_attn_layer_norm.bias") + ) + rename_keys.append( + (f"decoder.decoder.layers.{i}.norm2.weight", f"model.decoder.layers.{i}.encoder_attn_layer_norm.weight") + ) + rename_keys.append( + (f"decoder.decoder.layers.{i}.norm2.bias", f"model.decoder.layers.{i}.encoder_attn_layer_norm.bias") + ) + rename_keys.append((f"decoder.decoder.layers.{i}.linear1.weight", f"model.decoder.layers.{i}.fc1.weight")) + rename_keys.append((f"decoder.decoder.layers.{i}.linear1.bias", f"model.decoder.layers.{i}.fc1.bias")) + rename_keys.append((f"decoder.decoder.layers.{i}.linear2.weight", f"model.decoder.layers.{i}.fc2.weight")) + rename_keys.append((f"decoder.decoder.layers.{i}.linear2.bias", f"model.decoder.layers.{i}.fc2.bias")) + rename_keys.append( + (f"decoder.decoder.layers.{i}.norm3.weight", f"model.decoder.layers.{i}.final_layer_norm.weight") + ) + rename_keys.append( + (f"decoder.decoder.layers.{i}.norm3.bias", f"model.decoder.layers.{i}.final_layer_norm.bias") + ) + + for i in range(config.decoder_layers): + # decoder + class and bounding box heads + rename_keys.append( + ( + f"decoder.dec_score_head.{i}.weight", + f"model.decoder.class_embed.{i}.weight", + ) + ) + rename_keys.append( + ( + f"decoder.dec_score_head.{i}.bias", + f"model.decoder.class_embed.{i}.bias", + ) + ) + rename_keys.append( + ( + f"decoder.dec_bbox_head.{i}.layers.0.weight", + f"model.decoder.bbox_embed.{i}.layers.0.weight", + ) + ) + rename_keys.append( + ( + f"decoder.dec_bbox_head.{i}.layers.0.bias", + f"model.decoder.bbox_embed.{i}.layers.0.bias", + ) + ) + rename_keys.append( + ( + f"decoder.dec_bbox_head.{i}.layers.1.weight", + f"model.decoder.bbox_embed.{i}.layers.1.weight", + ) + ) + rename_keys.append( + ( + f"decoder.dec_bbox_head.{i}.layers.1.bias", + f"model.decoder.bbox_embed.{i}.layers.1.bias", + ) + ) + rename_keys.append( + ( + f"decoder.dec_bbox_head.{i}.layers.2.weight", + f"model.decoder.bbox_embed.{i}.layers.2.weight", + ) + ) + rename_keys.append( + ( + f"decoder.dec_bbox_head.{i}.layers.2.bias", + f"model.decoder.bbox_embed.{i}.layers.2.bias", + ) + ) + + # decoder projection + for i in range(len(config.decoder_in_channels)): + rename_keys.append( + ( + f"decoder.input_proj.{i}.conv.weight", + f"model.decoder_input_proj.{i}.0.weight", + ) + ) + for last in last_key: + rename_keys.append( + ( + f"decoder.input_proj.{i}.norm.{last}", + f"model.decoder_input_proj.{i}.1.{last}", + ) + ) + + # convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads + rename_keys.extend( + [ + ("decoder.denoising_class_embed.weight", "model.denoising_class_embed.weight"), + ("decoder.query_pos_head.layers.0.weight", "model.decoder.query_pos_head.layers.0.weight"), + ("decoder.query_pos_head.layers.0.bias", "model.decoder.query_pos_head.layers.0.bias"), + ("decoder.query_pos_head.layers.1.weight", "model.decoder.query_pos_head.layers.1.weight"), + ("decoder.query_pos_head.layers.1.bias", "model.decoder.query_pos_head.layers.1.bias"), + ("decoder.enc_output.0.weight", "model.enc_output.0.weight"), + ("decoder.enc_output.0.bias", "model.enc_output.0.bias"), + ("decoder.enc_output.1.weight", "model.enc_output.1.weight"), + ("decoder.enc_output.1.bias", "model.enc_output.1.bias"), + ("decoder.enc_score_head.weight", "model.enc_score_head.weight"), + ("decoder.enc_score_head.bias", "model.enc_score_head.bias"), + ("decoder.enc_bbox_head.layers.0.weight", "model.enc_bbox_head.layers.0.weight"), + ("decoder.enc_bbox_head.layers.0.bias", "model.enc_bbox_head.layers.0.bias"), + ("decoder.enc_bbox_head.layers.1.weight", "model.enc_bbox_head.layers.1.weight"), + ("decoder.enc_bbox_head.layers.1.bias", "model.enc_bbox_head.layers.1.bias"), + ("decoder.enc_bbox_head.layers.2.weight", "model.enc_bbox_head.layers.2.weight"), + ("decoder.enc_bbox_head.layers.2.bias", "model.enc_bbox_head.layers.2.bias"), + ] + ) + + return rename_keys + + +def rename_key(state_dict, old, new): + try: + val = state_dict.pop(old) + state_dict[new] = val + except Exception: + pass + + +def read_in_q_k_v(state_dict, config): + prefix = "" + encoder_hidden_dim = config.encoder_hidden_dim + + # first: transformer encoder + for i in range(config.encoder_layers): + # read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_weight") + in_proj_bias = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.weight"] = in_proj_weight[ + :encoder_hidden_dim, : + ] + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.bias"] = in_proj_bias[:encoder_hidden_dim] + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.weight"] = in_proj_weight[ + encoder_hidden_dim : 2 * encoder_hidden_dim, : + ] + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.bias"] = in_proj_bias[ + encoder_hidden_dim : 2 * encoder_hidden_dim + ] + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.weight"] = in_proj_weight[ + -encoder_hidden_dim:, : + ] + state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.bias"] = in_proj_bias[-encoder_hidden_dim:] + # next: transformer decoder (which is a bit more complex because it also includes cross-attention) + for i in range(config.decoder_layers): + # read in weights + bias of input projection layer of self-attention + in_proj_weight = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_weight") + in_proj_bias = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :] + state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256] + state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512] + state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:] + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + + return im + + +@torch.no_grad() +def convert_rt_detr_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, repo_id): + """ + Copy/paste/tweak model's weights to our RTDETR structure. + """ + + # load default config + config = get_rt_detr_config(model_name) + + # load original model from torch hub + model_name_to_checkpoint_url = { + "rtdetr_r18vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r18vd_dec3_6x_coco_from_paddle.pth", + "rtdetr_r34vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r34vd_dec4_6x_coco_from_paddle.pth", + "rtdetr_r50vd_m": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_m_6x_coco_from_paddle.pth", + "rtdetr_r50vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_6x_coco_from_paddle.pth", + "rtdetr_r101vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r101vd_6x_coco_from_paddle.pth", + "rtdetr_r18vd_coco_o365": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r18vd_5x_coco_objects365_from_paddle.pth", + "rtdetr_r50vd_coco_o365": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_2x_coco_objects365_from_paddle.pth", + "rtdetr_r101vd_coco_o365": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r101vd_2x_coco_objects365_from_paddle.pth", + } + logger.info(f"Converting model {model_name}...") + state_dict = torch.hub.load_state_dict_from_url(model_name_to_checkpoint_url[model_name], map_location="cpu")[ + "ema" + ]["module"] + + # rename keys + for src, dest in create_rename_keys(config): + rename_key(state_dict, src, dest) + # query, key and value matrices need special treatment + read_in_q_k_v(state_dict, config) + # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them + for key in state_dict.copy().keys(): + if key.endswith("num_batches_tracked"): + del state_dict[key] + # for two_stage + if "bbox_embed" in key or ("class_embed" in key and "denoising_" not in key): + state_dict[key.split("model.decoder.")[-1]] = state_dict[key] + + # finally, create HuggingFace model and load state dict + model = RTDetrForObjectDetection(config) + model.load_state_dict(state_dict) + model.eval() + + # load image processor + image_processor = RTDetrImageProcessor() + + # prepare image + img = prepare_img() + + # preprocess image + transformations = transforms.Compose( + [ + transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + ] + ) + original_pixel_values = transformations(img).unsqueeze(0) # insert batch dimension + + encoding = image_processor(images=img, return_tensors="pt") + pixel_values = encoding["pixel_values"] + + assert torch.allclose(original_pixel_values, pixel_values) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + pixel_values = pixel_values.to(device) + + # Pass image by the model + outputs = model(pixel_values) + + if model_name == "rtdetr_r18vd": + expected_slice_logits = torch.tensor( + [ + [-4.3364253, -6.465683, -3.6130402], + [-4.083815, -6.4039373, -6.97881], + [-4.192215, -7.3410473, -6.9027247], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.16868353, 0.19833282, 0.21182671], + [0.25559652, 0.55121744, 0.47988364], + [0.7698693, 0.4124569, 0.46036878], + ] + ) + elif model_name == "rtdetr_r34vd": + expected_slice_logits = torch.tensor( + [ + [-4.3727384, -4.7921476, -5.7299604], + [-4.840536, -8.455345, -4.1745796], + [-4.1277084, -5.2154565, -5.7852697], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.258278, 0.5497808, 0.4732004], + [0.16889669, 0.19890057, 0.21138911], + [0.76632994, 0.4147879, 0.46851268], + ] + ) + elif model_name == "rtdetr_r50vd_m": + expected_slice_logits = torch.tensor( + [ + [-4.319764, -6.1349025, -6.094794], + [-5.1056995, -7.744766, -4.803956], + [-4.7685347, -7.9278393, -4.5751696], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.2582739, 0.55071366, 0.47660282], + [0.16811174, 0.19954777, 0.21292639], + [0.54986024, 0.2752091, 0.0561416], + ] + ) + elif model_name == "rtdetr_r50vd": + expected_slice_logits = torch.tensor( + [ + [-4.6476398, -5.001154, -4.9785104], + [-4.1593494, -4.7038546, -5.946485], + [-4.4374595, -4.658361, -6.2352347], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.16880608, 0.19992264, 0.21225442], + [0.76837635, 0.4122631, 0.46368608], + [0.2595386, 0.5483334, 0.4777486], + ] + ) + elif model_name == "rtdetr_r101vd": + expected_slice_logits = torch.tensor( + [ + [-4.6162, -4.9189, -4.6656], + [-4.4701, -4.4997, -4.9659], + [-5.6641, -7.9000, -5.0725], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.7707, 0.4124, 0.4585], + [0.2589, 0.5492, 0.4735], + [0.1688, 0.1993, 0.2108], + ] + ) + elif model_name == "rtdetr_r18vd_coco_o365": + expected_slice_logits = torch.tensor( + [ + [-4.8726, -5.9066, -5.2450], + [-4.8157, -6.8764, -5.1656], + [-4.7492, -5.7006, -5.1333], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.2552, 0.5501, 0.4773], + [0.1685, 0.1986, 0.2104], + [0.7692, 0.4141, 0.4620], + ] + ) + elif model_name == "rtdetr_r50vd_coco_o365": + expected_slice_logits = torch.tensor( + [ + [-4.6491, -3.9252, -5.3163], + [-4.1386, -5.0348, -3.9016], + [-4.4778, -4.5423, -5.7356], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.2583, 0.5492, 0.4747], + [0.5501, 0.2754, 0.0574], + [0.7693, 0.4137, 0.4613], + ] + ) + elif model_name == "rtdetr_r101vd_coco_o365": + expected_slice_logits = torch.tensor( + [ + [-4.5152, -5.6811, -5.7311], + [-4.5358, -7.2422, -5.0941], + [-4.6919, -5.5834, -6.0145], + ] + ) + expected_slice_boxes = torch.tensor( + [ + [0.7703, 0.4140, 0.4583], + [0.1686, 0.1991, 0.2107], + [0.2570, 0.5496, 0.4750], + ] + ) + else: + raise ValueError(f"Unknown rt_detr_name: {model_name}") + + assert torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits.to(outputs.logits.device), atol=1e-4) + assert torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes.to(outputs.pred_boxes.device), atol=1e-3) + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + # Upload model, image processor and config to the hub + logger.info("Uploading PyTorch model and image processor to the hub...") + config.push_to_hub( + repo_id=repo_id, commit_message="Add config from convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py" + ) + model.push_to_hub( + repo_id=repo_id, commit_message="Add model from convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py" + ) + image_processor.push_to_hub( + repo_id=repo_id, + commit_message="Add image processor from convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + default="rtdetr_r50vd", + type=str, + help="model_name of the checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to the hub or not.") + parser.add_argument( + "--repo_id", + type=str, + help="repo_id where the model will be pushed to.", + ) + args = parser.parse_args() + convert_rt_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.repo_id) diff --git a/src/transformers/models/rt_detr/image_processing_rt_detr.py b/src/transformers/models/rt_detr/image_processing_rt_detr.py new file mode 100644 index 00000000000000..1c66dee5c9836f --- /dev/null +++ b/src/transformers/models/rt_detr/image_processing_rt_detr.py @@ -0,0 +1,1120 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for RT-DETR.""" + +import pathlib +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils import BaseImageProcessor, get_size_dict +from ...image_transforms import ( + PaddingMode, + center_to_corners_format, + corners_to_center_format, + pad, + rescale, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + AnnotationFormat, + AnnotationType, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_annotations, + validate_kwargs, + validate_preprocess_arguments, +) +from ...utils import ( + is_flax_available, + is_jax_tensor, + is_tf_available, + is_tf_tensor, + is_torch_available, + is_torch_tensor, + logging, + requires_backends, +) +from ...utils.generic import TensorType + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,) + + +# Copied from transformers.models.detr.image_processing_detr.get_size_with_aspect_ratio +def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. + + Args: + image_size (`Tuple[int, int]`): + The input image size. + size (`int`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + """ + height, width = image_size + raw_size = None + if max_size is not None: + min_original_size = float(min((height, width))) + max_original_size = float(max((height, width))) + if max_original_size / min_original_size * size > max_size: + raw_size = max_size * min_original_size / max_original_size + size = int(round(raw_size)) + + if (height <= width and height == size) or (width <= height and width == size): + oh, ow = height, width + elif width < height: + ow = size + if max_size is not None and raw_size is not None: + oh = int(raw_size * height / width) + else: + oh = int(size * height / width) + else: + oh = size + if max_size is not None and raw_size is not None: + ow = int(raw_size * width / height) + else: + ow = int(size * width / height) + + return (oh, ow) + + +# Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size +def get_resize_output_image_size( + input_image: np.ndarray, + size: Union[int, Tuple[int, int], List[int]], + max_size: Optional[int] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> Tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. If the desired output size + is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output + image size is computed by keeping the aspect ratio of the input image size. + + Args: + input_image (`np.ndarray`): + The image to resize. + size (`int` or `Tuple[int, int]` or `List[int]`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. + """ + image_size = get_image_size(input_image, input_data_format) + if isinstance(size, (list, tuple)): + return size + + return get_size_with_aspect_ratio(image_size, size, max_size) + + +# Copied from transformers.models.detr.image_processing_detr.get_image_size_for_max_height_width +def get_image_size_for_max_height_width( + input_image: np.ndarray, + max_height: int, + max_width: int, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> Tuple[int, int]: + """ + Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. + Important, even if image_height < max_height and image_width < max_width, the image will be resized + to at least one of the edges be equal to max_height or max_width. + For example: + - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) + - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) + Args: + input_image (`np.ndarray`): + The image to resize. + max_height (`int`): + The maximum allowed height. + max_width (`int`): + The maximum allowed width. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. + """ + image_size = get_image_size(input_image, input_data_format) + height, width = image_size + height_scale = max_height / height + width_scale = max_width / width + min_scale = min(height_scale, width_scale) + new_height = int(height * min_scale) + new_width = int(width * min_scale) + return new_height, new_width + + +# Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn +def get_numpy_to_framework_fn(arr) -> Callable: + """ + Returns a function that converts a numpy array to the framework of the input array. + + Args: + arr (`np.ndarray`): The array to convert. + """ + if isinstance(arr, np.ndarray): + return np.array + if is_tf_available() and is_tf_tensor(arr): + import tensorflow as tf + + return tf.convert_to_tensor + if is_torch_available() and is_torch_tensor(arr): + import torch + + return torch.tensor + if is_flax_available() and is_jax_tensor(arr): + import jax.numpy as jnp + + return jnp.array + raise ValueError(f"Cannot convert arrays of type {type(arr)}") + + +# Copied from transformers.models.detr.image_processing_detr.safe_squeeze +def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray: + """ + Squeezes an array, but only if the axis specified has dim 1. + """ + if axis is None: + return arr.squeeze() + + try: + return arr.squeeze(axis=axis) + except ValueError: + return arr + + +# Copied from transformers.models.detr.image_processing_detr.normalize_annotation +def normalize_annotation(annotation: Dict, image_size: Tuple[int, int]) -> Dict: + image_height, image_width = image_size + norm_annotation = {} + for key, value in annotation.items(): + if key == "boxes": + boxes = value + boxes = corners_to_center_format(boxes) + boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32) + norm_annotation[key] = boxes + else: + norm_annotation[key] = value + return norm_annotation + + +# Copied from transformers.models.detr.image_processing_detr.max_across_indices +def max_across_indices(values: Iterable[Any]) -> List[Any]: + """ + Return the maximum value across all indices of an iterable of values. + """ + return [max(values_i) for values_i in zip(*values)] + + +# Copied from transformers.models.detr.image_processing_detr.get_max_height_width +def get_max_height_width( + images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> List[int]: + """ + Get the maximum height and width across all images in a batch. + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(images[0]) + + if input_data_format == ChannelDimension.FIRST: + _, max_height, max_width = max_across_indices([img.shape for img in images]) + elif input_data_format == ChannelDimension.LAST: + max_height, max_width, _ = max_across_indices([img.shape for img in images]) + else: + raise ValueError(f"Invalid channel dimension format: {input_data_format}") + return (max_height, max_width) + + +# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask +def make_pixel_mask( + image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> np.ndarray: + """ + Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + + Args: + image (`np.ndarray`): + Image to make the pixel mask for. + output_size (`Tuple[int, int]`): + Output size of the mask. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + mask = np.zeros(output_size, dtype=np.int64) + mask[:input_height, :input_width] = 1 + return mask + + +def prepare_coco_detection_annotation( + image, + target, + return_segmentation_masks: bool = False, + input_data_format: Optional[Union[ChannelDimension, str]] = None, +): + """ + Convert the target in COCO format into the format expected by RTDETR. + """ + image_height, image_width = get_image_size(image, channel_dim=input_data_format) + + image_id = target["image_id"] + image_id = np.asarray([image_id], dtype=np.int64) + + # Get all COCO annotations for the given image. + annotations = target["annotations"] + annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0] + + classes = [obj["category_id"] for obj in annotations] + classes = np.asarray(classes, dtype=np.int64) + + # for conversion to coco api + area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32) + iscrowd = np.asarray([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in annotations], dtype=np.int64) + + boxes = [obj["bbox"] for obj in annotations] + # guard against no boxes via resizing + boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width) + boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + + new_target = {} + new_target["image_id"] = image_id + new_target["class_labels"] = classes[keep] + new_target["boxes"] = boxes[keep] + new_target["area"] = area[keep] + new_target["iscrowd"] = iscrowd[keep] + new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64) + + if annotations and "keypoints" in annotations[0]: + keypoints = [obj["keypoints"] for obj in annotations] + # Converting the filtered keypoints list to a numpy array + keypoints = np.asarray(keypoints, dtype=np.float32) + # Apply the keep mask here to filter the relevant annotations + keypoints = keypoints[keep] + num_keypoints = keypoints.shape[0] + keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints + new_target["keypoints"] = keypoints + + return new_target + + +# Copied from transformers.models.detr.image_processing_detr.resize_annotation +def resize_annotation( + annotation: Dict[str, Any], + orig_size: Tuple[int, int], + target_size: Tuple[int, int], + threshold: float = 0.5, + resample: PILImageResampling = PILImageResampling.NEAREST, +): + """ + Resizes an annotation to a target size. + + Args: + annotation (`Dict[str, Any]`): + The annotation dictionary. + orig_size (`Tuple[int, int]`): + The original size of the input image. + target_size (`Tuple[int, int]`): + The target size of the image, as returned by the preprocessing `resize` step. + threshold (`float`, *optional*, defaults to 0.5): + The threshold used to binarize the segmentation masks. + resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`): + The resampling filter to use when resizing the masks. + """ + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size)) + ratio_height, ratio_width = ratios + + new_annotation = {} + new_annotation["size"] = target_size + + for key, value in annotation.items(): + if key == "boxes": + boxes = value + scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32) + new_annotation["boxes"] = scaled_boxes + elif key == "area": + area = value + scaled_area = area * (ratio_width * ratio_height) + new_annotation["area"] = scaled_area + elif key == "masks": + masks = value[:, None] + masks = np.array([resize(mask, target_size, resample=resample) for mask in masks]) + masks = masks.astype(np.float32) + masks = masks[:, 0] > threshold + new_annotation["masks"] = masks + elif key == "size": + new_annotation["size"] = target_size + else: + new_annotation[key] = value + + return new_annotation + + +class RTDetrImageProcessor(BaseImageProcessor): + r""" + Constructs a RT-DETR image processor. + + Args: + format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): + Data format of the annotations. One of "coco_detection" or "coco_panoptic". + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be + overridden by the `do_resize` parameter in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 640, "width": 640}`): + Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter + in the `preprocess` method. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the + `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `False`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): + Mean values to use when normalizing the image. Can be a single value or a list of values, one for each + channel. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): + Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one + for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_annotations (`bool`, *optional*, defaults to `True`): + Controls whether to convert the annotations to the format expected by the DETR model. Converts the + bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. + Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `False`): + Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` + method. If `True`, padding will be applied to the bottom and right of the image with zeros. + If `pad_size` is provided, the image will be padded to the specified dimensions. + Otherwise, the image will be padded to the maximum height and width of the batch. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + + model_input_names = ["pixel_values", "pixel_mask"] + + def __init__( + self, + format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = False, + image_mean: Union[float, List[float]] = None, + image_std: Union[float, List[float]] = None, + do_convert_annotations: bool = True, + do_pad: bool = False, + pad_size: Optional[Dict[str, int]] = None, + **kwargs, + ) -> None: + size = size if size is not None else {"height": 640, "width": 640} + size = get_size_dict(size, default_to_square=False) + + if do_convert_annotations is None: + do_convert_annotations = do_normalize + + super().__init__(**kwargs) + self.format = format + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.do_convert_annotations = do_convert_annotations + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_pad = do_pad + self.pad_size = pad_size + self._valid_processor_keys = [ + "images", + "annotations", + "return_segmentation_masks", + "masks_path", + "do_resize", + "size", + "resample", + "do_rescale", + "rescale_factor", + "do_normalize", + "do_convert_annotations", + "image_mean", + "image_std", + "do_pad", + "pad_size", + "format", + "return_tensors", + "data_format", + "input_data_format", + ] + + def prepare_annotation( + self, + image: np.ndarray, + target: Dict, + format: Optional[AnnotationFormat] = None, + return_segmentation_masks: bool = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> Dict: + """ + Prepare an annotation for feeding into RTDETR model. + """ + format = format if format is not None else self.format + + if format == AnnotationFormat.COCO_DETECTION: + return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_detection_annotation( + image, target, return_segmentation_masks, input_data_format=input_data_format + ) + else: + raise ValueError(f"Format {format} is not supported.") + return target + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an + int, smaller edge of the image will be matched to this number. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` parameter is deprecated and will be removed in v4.26. " + "Please specify in `size['longest_edge'] instead`.", + ) + max_size = kwargs.pop("max_size") + else: + max_size = None + size = get_size_dict(size, max_size=max_size, default_to_square=False) + if "shortest_edge" in size and "longest_edge" in size: + new_size = get_resize_output_image_size( + image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format + ) + elif "max_height" in size and "max_width" in size: + new_size = get_image_size_for_max_height_width( + image, size["max_height"], size["max_width"], input_data_format=input_data_format + ) + elif "height" in size and "width" in size: + new_size = (size["height"], size["width"]) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" + f" {size.keys()}." + ) + image = resize( + image, + size=new_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + return image + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation + def resize_annotation( + self, + annotation, + orig_size, + size, + resample: PILImageResampling = PILImageResampling.NEAREST, + ) -> Dict: + """ + Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched + to this number. + """ + return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale + def rescale( + self, + image: np.ndarray, + rescale_factor: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Rescale the image by the given factor. image = image * rescale_factor. + + Args: + image (`np.ndarray`): + Image to rescale. + rescale_factor (`float`): + The value to use for rescaling. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. If unset, is inferred from the input image. Can be + one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + """ + return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation + def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict: + """ + Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to + `[center_x, center_y, width, height]` format and from absolute to relative pixel values. + """ + return normalize_annotation(annotation, image_size=image_size) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._update_annotation_for_padded_image + def _update_annotation_for_padded_image( + self, + annotation: Dict, + input_image_size: Tuple[int, int], + output_image_size: Tuple[int, int], + padding, + update_bboxes, + ) -> Dict: + """ + Update the annotation for a padded image. + """ + new_annotation = {} + new_annotation["size"] = output_image_size + + for key, value in annotation.items(): + if key == "masks": + masks = value + masks = pad( + masks, + padding, + mode=PaddingMode.CONSTANT, + constant_values=0, + input_data_format=ChannelDimension.FIRST, + ) + masks = safe_squeeze(masks, 1) + new_annotation["masks"] = masks + elif key == "boxes" and update_bboxes: + boxes = value + boxes *= np.asarray( + [ + input_image_size[1] / output_image_size[1], + input_image_size[0] / output_image_size[0], + input_image_size[1] / output_image_size[1], + input_image_size[0] / output_image_size[0], + ] + ) + new_annotation["boxes"] = boxes + elif key == "size": + new_annotation["size"] = output_image_size + else: + new_annotation[key] = value + return new_annotation + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image + def _pad_image( + self, + image: np.ndarray, + output_size: Tuple[int, int], + annotation: Optional[Dict[str, Any]] = None, + constant_values: Union[float, Iterable[float]] = 0, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + update_bboxes: bool = True, + ) -> np.ndarray: + """ + Pad an image with zeros to the given size. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = output_size + + pad_bottom = output_height - input_height + pad_right = output_width - input_width + padding = ((0, pad_bottom), (0, pad_right)) + padded_image = pad( + image, + padding, + mode=PaddingMode.CONSTANT, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + ) + if annotation is not None: + annotation = self._update_annotation_for_padded_image( + annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes + ) + return padded_image, annotation + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad + def pad( + self, + images: List[np.ndarray], + annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None, + constant_values: Union[float, Iterable[float]] = 0, + return_pixel_mask: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + update_bboxes: bool = True, + pad_size: Optional[Dict[str, int]] = None, + ) -> BatchFeature: + """ + Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width + in the batch and optionally returns their corresponding pixel mask. + + Args: + images (List[`np.ndarray`]): + Images to pad. + annotations (`AnnotationType` or `List[AnnotationType]`, *optional*): + Annotations to transform according to the padding that is applied to the images. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + return_pixel_mask (`bool`, *optional*, defaults to `True`): + Whether to return a pixel mask. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + update_bboxes (`bool`, *optional*, defaults to `True`): + Whether to update the bounding boxes in the annotations to match the padded images. If the + bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)` + format, the bounding boxes will not be updated. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + pad_size = pad_size if pad_size is not None else self.pad_size + if pad_size is not None: + padded_size = (pad_size["height"], pad_size["width"]) + else: + padded_size = get_max_height_width(images, input_data_format=input_data_format) + + annotation_list = annotations if annotations is not None else [None] * len(images) + padded_images = [] + padded_annotations = [] + for image, annotation in zip(images, annotation_list): + padded_image, padded_annotation = self._pad_image( + image, + padded_size, + annotation, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + update_bboxes=update_bboxes, + ) + padded_images.append(padded_image) + padded_annotations.append(padded_annotation) + + data = {"pixel_values": padded_images} + + if return_pixel_mask: + masks = [ + make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format) + for image in images + ] + data["pixel_mask"] = masks + + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations + ] + + return encoded_inputs + + def preprocess( + self, + images: ImageInput, + annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None, + return_segmentation_masks: bool = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample=None, # PILImageResampling + do_rescale: Optional[bool] = None, + rescale_factor: Optional[Union[int, float]] = None, + do_normalize: Optional[bool] = None, + do_convert_annotations: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + format: Optional[Union[str, AnnotationFormat]] = None, + return_tensors: Optional[Union[TensorType, str]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_size: Optional[Dict[str, int]] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or a batch of images so that it can be used by the model. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging + from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. + annotations (`AnnotationType` or `List[AnnotationType]`, *optional*): + List of annotations associated with the image or batch of images. If annotation is for object + detection, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a + dictionary. An image can have no annotations, in which case the list should be empty. + If annotation is for segmentation, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary. + An image can have no segments, in which case the list should be empty. + - "file_name" (`str`): The file name of the image. + return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks): + Whether to return segmentation masks. + masks_path (`str` or `pathlib.Path`, *optional*): + Path to the directory containing the segmentation masks. + do_resize (`bool`, *optional*, defaults to self.do_resize): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to self.size): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to self.resample): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to self.do_rescale): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to self.rescale_factor): + Rescale factor to use when rescaling the image. + do_normalize (`bool`, *optional*, defaults to self.do_normalize): + Whether to normalize the image. + do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations): + Whether to convert the annotations to the format expected by the model. Converts the bounding + boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)` + and in relative coordinates. + image_mean (`float` or `List[float]`, *optional*, defaults to self.image_mean): + Mean to use when normalizing the image. + image_std (`float` or `List[float]`, *optional*, defaults to self.image_std): + Standard deviation to use when normalizing the image. + do_pad (`bool`, *optional*, defaults to self.do_pad): + Whether to pad the image. If `True`, padding will be applied to the bottom and right of + the image with zeros. If `pad_size` is provided, the image will be padded to the specified + dimensions. Otherwise, the image will be padded to the maximum height and width of the batch. + format (`str` or `AnnotationFormat`, *optional*, defaults to self.format): + Format of the annotations. + return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors): + Type of tensors to return. If `None`, will return the list of images. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + do_resize = self.do_resize if do_resize is None else do_resize + size = self.size if size is None else size + size = get_size_dict(size=size, default_to_square=True) + resample = self.resample if resample is None else resample + do_rescale = self.do_rescale if do_rescale is None else do_rescale + rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor + do_normalize = self.do_normalize if do_normalize is None else do_normalize + image_mean = self.image_mean if image_mean is None else image_mean + image_std = self.image_std if image_std is None else image_std + do_convert_annotations = ( + self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations + ) + do_pad = self.do_pad if do_pad is None else do_pad + pad_size = self.pad_size if pad_size is None else pad_size + format = self.format if format is None else format + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated. + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if annotations is not None and isinstance(annotations, dict): + annotations = [annotations] + + if annotations is not None and len(images) != len(annotations): + raise ValueError( + f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." + ) + + format = AnnotationFormat(format) + if annotations is not None: + validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) + + images = make_list_of_images(images) + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + # All transformations expect numpy arrays + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) + if annotations is not None: + prepared_images = [] + prepared_annotations = [] + for image, target in zip(images, annotations): + target = self.prepare_annotation( + image, + target, + format, + return_segmentation_masks=return_segmentation_masks, + masks_path=masks_path, + input_data_format=input_data_format, + ) + prepared_images.append(image) + prepared_annotations.append(target) + images = prepared_images + annotations = prepared_annotations + del prepared_images, prepared_annotations + + # transformations + if do_resize: + if annotations is not None: + resized_images, resized_annotations = [], [] + for image, target in zip(images, annotations): + orig_size = get_image_size(image, input_data_format) + resized_image = self.resize( + image, size=size, resample=resample, input_data_format=input_data_format + ) + resized_annotation = self.resize_annotation( + target, orig_size, get_image_size(resized_image, input_data_format) + ) + resized_images.append(resized_image) + resized_annotations.append(resized_annotation) + images = resized_images + annotations = resized_annotations + del resized_images, resized_annotations + else: + images = [ + self.resize(image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images] + + if do_normalize: + images = [ + self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images + ] + + if do_convert_annotations and annotations is not None: + annotations = [ + self.normalize_annotation(annotation, get_image_size(image, input_data_format)) + for annotation, image in zip(annotations, images) + ] + + if do_pad: + # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} + encoded_inputs = self.pad( + images, + annotations=annotations, + return_pixel_mask=True, + data_format=data_format, + input_data_format=input_data_format, + update_bboxes=do_convert_annotations, + return_tensors=return_tensors, + pad_size=pad_size, + ) + else: + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in images + ] + encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations + ] + + return encoded_inputs + + def post_process_object_detection( + self, + outputs, + threshold: float = 0.5, + target_sizes: Union[TensorType, List[Tuple]] = None, + use_focal_loss: bool = True, + ): + """ + Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. Only supports PyTorch. + + Args: + outputs ([`DetrObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.5): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + use_focal_loss (`bool` defaults to `True`): + Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied + to compute the scores of each detection, otherwise, a softmax function is used. + + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + requires_backends(self, ["torch"]) + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + # convert from relative cxcywh to absolute xyxy + boxes = center_to_corners_format(out_bbox) + if target_sizes is not None: + if len(out_logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if isinstance(target_sizes, List): + img_h = torch.Tensor([i[0] for i in target_sizes]) + img_w = torch.Tensor([i[1] for i in target_sizes]) + else: + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + num_top_queries = out_logits.shape[1] + num_classes = out_logits.shape[2] + + if use_focal_loss: + scores = torch.nn.functional.sigmoid(out_logits) + scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1) + labels = index % num_classes + index = index // num_classes + boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) + else: + scores = torch.nn.functional.softmax(out_logits)[:, :, :-1] + scores, labels = scores.max(dim=-1) + if scores.shape[1] > num_top_queries: + scores, index = torch.topk(scores, num_top_queries, dim=-1) + labels = torch.gather(labels, dim=1, index=index) + boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) + + results = [] + for s, l, b in zip(scores, labels, boxes): + score = s[s > threshold] + label = l[s > threshold] + box = b[s > threshold] + results.append({"scores": score, "labels": label, "boxes": box}) + + return results diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py new file mode 100644 index 00000000000000..807589cc62deab --- /dev/null +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -0,0 +1,2675 @@ +# coding=utf-8 +# Copyright 2024 Baidu Inc and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch RT-DETR model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from functools import lru_cache, partial +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from ...activations import ACT2CLS, ACT2FN +from ...image_transforms import center_to_corners_format, corners_to_center_format +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_ninja_available, + is_scipy_available, + is_torch_cuda_available, + logging, + replace_return_docstrings, + requires_backends, +) +from ...utils.backbone_utils import load_backbone +from .configuration_rt_detr import RTDetrConfig + + +if is_scipy_available(): + from scipy.optimize import linear_sum_assignment + +logger = logging.get_logger(__name__) + +MultiScaleDeformableAttention = None + + +# Copied from transformers.models.deformable_detr.modeling_deformable_detr.load_cuda_kernels +def load_cuda_kernels(): + from torch.utils.cpp_extension import load + + global MultiScaleDeformableAttention + + root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr" + src_files = [ + root / filename + for filename in [ + "vision.cpp", + os.path.join("cpu", "ms_deform_attn_cpu.cpp"), + os.path.join("cuda", "ms_deform_attn_cuda.cu"), + ] + ] + + MultiScaleDeformableAttention = load( + "MultiScaleDeformableAttention", + src_files, + with_cuda=True, + extra_include_paths=[str(root)], + extra_cflags=["-DWITH_CUDA=1"], + extra_cuda_cflags=[ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ], + ) + + +# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction +class MultiScaleDeformableAttentionFunction(Function): + @staticmethod + def forward( + context, + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + im2col_step, + ): + context.im2col_step = im2col_step + output = MultiScaleDeformableAttention.ms_deform_attn_forward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + context.im2col_step, + ) + context.save_for_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights + ) + return output + + @staticmethod + @once_differentiable + def backward(context, grad_output): + ( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) = context.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output, + context.im2col_step, + ) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "RTDetrConfig" +# TODO: Replace all occurrences of the checkpoint with the final one +_CHECKPOINT_FOR_DOC = "PekingU/rtdetr_r50vd" + + +@dataclass +class RTDetrDecoderOutput(ModelOutput): + """ + Base class for outputs of the RTDetrDecoder. This class adds two attributes to + BaseModelOutputWithCrossAttentions, namely: + - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer) + - a stacked tensor of intermediate reference points. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + intermediate_hidden_states: torch.FloatTensor = None + intermediate_logits: torch.FloatTensor = None + intermediate_reference_points: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class RTDetrModelOutput(ModelOutput): + """ + Base class for outputs of the RT-DETR encoder-decoder model. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, num_queries, hidden_size)`. Hidden-states of the decoder at the output of each layer + plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_queries, + num_queries)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each + layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`): + Logits of predicted bounding boxes coordinates in the encoder stage. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values + """ + + last_hidden_state: torch.FloatTensor = None + intermediate_hidden_states: torch.FloatTensor = None + intermediate_logits: torch.FloatTensor = None + intermediate_reference_points: torch.FloatTensor = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + init_reference_points: torch.FloatTensor = None + enc_topk_logits: Optional[torch.FloatTensor] = None + enc_topk_bboxes: Optional[torch.FloatTensor] = None + enc_outputs_class: Optional[torch.FloatTensor] = None + enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + denoising_meta_values: Optional[Dict] = None + + +@dataclass +class RTDetrObjectDetectionOutput(ModelOutput): + """ + Output type of [`RTDetrForObjectDetection`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~RTDetrImageProcessor.post_process_object_detection`] to retrieve the + unnormalized (absolute) bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, num_queries, hidden_size)`. Hidden-states of the decoder at the output of each layer + plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_queries, + num_queries)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each + layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[Dict] = None + logits: torch.FloatTensor = None + pred_boxes: torch.FloatTensor = None + auxiliary_outputs: Optional[List[Dict]] = None + last_hidden_state: torch.FloatTensor = None + intermediate_hidden_states: torch.FloatTensor = None + intermediate_logits: torch.FloatTensor = None + intermediate_reference_points: torch.FloatTensor = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + init_reference_points: Optional[Tuple[torch.FloatTensor]] = None + enc_topk_logits: Optional[torch.FloatTensor] = None + enc_topk_bboxes: Optional[torch.FloatTensor] = None + enc_outputs_class: Optional[torch.FloatTensor] = None + enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + denoising_meta_values: Optional[Dict] = None + + +def _get_clones(partial_module, N): + return nn.ModuleList([partial_module() for i in range(N)]) + + +# Copied from transformers.models.conditional_detr.modeling_conditional_detr.inverse_sigmoid +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->RTDetr +class RTDetrFrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->RTDetr +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `RTDetrFrozenBatchNorm2d`. + + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = RTDetrFrozenBatchNorm2d(module.num_features) + + if not module.weight.device == torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +def get_contrastive_denoising_training_group( + targets, + num_classes, + num_queries, + class_embed, + num_denoising_queries=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, +): + """ + Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes. + + Args: + targets (`List[dict]`): + The target objects, each containing 'class_labels' and 'boxes' for objects in an image. + num_classes (`int`): + Total number of classes in the dataset. + num_queries (`int`): + Number of query slots in the transformer. + class_embed (`callable`): + A function or a model layer to embed class labels. + num_denoising_queries (`int`, *optional*, defaults to 100): + Number of denoising queries. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + Ratio of noise applied to labels. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale of noise applied to bounding boxes. + Returns: + `tuple` comprising various elements: + - **input_query_class** (`torch.FloatTensor`) -- + Class queries with applied label noise. + - **input_query_bbox** (`torch.FloatTensor`) -- + Bounding box queries with applied box noise. + - **attn_mask** (`torch.FloatTensor`) -- + Attention mask for separating denoising and reconstruction queries. + - **denoising_meta_values** (`dict`) -- + Metadata including denoising positive indices, number of groups, and split sizes. + """ + + if num_denoising_queries <= 0: + return None, None, None, None + + num_ground_truths = [len(t["class_labels"]) for t in targets] + device = targets[0]["class_labels"].device + + max_gt_num = max(num_ground_truths) + if max_gt_num == 0: + return None, None, None, None + + num_groups_denoising_queries = num_denoising_queries // max_gt_num + num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries + # pad gt to max_num of a batch + batch_size = len(num_ground_truths) + + input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device) + input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device) + pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device) + + for i in range(batch_size): + num_gt = num_ground_truths[i] + if num_gt > 0: + input_query_class[i, :num_gt] = targets[i]["class_labels"] + input_query_bbox[i, :num_gt] = targets[i]["boxes"] + pad_gt_mask[i, :num_gt] = 1 + # each group has positive and negative queries. + input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries]) + input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1]) + pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries]) + # positive and negative mask + negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device) + negative_gt_mask[:, max_gt_num:] = 1 + negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1]) + positive_gt_mask = 1 - negative_gt_mask + # contrastive denoising training positive index + positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask + denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1] + denoise_positive_idx = torch.split( + denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths] + ) + # total denoising queries + num_denoising_queries = int(max_gt_num * 2 * num_groups_denoising_queries) + + if label_noise_ratio > 0: + mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5) + # randomly put a new one here + new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) + input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class) + + if box_noise_scale > 0: + known_bbox = center_to_corners_format(input_query_bbox) + diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale + rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 + rand_part = torch.rand_like(input_query_bbox) + rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) + rand_part *= rand_sign + known_bbox += rand_part * diff + known_bbox.clip_(min=0.0, max=1.0) + input_query_bbox = corners_to_center_format(known_bbox) + input_query_bbox = inverse_sigmoid(input_query_bbox) + + input_query_class = class_embed(input_query_class) + + target_size = num_denoising_queries + num_queries + attn_mask = torch.full([target_size, target_size], False, dtype=torch.bool, device=device) + # match query cannot see the reconstruction + attn_mask[num_denoising_queries:, :num_denoising_queries] = True + + # reconstructions cannot see each other + for i in range(num_groups_denoising_queries): + idx_block_start = max_gt_num * 2 * i + idx_block_end = max_gt_num * 2 * (i + 1) + attn_mask[idx_block_start:idx_block_end, :idx_block_start] = True + attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = True + + denoising_meta_values = { + "dn_positive_idx": denoise_positive_idx, + "dn_num_group": num_groups_denoising_queries, + "dn_num_split": [num_denoising_queries, num_queries], + } + + return input_query_class, input_query_bbox, attn_mask, denoising_meta_values + + +class RTDetrConvEncoder(nn.Module): + """ + Convolutional backbone using the modeling_rt_detr_resnet.py. + + nn.BatchNorm2d layers are replaced by RTDetrFrozenBatchNorm2d as defined above. + https://github.com/lyuwenyu/RT-DETR/blob/main/rtdetr_pytorch/src/nn/backbone/presnet.py#L142 + """ + + def __init__(self, config): + super().__init__() + + backbone = load_backbone(config) + + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = self.model.channels + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the model to get list of feature maps + features = self.model(pixel_values).feature_maps + + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + out.append((feature_map, mask)) + return out + + +class RTDetrConvNormLayer(nn.Module): + def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=False, + ) + self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.norm(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RTDetrEncoderLayer(nn.Module): + def __init__(self, config: RTDetrConfig): + super().__init__() + self.normalize_before = config.normalize_before + + # self-attention + self.self_attn = RTDetrMultiheadAttention( + embed_dim=config.encoder_hidden_dim, + num_heads=config.num_attention_heads, + dropout=config.dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.encoder_activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim) + self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor = None, + output_attentions: bool = False, + **kwargs, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + position_embeddings (`torch.FloatTensor`, *optional*): + Object queries (also called content embeddings), to be added to the hidden states. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + residual = hidden_states + + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + if self.training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class RTDetrRepVggBlock(nn.Module): + """ + RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again". + """ + + def __init__(self, config: RTDetrConfig): + super().__init__() + + activation = config.activation_function + hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion) + self.conv1 = RTDetrConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1) + self.conv2 = RTDetrConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, x): + y = self.conv1(x) + self.conv2(x) + return self.activation(y) + + +class RTDetrCSPRepLayer(nn.Module): + """ + Cross Stage Partial (CSP) network layer with RepVGG blocks. + """ + + def __init__(self, config: RTDetrConfig): + super().__init__() + + in_channels = config.encoder_hidden_dim * 2 + out_channels = config.encoder_hidden_dim + num_blocks = 3 + activation = config.activation_function + + hidden_channels = int(out_channels * config.hidden_expansion) + self.conv1 = RTDetrConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.conv2 = RTDetrConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.bottlenecks = nn.Sequential(*[RTDetrRepVggBlock(config) for _ in range(num_blocks)]) + if hidden_channels != out_channels: + self.conv3 = RTDetrConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation) + else: + self.conv3 = nn.Identity() + + def forward(self, hidden_state): + device = hidden_state.device + hidden_state_1 = self.conv1(hidden_state) + hidden_state_1 = self.bottlenecks(hidden_state_1).to(device) + hidden_state_2 = self.conv2(hidden_state).to(device) + return self.conv3(hidden_state_1 + hidden_state_2) + + +# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention +def multi_scale_deformable_attention( + value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor +) -> Tensor: + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = ( + value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width) + ) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) + # batch_size*num_heads, hidden_dim, num_queries, num_points + sampling_value_l_ = nn.functional.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + batch_size * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->RTDetr +class RTDetrMultiscaleDeformableAttention(nn.Module): + """ + Multiscale deformable attention as proposed in Deformable DETR. + """ + + def __init__(self, config: RTDetrConfig, num_heads: int, n_points: int): + super().__init__() + + kernel_loaded = MultiScaleDeformableAttention is not None + if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded: + try: + load_cuda_kernels() + except Exception as e: + logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") + + if config.d_model % num_heads != 0: + raise ValueError( + f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" + ) + dim_per_head = config.d_model // num_heads + # check if dim_per_head is power of 2 + if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): + warnings.warn( + "You'd better set embed_dim (d_model) in RTDetrMultiscaleDeformableAttention to make the" + " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" + " implementation." + ) + + self.im2col_step = 64 + + self.d_model = config.d_model + self.n_levels = config.num_feature_levels + self.n_heads = num_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2) + self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points) + self.value_proj = nn.Linear(config.d_model, config.d_model) + self.output_proj = nn.Linear(config.d_model, config.d_model) + + self.disable_custom_kernels = config.disable_custom_kernels + + self._reset_parameters() + + def _reset_parameters(self): + nn.init.constant_(self.sampling_offsets.weight.data, 0.0) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + nn.init.constant_(self.attention_weights.weight.data, 0.0) + nn.init.constant_(self.attention_weights.bias.data, 0.0) + nn.init.xavier_uniform_(self.value_proj.weight.data) + nn.init.constant_(self.value_proj.bias.data, 0.0) + nn.init.xavier_uniform_(self.output_proj.weight.data) + nn.init.constant_(self.output_proj.bias.data, 0.0) + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + output_attentions: bool = False, + ): + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + raise ValueError( + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" + ) + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + # we invert the attention_mask + value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points + ) + # batch_size, num_queries, n_heads, n_levels, n_points, 2 + num_coordinates = reference_points.shape[-1] + if num_coordinates == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif num_coordinates == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + + if self.disable_custom_kernels: + # PyTorch implementation + output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) + else: + try: + # custom kernel + output = MultiScaleDeformableAttentionFunction.apply( + value, + spatial_shapes, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + except Exception: + # PyTorch implementation + output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) + output = self.output_proj(output) + + return output, attention_weights + + +class RTDetrMultiheadAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + + Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper). + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if self.head_dim * num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, target_len, embed_dim = hidden_states.size() + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + # get queries, keys and values + query_states = self.q_proj(hidden_states) * self.scaling + key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size) + value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size) + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + source_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" + f" {attn_weights.size()}" + ) + + # expand attention_mask + if attention_mask is not None: + # [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size()) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, target_len, source_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" + f" {attention_mask.size()}" + ) + attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask + attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class RTDetrDecoderLayer(nn.Module): + def __init__(self, config: RTDetrConfig): + super().__init__() + # self-attention + self.self_attn = RTDetrMultiheadAttention( + embed_dim=config.d_model, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.decoder_activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + # cross-attention + self.encoder_attn = RTDetrMultiscaleDeformableAttention( + config, + num_heads=config.decoder_attention_heads, + n_points=config.decoder_n_points, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + # feedforward neural networks + self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model) + self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(seq_len, batch, embed_dim)`. + position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings that are added to the queries and keys in the self-attention layer. + reference_points (`torch.FloatTensor`, *optional*): + Reference points. + spatial_shapes (`torch.LongTensor`, *optional*): + Spatial shapes. + level_start_index (`torch.LongTensor`, *optional*): + Level start index. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + second_residual = hidden_states + + # Cross-Attention + cross_attn_weights = None + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = second_residual + hidden_states + + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +class RTDetrPreTrainedModel(PreTrainedModel): + config_class = RTDetrConfig + base_model_prefix = "rt_detr" + main_input_name = "pixel_values" + _no_split_modules = [r"RTDetrConvEncoder", r"RTDetrEncoderLayer", r"RTDetrDecoderLayer"] + + def _init_weights(self, module): + """Initalize the weights""" + + """initialize conv/fc bias value according to a given probability value.""" + if isinstance(module, nn.Linear) and hasattr(module, "class_embed"): + prior_prob = self.config.initializer_range + bias = float(-math.log((1 - prior_prob) / prior_prob)) + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, bias) + elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: + nn.init.xavier_uniform_(module.weight_embedding.weight) + if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: + nn.init.xavier_uniform_(module.denoising_class_embed.weight) + + +RTDETR_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`RTDetrConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +RTDETR_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`RTDetrImageProcessor.__call__`] for details. + pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + labels (`List[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class RTDetrEncoder(nn.Module): + def __init__(self, config: RTDetrConfig): + super().__init__() + + self.layers = nn.ModuleList([RTDetrEncoderLayer(config) for _ in range(config.encoder_layers)]) + + def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor: + hidden_states = src + for layer in self.layers: + hidden_states = layer( + hidden_states, + attention_mask=src_mask, + position_embeddings=pos_embed, + output_attentions=output_attentions, + ) + return hidden_states + + +class RTDetrHybridEncoder(nn.Module): + """ + Decoder consisting of a projection layer, a set of `RTDetrEncoder`, a top-down Feature Pyramid Network + (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://arxiv.org/abs/2304.08069 + + Args: + config: RTDetrConfig + """ + + def __init__(self, config: RTDetrConfig): + super().__init__() + self.config = config + self.in_channels = config.encoder_in_channels + self.feat_strides = config.feat_strides + self.encoder_hidden_dim = config.encoder_hidden_dim + self.encode_proj_layers = config.encode_proj_layers + self.positional_encoding_temperature = config.positional_encoding_temperature + self.eval_size = config.eval_size + self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels] + self.out_strides = self.feat_strides + activation_function = config.activation_function + + # encoder transformer + self.encoder = nn.ModuleList([RTDetrEncoder(config) for _ in range(len(self.encode_proj_layers))]) + # top-down fpn + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1, 0, -1): + self.lateral_convs.append( + RTDetrConvNormLayer( + config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1, activation=activation_function + ) + ) + self.fpn_blocks.append(RTDetrCSPRepLayer(config)) + + # bottom-up pan + self.downsample_convs = nn.ModuleList() + self.pan_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1): + self.downsample_convs.append( + RTDetrConvNormLayer( + config, self.encoder_hidden_dim, self.encoder_hidden_dim, 3, 2, activation=activation_function + ) + ) + self.pan_blocks.append(RTDetrCSPRepLayer(config)) + + @staticmethod + def build_2d_sincos_position_embedding(width, height, embed_dim=256, temperature=10000.0): + grid_w = torch.arange(int(width), dtype=torch.float32) + grid_h = torch.arange(int(height), dtype=torch.float32) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") + if embed_dim % 4 != 0: + raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1.0 / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :] + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + position_embeddings=None, + spatial_shapes=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + [What are attention masks?](../glossary#attention-mask) + position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of each feature map. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): + Starting index of each feature map. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): + Ratio of valid area in each feature level. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeds + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + # encoder + if self.config.encoder_layers > 0: + for i, enc_ind in enumerate(self.encode_proj_layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states[enc_ind],) + height, width = hidden_states[enc_ind].shape[2:] + # flatten [batch, channel, height, width] to [batch, height*width, channel] + src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1) + if self.training or self.eval_size is None: + pos_embed = self.build_2d_sincos_position_embedding( + width, height, self.encoder_hidden_dim, self.positional_encoding_temperature + ).to(src_flatten.device) + else: + pos_embed = None + + layer_outputs = self.encoder[i]( + src_flatten, + pos_embed=pos_embed, + output_attentions=output_attentions, + ) + hidden_states[enc_ind] = ( + layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous() + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states[enc_ind],) + + # broadcasting and fusion + fpn_feature_maps = [hidden_states[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_high = fpn_feature_maps[0] + feat_low = hidden_states[idx - 1] + feat_high = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_high) + fpn_feature_maps[0] = feat_high + upsample_feat = F.interpolate(feat_high, scale_factor=2.0, mode="nearest") + fps_map = self.fpn_blocks[len(self.in_channels) - 1 - idx](torch.concat([upsample_feat, feat_low], dim=1)) + fpn_feature_maps.insert(0, fps_map) + + fpn_states = [fpn_feature_maps[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = fpn_states[-1] + feat_high = fpn_feature_maps[idx + 1] + downsample_feat = self.downsample_convs[idx](feat_low) + hidden_states = self.pan_blocks[idx]( + torch.concat([downsample_feat, feat_high.to(downsample_feat.device)], dim=1) + ) + fpn_states.append(hidden_states) + + if not return_dict: + return tuple(v for v in [fpn_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput(last_hidden_state=fpn_states, hidden_states=encoder_states, attentions=all_attentions) + + +class RTDetrDecoder(RTDetrPreTrainedModel): + def __init__(self, config: RTDetrConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layers = nn.ModuleList([RTDetrDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.query_pos_head = RTDetrMLPPredictionHead(config, 4, 2 * config.d_model, config.d_model, num_layers=2) + + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings=None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + The query embeddings that are passed into the decoder. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected + in `[0, 1]`: + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Position embeddings that are added to the queries and keys in each self-attention layer. + reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*): + Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area. + spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of the feature maps. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*): + Indexes for the start of each feature level. In range `[0, sequence_length]`. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*): + Ratio of valid area in each feature level. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + intermediate = () + intermediate_reference_points = () + intermediate_logits = () + + reference_points = F.sigmoid(reference_points) + + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L252 + for idx, decoder_layer in enumerate(self.layers): + reference_points_input = reference_points.unsqueeze(2) + position_embeddings = self.query_pos_head(reference_points) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points_input, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[idx](hidden_states) + new_reference_points = F.sigmoid(tmp + inverse_sigmoid(reference_points)) + reference_points = new_reference_points.detach() + + intermediate += (hidden_states,) + intermediate_reference_points += ( + (new_reference_points,) if self.bbox_embed is not None else (reference_points,) + ) + + if self.class_embed is not None: + logits = self.class_embed[idx](hidden_states) + intermediate_logits += (logits,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # Keep batch_size as first dimension + intermediate = torch.stack(intermediate, dim=1) + intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1) + if self.class_embed is not None: + intermediate_logits = torch.stack(intermediate_logits, dim=1) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + intermediate, + intermediate_logits, + intermediate_reference_points, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] + if v is not None + ) + return RTDetrDecoderOutput( + last_hidden_state=hidden_states, + intermediate_hidden_states=intermediate, + intermediate_logits=intermediate_logits, + intermediate_reference_points=intermediate_reference_points, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top. + """, + RTDETR_START_DOCSTRING, +) +class RTDetrModel(RTDetrPreTrainedModel): + def __init__(self, config: RTDetrConfig): + super().__init__(config) + + # Create backbone + self.backbone = RTDetrConvEncoder(config) + intermediate_channel_sizes = self.backbone.intermediate_channel_sizes + + # Create encoder input projection layers + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/hybrid_encoder.py#L212 + num_backbone_outs = len(intermediate_channel_sizes) + encoder_input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = intermediate_channel_sizes[_] + encoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list) + + # Create encoder + self.encoder = RTDetrHybridEncoder(config) + + # denoising part + if config.num_denoising > 0: + self.denoising_class_embed = nn.Embedding( + config.num_labels + 1, config.d_model, padding_idx=config.num_labels + ) + + # decoder embedding + if config.learn_initial_query: + self.weight_embedding = nn.Embedding(config.num_queries, config.d_model) + + # encoder head + self.enc_output = nn.Sequential( + nn.Linear(config.d_model, config.d_model), + nn.LayerNorm(config.d_model, eps=config.layer_norm_eps), + ) + self.enc_score_head = nn.Linear(config.d_model, config.num_labels) + self.enc_bbox_head = RTDetrMLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3) + + # init encoder output anchors and valid_mask + if config.anchor_image_size: + self.anchors, self.valid_mask = self.generate_anchors() + + # Create decoder input projection layers + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412 + num_backbone_outs = len(config.decoder_in_channels) + decoder_input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = config.decoder_in_channels[_] + decoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False), + nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + for _ in range(config.num_feature_levels - num_backbone_outs): + decoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + in_channels = config.d_model + self.decoder_input_proj = nn.ModuleList(decoder_input_proj_list) + + # decoder + self.decoder = RTDetrDecoder(config) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def freeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(True) + + @lru_cache(maxsize=32) + def generate_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device="cpu"): + if spatial_shapes is None: + spatial_shapes = [ + [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)] + for s in self.config.feat_strides + ] + anchors = [] + for level, (height, width) in enumerate(spatial_shapes): + grid_y, grid_x = torch.meshgrid( + torch.arange(end=height, dtype=dtype), torch.arange(end=width, dtype=dtype), indexing="ij" + ) + grid_xy = torch.stack([grid_x, grid_y], -1) + valid_wh = torch.tensor([width, height]).to(dtype) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_wh + wh = torch.ones_like(grid_xy) * grid_size * (2.0**level) + anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4)) + # define the valid range for anchor coordinates + eps = 1e-2 + anchors = torch.concat(anchors, 1).to(device) + valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) + anchors = torch.log(anchors / (1 - anchors)) + anchors = torch.where(valid_mask, anchors, torch.inf) + + return anchors, valid_mask + + @add_start_docstrings_to_model_forward(RTDETR_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=RTDetrModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[List[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], RTDetrModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, RTDetrModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("PekingU/rtdetr_r50vd") + >>> model = RTDetrModel.from_pretrained("PekingU/rtdetr_r50vd") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 300, 256] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + + features = self.backbone(pixel_values, pixel_mask) + + proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + + if encoder_outputs is None: + encoder_outputs = self.encoder( + proj_feats, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if output_hidden_states else None, + attentions=encoder_outputs[2] + if len(encoder_outputs) > 2 + else encoder_outputs[1] + if output_attentions + else None, + ) + + # Equivalent to def _get_encoder_input + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412 + sources = [] + for level, source in enumerate(encoder_outputs[0]): + sources.append(self.decoder_input_proj[level](source)) + + # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage + if self.config.num_feature_levels > len(sources): + _len_sources = len(sources) + sources.append(self.decoder_input_proj[_len_sources](encoder_outputs[0])[-1]) + for i in range(_len_sources + 1, self.config.num_feature_levels): + sources.append(self.decoder_input_proj[i](encoder_outputs[0][-1])) + + # Prepare encoder inputs (by flattening) + source_flatten = [] + spatial_shapes = [] + for level, source in enumerate(sources): + batch_size, num_channels, height, width = source.shape + spatial_shape = (height, width) + spatial_shapes.append(spatial_shape) + source = source.flatten(2).transpose(1, 2) + source_flatten.append(source) + source_flatten = torch.cat(source_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + + # prepare denoising training + if self.training and self.config.num_denoising > 0 and labels is not None: + ( + denoising_class, + denoising_bbox_unact, + attention_mask, + denoising_meta_values, + ) = get_contrastive_denoising_training_group( + targets=labels, + num_classes=self.config.num_labels, + num_queries=self.config.num_queries, + class_embed=self.denoising_class_embed, + num_denoising_queries=self.config.num_denoising, + label_noise_ratio=self.config.label_noise_ratio, + box_noise_scale=self.config.box_noise_scale, + ) + else: + denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None + + batch_size = len(source_flatten) + device = source_flatten.device + + # prepare input for decoder + if self.training or self.config.anchor_image_size is None: + anchors, valid_mask = self.generate_anchors(spatial_shapes, device=device) + else: + anchors, valid_mask = self.anchors.to(device), self.valid_mask.to(device) + + # use the valid_mask to selectively retain values in the feature map where the mask is `True` + memory = valid_mask.to(source_flatten.dtype) * source_flatten + + output_memory = self.enc_output(memory) + + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors + + _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1) + + reference_points_unact = enc_outputs_coord_logits.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1]) + ) + + enc_topk_bboxes = F.sigmoid(reference_points_unact) + if denoising_bbox_unact is not None: + reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1) + + enc_topk_logits = enc_outputs_class.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]) + ) + + # extract region features + if self.config.learn_initial_query: + target = self.weight_embedding.tile([batch_size, 1, 1]) + else: + target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])) + target = target.detach() + + if denoising_class is not None: + target = torch.concat([denoising_class, target], 1) + + init_reference_points = reference_points_unact.detach() + + # decoder + decoder_outputs = self.decoder( + inputs_embeds=target, + encoder_hidden_states=source_flatten, + encoder_attention_mask=attention_mask, + reference_points=init_reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + enc_outputs = tuple( + value + for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits] + if value is not None + ) + dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values]) + tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs + + return tuple_outputs + + return RTDetrModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_logits=decoder_outputs.intermediate_logits, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + init_reference_points=init_reference_points, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + denoising_meta_values=denoising_meta_values, + ) + + +# Copied from transformers.models.detr.modeling_detr.dice_loss +def dice_loss(inputs, targets, num_boxes): + """ + Compute the DICE loss, similar to generalized IOU for masks + + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs (0 for the negative class and 1 for the positive + class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_boxes + + +# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss +def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + + Args: + inputs (`torch.FloatTensor` of arbitrary shape): + The predictions for each example. + targets (`torch.FloatTensor` with the same shape as `inputs`) + A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class + and 1 for the positive class). + alpha (`float`, *optional*, defaults to `0.25`): + Optional weighting factor in the range (0,1) to balance positive vs. negative examples. + gamma (`int`, *optional*, defaults to `2`): + Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. + + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + # add modulating factor + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_boxes + + +class RTDetrLoss(nn.Module): + """ + This class computes the losses for RTDetr. The process happens in two steps: 1) we compute hungarian assignment + between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / + prediction (supervise class and box). + + Args: + matcher (`DetrHungarianMatcher`): + Module able to compute a matching between targets and proposals. + weight_dict (`Dict`): + Dictionary relating each loss with its weights. These losses are configured in RTDetrConf as + `weight_loss_vfl`, `weight_loss_bbox`, `weight_loss_giou` + losses (`List[str]`): + List of all the losses to be applied. See `get_loss` for a list of all available losses. + alpha (`float`): + Parameter alpha used to compute the focal loss. + gamma (`float`): + Parameter gamma used to compute the focal loss. + eos_coef (`float`): + Relative classification weight applied to the no-object category. + num_classes (`int`): + Number of object categories, omitting the special no-object category. + """ + + def __init__(self, config): + super().__init__() + + self.matcher = RTDetrHungarianMatcher(config) + self.num_classes = config.num_labels + self.weight_dict = { + "loss_vfl": config.weight_loss_vfl, + "loss_bbox": config.weight_loss_bbox, + "loss_giou": config.weight_loss_giou, + } + self.losses = ["vfl", "boxes"] + self.eos_coef = config.eos_coefficient + empty_weight = torch.ones(config.num_labels + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer("empty_weight", empty_weight) + self.alpha = config.focal_loss_alpha + self.gamma = config.focal_loss_gamma + + def loss_labels_vfl(self, outputs, targets, indices, num_boxes, log=True): + if "pred_boxes" not in outputs: + raise KeyError("No predicted boxes found in outputs") + if "logits" not in outputs: + raise KeyError("No predicted logits found in outputs") + idx = self._get_source_permutation_idx(indices) + + src_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat([_target["boxes"][i] for _target, (_, i) in zip(targets, indices)], dim=0) + ious, _ = box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes)) + ious = torch.diag(ious).detach() + + src_logits = outputs["logits"] + target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_original + target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1] + + target_score_original = torch.zeros_like(target_classes, dtype=src_logits.dtype) + target_score_original[idx] = ious.to(target_score_original.dtype) + target_score = target_score_original.unsqueeze(-1) * target + + pred_score = F.sigmoid(src_logits).detach() + weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score + + loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction="none") + loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes + return {"loss_vfl": loss} + + def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + """Classification loss (NLL) + targets dicts must contain the key "class_labels" containing a tensor of dim [nb_target_boxes] + """ + if "logits" not in outputs: + raise KeyError("No logits were found in the outputs") + + src_logits = outputs["logits"] + + idx = self._get_source_permutation_idx(indices) + target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_original + + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.class_weight) + losses = {"loss_ce": loss_ce} + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ + Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. This is not + really a loss, it is intended for logging purposes only. It doesn't propagate gradients. + """ + logits = outputs["logits"] + device = logits.device + target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1) + card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float()) + losses = {"cardinality_error": card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """ + Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. Targets dicts must + contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes are expected in + format (center_x, center_y, w, h), normalized by the image size. + """ + if "pred_boxes" not in outputs: + raise KeyError("No predicted boxes found in outputs") + idx = self._get_source_permutation_idx(indices) + src_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) + + losses = {} + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none") + losses["loss_bbox"] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag( + generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes)) + ) + losses["loss_giou"] = loss_giou.sum() / num_boxes + return losses + + def loss_masks(self, outputs, targets, indices, num_boxes): + """ + Compute the losses related to the masks: the focal loss and the dice loss. Targets dicts must contain the key + "masks" containing a tensor of dim [nb_target_boxes, h, w]. + """ + if "pred_masks" not in outputs: + raise KeyError("No predicted masks found in outputs") + + source_idx = self._get_source_permutation_idx(indices) + target_idx = self._get_target_permutation_idx(indices) + source_masks = outputs["pred_masks"] + source_masks = source_masks[source_idx] + masks = [t["masks"] for t in targets] + target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() + target_masks = target_masks.to(source_masks) + target_masks = target_masks[target_idx] + + # upsample predictions to the target size + source_masks = nn.functional.interpolate( + source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False + ) + source_masks = source_masks[:, 0].flatten(1) + + target_masks = target_masks.flatten(1) + target_masks = target_masks.view(source_masks.shape) + losses = { + "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes), + "loss_dice": dice_loss(source_masks, target_masks, num_boxes), + } + return losses + + def loss_labels_bce(self, outputs, targets, indices, num_boxes, log=True): + src_logits = outputs["logits"] + idx = self._get_source_permutation_idx(indices) + target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_original + + target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1] + loss = F.binary_cross_entropy_with_logits(src_logits, target * 1.0, reduction="none") + loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes + return {"loss_bce": loss} + + def _get_source_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)]) + source_idx = torch.cat([source for (source, _) in indices]) + return batch_idx, source_idx + + def _get_target_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)]) + target_idx = torch.cat([target for (_, target) in indices]) + return batch_idx, target_idx + + def loss_labels_focal(self, outputs, targets, indices, num_boxes, log=True): + if "logits" not in outputs: + raise KeyError("No logits found in outputs") + + src_logits = outputs["logits"] + + idx = self._get_source_permutation_idx(indices) + target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_original + + target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1] + loss = sigmoid_focal_loss(src_logits, target, self.alpha, self.gamma, reduction="none") + loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes + return {"loss_focal": loss} + + def get_loss(self, loss, outputs, targets, indices, num_boxes): + loss_map = { + "labels": self.loss_labels, + "cardinality": self.loss_cardinality, + "boxes": self.loss_boxes, + "masks": self.loss_masks, + "bce": self.loss_labels_bce, + "focal": self.loss_labels_focal, + "vfl": self.loss_labels_vfl, + } + if loss not in loss_map: + raise ValueError(f"Loss {loss} not supported") + return loss_map[loss](outputs, targets, indices, num_boxes) + + @staticmethod + def get_cdn_matched_indices(dn_meta, targets): + dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"] + num_gts = [len(t["class_labels"]) for t in targets] + device = targets[0]["class_labels"].device + + dn_match_indices = [] + for i, num_gt in enumerate(num_gts): + if num_gt > 0: + gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device) + gt_idx = gt_idx.tile(dn_num_group) + assert len(dn_positive_idx[i]) == len(gt_idx) + dn_match_indices.append((dn_positive_idx[i], gt_idx)) + else: + dn_match_indices.append( + ( + torch.zeros(0, dtype=torch.int64, device=device), + torch.zeros(0, dtype=torch.int64, device=device), + ) + ) + + return dn_match_indices + + def forward(self, outputs, targets): + """ + This performs the loss computation. + + Args: + outputs (`dict`, *optional*): + Dictionary of tensors, see the output specification of the model for the format. + targets (`List[dict]`, *optional*): + List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the + losses applied, see each loss' doc. + """ + outputs_without_aux = {k: v for k, v in outputs.items() if "auxiliary_outputs" not in k} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes across all nodes, for normalization purposes + num_boxes = sum(len(t["class_labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + num_boxes = torch.clamp(num_boxes, min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes) + l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} + losses.update(l_dict) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if "auxiliary_outputs" in outputs: + for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]): + indices = self.matcher(auxiliary_outputs, targets) + for loss in self.losses: + if loss == "masks": + # Intermediate masks losses are too costly to compute, we ignore them. + continue + l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes) + l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} + l_dict = {k + f"_aux_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + # In case of cdn auxiliary losses. For rtdetr + if "dn_auxiliary_outputs" in outputs: + if "denoising_meta_values" not in outputs: + raise ValueError( + "The output must have the 'denoising_meta_values` key. Please, ensure that 'outputs' includes a 'denoising_meta_values' entry." + ) + indices = self.get_cdn_matched_indices(outputs["denoising_meta_values"], targets) + num_boxes = num_boxes * outputs["denoising_meta_values"]["dn_num_group"] + + for i, auxiliary_outputs in enumerate(outputs["dn_auxiliary_outputs"]): + # indices = self.matcher(auxiliary_outputs, targets) + for loss in self.losses: + if loss == "masks": + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} + l_dict = {k + f"_dn_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + +# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py +class RTDetrMLPPredictionHead(nn.Module): + """ + Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_paddle/ppdet/modeling/transformers/utils.py#L453 + + """ + + def __init__(self, config, input_dim, d_model, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [d_model] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class RTDetrHungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more + predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are + un-matched (and thus treated as non-objects). + + Args: + config: RTDetrConfig + """ + + def __init__(self, config): + super().__init__() + requires_backends(self, ["scipy"]) + + self.class_cost = config.matcher_class_cost + self.bbox_cost = config.matcher_bbox_cost + self.giou_cost = config.matcher_giou_cost + + self.use_focal_loss = config.use_focal_loss + self.alpha = config.matcher_alpha + self.gamma = config.matcher_gamma + + if self.class_cost == self.bbox_cost == self.giou_cost == 0: + raise ValueError("All costs of the Matcher can't be 0") + + @torch.no_grad() + def forward(self, outputs, targets): + """Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + batch_size, num_queries = outputs["logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + # Also concat the target labels and boxes + target_ids = torch.cat([v["class_labels"] for v in targets]) + target_bbox = torch.cat([v["boxes"] for v in targets]) + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + if self.use_focal_loss: + out_prob = F.sigmoid(outputs["logits"].flatten(0, 1)) + out_prob = out_prob[:, target_ids] + neg_cost_class = (1 - self.alpha) * (out_prob**self.gamma) * (-(1 - out_prob + 1e-8).log()) + pos_cost_class = self.alpha * ((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log()) + class_cost = pos_cost_class - neg_cost_class + else: + out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] + class_cost = -out_prob[:, target_ids] + + # Compute the L1 cost between boxes + bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) + # Compute the giou cost betwen boxes + giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox)) + # Compute the final cost matrix + cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost + cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))] + + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + + +# Copied from transformers.models.detr.modeling_detr._upcast +def _upcast(t: Tensor) -> Tensor: + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.is_floating_point(): + return t if t.dtype in (torch.float32, torch.float64) else t.float() + else: + return t if t.dtype in (torch.int32, torch.int64) else t.int() + + +# Copied from transformers.models.detr.modeling_detr.box_area +def box_area(boxes: Tensor) -> Tensor: + """ + Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. + + Args: + boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): + Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 + < x2` and `0 <= y1 < y2`. + + Returns: + `torch.FloatTensor`: a tensor containing the area for each box. + """ + boxes = _upcast(boxes) + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + +# Copied from transformers.models.detr.modeling_detr.box_iou +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2] + inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +# Copied from transformers.models.detr.modeling_detr.generalized_box_iou +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format. + + Returns: + `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): + raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") + if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): + raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") + iou, union = box_iou(boxes1, boxes2) + + top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2] + area = width_height[:, :, 0] * width_height[:, :, 1] + + return iou - (area - union) / area + + +# Copied from transformers.models.detr.modeling_detr._max_by_axis +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +# Copied from transformers.models.detr.modeling_detr.NestedTensor +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +# Copied from transformers.models.detr.modeling_detr.nested_tensor_from_tensor_list +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + if tensor_list[0].ndim == 3: + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + batch_shape = [len(tensor_list)] + max_size + batch_size, num_channels, height, width = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("Only 3-dimensional tensors are supported") + return NestedTensor(tensor, mask) + + +@add_start_docstrings( + """ + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further + decoded into scores and classes. + """, + RTDETR_START_DOCSTRING, +) +class RTDetrForObjectDetection(RTDetrPreTrainedModel): + # When using clones, all layers > 0 will be clones, but layer 0 *is* required + _tied_weights_keys = ["bbox_embed", "class_embed"] + # We can't initialize the model on meta device as some weights are modified during the initialization + _no_split_modules = None + + def __init__(self, config: RTDetrConfig): + super().__init__(config) + + # RTDETR encoder-decoder model + self.model = RTDetrModel(config) + + # Detection heads on top + self.class_embed = partial(nn.Linear, config.d_model, config.num_labels) + self.bbox_embed = partial(RTDetrMLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3) + + # if two-stage, the last class_embed and bbox_embed is for region proposal generation + num_pred = config.decoder_layers + if config.with_box_refine: + self.class_embed = _get_clones(self.class_embed, num_pred) + self.bbox_embed = _get_clones(self.bbox_embed, num_pred) + else: + self.class_embed = nn.ModuleList([self.class_embed() for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList([self.bbox_embed() for _ in range(num_pred)]) + + # hack implementation for iterative bounding box refinement + self.model.decoder.class_embed = self.class_embed + self.model.decoder.bbox_embed = self.bbox_embed + + # Initialize weights and apply final processing + self.post_init() + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)] + + @add_start_docstrings_to_model_forward(RTDETR_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=RTDetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[List[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], RTDetrObjectDetectionOutput]: + r""" + labels (`List[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Returns: + + Examples: + + ```python + >>> from transformers import RTDetrImageProcessor, RTDetrForObjectDetection + >>> from PIL import Image + >>> import requests + >>> import torch + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd") + >>> model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 300, 80] + + >>> boxes = outputs.pred_boxes + >>> list(boxes.shape) + [1, 300, 4] + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[ + ... 0 + ... ] + + >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + Detected sofa with confidence 0.97 at location [0.14, 0.38, 640.13, 476.21] + Detected cat with confidence 0.96 at location [343.38, 24.28, 640.14, 371.5] + Detected cat with confidence 0.958 at location [13.23, 54.18, 318.98, 472.22] + Detected remote with confidence 0.951 at location [40.11, 73.44, 175.96, 118.48] + Detected remote with confidence 0.924 at location [333.73, 76.58, 369.97, 186.99] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + denoising_meta_values = ( + outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None + ) + + outputs_class = outputs.intermediate_logits if return_dict else outputs[2] + outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3] + + if self.training and denoising_meta_values is not None: + dn_out_coord, outputs_coord = torch.split(outputs_coord, denoising_meta_values["dn_num_split"], dim=2) + dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2) + + logits = outputs_class[:, -1] + pred_boxes = outputs_coord[:, -1] + + loss, loss_dict, auxiliary_outputs = None, None, None + if labels is not None: + # First: create the criterion + criterion = RTDetrLoss(self.config) + criterion.to(self.device) + # Second: compute the losses, based on outputs and labels + outputs_loss = {} + outputs_loss["logits"] = logits + outputs_loss["pred_boxes"] = pred_boxes + if self.config.auxiliary_loss: + enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5] + enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4] + auxiliary_outputs = self._set_aux_loss( + outputs_class[:, :-1].transpose(0, 1), outputs_coord[:, :-1].transpose(0, 1) + ) + outputs_loss["auxiliary_outputs"] = auxiliary_outputs + outputs_loss["auxiliary_outputs"].extend(self._set_aux_loss([enc_topk_logits], [enc_topk_bboxes])) + if self.training and denoising_meta_values is not None: + outputs_loss["dn_auxiliary_outputs"] = self._set_aux_loss( + dn_out_class.transpose(0, 1), dn_out_coord.transpose(0, 1) + ) + outputs_loss["denoising_meta_values"] = denoising_meta_values + + loss_dict = criterion(outputs_loss, labels) + + loss = sum(loss_dict.values()) + + if not return_dict: + if auxiliary_outputs is not None: + output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs + else: + output = (logits, pred_boxes) + outputs + return ((loss, loss_dict) + output) if loss is not None else output + + return RTDetrObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + intermediate_hidden_states=outputs.intermediate_hidden_states, + intermediate_logits=outputs.intermediate_logits, + intermediate_reference_points=outputs.intermediate_reference_points, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + init_reference_points=outputs.init_reference_points, + enc_topk_logits=outputs.enc_topk_logits, + enc_topk_bboxes=outputs.enc_topk_bboxes, + enc_outputs_class=outputs.enc_outputs_class, + enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, + denoising_meta_values=outputs.denoising_meta_values, + ) diff --git a/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py b/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py new file mode 100644 index 00000000000000..75102efab3d73e --- /dev/null +++ b/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py @@ -0,0 +1,426 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PyTorch RTDetr specific ResNet model. The main difference between hugginface ResNet model is that this RTDetrResNet model forces to use shortcut at the first layer in the resnet-18/34 models. +See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L126 for details. +""" + +from typing import Optional + +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BackboneOutput, + BaseModelOutputWithNoAttention, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.backbone_utils import BackboneMixin +from .configuration_rt_detr_resnet import RTDetrResNetConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "RTDetrResNetConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/resnet-50" +_EXPECTED_OUTPUT_SHAPE = [1, 2048, 7, 7] + + +# Copied from transformers.models.resnet.modeling_resnet.ResNetConvLayer -> RTDetrResNetConvLayer +class RTDetrResNetConvLayer(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu" + ): + super().__init__() + self.convolution = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False + ) + self.normalization = nn.BatchNorm2d(out_channels) + self.activation = ACT2FN[activation] if activation is not None else nn.Identity() + + def forward(self, input: Tensor) -> Tensor: + hidden_state = self.convolution(input) + hidden_state = self.normalization(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RTDetrResNetEmbeddings(nn.Module): + """ + ResNet Embeddings (stem) composed of a deep aggressive convolution. + """ + + def __init__(self, config: RTDetrResNetConfig): + super().__init__() + self.embedder = nn.Sequential( + *[ + RTDetrResNetConvLayer( + config.num_channels, + config.embedding_size // 2, + kernel_size=3, + stride=2, + activation=config.hidden_act, + ), + RTDetrResNetConvLayer( + config.embedding_size // 2, + config.embedding_size // 2, + kernel_size=3, + stride=1, + activation=config.hidden_act, + ), + RTDetrResNetConvLayer( + config.embedding_size // 2, + config.embedding_size, + kernel_size=3, + stride=1, + activation=config.hidden_act, + ), + ] + ) + self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.num_channels = config.num_channels + + def forward(self, pixel_values: Tensor) -> Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embedding = self.embedder(pixel_values) + embedding = self.pooler(embedding) + return embedding + + +# Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut -> RTDetrResNetChortCut +class RTDetrResNetShortCut(nn.Module): + """ + ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + def __init__(self, in_channels: int, out_channels: int, stride: int = 2): + super().__init__() + self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + self.normalization = nn.BatchNorm2d(out_channels) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = self.convolution(input) + hidden_state = self.normalization(hidden_state) + return hidden_state + + +class RTDetrResNetBasicLayer(nn.Module): + """ + A classic ResNet's residual layer composed by two `3x3` convolutions. + See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L34. + """ + + def __init__( + self, + config: RTDetrResNetConfig, + in_channels: int, + out_channels: int, + stride: int = 1, + should_apply_shortcut: bool = False, + ): + super().__init__() + if in_channels != out_channels: + self.shortcut = ( + nn.Sequential( + *[nn.AvgPool2d(2, 2, 0, ceil_mode=True), RTDetrResNetShortCut(in_channels, out_channels, stride=1)] + ) + if should_apply_shortcut + else nn.Identity() + ) + else: + self.shortcut = ( + RTDetrResNetShortCut(in_channels, out_channels, stride=stride) + if should_apply_shortcut + else nn.Identity() + ) + self.layer = nn.Sequential( + RTDetrResNetConvLayer(in_channels, out_channels, stride=stride), + RTDetrResNetConvLayer(out_channels, out_channels, activation=None), + ) + self.activation = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RTDetrResNetBottleNeckLayer(nn.Module): + """ + A classic RTDetrResNet's bottleneck layer composed by three `3x3` convolutions. + + The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3` + convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. If + `downsample_in_bottleneck` is true, downsample will be in the first layer instead of the second layer. + """ + + def __init__( + self, + config: RTDetrResNetConfig, + in_channels: int, + out_channels: int, + stride: int = 1, + ): + super().__init__() + reduction = 4 + should_apply_shortcut = in_channels != out_channels or stride != 1 + reduces_channels = out_channels // reduction + if stride == 2: + self.shortcut = nn.Sequential( + *[ + nn.AvgPool2d(2, 2, 0, ceil_mode=True), + RTDetrResNetShortCut(in_channels, out_channels, stride=1) + if should_apply_shortcut + else nn.Identity(), + ] + ) + else: + self.shortcut = ( + RTDetrResNetShortCut(in_channels, out_channels, stride=stride) + if should_apply_shortcut + else nn.Identity() + ) + self.layer = nn.Sequential( + RTDetrResNetConvLayer( + in_channels, reduces_channels, kernel_size=1, stride=stride if config.downsample_in_bottleneck else 1 + ), + RTDetrResNetConvLayer( + reduces_channels, reduces_channels, stride=stride if not config.downsample_in_bottleneck else 1 + ), + RTDetrResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None), + ) + self.activation = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RTDetrResNetStage(nn.Module): + """ + A RTDetrResNet stage composed by stacked layers. + """ + + def __init__( + self, + config: RTDetrResNetConfig, + in_channels: int, + out_channels: int, + stride: int = 2, + depth: int = 2, + ): + super().__init__() + + layer = RTDetrResNetBottleNeckLayer if config.layer_type == "bottleneck" else RTDetrResNetBasicLayer + + if config.layer_type == "bottleneck": + first_layer = layer( + config, + in_channels, + out_channels, + stride=stride, + ) + else: + first_layer = layer(config, in_channels, out_channels, stride=stride, should_apply_shortcut=True) + self.layers = nn.Sequential( + first_layer, *[layer(config, out_channels, out_channels) for _ in range(depth - 1)] + ) + + def forward(self, input: Tensor) -> Tensor: + hidden_state = input + for layer in self.layers: + hidden_state = layer(hidden_state) + return hidden_state + + +# Copied from transformers.models.resnet.modeling_resnet.ResNetEncoder with ResNet->RTDetrResNet +class RTDetrResNetEncoder(nn.Module): + def __init__(self, config: RTDetrResNetConfig): + super().__init__() + self.stages = nn.ModuleList([]) + # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input + self.stages.append( + RTDetrResNetStage( + config, + config.embedding_size, + config.hidden_sizes[0], + stride=2 if config.downsample_in_first_stage else 1, + depth=config.depths[0], + ) + ) + in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:]) + for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]): + self.stages.append(RTDetrResNetStage(config, in_channels, out_channels, depth=depth)) + + def forward( + self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True + ) -> BaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + hidden_state = stage_module(hidden_state) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + ) + + +# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel with ResNet->RTDetrResNet +class RTDetrResNetPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RTDetrResNetConfig + base_model_prefix = "resnet" + main_input_name = "pixel_values" + _no_split_modules = ["RTDetrResNetConvLayer", "RTDetrResNetShortCut"] + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + +RTDETR_RESNET_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`RTDetrResNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +RTDETR_RESNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`RTDetrImageProcessor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """ + ResNet backbone, to be used with frameworks like RTDETR. + """, + RTDETR_RESNET_START_DOCSTRING, +) +class RTDetrResNetBackbone(RTDetrResNetPreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [config.embedding_size] + config.hidden_sizes + self.embedder = RTDetrResNetEmbeddings(config) + self.encoder = RTDetrResNetEncoder(config) + + # initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(RTDETR_RESNET_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import RTDetrResNetConfig, RTDetrResNetBackbone + >>> import torch + + >>> config = RTDetrResNetConfig() + >>> model = RTDetrResNetBackbone(config) + + >>> pixel_values = torch.randn(1, 3, 224, 224) + + >>> with torch.no_grad(): + ... outputs = model(pixel_values) + + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 2048, 7, 7] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + embedding_output = self.embedder(pixel_values) + + outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True) + + hidden_states = outputs.hidden_states + + feature_maps = () + for idx, stage in enumerate(self.stage_names): + if stage in self.out_features: + feature_maps += (hidden_states[idx],) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=None, + ) diff --git a/src/transformers/models/timm_backbone/modeling_timm_backbone.py b/src/transformers/models/timm_backbone/modeling_timm_backbone.py index 74e7388b7dcab5..ffe83daf7bc23b 100644 --- a/src/transformers/models/timm_backbone/modeling_timm_backbone.py +++ b/src/transformers/models/timm_backbone/modeling_timm_backbone.py @@ -113,10 +113,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): return super()._from_config(config, **kwargs) def freeze_batch_norm_2d(self): - timm.layers.freeze_batch_norm_2d(self._backbone) + timm.utils.model.freeze_batch_norm_2d(self._backbone) def unfreeze_batch_norm_2d(self): - timm.layers.unfreeze_batch_norm_2d(self._backbone) + timm.utils.model.unfreeze_batch_norm_2d(self._backbone) def _init_weights(self, module): """ diff --git a/src/transformers/utils/backbone_utils.py b/src/transformers/utils/backbone_utils.py index e689fee20fe8fd..86a1fae4ad0c35 100644 --- a/src/transformers/utils/backbone_utils.py +++ b/src/transformers/utils/backbone_utils.py @@ -313,7 +313,6 @@ def load_backbone(config): use_pretrained_backbone = getattr(config, "use_pretrained_backbone", None) backbone_checkpoint = getattr(config, "backbone", None) backbone_kwargs = getattr(config, "backbone_kwargs", None) - backbone_kwargs = {} if backbone_kwargs is None else backbone_kwargs if backbone_kwargs and backbone_config is not None: diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 0cda4ed7b96349..b72b14b93c62b2 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -7492,6 +7492,41 @@ def load_tf_weights_in_roformer(*args, **kwargs): requires_backends(load_tf_weights_in_roformer, ["torch"]) +class RTDetrForObjectDetection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class RTDetrModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class RTDetrPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class RTDetrResNetBackbone(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class RTDetrResNetPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class RwkvForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index a27dc024447f42..5a3011be5e2171 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -492,6 +492,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class RTDetrImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class SamImageProcessor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/rt_detr/__init__.py b/tests/models/rt_detr/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/rt_detr/test_image_processing_rt_detr.py b/tests/models/rt_detr/test_image_processing_rt_detr.py new file mode 100644 index 00000000000000..3960e4401916de --- /dev/null +++ b/tests/models/rt_detr/test_image_processing_rt_detr.py @@ -0,0 +1,364 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import unittest + +import requests + +from transformers.testing_utils import require_torch, require_vision, slow +from transformers.utils import is_torch_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_vision_available(): + from PIL import Image + + from transformers import RTDetrImageProcessor + +if is_torch_available(): + import torch + + +class RTDetrImageProcessingTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=4, + num_channels=3, + do_resize=True, + size=None, + do_rescale=True, + rescale_factor=1 / 255, + do_normalize=False, + do_pad=False, + return_tensors="pt", + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.do_resize = do_resize + self.size = size if size is not None else {"height": 640, "width": 640} + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.do_pad = do_pad + self.return_tensors = return_tensors + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_rescale": self.do_rescale, + "rescale_factor": self.rescale_factor, + "do_normalize": self.do_normalize, + "do_pad": self.do_pad, + "return_tensors": self.return_tensors, + } + + def get_expected_values(self): + return self.size["height"], self.size["width"] + + def expected_output_image_shape(self, images): + height, width = self.get_expected_values() + return self.num_channels, height, width + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=30, + max_resolution=400, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + +@require_torch +@require_vision +class RtDetrImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = RTDetrImageProcessor if is_vision_available() else None + + def setUp(self): + super().setUp() + self.image_processor_tester = RTDetrImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "resample")) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "return_tensors")) + + def test_image_processor_from_dict_with_kwargs(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 640, "width": 640}) + + def test_valid_coco_detection_annotations(self): + # prepare image and target + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f: + target = json.loads(f.read()) + + params = {"image_id": 39769, "annotations": target} + + # encode them + image_processing = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd") + + # legal encodings (single image) + _ = image_processing(images=image, annotations=params, return_tensors="pt") + _ = image_processing(images=image, annotations=[params], return_tensors="pt") + + # legal encodings (batch of one image) + _ = image_processing(images=[image], annotations=params, return_tensors="pt") + _ = image_processing(images=[image], annotations=[params], return_tensors="pt") + + # legal encoding (batch of more than one image) + n = 5 + _ = image_processing(images=[image] * n, annotations=[params] * n, return_tensors="pt") + + # example of an illegal encoding (missing the 'image_id' key) + with self.assertRaises(ValueError) as e: + image_processing(images=image, annotations={"annotations": target}, return_tensors="pt") + + self.assertTrue(str(e.exception).startswith("Invalid COCO detection annotations")) + + # example of an illegal encoding (unequal lengths of images and annotations) + with self.assertRaises(ValueError) as e: + image_processing(images=[image] * n, annotations=[params] * (n - 1), return_tensors="pt") + + self.assertTrue(str(e.exception) == "The number of images (5) and annotations (4) do not match.") + + @slow + def test_call_pytorch_with_coco_detection_annotations(self): + # prepare image and target + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f: + target = json.loads(f.read()) + + target = {"image_id": 39769, "annotations": target} + + # encode them + image_processing = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd") + encoding = image_processing(images=image, annotations=target, return_tensors="pt") + + # verify pixel values + expected_shape = torch.Size([1, 3, 640, 640]) + self.assertEqual(encoding["pixel_values"].shape, expected_shape) + + expected_slice = torch.tensor([0.5490, 0.5647, 0.5725]) + self.assertTrue(torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-4)) + + # verify area + expected_area = torch.tensor([2827.9883, 5403.4761, 235036.7344, 402070.2188, 71068.8281, 79601.2812]) + self.assertTrue(torch.allclose(encoding["labels"][0]["area"], expected_area)) + # verify boxes + expected_boxes_shape = torch.Size([6, 4]) + self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape) + expected_boxes_slice = torch.tensor([0.5503, 0.2765, 0.0604, 0.2215]) + self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"][0], expected_boxes_slice, atol=1e-3)) + # verify image_id + expected_image_id = torch.tensor([39769]) + self.assertTrue(torch.allclose(encoding["labels"][0]["image_id"], expected_image_id)) + # verify is_crowd + expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0]) + self.assertTrue(torch.allclose(encoding["labels"][0]["iscrowd"], expected_is_crowd)) + # verify class_labels + expected_class_labels = torch.tensor([75, 75, 63, 65, 17, 17]) + self.assertTrue(torch.allclose(encoding["labels"][0]["class_labels"], expected_class_labels)) + # verify orig_size + expected_orig_size = torch.tensor([480, 640]) + self.assertTrue(torch.allclose(encoding["labels"][0]["orig_size"], expected_orig_size)) + # verify size + expected_size = torch.tensor([640, 640]) + self.assertTrue(torch.allclose(encoding["labels"][0]["size"], expected_size)) + + @slow + def test_image_processor_outputs(self): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + + image_processing = self.image_processing_class(**self.image_processor_dict) + encoding = image_processing(images=image, return_tensors="pt") + + # verify pixel values: shape + expected_shape = torch.Size([1, 3, 640, 640]) + self.assertEqual(encoding["pixel_values"].shape, expected_shape) + + # verify pixel values: output values + expected_slice = torch.tensor([0.5490196347236633, 0.5647059082984924, 0.572549045085907]) + self.assertTrue(torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-5)) + + def test_multiple_images_processor_outputs(self): + images_urls = [ + "http://images.cocodataset.org/val2017/000000000139.jpg", + "http://images.cocodataset.org/val2017/000000000285.jpg", + "http://images.cocodataset.org/val2017/000000000632.jpg", + "http://images.cocodataset.org/val2017/000000000724.jpg", + "http://images.cocodataset.org/val2017/000000000776.jpg", + "http://images.cocodataset.org/val2017/000000000785.jpg", + "http://images.cocodataset.org/val2017/000000000802.jpg", + "http://images.cocodataset.org/val2017/000000000872.jpg", + ] + + images = [] + for url in images_urls: + image = Image.open(requests.get(url, stream=True).raw) + images.append(image) + + # apply image processing + image_processing = self.image_processing_class(**self.image_processor_dict) + encoding = image_processing(images=images, return_tensors="pt") + + # verify if pixel_values is part of the encoding + self.assertIn("pixel_values", encoding) + + # verify pixel values: shape + expected_shape = torch.Size([8, 3, 640, 640]) + self.assertEqual(encoding["pixel_values"].shape, expected_shape) + + # verify pixel values: output values + expected_slices = torch.tensor( + [ + [0.5333333611488342, 0.5568627715110779, 0.5647059082984924], + [0.5372549295425415, 0.4705882668495178, 0.4274510145187378], + [0.3960784673690796, 0.35686275362968445, 0.3686274588108063], + [0.20784315466880798, 0.1882353127002716, 0.15294118225574493], + [0.364705890417099, 0.364705890417099, 0.3686274588108063], + [0.8078432083129883, 0.8078432083129883, 0.8078432083129883], + [0.4431372880935669, 0.4431372880935669, 0.4431372880935669], + [0.19607844948768616, 0.21176472306251526, 0.3607843220233917], + ] + ) + self.assertTrue(torch.allclose(encoding["pixel_values"][:, 1, 0, :3], expected_slices, atol=1e-5)) + + @slow + def test_batched_coco_detection_annotations(self): + image_0 = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + image_1 = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png").resize((800, 800)) + + with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f: + target = json.loads(f.read()) + + annotations_0 = {"image_id": 39769, "annotations": target} + annotations_1 = {"image_id": 39769, "annotations": target} + + # Adjust the bounding boxes for the resized image + w_0, h_0 = image_0.size + w_1, h_1 = image_1.size + for i in range(len(annotations_1["annotations"])): + coords = annotations_1["annotations"][i]["bbox"] + new_bbox = [ + coords[0] * w_1 / w_0, + coords[1] * h_1 / h_0, + coords[2] * w_1 / w_0, + coords[3] * h_1 / h_0, + ] + annotations_1["annotations"][i]["bbox"] = new_bbox + + images = [image_0, image_1] + annotations = [annotations_0, annotations_1] + + image_processing = RTDetrImageProcessor() + encoding = image_processing( + images=images, + annotations=annotations, + return_segmentation_masks=True, + return_tensors="pt", # do_convert_annotations=True + ) + + # Check the pixel values have been padded + postprocessed_height, postprocessed_width = 640, 640 + expected_shape = torch.Size([2, 3, postprocessed_height, postprocessed_width]) + self.assertEqual(encoding["pixel_values"].shape, expected_shape) + + # Check the bounding boxes have been adjusted for padded images + self.assertEqual(encoding["labels"][0]["boxes"].shape, torch.Size([6, 4])) + self.assertEqual(encoding["labels"][1]["boxes"].shape, torch.Size([6, 4])) + expected_boxes_0 = torch.tensor( + [ + [0.6879, 0.4609, 0.0755, 0.3691], + [0.2118, 0.3359, 0.2601, 0.1566], + [0.5011, 0.5000, 0.9979, 1.0000], + [0.5010, 0.5020, 0.9979, 0.9959], + [0.3284, 0.5944, 0.5884, 0.8112], + [0.8394, 0.5445, 0.3213, 0.9110], + ] + ) + expected_boxes_1 = torch.tensor( + [ + [0.5503, 0.2765, 0.0604, 0.2215], + [0.1695, 0.2016, 0.2080, 0.0940], + [0.5006, 0.4933, 0.9977, 0.9865], + [0.5008, 0.5002, 0.9983, 0.9955], + [0.2627, 0.5456, 0.4707, 0.8646], + [0.7715, 0.4115, 0.4570, 0.7161], + ] + ) + self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"], expected_boxes_0, rtol=1e-3)) + self.assertTrue(torch.allclose(encoding["labels"][1]["boxes"], expected_boxes_1, rtol=1e-3)) + + # Check if do_convert_annotations=False, then the annotations are not converted to centre_x, centre_y, width, height + # format and not in the range [0, 1] + encoding = image_processing( + images=images, + annotations=annotations, + return_segmentation_masks=True, + do_convert_annotations=False, + return_tensors="pt", + ) + self.assertEqual(encoding["labels"][0]["boxes"].shape, torch.Size([6, 4])) + self.assertEqual(encoding["labels"][1]["boxes"].shape, torch.Size([6, 4])) + # Convert to absolute coordinates + unnormalized_boxes_0 = torch.vstack( + [ + expected_boxes_0[:, 0] * postprocessed_width, + expected_boxes_0[:, 1] * postprocessed_height, + expected_boxes_0[:, 2] * postprocessed_width, + expected_boxes_0[:, 3] * postprocessed_height, + ] + ).T + unnormalized_boxes_1 = torch.vstack( + [ + expected_boxes_1[:, 0] * postprocessed_width, + expected_boxes_1[:, 1] * postprocessed_height, + expected_boxes_1[:, 2] * postprocessed_width, + expected_boxes_1[:, 3] * postprocessed_height, + ] + ).T + # Convert from centre_x, centre_y, width, height to x_min, y_min, x_max, y_max + expected_boxes_0 = torch.vstack( + [ + unnormalized_boxes_0[:, 0] - unnormalized_boxes_0[:, 2] / 2, + unnormalized_boxes_0[:, 1] - unnormalized_boxes_0[:, 3] / 2, + unnormalized_boxes_0[:, 0] + unnormalized_boxes_0[:, 2] / 2, + unnormalized_boxes_0[:, 1] + unnormalized_boxes_0[:, 3] / 2, + ] + ).T + expected_boxes_1 = torch.vstack( + [ + unnormalized_boxes_1[:, 0] - unnormalized_boxes_1[:, 2] / 2, + unnormalized_boxes_1[:, 1] - unnormalized_boxes_1[:, 3] / 2, + unnormalized_boxes_1[:, 0] + unnormalized_boxes_1[:, 2] / 2, + unnormalized_boxes_1[:, 1] + unnormalized_boxes_1[:, 3] / 2, + ] + ).T + self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"], expected_boxes_0, rtol=1)) + self.assertTrue(torch.allclose(encoding["labels"][1]["boxes"], expected_boxes_1, rtol=1)) diff --git a/tests/models/rt_detr/test_modeling_rt_detr.py b/tests/models/rt_detr/test_modeling_rt_detr.py new file mode 100644 index 00000000000000..05ce68fb92e82d --- /dev/null +++ b/tests/models/rt_detr/test_modeling_rt_detr.py @@ -0,0 +1,680 @@ +# coding = utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch RT_DETR model.""" + +import inspect +import math +import unittest + +from transformers import ( + RTDetrConfig, + RTDetrImageProcessor, + RTDetrResNetConfig, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import require_torch, require_vision, torch_device +from transformers.utils import cached_property + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import RTDetrForObjectDetection, RTDetrModel + +if is_vision_available(): + from PIL import Image + + +CHECKPOINT = "PekingU/rtdetr_r50vd" # TODO: replace + + +class RTDetrModelTester: + def __init__( + self, + parent, + batch_size=3, + is_training=True, + use_labels=True, + n_targets=3, + num_labels=10, + initializer_range=0.02, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + # backbone + backbone_config=None, + # encoder HybridEncoder + encoder_hidden_dim=32, + encoder_in_channels=[128, 256, 512], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=64, + encoder_attention_heads=2, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + # decoder RTDetrTransformer + d_model=32, + num_queries=30, + decoder_in_channels=[32, 32, 32], + decoder_ffn_dim=64, + num_feature_levels=3, + decoder_n_points=4, + decoder_layers=2, + decoder_attention_heads=2, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=0, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=[64, 64], + image_size=64, + disable_custom_kernels=True, + with_box_refine=True, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = 3 + self.is_training = is_training + self.use_labels = use_labels + self.n_targets = n_targets + self.num_labels = num_labels + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + self.backbone_config = backbone_config + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_layers = encoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.eval_size = eval_size + self.normalize_before = normalize_before + self.d_model = d_model + self.num_queries = num_queries + self.decoder_in_channels = decoder_in_channels + self.decoder_ffn_dim = decoder_ffn_dim + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.image_size = image_size + self.disable_custom_kernels = disable_custom_kernels + self.with_box_refine = with_box_refine + + self.encoder_seq_length = math.ceil(self.image_size / 32) * math.ceil(self.image_size / 32) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device) + + labels = None + if self.use_labels: + # labels is a list of Dict (each Dict being the labels for a given example in the batch) + labels = [] + for i in range(self.batch_size): + target = {} + target["class_labels"] = torch.randint( + high=self.num_labels, size=(self.n_targets,), device=torch_device + ) + target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device) + labels.append(target) + + config = self.get_config() + config.num_labels = self.num_labels + return config, pixel_values, pixel_mask, labels + + def get_config(self): + hidden_sizes = [10, 20, 30, 40] + backbone_config = RTDetrResNetConfig( + embeddings_size=10, + hidden_sizes=hidden_sizes, + depths=[1, 1, 2, 1], + out_features=["stage2", "stage3", "stage4"], + out_indices=[2, 3, 4], + ) + return RTDetrConfig.from_backbone_configs( + backbone_config=backbone_config, + encoder_hidden_dim=self.encoder_hidden_dim, + encoder_in_channels=hidden_sizes[1:], + feat_strides=self.feat_strides, + encoder_layers=self.encoder_layers, + encoder_ffn_dim=self.encoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + dropout=self.dropout, + activation_dropout=self.activation_dropout, + encode_proj_layers=self.encode_proj_layers, + positional_encoding_temperature=self.positional_encoding_temperature, + encoder_activation_function=self.encoder_activation_function, + activation_function=self.activation_function, + eval_size=self.eval_size, + normalize_before=self.normalize_before, + d_model=self.d_model, + num_queries=self.num_queries, + decoder_in_channels=self.decoder_in_channels, + decoder_ffn_dim=self.decoder_ffn_dim, + num_feature_levels=self.num_feature_levels, + decoder_n_points=self.decoder_n_points, + decoder_layers=self.decoder_layers, + decoder_attention_heads=self.decoder_attention_heads, + decoder_activation_function=self.decoder_activation_function, + attention_dropout=self.attention_dropout, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=self.box_noise_scale, + learn_initial_query=self.learn_initial_query, + anchor_image_size=self.anchor_image_size, + image_size=self.image_size, + disable_custom_kernels=self.disable_custom_kernels, + with_box_refine=self.with_box_refine, + ) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + def create_and_check_rt_detr_model(self, config, pixel_values, pixel_mask, labels): + model = RTDetrModel(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.d_model)) + + def create_and_check_rt_detr_object_detection_head_model(self, config, pixel_values, pixel_mask, labels): + model = RTDetrForObjectDetection(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels) + + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + +@require_torch +class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (RTDetrModel, RTDetrForObjectDetection) if is_torch_available() else () + pipeline_model_mapping = ( + {"image-feature-extraction": RTDetrModel, "object-detection": RTDetrForObjectDetection} + if is_torch_available() + else {} + ) + is_encoder_decoder = True + test_torchscript = False + test_pruning = False + test_head_masking = False + test_missing_keys = False + + # special case for head models + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "RTDetrForObjectDetection": + labels = [] + for i in range(self.model_tester.batch_size): + target = {} + target["class_labels"] = torch.ones( + size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long + ) + target["boxes"] = torch.ones( + self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float + ) + labels.append(target) + inputs_dict["labels"] = labels + + return inputs_dict + + def setUp(self): + self.model_tester = RTDetrModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=RTDetrConfig, + has_text_modality=False, + common_properties=["hidden_size", "num_attention_heads"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_rt_detr_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_rt_detr_model(*config_and_inputs) + + def test_rt_detr_object_detection_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_rt_detr_object_detection_head_model(*config_and_inputs) + + @unittest.skip(reason="RTDetr does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="RTDetr does not use test_inputs_embeds_matches_input_ids") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="RTDetr does not support input and output embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="RTDetr does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="RTDetr does not use token embeddings") + def test_resize_tokens_embeddings(self): + pass + + @unittest.skip(reason="Feed forward chunking is not implemented") + def test_feed_forward_chunking(self): + pass + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.encoder_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.encoder_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [ + self.model_tester.encoder_attention_heads, + self.model_tester.encoder_seq_length, + self.model_tester.encoder_seq_length, + ], + ) + out_len = len(outputs) + + correct_outlen = 13 + + # loss is at first position + if "labels" in inputs_dict: + correct_outlen += 1 # loss is added to beginning + # Object Detection model returns pred_logits and pred_boxes + if model_class.__name__ == "RTDetrForObjectDetection": + correct_outlen += 2 + + self.assertEqual(out_len, correct_outlen) + + # decoder attentions + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.decoder_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [ + self.model_tester.decoder_attention_heads, + self.model_tester.num_queries, + self.model_tester.num_queries, + ], + ) + + # cross attentions + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.decoder_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.decoder_attention_heads, + self.model_tester.num_feature_levels, + self.model_tester.decoder_n_points, + ], + ) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + else: + # RTDetr should maintin encoder_hidden_states output + added_hidden_states = 2 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions + + self.assertEqual(len(self_attentions), self.model_tester.encoder_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [ + self.model_tester.encoder_attention_heads, + self.model_tester.encoder_seq_length, + self.model_tester.encoder_seq_length, + ], + ) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", len(self.model_tester.encoder_in_channels) - 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[1].shape[-2:]), + [ + self.model_tester.image_size // self.model_tester.feat_strides[-1], + self.model_tester.image_size // self.model_tester.feat_strides[-1], + ], + ) + + if config.is_encoder_decoder: + hidden_states = outputs.decoder_hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.decoder_layers + 1 + ) + + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.num_queries, self.model_tester.d_model], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + # we take the first output since last_hidden_state is the first item + output = outputs[0] + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_attentions = outputs.encoder_attentions[0] + encoder_hidden_states.retain_grad() + encoder_attentions.retain_grad() + + decoder_attentions = outputs.decoder_attentions[0] + decoder_attentions.retain_grad() + + cross_attentions = outputs.cross_attentions[0] + cross_attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(encoder_attentions.grad) + self.assertIsNotNone(decoder_attentions.grad) + self.assertIsNotNone(cross_attentions.grad) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + arg_names = [*signature.parameters.keys()] + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_different_timm_backbone(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # let's pick a random timm backbone + config.backbone = "tf_mobilenetv3_small_075" + config.backbone_config = None + config.use_timm_backbone = True + config.backbone_kwargs = {"out_indices": [2, 3, 4]} + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if model_class.__name__ == "RTDetrForObjectDetection": + expected_shape = ( + self.model_tester.batch_size, + self.model_tester.num_queries, + self.model_tester.num_labels, + ) + self.assertEqual(outputs.logits.shape, expected_shape) + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.model.backbone.intermediate_channel_sizes), 3) + else: + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.backbone.intermediate_channel_sizes), 3) + + self.assertTrue(outputs) + + def test_hf_backbone(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Load a pretrained HF checkpoint as backbone + config.backbone = "microsoft/resnet-18" + config.backbone_config = None + config.use_timm_backbone = False + config.use_pretrained_backbone = True + config.backbone_kwargs = {"out_indices": [2, 3, 4]} + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if model_class.__name__ == "RTDetrForObjectDetection": + expected_shape = ( + self.model_tester.batch_size, + self.model_tester.num_queries, + self.model_tester.num_labels, + ) + self.assertEqual(outputs.logits.shape, expected_shape) + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.model.backbone.intermediate_channel_sizes), 3) + else: + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.backbone.intermediate_channel_sizes), 3) + + self.assertTrue(outputs) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + # Skip the check for the backbone + for name, module in model.named_modules(): + if module.__class__.__name__ == "RTDetrConvEncoder": + backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()] + break + + for name, param in model.named_parameters(): + if param.requires_grad: + if ( + "level_embed" in name + or "sampling_offsets.bias" in name + or "value_proj" in name + or "output_proj" in name + or "reference_points" in name + or name in backbone_params + ): + continue + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + +TOLERANCE = 1e-4 + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_torch +@require_vision +class RTDetrModelIntegrationTest(unittest.TestCase): + @cached_property + def default_image_processor(self): + return RTDetrImageProcessor.from_pretrained(CHECKPOINT) if is_vision_available() else None + + def test_inference_object_detection_head(self): + model = RTDetrForObjectDetection.from_pretrained(CHECKPOINT).to(torch_device) + + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + expected_shape_logits = torch.Size((1, 300, model.config.num_labels)) + self.assertEqual(outputs.logits.shape, expected_shape_logits) + + expected_logits = torch.tensor( + [ + [-4.64763879776001, -5.001153945922852, -4.978509902954102], + [-4.159348487854004, -4.703853607177734, -5.946484565734863], + [-4.437461853027344, -4.65836238861084, -6.235235691070557], + ] + ).to(torch_device) + expected_boxes = torch.tensor( + [ + [0.1688060760498047, 0.19992263615131378, 0.21225441992282867], + [0.768376350402832, 0.41226309537887573, 0.4636859893798828], + [0.25953856110572815, 0.5483334064483643, 0.4777486026287079], + ] + ).to(torch_device) + + self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_logits, atol=1e-4)) + + expected_shape_boxes = torch.Size((1, 300, 4)) + self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes) + self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=1e-4)) + + # verify postprocessing + results = image_processor.post_process_object_detection( + outputs, threshold=0.0, target_sizes=[image.size[::-1]] + )[0] + expected_scores = torch.tensor( + [0.9703017473220825, 0.9599503874778748, 0.9575679302215576, 0.9506784677505493], device=torch_device + ) + expected_labels = [57, 15, 15, 65] + expected_slice_boxes = torch.tensor( + [ + [0.13774872, 0.37821293, 640.13074, 476.21088], + [343.38132, 24.276838, 640.1404, 371.49573], + [13.225126, 54.179348, 318.98422, 472.2207], + [40.114475, 73.44104, 175.9573, 118.48469], + ], + device=torch_device, + ) + + self.assertTrue(torch.allclose(results["scores"][:4], expected_scores, atol=1e-4)) + self.assertSequenceEqual(results["labels"][:4].tolist(), expected_labels) + self.assertTrue(torch.allclose(results["boxes"][:4], expected_slice_boxes, atol=1e-4)) diff --git a/tests/models/rt_detr/test_modeling_rt_detr_resnet.py b/tests/models/rt_detr/test_modeling_rt_detr_resnet.py new file mode 100644 index 00000000000000..c925ef14ed0c5a --- /dev/null +++ b/tests/models/rt_detr/test_modeling_rt_detr_resnet.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import RTDetrResNetConfig +from transformers.testing_utils import require_torch, torch_device +from transformers.utils.import_utils import is_torch_available + +from ...test_backbone_common import BackboneTesterMixin +from ...test_modeling_common import floats_tensor, ids_tensor + + +if is_torch_available(): + from transformers import RTDetrResNetBackbone + + +class RTDetrResNetModelTester: + def __init__( + self, + parent, + batch_size=3, + image_size=32, + num_channels=3, + embeddings_size=10, + hidden_sizes=[10, 20, 30, 40], + depths=[1, 1, 2, 1], + is_training=True, + use_labels=True, + hidden_act="relu", + num_labels=3, + scope=None, + out_features=["stage2", "stage3", "stage4"], + out_indices=[2, 3, 4], + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.num_channels = num_channels + self.embeddings_size = embeddings_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.is_training = is_training + self.use_labels = use_labels + self.hidden_act = hidden_act + self.num_labels = num_labels + self.scope = scope + self.num_stages = len(hidden_sizes) + self.out_features = out_features + self.out_indices = out_indices + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.num_labels) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return RTDetrResNetConfig( + num_channels=self.num_channels, + embeddings_size=self.embeddings_size, + hidden_sizes=self.hidden_sizes, + depths=self.depths, + hidden_act=self.hidden_act, + num_labels=self.num_labels, + out_features=self.out_features, + out_indices=self.out_indices, + ) + + def create_and_check_backbone(self, config, pixel_values, labels): + model = RTDetrResNetBackbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) + self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4]) + + # verify channels + self.parent.assertEqual(len(model.channels), len(config.out_features)) + self.parent.assertListEqual(model.channels, config.hidden_sizes[1:]) + + # verify backbone works with out_features=None + config.out_features = None + model = RTDetrResNetBackbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), 1) + self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1]) + + # verify channels + self.parent.assertEqual(len(model.channels), 1) + self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]]) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class RTDetrResNetBackboneTest(BackboneTesterMixin, unittest.TestCase): + all_model_classes = (RTDetrResNetBackbone,) if is_torch_available() else () + has_attentions = False + config_class = RTDetrResNetConfig + + def setUp(self): + self.model_tester = RTDetrResNetModelTester(self)