Skip to content

Commit

Permalink
Support custom binary conv list
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed Aug 20, 2019
1 parent 2ff86f2 commit 3ee4c79
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 8 deletions.
25 changes: 21 additions & 4 deletions tools/onnx2bnn/OnnxConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,10 @@ std::vector<OnnxConverter::BTensor> OnnxConverter::split(
return outputs;
}

std::vector<std::string> OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto &model_proto,
const std::string &filepath,
const OnnxConverter::Level level) {
std::vector<std::string> OnnxConverter::Convert(
const ONNX_NAMESPACE::ModelProto &model_proto, const std::string &filepath,
const OnnxConverter::Level level,
const std::vector<std::string> &expected_binary_conv_outputs) {
GOOGLE_PROTOBUF_VERIFY_VERSION;

// We recognize binary convolutions in our custom ONNX optimizers.
Expand Down Expand Up @@ -271,7 +272,12 @@ std::vector<std::string> OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto
}

auto ori_weight_name = m(node.input(1));
const bool binary_conv = (node.domain() == "dabnn");
const bool binary_conv =
(node.domain() == "dabnn") ||
(std::find(expected_binary_conv_outputs.begin(),
expected_binary_conv_outputs.end(),
node.output(0)) !=
expected_binary_conv_outputs.end());
if (binary_conv) {
binary_conv_outputs.push_back(node.output(0));
}
Expand Down Expand Up @@ -476,6 +482,17 @@ std::vector<std::string> OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto
throw std::invalid_argument("Unsupported operator " + op);
}
}

for (const auto &expected : expected_binary_conv_outputs) {
if (std::find(binary_conv_outputs.begin(), binary_conv_outputs.end(),
expected) == binary_conv_outputs.end()) {
throw std::invalid_argument(
expected +
" is in the list file but not in the ONNX model, please check "
"your list file");
}
}

auto flat_layers = builder_.CreateVector(layers_);
auto flat_inputs = builder_.CreateVector(inputs);
auto flat_tensors = builder_.CreateVector(tensors_);
Expand Down
2 changes: 1 addition & 1 deletion tools/onnx2bnn/OnnxConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class OnnxConverter {
};
std::vector<std::string> Convert(const ONNX_NAMESPACE::ModelProto &model,
const std::string &filepath,
const Level level=Level::kModerate);
const Level level, const std::vector<std::string> &expected_binary_conv_outputs);
};

template <>
Expand Down
24 changes: 21 additions & 3 deletions tools/onnx2bnn/onnx2bnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void usage(const std::string &filename) {
std::cout << "Usage:" << std::endl;
std::cout << " " << filename
<< " onnx_model output_filename [ --strict | --moderate | "
"--aggressive ] [--verbose]"
"--aggressive ] [--binary-list] [--verbose]"
<< std::endl;
std::cout << std::endl;
std::cout << "Options:" << std::endl;
Expand All @@ -41,6 +41,11 @@ void usage(const std::string &filename) {
"A Conv operator, whose input is got from a Sign op and a Pad op "
"(the order doesn't matter), and weight is got from a Sign op."
<< std::endl;
std::cout
<< " --binary-list A text file containing the **output "
"names** of some convolutions, which will be treated as binary "
"convlutions unconditionally. It is mainly for benchmark purpose."
<< std::endl;
std::cout << std::endl;
std::cout << "Example:" << std::endl;
std::cout << " " << filename
Expand All @@ -55,6 +60,7 @@ void usage(const std::string &filename) {

int main(int argc, char **argv) {
argh::parser cmdl;
cmdl.add_param("--binary-list");
cmdl.parse(argc, argv);
google::InitGoogleLogging(cmdl[0].c_str());
FLAGS_alsologtostderr = true;
Expand Down Expand Up @@ -85,6 +91,18 @@ int main(int argc, char **argv) {
FLAGS_v = 5;
}

const auto binary_list_filepath = cmdl("binary-list").str();
vector<string> expected_binary_conv_outputs;
if (!binary_list_filepath.empty()) {
std::ifstream ifs(binary_list_filepath);
if (ifs.is_open()) {
string binary_conv_output;
while (ifs >> binary_conv_output) {
expected_binary_conv_outputs.push_back(binary_conv_output);
}
}
}

ONNX_NAMESPACE::ModelProto model_proto;
{
std::ifstream ifs(cmdl[1], std::ios::in | std::ios::binary);
Expand All @@ -93,8 +111,8 @@ int main(int argc, char **argv) {
}

bnn::OnnxConverter converter;
const auto binary_conv_outputs =
converter.Convert(model_proto, cmdl[2], opt_level);
const auto binary_conv_outputs = converter.Convert(
model_proto, cmdl[2], opt_level, expected_binary_conv_outputs);

LOG(INFO) << "Conversion completed! Found " << binary_conv_outputs.size()
<< " binary convolutions. Add --verbose to get what they are.";
Expand Down

0 comments on commit 3ee4c79

Please sign in to comment.