diff --git a/tools/onnx2bnn/onnx2bnn.cpp b/tools/onnx2bnn/onnx2bnn.cpp index 4439375..8d3e82a 100644 --- a/tools/onnx2bnn/onnx2bnn.cpp +++ b/tools/onnx2bnn/onnx2bnn.cpp @@ -6,33 +6,62 @@ #include #include +#include #include #include "NodeAttrHelper.h" #include "OnnxConverter.h" -#include #include "common/log_helper.h" using std::string; using std::vector; void usage(const std::string &filename) { - std::cout << "Usage: " << filename << " onnx_model output_filename" << std::endl; + std::cout + << "Usage: " << filename + << " onnx_model output_filename [--optimize strict|moderate|aggressive]" + << std::endl; + std::cout << "Example: " << filename + << " model.onnx model.dab (The optimization leval will be " + "\"aggressive\")" + << std::endl; + std::cout << "Example: " << filename + << " model.onnx model.dab --optimize strict (The optimization " + "level will be \"strict\")" + << std::endl; } int main(int argc, char **argv) { - argh::parser cmdl(argc, argv); + argh::parser cmdl; + cmdl.add_param("optimize"); + cmdl.parse(argc, argv); google::InitGoogleLogging(cmdl[0].c_str()); FLAGS_alsologtostderr = true; if (!cmdl(2)) { usage(cmdl[0]); return -1; } - bnn::OnnxConverter::Level opt_level = bnn::OnnxConverter::Level::kModerate; - if (cmdl["strict"]) { - opt_level = bnn::OnnxConverter::Level::kStrict; + // flags like 'onnx2bnn --strict' is not supported now + for (const auto flag : cmdl.flags()) { + std::cout << "Invalid flag: " << flag << std::endl; + usage(cmdl[0]); + return -2; } - if (cmdl["aggressive"]) { + + const std::string opt_level_str = + cmdl("optimize").str().empty() ? "aggressive" : cmdl("optimize").str(); + + bnn::OnnxConverter::Level opt_level; + if (opt_level_str == "strict") { + opt_level = bnn::OnnxConverter::Level::kStrict; + } else if (opt_level_str == "moderate") { + opt_level = bnn::OnnxConverter::Level::kModerate; + } else if (opt_level_str == "aggressive") { opt_level = bnn::OnnxConverter::Level::kAggressive; + } else { + std::cout << "Invalid optimization level: " << opt_level_str + << std::endl; + usage(cmdl[0]); + return -3; } ONNX_NAMESPACE::ModelProto model_proto;