diff --git a/tools/onnx2bnn/OnnxConverter.cpp b/tools/onnx2bnn/OnnxConverter.cpp index e5de8cf..71219b4 100644 --- a/tools/onnx2bnn/OnnxConverter.cpp +++ b/tools/onnx2bnn/OnnxConverter.cpp @@ -262,14 +262,12 @@ std::vector OnnxConverter::Convert( for (const auto &node : model_proto_.graph().node()) { NodeAttrHelper helper(node); const auto &op = node.op_type(); - if (has_reshape && op != "Gemm") { - throw std::invalid_argument( - "Reshape can only be the last layer or precede a gemm layer " - "for now"); - } - has_reshape = false; VLOG(5) << "Node " << node.name(); if (op == "Conv") { + if (has_reshape) { + throw std::invalid_argument("Reshape before " + op + + " is not supported"); + } VLOG(5) << "Start converting Conv"; auto strides = helper.get("strides", vector{1, 1}); auto pads = helper.get("pads", vector{0, 0, 0, 0}); @@ -319,6 +317,10 @@ std::vector OnnxConverter::Convert( VLOG(5) << "Converting Conv completed"; } else if (op == "AveragePool" || op == "MaxPool" || op == "GlobalAveragePool" || op == "GlobalMaxPool") { + if (has_reshape) { + throw std::invalid_argument("Reshape before " + op + + " is not supported"); + } VLOG(5) << "Start converting Pool"; auto input_name = m(node.input(0)); auto output_name = m(node.output(0)); @@ -407,6 +409,10 @@ std::vector OnnxConverter::Convert( layers_.push_back(layer); VLOG(5) << "Converting Relu completed"; } else if (op == "Add") { + if (has_reshape) { + throw std::invalid_argument("Reshape before " + op + + " is not supported"); + } VLOG(5) << "Start converting Add"; auto input1_name = m(node.input(0)); auto input2_name = m(node.input(1)); @@ -420,6 +426,9 @@ std::vector OnnxConverter::Convert( layers_.push_back(layer); VLOG(5) << "Converting Add completed"; } else if (op == "Gemm") { + if (has_reshape) { + has_reshape = false; + } VLOG(5) << "Start converting Gemm"; auto transA = helper.get("transA", 0); auto transB = helper.get("transB", 0); @@ -478,6 +487,10 @@ std::vector OnnxConverter::Convert( layers_.push_back(layer); VLOG(5) << "Converting Softmax completed"; } else if (op == "Concat") { + if (has_reshape) { + throw std::invalid_argument("Reshape before " + op + + " is not supported"); + } VLOG(5) << "Start converting Concat"; vector concat_inputs_str; for (const auto &onnx_input : node.input()) {