Skip to content

Commit

Permalink
Workaround to make conv2d_transpose compilation for CUDA work
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed Dec 6, 2019
1 parent ba9d96b commit 10f4b18
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
16 changes: 16 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,14 @@ def test_forward_convolution():
'NCHW', [4, 124, 17, 17])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 12, 17, 17])
# kernel 2x2, strides (2,2)
_test_convolution('conv_transpose', [4, 19, 8, 8], [2, 2, 19, 19], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 19, 16, 16])
_test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 12, 32], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 12, 16, 16])
# output channel is 1
_test_convolution('conv_transpose', [1, 19, 8, 8], [1, 1, 1, 19], [1, 1], [1, 1], 'VALID',
'NCHW', [1, 1, 8, 8])

_test_convolution('conv', [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution('conv', [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
Expand All @@ -386,6 +394,14 @@ def test_forward_convolution():
'NHWC', [4, 17, 17, 124])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 12])
# kernel 2x2, strides (2,2)
_test_convolution('conv_transpose', [4, 8, 8, 19], [2, 2, 19, 19], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 16, 16, 19])
_test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 12, 32], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 16, 16, 12])
# output channel is 1
_test_convolution('conv_transpose', [1, 8, 8, 19], [1, 1, 1, 19], [1, 1], [1, 1], 'VALID',
'NHWC', [1, 8, 8, 1])


#######################################################################
Expand Down
20 changes: 18 additions & 2 deletions topi/python/topi/cuda/conv2d_transpose_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,24 @@ def _callback(op):
cfg.define_knob("unroll_explicit", [0, 1])

if cfg.is_fallback:
N, F, Y, X = get_const_tuple(conv.shape)
_fallback_schedule(N, F, Y, X)
ko = int(kernel.shape[1])
kh = int(kernel.shape[2])
kw = int(kernel.shape[3])
stride_h, stride_w = cfg.stride
# Workaround to make CUDA compilation work. Issue #4470
# TODO make _fallback_schedule work for all kernel/strides combinations
# after issue #4470 is resolved
do_fallback = True
if ko == 1:
do_fallback = False
elif (kh, kw) == (1, 1):
do_fallback = True
elif (kh, kw) == (stride_h, stride_w):
do_fallback = False

if do_fallback:
N, F, Y, X = get_const_tuple(conv.shape)
_fallback_schedule(N, F, Y, X)

##### space definition end #####

Expand Down

0 comments on commit 10f4b18

Please sign in to comment.