Skip to content

Commit

Permalink
Fixing export_onnx and refactor simplify_onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Oct 2, 2021
1 parent f0746c6 commit 4a7a094
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions tools/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
except ImportError:
onnxsim = None

from yolort import models
from yolort.models import YOLOv5


def get_parser():
Expand Down Expand Up @@ -76,44 +76,43 @@ def export_onnx(model, inputs, export_onnx_path, opset_version, enable_simplify)
)

if enable_simplify:
export_onnx_sim_path = export_onnx_path.with_suffix(".sim.onnx")
if onnxsim is None:
raise ImportError("onnx-simplifier not found and is required by yolort")
input_shapes = {"images_tensors": list(inputs[0][0].shape)}
simplify_onnx(export_onnx_path, input_shapes)

print(f"Simplifing with onnx-simplifier {onnxsim.__version__}...")

# load onnx mode
onnx_model = onnx.load(export_onnx_path)
def simplify_onnx(onnx_path, input_shapes):
if onnxsim is None:
raise ImportError("onnx-simplifier not found and is required by yolort")

# conver mode
model_sim, check = onnxsim.simplify(
onnx_model,
input_shapes={"images_tensors": list(inputs[0][0].shape)},
dynamic_input_shape=True,
)
print(f"Simplifing with onnx-simplifier {onnxsim.__version__}...")

assert check, "Simplified ONNX model could not be validated"
# Load onnx mode
onnx_model = onnx.load(onnx_path)

onnx.save(model_sim, export_onnx_sim_path)
# Simlify the ONNX model
model_sim, check = onnxsim.simplify(
onnx_model,
input_shapes=input_shapes,
dynamic_input_shape=True,
)

assert check, "Simplified ONNX model could not be validated"
export_onnx_sim_path = onnx_path.with_suffix(".sim.onnx")
onnx.save(model_sim, export_onnx_sim_path)


def cli_main():
parser = get_parser()
args = parser.parse_args()
print("Command Line Args: {}".format(args))
print(f"Command Line Args: {args}")
checkpoint_path = Path(args.checkpoint_path)
assert checkpoint_path.is_file(), f"Not found checkpoint: {checkpoint_path}"

# input data
images = torch.rand(3, args.image_size, args.image_size)
inputs = ([images],)
images = [torch.rand(3, args.image_size, args.image_size)]
inputs = (images,)

model = models.__dict__[args.arch](
num_classes=args.num_classes,
export_friendly=args.export_friendly,
score_thresh=args.score_thresh,
)
model.load_from_yolov5(checkpoint_path)
model = YOLOv5.load_from_yolov5(checkpoint_path, score_thresh=args.score_thresh)
model.eval()

# export ONNX models
Expand Down

0 comments on commit 4a7a094

Please sign in to comment.