Skip to content

Commit

Permalink
Add new unittests for gIOHW format in conv_transpose_mkldnn_op (Paddl…
Browse files Browse the repository at this point in the history
…ePaddle#37344)

* Add new unittests

* Replace I with O channel for filter groups

* Undo changes affecting other operators

* Fix oneDNN namespace typo

* Fix code format error
  • Loading branch information
Silv3S authored and Zjq9409 committed Dec 10, 2021
1 parent c3f0ab0 commit 86df4f6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
10 changes: 4 additions & 6 deletions paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@ using Tensor = framework::Tensor;
using framework::DataLayout;

inline dnnl::memory::dims GetWeightsTz(const Tensor* filter, const int groups) {
auto iohw_weights_tz = framework::vectorize(filter->dims());
auto weights_tz = iohw_weights_tz;

// IOHW -> OIHW
weights_tz[0] = iohw_weights_tz[1];
weights_tz[1] = iohw_weights_tz[0];
auto weights_tz = framework::vectorize(filter->dims());
int g = std::max(groups, 1);
int g_dim = (g > 1) ? 1 : 0;
platform::GetGroupConvWeightsTz(weights_tz, g);
// gIOHW -> gOIHW || IOHW -> OIHW
std::swap(weights_tz[g_dim + 0], weights_tz[g_dim + 1]);
return weights_tz;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,27 @@ def init_test_case(self):
self.padding_algorithm = "EXPLICIT"


class TestMKLDNNWithGroups(TestConv2DTransposeMKLDNNOp):
def init_test_case(self):
TestConv2DTransposeMKLDNNOp.init_test_case(self)
self.pad = [1, 1]
self.groups = 2
self.input_size = [2, 4, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 3, 3, 3]


class TestMKLDNNWithGroups_NHWC(TestConv2DTransposeMKLDNNOp):
def init_test_case(self):
TestConv2DTransposeMKLDNNOp.init_test_case(self)
self.pad = [1, 1]
self.groups = 2
self.input_size = [2, 5, 5, 4] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 3, 3, 3]
self.data_format = 'NHWC'


if __name__ == '__main__':
enable_static()
unittest.main()

0 comments on commit 86df4f6

Please sign in to comment.