diff --git a/export.py b/export.py index c56a0a99a635..0e8e4242f487 100644 --- a/export.py +++ b/export.py @@ -277,8 +277,6 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te try: import tensorflow as tf - from models.tf import representative_dataset_gen - LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') batch_size, ch, *imgsz = list(im.shape) # BCHW f = str(file).replace('.pt', '-fp16.tflite') @@ -288,6 +286,8 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te converter.target_spec.supported_types = [tf.float16] converter.optimizations = [tf.lite.Optimize.DEFAULT] if int8: + from models.tf import representative_dataset_gen + check_requirements(('flatbuffers==1.12',)) # https://github.com/ultralytics/yolov5/issues/5707 dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib) converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]