Skip to content

Commit

Permalink
Add extremesoft level
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed May 15, 2019
1 parent d55df70 commit 0b3e8e5
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
7 changes: 5 additions & 2 deletions tools/onnx2bnn/OnnxConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ vector<bin_t> bitpack(const float *data, Shape shape) {

void OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto &model_proto,
const std::string &filepath,
const bool strict) {
const OnnxConverter::Level level) {
GOOGLE_PROTOBUF_VERIFY_VERSION;

// We recognize binary convolutions in our custom ONNX optimizers.
Expand All @@ -223,9 +223,12 @@ void OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto &model_proto,
// https://github.com/daquexian/onnx/blob/optimizer_for_bnn/onnx/optimizer/passes/dabnn_bconv_strict.h
// for details.
vector<string> optimizers{"eliminate_nop_pad", "dabnn_bconv_strict"};
if (!strict) {
if (level == Level::kSoft || level == Level::kExtremeSoft) {
optimizers.push_back("dabnn_bconv_soft");
}
if (level == Level::kExtremeSoft) {
optimizers.push_back("dabnn_bconv_extreme_soft");
}
// model_proto is only used here. Please use the member variable model_proto_
// in the following code
model_proto_ = ONNX_NAMESPACE::optimization::Optimize(
Expand Down
7 changes: 6 additions & 1 deletion tools/onnx2bnn/OnnxConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,14 @@ class OnnxConverter {
}

public:
enum class Level {
kStrict,
kSoft,
kExtremeSoft,
};
void Convert(const ONNX_NAMESPACE::ModelProto &model,
const std::string &filepath,
const bool strict=false);
const Level level=Level::kSoft);
};

template <>
Expand Down
16 changes: 12 additions & 4 deletions tools/onnx2bnn/onnx2bnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@ void usage(const std::string &filename) {
}

int main(int argc, char **argv) {
argh::parser cmdl(argv);
argh::parser cmdl(argc, argv);
google::InitGoogleLogging(cmdl[0].c_str());
FLAGS_alsologtostderr = true;
if (argc < 3) {
usage(argv[0]);
if (!cmdl(2)) {
usage(cmdl[0]);
return -1;
}
bnn::OnnxConverter::Level opt_level = bnn::OnnxConverter::Level::kSoft;
if (cmdl["strict"]) {
opt_level = bnn::OnnxConverter::Level::kStrict;
}
if (cmdl["extremesoft"]) {
opt_level = bnn::OnnxConverter::Level::kExtremeSoft;
}

ONNX_NAMESPACE::ModelProto model_proto;
{
std::ifstream ifs(cmdl[1], std::ios::in | std::ios::binary);
Expand All @@ -35,7 +43,7 @@ int main(int argc, char **argv) {
}

bnn::OnnxConverter converter;
converter.Convert(model_proto, cmdl[2], cmdl["strict"]);
converter.Convert(model_proto, cmdl[2], opt_level);

google::protobuf::ShutdownProtobufLibrary();
return 0;
Expand Down

0 comments on commit 0b3e8e5

Please sign in to comment.