Skip to content

Commit

Permalink
Merge pull request #54 from JDAI-CV/fix_onnx2bnn_bn
Browse files Browse the repository at this point in the history
Refuse the case that bin conv is not followed by a bn, fuse conv bias into bn
  • Loading branch information
daquexian authored Aug 25, 2019
2 parents bcb72c3 + 2d29e6c commit cd3092e
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions tools/onnx2bnn/OnnxConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,17 @@ std::vector<std::string> OnnxConverter::Convert(
expected_binary_conv_outputs.end());
if (binary_conv) {
binary_conv_outputs.push_back(node.output(0));
bool precede_bn = false;
for (const auto &node2 : model_proto_.graph().node()) {
if (node2.op_type() == "BatchNormalization" &&
node2.input(0) == node.output(0)) {
precede_bn = true;
break;
}
}
if (!precede_bn) {
throw std::invalid_argument("Binary convolutions should precede BatchNorm");
}
}
AddConv(m(node.input(0)), strides, pads, dilations, group,
ori_weight_name, bias_name, m(node.output(0)), binary_conv);
Expand Down Expand Up @@ -556,6 +567,13 @@ void OnnxConverter::CalculateCoeff(const ONNX_NAMESPACE::NodeProto &node,
height *
coeff_a_data[i];
}
if (node2.input_size() == 3) {
const auto &bias = onnx_float_tensors_[node2.input(2)];

FORZ(i, coeff_b_data.size()) {
coeff_b_data[i] += coeff_a_data[i] * bias.data[i];
}
}
}
{
FORZ(i, coeff_a_data.size()) { coeff_a_data[i] *= -2; }
Expand Down

0 comments on commit cd3092e

Please sign in to comment.