Skip to content

Commit

Permalink
Fix TensorRT potential unordered binding addresses (ultralytics#5826)
Browse files Browse the repository at this point in the history
* feat: change file suffix in pythonic way

* fix: enforce binding addresses order

* fix: enforce binding addresses order
  • Loading branch information
imyhxy committed Nov 30, 2021
1 parent 777d5ba commit 57fadd4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
3 changes: 2 additions & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
assert onnx.exists(), f'failed to export ONNX file: {onnx}'

LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
f = str(file).replace('.pt', '.engine') # TensorRT engine file
f = file.with_suffix('.engine') # TensorRT engine file
logger = trt.Logger(trt.Logger.INFO)
if verbose:
logger.min_severity = trt.Logger.Severity.VERBOSE
Expand Down Expand Up @@ -310,6 +310,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')


@torch.no_grad()
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
weights=ROOT / 'yolov5s.pt', # weights path
Expand Down
6 changes: 3 additions & 3 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import math
import platform
import warnings
from collections import namedtuple
from collections import OrderedDict, namedtuple
from copy import copy
from pathlib import Path

Expand Down Expand Up @@ -326,14 +326,14 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
logger = trt.Logger(trt.Logger.INFO)
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(f.read())
bindings = dict()
bindings = OrderedDict()
for index in range(model.num_bindings):
name = model.get_binding_name(index)
dtype = trt.nptype(model.get_binding_dtype(index))
shape = tuple(model.get_binding_shape(index))
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
binding_addrs = {n: d.ptr for n, d in bindings.items()}
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
context = model.create_execution_context()
batch_size = bindings['images'].shape[0]
else: # TensorFlow model (TFLite, pb, saved_model)
Expand Down

0 comments on commit 57fadd4

Please sign in to comment.