diff --git a/dabnn/layers/BinConv.cpp b/dabnn/layers/BinConv.cpp index ccf7ff1..ab3095a 100644 --- a/dabnn/layers/BinConv.cpp +++ b/dabnn/layers/BinConv.cpp @@ -83,11 +83,13 @@ BinConv::Method BinConv::method() const { return Method::DIRECT_CONV; } else if (gemm_compatible()) { return Method::BGEMM; - } else { + } else if (input_mat->elem_c == 64) { return Method::BCONV_NAIVE; + } else { + return Method::BGEMM_NAIVE; } } else { - if (weight_mat->c == 1) { + if (input_mat->elem_c == 64) { return Method::BCONV_NAIVE; } else { return Method::BGEMM_NAIVE; @@ -128,11 +130,11 @@ bool BinConv::gemm_compatible() const { #ifdef __aarch64__ return true; #else - // If weight_mat->c == 1 (weight_mat has 64 channels), we use bconv_64 + // If input_mat->elem_c == 1 (weight_mat has 64 channels), we use bconv_64 // in aarch64 for the fastest speed, however, bconv_64 is not implemented // in armv7 // TODO: Implement bconv_64 for armv7 - return weight_mat->c != 1; + return input_mat->elem_c != 64; #endif #else return false;