diff --git a/yolort/models/__init__.py b/yolort/models/__init__.py index f933888a..e60dfe73 100644 --- a/yolort/models/__init__.py +++ b/yolort/models/__init__.py @@ -73,7 +73,7 @@ def yolotr(upstream_version: str = 'v4.0', export_friendly: bool = False, **kwar export_friendly (bool): Deciding whether to use (ONNX/TVM) export friendly mode. """ if upstream_version == 'v4.0': - model = YOLOModule(arch="yolov5_darknet_pan_s_tr", **kwargs) + model = YOLOModule(arch="yolov5_darknet_tan_s_r40", **kwargs) else: raise NotImplementedError("Currently only supports v4.0 versions") diff --git a/yolort/models/yolo.py b/yolort/models/yolo.py index 3b7bb95f..c907c5b3 100644 --- a/yolort/models/yolo.py +++ b/yolort/models/yolo.py @@ -128,13 +128,15 @@ def forward( model_urls_root = 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0' model_urls = { + # Path Aggregation Network 'yolov5_darknet_pan_s_r31_coco': f'{model_urls_root}/yolov5_darknet_pan_s_r31_coco-eb728698.pt', 'yolov5_darknet_pan_m_r31_coco': f'{model_urls_root}/yolov5_darknet_pan_m_r31_coco-670dc553.pt', 'yolov5_darknet_pan_l_r31_coco': f'{model_urls_root}/yolov5_darknet_pan_l_r31_coco-4dcc8209.pt', 'yolov5_darknet_pan_s_r40_coco': f'{model_urls_root}/yolov5_darknet_pan_s_r40_coco-e3fd213d.pt', 'yolov5_darknet_pan_m_r40_coco': f'{model_urls_root}/yolov5_darknet_pan_m_r40_coco-d295cb02.pt', 'yolov5_darknet_pan_l_r40_coco': f'{model_urls_root}/yolov5_darknet_pan_l_r40_coco-4416841f.pt', - 'yolov5_darknet_pan_s_tr_coco': f'{model_urls_root}/yolov5_darknet_pan_s_tr_coco-f09f21f7.pt', + # Tranformer Attention Network + 'yolov5_darknet_tan_s_r40_coco': f'{model_urls_root}/yolov5_darknet_tan_s_r40_coco-fe1069ce.pt', } @@ -312,7 +314,7 @@ def yolov5_darknet_tan_s_r40(pretrained: bool = False, progress: bool = True, nu progress (bool): If True, displays a progress bar of the download to stderr """ backbone_name = 'darknet_s_r4_0' - weights_name = 'yolov5_darknet_pan_s_tr_coco' + weights_name = 'yolov5_darknet_tan_s_r40_coco' depth_multiple = 0.33 width_multiple = 0.5 version = 'v4.0'