From d54540b990b7c49b46a6f1173bf0b23d073a12d2 Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Fri, 16 Jun 2023 16:49:40 +0900 Subject: [PATCH] Add intg test --- .../custom_deformable_detr_detector.py | 31 ++++++ .../adapters/mmdet/models/heads/__init__.py | 2 - .../heads/custom_deformable_detr_head.py | 28 ------ .../resnet50_deformable-detr/model.py | 2 +- tests/e2e/cli/detection/test_detection.py | 42 ++++---- .../cli/detection/test_detection.py | 18 +++- .../mmdet/models/detectors/conftest.py | 2 +- .../test_custom_deformable_detr_detector.py | 2 +- .../adapters/mmdet/models/heads/__init__.py | 4 - .../heads/test_custom_deformable_detr_head.py | 96 ------------------- 10 files changed, 71 insertions(+), 156 deletions(-) delete mode 100644 otx/algorithms/detection/adapters/mmdet/models/heads/custom_deformable_detr_head.py delete mode 100644 tests/unit/algorithms/detection/adapters/mmdet/models/heads/__init__.py delete mode 100644 tests/unit/algorithms/detection/adapters/mmdet/models/heads/test_custom_deformable_detr_head.py diff --git a/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_deformable_detr_detector.py b/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_deformable_detr_detector.py index 854e9c4d743..e5dda6b3d5e 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_deformable_detr_detector.py +++ b/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_deformable_detr_detector.py @@ -7,6 +7,12 @@ from mmdet.models.builder import DETECTORS from mmdet.models.detectors.deformable_detr import DeformableDETR +from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import ( + ActivationMapHook, + FeatureVectorHook, +) +from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled + @DETECTORS.register_module() class CustomDeformableDETR(DeformableDETR): @@ -18,3 +24,28 @@ class CustomDeformableDETR(DeformableDETR): def __init__(self, *args, task_adapt=None, **kwargs): super().__init__(*args, **kwargs) self.task_adapt = task_adapt + + +if is_mmdeploy_enabled(): + from mmdeploy.core import FUNCTION_REWRITER + + @FUNCTION_REWRITER.register_rewriter( + "otx.algorithms.detection.adapters.mmdet.models.detectors.custom_deformable_detr_detector.CustomDeformableDETR.simple_test" + ) + def custom_deformable_detr__simple_test(ctx, self, img, img_metas, **kwargs): + """Function for custom_mask_rcnn__simple_test.""" + height = int(img_metas[0]["img_shape"][0]) + width = int(img_metas[0]["img_shape"][1]) + img_metas[0]["batch_input_shape"] = (height, width) + img_metas[0]["img_shape"] = (height, width, 3) + feat = self.extract_feat(img) + outs = self.bbox_head(feat, img_metas) + bbox_results = self.bbox_head.get_bboxes(*outs, img_metas=img_metas, **kwargs) + + if ctx.cfg["dump_features"]: + feature_vector = FeatureVectorHook.func(feat) + cls_scores = outs[0] + saliency_map = ActivationMapHook.func(cls_scores) + return (*bbox_results, feature_vector, saliency_map) + + return bbox_results diff --git a/otx/algorithms/detection/adapters/mmdet/models/heads/__init__.py b/otx/algorithms/detection/adapters/mmdet/models/heads/__init__.py index 47db0302786..c531b25265d 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/heads/__init__.py +++ b/otx/algorithms/detection/adapters/mmdet/models/heads/__init__.py @@ -6,7 +6,6 @@ from .cross_dataset_detector_head import CrossDatasetDetectorHead from .custom_anchor_generator import SSDAnchorGeneratorClustered from .custom_atss_head import CustomATSSHead, CustomATSSHeadTrackingLossDynamics -from .custom_deformable_detr_head import CustomDeformableDETRHead from .custom_retina_head import CustomRetinaHead from .custom_roi_head import CustomRoIHead from .custom_ssd_head import CustomSSDHead @@ -17,7 +16,6 @@ "CrossDatasetDetectorHead", "SSDAnchorGeneratorClustered", "CustomATSSHead", - "CustomDeformableDETRHead", "CustomRetinaHead", "CustomSSDHead", "CustomRoIHead", diff --git a/otx/algorithms/detection/adapters/mmdet/models/heads/custom_deformable_detr_head.py b/otx/algorithms/detection/adapters/mmdet/models/heads/custom_deformable_detr_head.py deleted file mode 100644 index e553465051b..00000000000 --- a/otx/algorithms/detection/adapters/mmdet/models/heads/custom_deformable_detr_head.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Custom Deformable DETR head for OTX.""" - -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -from mmdet.models.builder import HEADS -from mmdet.models.dense_heads.deformable_detr_head import DeformableDETRHead - - -@HEADS.register_module() -class CustomDeformableDETRHead(DeformableDETRHead): - """Custom Deformable DETR Head. - - Since batch_input_shape are not added in mmdeploy, here this function add it. - However additional if condition may leads time consumption therefore we need to - find better way to add "batch_input_shape" to img_metas when the model is exported. - """ - - def forward(self, mlvl_feats, img_metas): - """Modified forward function for onnx export.""" - - if "batch_input_shape" not in img_metas[0]: - height = int(img_metas[0]["img_shape"][0]) - width = int(img_metas[0]["img_shape"][1]) - img_metas[0]["batch_input_shape"] = (height, width) - img_metas[0]["img_shape"] = (height, width, 3) - return super().forward(mlvl_feats, img_metas) diff --git a/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/model.py b/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/model.py index ccbb656dede..bebbe0b9c80 100644 --- a/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/model.py +++ b/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/model.py @@ -22,7 +22,7 @@ num_outs=4, ), bbox_head=dict( - type="CustomDeformableDETRHead", + type="DeformableDETRHead", num_query=300, num_classes=80, in_channels=2048, diff --git a/tests/e2e/cli/detection/test_detection.py b/tests/e2e/cli/detection/test_detection.py index 561b0629b18..510bd00a40f 100644 --- a/tests/e2e/cli/detection/test_detection.py +++ b/tests/e2e/cli/detection/test_detection.py @@ -90,10 +90,16 @@ templates = Registry("otx/algorithms/detection").filter(task_type="DETECTION").templates templates_ids = [template.model_template_id for template in templates] + template_experimental = parse_model_template( + "otx/algorithms/detection/configs/detection/resnet50_deformable-detr/template_experimental.yaml" + ) + templates_w_experimental = templates + [template_experimental] + templates_ids_w_experimental = templates_ids + [template_experimental.model_template_id] + class TestToolsMPADetection: @e2e_pytest_component - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_train(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_train_testing(template, tmp_dir_path, otx_dir, args0) @@ -104,7 +110,7 @@ def test_otx_train(self, template, tmp_dir_path): @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_resume(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection/test_resume" otx_resume_testing(template, tmp_dir_path, otx_dir, args) @@ -118,7 +124,7 @@ def test_otx_resume(self, template, tmp_dir_path): @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) @pytest.mark.parametrize("dump_features", [True, False]) def test_otx_export(self, template, tmp_dir_path, dump_features): tmp_dir_path = tmp_dir_path / "detection" @@ -126,21 +132,21 @@ def test_otx_export(self, template, tmp_dir_path, dump_features): @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_export_fp16(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_export_testing(template, tmp_dir_path, half_precision=True) @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_eval(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_eval_testing(template, tmp_dir_path, otx_dir, args) @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) @pytest.mark.parametrize("half_precision", [True, False]) def test_otx_eval_openvino(self, template, tmp_dir_path, half_precision): tmp_dir_path = tmp_dir_path / "detection" @@ -190,42 +196,42 @@ def test_otx_explain_process_saliency_maps_openvino(self, template, tmp_dir_path @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_demo(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_demo_testing(template, tmp_dir_path, otx_dir, args) @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_demo_openvino(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_demo_openvino_testing(template, tmp_dir_path, otx_dir, args) @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_deploy_openvino(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_deploy_openvino_testing(template, tmp_dir_path, otx_dir, args) @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_eval_deployment(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_eval_deployment_testing(template, tmp_dir_path, otx_dir, args, threshold=0.0) @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_demo_deployment(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_demo_deployment_testing(template, tmp_dir_path, otx_dir, args) @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_hpo(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection/test_hpo" otx_hpo_testing(template, tmp_dir_path, otx_dir, args) @@ -282,7 +288,7 @@ def test_nncf_eval_openvino(self, template, tmp_dir_path): @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_pot_optimize(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" pot_optimize_testing(template, tmp_dir_path, otx_dir, args) @@ -296,7 +302,7 @@ def test_pot_validate_fq(self, template, tmp_dir_path): @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_pot_eval(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" pot_eval_testing(template, tmp_dir_path, otx_dir, args) @@ -304,7 +310,7 @@ def test_pot_eval(self, template, tmp_dir_path): @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") @pytest.mark.skipif(MULTI_GPU_UNAVAILABLE, reason="The number of gpu is insufficient") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_multi_gpu_train(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection/test_multi_gpu" args1 = copy.deepcopy(args) @@ -314,14 +320,14 @@ def test_otx_multi_gpu_train(self, template, tmp_dir_path): class TestToolsMPASemiSLDetection: @e2e_pytest_component - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_train(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection/test_semisl" otx_train_testing(template, tmp_dir_path, otx_dir, args_semisl) @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_eval(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection/test_semisl" otx_eval_testing(template, tmp_dir_path, otx_dir, args) @@ -329,7 +335,7 @@ def test_otx_eval(self, template, tmp_dir_path): @e2e_pytest_component @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") @pytest.mark.skipif(MULTI_GPU_UNAVAILABLE, reason="The number of gpu is insufficient") - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_multi_gpu_train_semisl(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection/test_multi_gpu_semisl" args_semisl_multigpu = copy.deepcopy(args_semisl) diff --git a/tests/integration/cli/detection/test_detection.py b/tests/integration/cli/detection/test_detection.py index 883aef94c94..18a2d65f230 100644 --- a/tests/integration/cli/detection/test_detection.py +++ b/tests/integration/cli/detection/test_detection.py @@ -68,10 +68,18 @@ templates = Registry("otx/algorithms/detection").filter(task_type="DETECTION").templates templates_ids = [template.model_template_id for template in templates] +experimental_template = parse_model_template( + "otx/algorithms/detection/configs/detection/resnet50_deformable-detr/template_experimental.yaml" +) +experimental_template_id = experimental_template.model_template_id + +templates_w_experimental = templates + [experimental_template] +templates_ids_w_experimental = templates_ids + [experimental_template_id] + class TestDetectionCLI: @e2e_pytest_component - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_train(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_train_testing(template, tmp_dir_path, otx_dir, args) @@ -90,26 +98,26 @@ def test_otx_resume(self, template, tmp_dir_path): otx_resume_testing(template, tmp_dir_path, otx_dir, args1) @e2e_pytest_component - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) @pytest.mark.parametrize("dump_features", [True, False]) def test_otx_export(self, template, tmp_dir_path, dump_features): tmp_dir_path = tmp_dir_path / "detection" otx_export_testing(template, tmp_dir_path, dump_features, check_ir_meta=True) @e2e_pytest_component - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_export_fp16(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_export_testing(template, tmp_dir_path, half_precision=True) @e2e_pytest_component - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_export_onnx(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_export_testing(template, tmp_dir_path, half_precision=False, is_onnx=True) @e2e_pytest_component - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_eval(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_eval_testing(template, tmp_dir_path, otx_dir, args) diff --git a/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/conftest.py b/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/conftest.py index 006b92a31a8..52b50f2722d 100644 --- a/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/conftest.py +++ b/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/conftest.py @@ -339,7 +339,7 @@ def fxt_cfg_custom_deformable_detr(num_classes: int = 3): num_outs=4, ), bbox_head=dict( - type="CustomDeformableDETRHead", + type="DeformableDETRHead", num_query=300, num_classes=80, in_channels=2048, diff --git a/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/test_custom_deformable_detr_detector.py b/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/test_custom_deformable_detr_detector.py index 29fe0c95fdb..ef3f0e8145a 100644 --- a/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/test_custom_deformable_detr_detector.py +++ b/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/test_custom_deformable_detr_detector.py @@ -1,4 +1,4 @@ -"""Test for CustomDeformableDETRHead.""" +"""Test for CustomDeformableDETR Detector.""" # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # diff --git a/tests/unit/algorithms/detection/adapters/mmdet/models/heads/__init__.py b/tests/unit/algorithms/detection/adapters/mmdet/models/heads/__init__.py deleted file mode 100644 index 0567857243c..00000000000 --- a/tests/unit/algorithms/detection/adapters/mmdet/models/heads/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Test for otx.algorithms.mmdetection.adapters.mmdet.models.heads.""" -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# diff --git a/tests/unit/algorithms/detection/adapters/mmdet/models/heads/test_custom_deformable_detr_head.py b/tests/unit/algorithms/detection/adapters/mmdet/models/heads/test_custom_deformable_detr_head.py deleted file mode 100644 index 74b9d19dec4..00000000000 --- a/tests/unit/algorithms/detection/adapters/mmdet/models/heads/test_custom_deformable_detr_head.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Test for otx.algorithms.mmdetection.adapters.mmdet.models.heads.custom_deformable_detr_head.""" -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -import numpy as np -import torch -import pytest - -from mmcv.utils import ConfigDict -from mmdet.models.builder import build_detector -from mmdet.models.dense_heads.deformable_detr_head import DeformableDETRHead - -from tests.test_suite.e2e_test_system import e2e_pytest_unit - - -class TestCustomDeformableDETRHead: - @pytest.fixture(autouse=True) - def setup(self) -> None: - cfg = ConfigDict( - type="CustomDeformableDETRHead", - num_query=300, - num_classes=80, - in_channels=2048, - sync_cls_avg_factor=True, - with_box_refine=True, - as_two_stage=True, - transformer=dict( - type="DeformableDetrTransformer", - encoder=dict( - type="DetrTransformerEncoder", - num_layers=6, - transformerlayers=dict( - type="BaseTransformerLayer", - attn_cfgs=dict(type="MultiScaleDeformableAttention", embed_dims=256), - feedforward_channels=1024, - ffn_dropout=0.1, - operation_order=("self_attn", "norm", "ffn", "norm"), - ), - ), - decoder=dict( - type="DeformableDetrTransformerDecoder", - num_layers=6, - return_intermediate=True, - transformerlayers=dict( - type="DetrTransformerDecoderLayer", - attn_cfgs=[ - dict(type="MultiheadAttention", embed_dims=256, num_heads=8, dropout=0.1), - dict(type="MultiScaleDeformableAttention", embed_dims=256), - ], - feedforward_channels=1024, - ffn_dropout=0.1, - operation_order=("self_attn", "norm", "cross_attn", "norm", "ffn", "norm"), - ), - ), - ), - positional_encoding=dict(type="SinePositionalEncoding", num_feats=128, normalize=True, offset=-0.5), - loss_cls=dict(type="FocalLoss", use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=2.0), - loss_bbox=dict(type="L1Loss", loss_weight=5.0), - loss_iou=dict(type="GIoULoss", loss_weight=2.0), - ) - self.head = build_detector(cfg) - - @e2e_pytest_unit - def test_forward(self, mocker): - def return_second_arg(a, b): - return b - - mocker.patch.object(DeformableDETRHead, "forward", side_effect=return_second_arg) - - feats = ( - torch.randn([1, 256, 100, 167]), - torch.randn([1, 256, 50, 84]), - torch.randn([1, 256, 25, 42]), - torch.randn([1, 256, 13, 21]), - ) - img_metas = [ - { - "filename": None, - "ori_filename": None, - "ori_shape": (128, 128, 3), - "img_shape": torch.Tensor([800, 1333]), - "pad_shape": (800, 1333, 3), - "scale_factor": np.array([10.4140625, 6.25, 10.4140625, 6.25], dtype=np.float32), - "flip": False, - "flip_direction": None, - "img_norm_cfg": { - "mean": np.array([123.675, 116.28, 103.53], dtype=np.float32), - "std": np.array([58.395, 57.12, 57.375], dtype=np.float32), - "to_rgb": False, - }, - } - ] - out = self.head(feats, img_metas) - assert out[0].get("batch_input_shape") == (800, 1333) - assert out[0].get("img_shape") == (800, 1333, 3)