Skip to content

Commit

Permalink
rebase to master and merge 2d/3d unification
Browse files Browse the repository at this point in the history
  • Loading branch information
optima2005 committed Dec 4, 2019
1 parent f138f37 commit 828127e
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 245 deletions.
241 changes: 0 additions & 241 deletions src/runtime/contrib/cudnn/conv_forward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -429,246 +429,5 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo")
w_dim, y_dim, data_dtype, conv_dtype, ret);
});

TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
.set_body([](TVMArgs args, TVMRetValue *ret) {
int mode = args[0];
int format = args[1];
int algo = args[2];
int pad_v[3], stride_v[3], dilation_v[3];
for (int i = 0; i < 3; i++) {
pad_v[i] = args[3 + i];
stride_v[i] = args[6 + i];
dilation_v[i] = args[9 + i];
}
DLTensor *x = args[12];
DLTensor *w = args[13];
DLTensor *y = args[14];
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
// Set Mode
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
// Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// Set Algo
entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);
// Set Ctx
entry_ptr->conv_entry.ctx = x->ctx;
// Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);

// Set Desc
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
3,
pad_v,
stride_v,
dilation_v,
entry_ptr->conv_entry.mode,
entry_ptr->conv_entry.data_type));
// Set Filter
int dim_v[5];
int tensor_stride_v[5];
for (int i = 0; i < 5; i++) {
dim_v[i] = static_cast<int>(w->shape[i]);
}
CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc,
entry_ptr->conv_entry.data_type,
CUDNN_TENSOR_NCHW,
5,
dim_v));
// Set Input
for (int i = 0; i < 5; i++) {
dim_v[i] = static_cast<int>(x->shape[i]);
}
tensor_stride_v[4] = 1;
for (int i = 4; i > 0; i--) {
tensor_stride_v[i - 1] = tensor_stride_v[i] * dim_v[i];
}
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.data_type,
5,
dim_v,
tensor_stride_v));
// Set Output
for (int i = 0; i < 5; i++) {
dim_v[i] = static_cast<int>(y->shape[i]);
}

tensor_stride_v[4] = 1;
for (int i = 4; i > 0; i--) {
tensor_stride_v[i - 1] = tensor_stride_v[i] * dim_v[i];
}
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc,
entry_ptr->conv_entry.data_type,
5,
dim_v,
tensor_stride_v));
// Set workspace
size_t workspace_size = 0;
CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(entry_ptr->handle,
entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.filter_desc,
entry_ptr->conv_entry.conv_desc,
entry_ptr->conv_entry.output_desc,
entry_ptr->conv_entry.fwd_algo,
&workspace_size));
entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
CUDNN_CALL(cudnnConvolutionForward(entry_ptr->handle,
CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
entry_ptr->conv_entry.input_desc,
x->data,
entry_ptr->conv_entry.filter_desc,
w->data,
entry_ptr->conv_entry.conv_desc,
entry_ptr->conv_entry.fwd_algo,
entry_ptr->conv_entry.workspace,
workspace_size,
CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
entry_ptr->conv_entry.output_desc,
y->data));
});


TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.output_shape")
.set_body([](TVMArgs args, TVMRetValue *ret) {
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
int format = args[0];
int pad_v[3], stride_v[3], dilation_v[3];
for (int i = 0; i < 3; i++) {
pad_v[i] = args[1 + i];
stride_v[i] = args[4 + i];
dilation_v[i] = args[7 + i];
}
int x_dim_v[5], w_dim_v[5];
for (int i = 0; i < 5; i++) {
x_dim_v[i] = args[10 + i];
w_dim_v[i] = args[15 + i];
}
void *out_shape = args[20];

// Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// conv desc
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
3,
pad_v,
stride_v,
dilation_v,
CUDNN_CROSS_CORRELATION,
entry_ptr->conv_entry.data_type));
// input desc
int tensor_stride_v[5];
tensor_stride_v[4] = 1;
for (int i = 1; i < 5; i++) {
tensor_stride_v[4 - i] = tensor_stride_v[5 - i] * x_dim_v[i];
}
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc,
CUDNN_DATA_FLOAT,
5,
x_dim_v,
tensor_stride_v));
// filter desc
CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc,
CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW,
5,
w_dim_v));

CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim(entry_ptr->conv_entry.conv_desc,
entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.filter_desc,
5,
static_cast<int*>(out_shape)));
});


TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.find_algo")
.set_body([](TVMArgs args, TVMRetValue *ret) {
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
int format = args[0];
int pad_v[3], stride_v[3], dilation_v[3];
for (int i = 0; i < 3; i++) {
pad_v[i] = args[1 + i];
stride_v[i] = args[4 + i];
dilation_v[i] = args[7 + i];
}
int x_dim_v[5], w_dim_v[5], y_dim_v[5];
for (int i = 0; i < 5; i++) {
x_dim_v[i] = args[10 + i];
w_dim_v[i] = args[15 + i];
y_dim_v[i] = args[20 + i];
}

// Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// conv desc
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
3,
pad_v,
stride_v,
dilation_v,
CUDNN_CROSS_CORRELATION,
entry_ptr->conv_entry.data_type));
// input desc
int tensor_stride_v[5];
tensor_stride_v[4] = 1;
for (int i = 4; i > 0; i--) {
tensor_stride_v[i - 1] = tensor_stride_v[i] * x_dim_v[i];
}
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc,
CUDNN_DATA_FLOAT,
5,
x_dim_v,
tensor_stride_v));
// filter desc
CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc,
CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW,
5,
w_dim_v));

// output desc
tensor_stride_v[4] = 1;
for (int i = 4; i > 0; i--) {
tensor_stride_v[i - 1] = tensor_stride_v[i] * y_dim_v[i];
}
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc,
entry_ptr->conv_entry.data_type,
5,
y_dim_v,
tensor_stride_v));

int returned_algo_count = 0;
cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT];
CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(entry_ptr->handle,
entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.filter_desc,
entry_ptr->conv_entry.conv_desc,
entry_ptr->conv_entry.output_desc,
CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
&returned_algo_count,
perf_results));

const std::vector<std::string> fwd_algo_names{
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM",
"CUDNN_CONVOLUTION_FWD_ALGO_GEMM",
"CUDNN_CONVOLUTION_FWD_ALGO_DIRECT",
"CUDNN_CONVOLUTION_FWD_ALGO_FFT",
"CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING",
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD",
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"
};

auto best_algo = perf_results[0].algo;
LOG(INFO) << "\tCUDNN Found " << returned_algo_count
<< " fwd algorithms, choosing " << fwd_algo_names[best_algo];
for (int i = 0; i < returned_algo_count; ++i) {
LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo]
<< " - time: " << perf_results[i].time << " ms"
<< ", Memory: " << perf_results[i].memory;
}

ret[0] = best_algo;
});

} // namespace contrib
} // namespace tvm
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape,
op_res1 = intrp1.evaluate(func)(data, kernel)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)

# normal conv2d
# normal conv3d
dshape = (1, 3, 5, 224, 224)
kshape = (10, 3, 3, 3, 3)
run_test_conv3d("float32", "float32", 1, dshape, kshape,
Expand Down
3 changes: 2 additions & 1 deletion topi/python/topi/cuda/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', o
dilation_w,
conv_mode=1,
tensor_format=tensor_format,
algo=-1) # let CUDNN choose the best algo
algo=-1, # let CUDNN choose the best algo
conv_dtype=dtype)

if layout == 'NCDHW':
return nn.conv3d_ncdhw(data, kernel, strides, padding, dilation, out_dtype)
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def get_pad_tuple3d(padding, kernel):
pad_w = padding[1] * 2
pad_d = padding[2] * 2
elif isinstance(padding, int):
pad_h = pad_w = pad_h = padding * 2
pad_d = pad_w = pad_h = padding * 2
elif padding == "VALID":
pad_h = 0
pad_w = 0
Expand Down
1 change: 0 additions & 1 deletion topi/python/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
"""
from __future__ import absolute_import as _abs

from .conv3d_ncdhw_python import conv3d_ncdhw_python
from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python
from .conv2d_nhwc_python import conv2d_nhwc_python
Expand Down

0 comments on commit 828127e

Please sign in to comment.