diff --git a/tools/onnx2bnn/OnnxConverter.cpp b/tools/onnx2bnn/OnnxConverter.cpp index 25433b4..6f0da4f 100644 --- a/tools/onnx2bnn/OnnxConverter.cpp +++ b/tools/onnx2bnn/OnnxConverter.cpp @@ -260,12 +260,14 @@ std::vector OnnxConverter::Convert( vector binary_conv_outputs; bool has_reshape = false; for (const auto &node : model_proto_.graph().node()) { - if (has_reshape) { - throw std::invalid_argument( - "Reshape can only be the last layer for now"); - } NodeAttrHelper helper(node); const auto &op = node.op_type(); + if (has_reshape && op != "Gemm") { + throw std::invalid_argument( + "Reshape can only be the last layer or precede a gemm layer " + "for now"); + } + has_reshape = false; VLOG(5) << "Node " << node.name(); if (op == "Conv") { VLOG(5) << "Start converting Conv";