Skip to content

Commit

Permalink
Assert engine precision ultralytics#6777
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidBaldsiefen committed Feb 27, 2022
1 parent 63ddb6f commit ac409f2
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 0 deletions.
2 changes: 2 additions & 0 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)

# Half
half &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16 supported on limited backends with CUDA
if engine:
assert (model.trt_fp16_input == half), 'model ' + ('requires' if model.trt_fp16_input else 'incompatible with') + ' --half'
if pt or jit:
model.model.half() if half else model.model.float()

Expand Down
3 changes: 3 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
w = str(weights[0] if isinstance(weights, list) else weights)
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w) # get backend
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
trt_fp16_input = False
w = attempt_download(w) # download if not local
if data: # data.yaml path (optional)
with open(data, errors='ignore') as f:
Expand Down Expand Up @@ -348,6 +349,8 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
shape = tuple(model.get_binding_shape(index))
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
if model.binding_is_input(index) and dtype == np.float16:
trt_fp16_input = dtype == np.float16
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
context = model.create_execution_context()
batch_size = bindings['images'].shape[0]
Expand Down
1 change: 1 addition & 0 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def run(data,
if pt or jit:
model.model.half() if half else model.model.float()
elif engine:
assert (model.trt_fp16_input == half), 'model ' + ('requires' if model.trt_fp16_input else 'incompatible with') + ' --half'
batch_size = model.batch_size
else:
half = False
Expand Down

0 comments on commit ac409f2

Please sign in to comment.