Skip to content

Commit

Permalink
Add max pool fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed May 29, 2019
1 parent 1f47ce7 commit 5a4fc6e
Showing 1 changed file with 49 additions and 9 deletions.
58 changes: 49 additions & 9 deletions dabnn/layers/MaxPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,45 @@ void maxpool3x3(const bnn::Mat &input, bnn::Mat &output, const int stride_h = 1,
}
#endif // __ARM_NEON

void max_pool_fallback(const bnn::Mat &input, const size_t pad_h,
const size_t pad_w, const size_t stride_h,
const size_t stride_w, const size_t kernel_h,
const size_t kernel_w, bnn::Mat &output) {
const int output_h =
(input.h + 2 * pad_h - ((kernel_h - 1) + 1)) / stride_h + 1;
const int output_w =
(input.w + 2 * pad_w - ((kernel_w - 1) + 1)) / stride_w + 1;

BNN_ASSERT(input.w * input.c * input.elemsize % 16 == 0, "Not align");
BNN_ASSERT(output.w * output.c * output.elemsize % 16 == 0, "Not align");

int input_y = 0;
FORZ(output_y, output_h) {
int input_x = 0;
FORZ(output_x, output_w) {
FORZ(output_c, input.c) {
float m = -std::numeric_limits<float>::max();
FORZ(kh, kernel_h) {
int y = input_y - pad_h + kh;
const float *input_ptr = input.point<float>(y, 0);
FORZ(kw, kernel_w) {
int x = input_x - pad_w + kw;
if (!(y < 0 || y >= input.h || x < 0 || x >= input.w)) {
const auto val = input_ptr[x * input.c + output_c];
m = std::max(m, val);
}
}
}

output[output_y * output_w * input.c + output_x * input.c +
output_c] = m;
}
input_x += stride_w;
}
input_y += stride_h;
}
}

MaxPool::MaxPool(NetCP net, const std::string &name, css input, css output,
int kernel_h, int kernel_w, int pad_h, int pad_w, int stride_h,
int stride_w)
Expand All @@ -229,22 +268,23 @@ MaxPool::MaxPool(NetCP net, const std::string &name, css input, css output,
}
void MaxPool::forward_impl() const {
#ifdef __ARM_NEON
// std::numeric_limits<float>::min() is the closest value to 0, so we uses
// -max()
pad(*input_mat, pad_h, pad_w, *padded_mat,
-std::numeric_limits<float>::max());
BNN_ASSERT(
(kernel_h == 3 && kernel_w == 3) || (kernel_h == 2 && kernel_w == 2),
"Not supported max_pool");
if (kernel_h == 3 && kernel_w == 3) {
// std::numeric_limits<float>::min() is the closest value to 0, so we uses
// -max()
pad(*input_mat, pad_h, pad_w, *padded_mat,
-std::numeric_limits<float>::max());
maxpool3x3(*padded_mat, *output_mat, stride_h, stride_w);
} else if (kernel_h == 2 && kernel_w == 2) {
pad(*input_mat, pad_h, pad_w, *padded_mat,
-std::numeric_limits<float>::max());
maxpool2x2(*padded_mat, *output_mat, stride_h, stride_w);
} else {
throw std::invalid_argument("Not supported max_pool");
max_pool_fallback(*input_mat, pad_h, pad_w, stride_h, stride_w,
kernel_h, kernel_w, *output_mat);
}
#else
throw std::invalid_argument("Not supported max_pool");
max_pool_fallback(*input_mat, pad_h, pad_w, stride_h, stride_w,
kernel_h, kernel_w, *output_mat);
#endif // __aarch64__
}

Expand Down

0 comments on commit 5a4fc6e

Please sign in to comment.