Skip to content

Commit

Permalink
Add variable-stride inference support (ultralytics#2091)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored Jan 30, 2021
1 parent 4e8253d commit c776b5c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
7 changes: 4 additions & 3 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def detect(save_img=False):

# Load model
model = attempt_load(weights, map_location=device) # load FP32 model
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
stride = int(model.stride.max()) # model stride
imgsz = check_img_size(imgsz, s=stride) # check img_size
if half:
model.half() # to FP16

Expand All @@ -46,10 +47,10 @@ def detect(save_img=False):
if webcam:
view_img = True
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz)
dataset = LoadStreams(source, img_size=imgsz, stride=stride)
else:
save_img = True
dataset = LoadImages(source, img_size=imgsz)
dataset = LoadImages(source, img_size=imgsz, stride=stride)

# Get names and colors
names = model.module.names if hasattr(model, 'module') else model.names
Expand Down
23 changes: 13 additions & 10 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __iter__(self):


class LoadImages: # for inference
def __init__(self, path, img_size=640):
def __init__(self, path, img_size=640, stride=32):
p = str(Path(path)) # os-agnostic
p = os.path.abspath(p) # absolute path
if '*' in p:
Expand All @@ -136,6 +136,7 @@ def __init__(self, path, img_size=640):
ni, nv = len(images), len(videos)

self.img_size = img_size
self.stride = stride
self.files = images + videos
self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv
Expand Down Expand Up @@ -181,7 +182,7 @@ def __next__(self):
print(f'image {self.count}/{self.nf} {path}: ', end='')

# Padded resize
img = letterbox(img0, new_shape=self.img_size)[0]
img = letterbox(img0, self.img_size, stride=self.stride)[0]

# Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
Expand All @@ -199,8 +200,9 @@ def __len__(self):


class LoadWebcam: # for inference
def __init__(self, pipe='0', img_size=640):
def __init__(self, pipe='0', img_size=640, stride=32):
self.img_size = img_size
self.stride = stride

if pipe.isnumeric():
pipe = eval(pipe) # local camera
Expand Down Expand Up @@ -243,7 +245,7 @@ def __next__(self):
print(f'webcam {self.count}: ', end='')

# Padded resize
img = letterbox(img0, new_shape=self.img_size)[0]
img = letterbox(img0, self.img_size, stride=self.stride)[0]

# Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
Expand All @@ -256,9 +258,10 @@ def __len__(self):


class LoadStreams: # multiple IP or RTSP cameras
def __init__(self, sources='streams.txt', img_size=640):
def __init__(self, sources='streams.txt', img_size=640, stride=32):
self.mode = 'stream'
self.img_size = img_size
self.stride = stride

if os.path.isfile(sources):
with open(sources, 'r') as f:
Expand All @@ -284,7 +287,7 @@ def __init__(self, sources='streams.txt', img_size=640):
print('') # newline

# check for common shapes
s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes
s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
if not self.rect:
print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
Expand Down Expand Up @@ -313,7 +316,7 @@ def __next__(self):
raise StopIteration

# Letterbox
img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0]
img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]

# Stack
img = np.stack(img, 0)
Expand Down Expand Up @@ -784,8 +787,8 @@ def replicate(img, labels):
return img, labels


def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
# Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
# Resize and pad image while meeting stride-multiple constraints
shape = img.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
Expand All @@ -800,7 +803,7 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, 32), np.mod(dh, 32) # wh padding
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
elif scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
Expand Down

0 comments on commit c776b5c

Please sign in to comment.