Skip to content

Commit

Permalink
Merge pull request #85 from breezedeus/pytorch
Browse files Browse the repository at this point in the history
add param static_resized_shape
  • Loading branch information
breezedeus authored Jun 22, 2024
2 parents cea2d47 + 6598db7 commit 7de2abb
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 1 deletion.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@

</div>

<div align="center">

[English](./README_en.md) | 中文

</div>


# CnSTD
# Update 2024.06.16:发布 V1.2.4

Expand Down
22 changes: 22 additions & 0 deletions README_en.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
<div align="center">
<img src="./docs/logo.png" width="250px"/>
<div>&nbsp;</div>

[![Downloads](https://static.pepy.tech/personalized-badge/cnstd?period=total&units=international_system&left_color=grey&right_color=orange&left_text=Downloads)](https://pepy.tech/project/cnstd)
[![Visitors](https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fgithub.com%2Fbreezedeus%2FCnSTD&label=Visitors&countColor=%23f5c791&style=flat&labelStyle=none)](https://visitorbadge.io/status?path=https%3A%2F%2Fgithub.com%2Fbreezedeus%2FCnSTD)
[![license](https://img.shields.io/github/license/breezedeus/cnstd)](./LICENSE)
[![PyPI version](https://badge.fury.io/py/cnstd.svg)](https://badge.fury.io/py/cnstd)
[![forks](https://img.shields.io/github/forks/breezedeus/cnstd)](https://img.shields.io/github/forks/breezedeus/cnstd)
[![stars](https://img.shields.io/github/stars/breezedeus/cnstd)](https://github.com/breezedeus/cnocr)
![last-releast](https://img.shields.io/github/release-date/breezedeus/cnstd?style=plastic)
![last-commit](https://img.shields.io/github/last-commit/breezedeus/cnstd)
[![Twitter](https://img.shields.io/twitter/url?url=https%3A%2F%2Ftwitter.com%2Fbreezedeus)](https://twitter.com/breezedeus)

</div>

<div align="center">

[中文](./README.md) | English

</div>

# CnSTD

## Update 2024.06.16: Release V1.2.4
Expand Down
11 changes: 11 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
# Release Notes

# Update 2024.06.22:发布 V1.2.4.2

Major Changes:

* Added a new parameter `static_resized_shape` when initializing `YoloDetector`, which is used to resize the input image to a fixed size. Some formats of models require fixed-size input images during inference, such as `CoreML`.

主要变更:

* `YoloDetector` 初始化时加入了参数 `static_resized_shape`, 用于把输入图片 resize 为固定大小。某些格式的模型在推理时需要固定大小的输入图片,如 `CoreML`


# Update 2024.06.17:发布 V1.2.4.1

Major Changes:
Expand Down
2 changes: 1 addition & 1 deletion cnstd/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
# specific language governing permissions and limitations
# under the License.

__version__ = '1.2.4.1'
__version__ = '1.2.4.2'
18 changes: 18 additions & 0 deletions cnstd/yolo_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,39 @@

from pathlib import Path
from typing import Union, Optional, Any, List, Dict, Tuple
import logging

from PIL import Image
import numpy as np
from ultralytics import YOLO

from .utils import sort_boxes, dedup_boxes, xyxy24p, select_device, expand_box_by_margin

logger = logging.getLogger(__name__)


class YoloDetector(object):
def __init__(
self,
*,
model_path: Optional[str] = None,
device: Optional[str] = None,
static_resized_shape: Optional[Union[int, Tuple[int, int]]] = None,
**kwargs,
):
"""
YOLO Detector based on Ultralytics.
Args:
model_path (optional str): model path, default is None.
device (optional str): device to use, default is None.
static_resized_shape (optional int or tuple): static resized shape, default is None.
When it is not None, the input image will be resized to this shape before detection,
ignoring the input parameter `resized_shape` if .detect() is called.
Some format of models may require a fixed input size, such as CoreML.
**kwargs (): other parameters.
"""
self.device = select_device(device)
self.static_resized_shape = static_resized_shape
self.model = YOLO(model_path, task='detect')

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -85,6 +101,8 @@ def detect(
for img in img_list
]

if self.static_resized_shape is not None:
resized_shape = self.static_resized_shape
batch_results = self.model.predict(
img_list, imgsz=resized_shape, conf=conf, device=self.device, **kwargs
)
Expand Down

0 comments on commit 7de2abb

Please sign in to comment.