Skip to content

Commit

Permalink
Automatic TFLite uint8 determination (#4515)
Browse files Browse the repository at this point in the history
* Auto TFLite uint8 detection

This PR automatically determines if TFLite models are uint8 quantized rather than accepting a manual argument.

The quantization determination is based on @zldrobit comment ultralytics/yolov5#1127 (comment)

* Cleanup
  • Loading branch information
MichaelAnderson-AI committed Aug 23, 2021
1 parent 6722788 commit c226c6c
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def run(weights='yolov5s.pt', # model.pt path(s)
hide_labels=False, # hide labels
hide_conf=False, # hide confidences
half=False, # use FP16 half-precision inference
tfl_int8=False, # INT8 quantized TFLite model
):
save_img = not nosave and not source.endswith('.txt') # save inference images
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
Expand Down Expand Up @@ -104,6 +103,7 @@ def wrap_frozen_graph(gd, inputs, outputs):
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
int8 = input_details[0]['dtype'] == np.uint8 # is TFLite quantized uint8 model
imgsz = check_img_size(imgsz, s=stride) # check image size

# Dataloader
Expand Down Expand Up @@ -145,15 +145,15 @@ def wrap_frozen_graph(gd, inputs, outputs):
elif saved_model:
pred = model(imn, training=False).numpy()
elif tflite:
if tfl_int8:
if int8:
scale, zero_point = input_details[0]['quantization']
imn = (imn / scale + zero_point).astype(np.uint8)
imn = (imn / scale + zero_point).astype(np.uint8) # de-scale
interpreter.set_tensor(input_details[0]['index'], imn)
interpreter.invoke()
pred = interpreter.get_tensor(output_details[0]['index'])
if tfl_int8:
if int8:
scale, zero_point = output_details[0]['quantization']
pred = (pred.astype(np.float32) - zero_point) * scale
pred = (pred.astype(np.float32) - zero_point) * scale # re-scale
pred[..., 0] *= imgsz[1] # x
pred[..., 1] *= imgsz[0] # y
pred[..., 2] *= imgsz[1] # w
Expand Down Expand Up @@ -268,7 +268,6 @@ def parse_opt():
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--tfl-int8', action='store_true', help='INT8 quantized TFLite model')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
return opt
Expand Down

0 comments on commit c226c6c

Please sign in to comment.