Skip to content

Commit

Permalink
Two dimensional size=(h,w) AutoShape support (ultralytics#9072)
Browse files Browse the repository at this point in the history
* Two dimensional `size=(h,w)` AutoShape support

May resolve ultralytics#9039

Signed-off-by: Glenn Jocher <[email protected]>

* Update hubconf.py

Signed-off-by: Glenn Jocher <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Glenn Jocher <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and Clay Januhowski committed Sep 8, 2022
1 parent 79e9829 commit 338235d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
10 changes: 7 additions & 3 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo

from models.common import AutoShape, DetectMultiBackend
from models.experimental import attempt_load
from models.yolo import DetectionModel
from models.yolo import ClassificationModel, DetectionModel
from utils.downloads import attempt_download
from utils.general import LOGGER, check_requirements, intersect_dicts, logging
from utils.torch_utils import select_device
Expand All @@ -45,8 +45,12 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
if pretrained and channels == 3 and classes == 80:
try:
model = DetectMultiBackend(path, device=device, fuse=autoshape) # detection model
if autoshape and isinstance(model.model, DetectionModel):
model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
if autoshape:
if model.pt and isinstance(model.model, ClassificationModel):
LOGGER.warning('WARNING: YOLOv5 v6.2 ClassificationModel is not yet AutoShape compatible. '
'You must pass torch tensors in BCHW to this model, i.e. shape(1,3,224,224).')
else:
model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
except Exception:
model = attempt_load(path, device=device, fuse=False) # arbitrary model
else:
Expand Down
8 changes: 5 additions & 3 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def _apply(self, fn):

@smart_inference_mode()
def forward(self, ims, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
# URI: = 'https://ultralytics.com/images/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
Expand All @@ -600,6 +600,8 @@ def forward(self, ims, size=640, augment=False, profile=False):

dt = (Profile(), Profile(), Profile())
with dt[0]:
if isinstance(size, int): # expand
size = (size, size)
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
if isinstance(ims, torch.Tensor): # torch
Expand All @@ -622,10 +624,10 @@ def forward(self, ims, size=640, augment=False, profile=False):
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
s = im.shape[:2] # HWC
shape0.append(s) # image shape
g = (size / max(s)) # gain
g = max(size) / max(s) # gain
shape1.append([y * g for y in s])
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)] # inf shape
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
Expand Down

0 comments on commit 338235d

Please sign in to comment.