Skip to content

Commit

Permalink
[Topi] Tensorcore support for Conv3D (apache#5284)
Browse files Browse the repository at this point in the history
* one weird trick.

* Added schedule knob for different workloads.

* Initial conv3d tensorcore working.

* Added conv3d tensorcore strategy.

* Added layout conversion to tensorcore friendly format for conv2d and conv3d.

* Add target name check.

* Fixed bad names and depthwise check.

* Removed duplicated attribute assignment.
  • Loading branch information
jwfromm authored and zhiics committed Apr 17, 2020
1 parent 09bd5fb commit 04388b4
Show file tree
Hide file tree
Showing 7 changed files with 545 additions and 25 deletions.
52 changes: 48 additions & 4 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .. import strategy
from ..op import OpPattern
from .._tensor import elemwise_shape_func
from ..strategy.generic import is_depthwise_conv2d

# relu
reg.register_broadcast_schedule("nn.relu")
Expand Down Expand Up @@ -139,13 +140,21 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
# pylint: disable=import-outside-toplevel
from tvm import relay
data, weight = inputs
assert desired_layout == 'NCHW', \
"Currently only transformation to NCHW layout is supported."
new_attrs = dict(attrs)
new_attrs['data_layout'] = desired_layout
if desired_layout == 'NCHW':
new_attrs = dict(attrs)
new_attrs['data_layout'] = desired_layout
new_attrs['kernel_layout'] = 'OIHW'
return relay.nn.conv2d(data, weight, **new_attrs)
elif desired_layout == 'NHWC':
# Check for depthwise convolution.
if is_depthwise_conv2d(data.shape, attrs['data_layout'], weight.shape,
attrs['kernel_layout'], attrs['groups']):
new_attrs['kernel_layout'] = 'HWOI'
else:
new_attrs['kernel_layout'] = 'HWIO'
return relay.nn.conv2d(data, weight, **new_attrs)
else:
assert "Layout %s is not yet supported." % (desired_layout)
return None


Expand Down Expand Up @@ -183,6 +192,41 @@ def alter_op_layout_conv3d(attrs, inputs, tinfos, out_type):
"""Alternate the layout of conv3d"""
return topi.nn.conv3d_alter_layout(attrs, inputs, tinfos, out_type)

@reg.register_convert_op_layout("nn.conv3d")
def convert_conv3d(attrs, inputs, tinfos, desired_layout):
"""Convert Layout pass registration for conv3d op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layout : str
The desired layout
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
# pylint: disable=import-outside-toplevel
from tvm import relay
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = desired_layout
if desired_layout == 'NCDHW':
new_attrs['kernel_layout'] = 'OIDHW'
return relay.nn.conv3d(data, weight, **new_attrs)
elif desired_layout == "NDHWC":
new_attrs['kernel_layout'] = 'DHWIO'
return relay.nn.conv3d(data, weight, **new_attrs)
else:
assert "Layout %s is not yet supported" % desired_layout
return None

# conv3d_winograd related operators
reg.register_strategy("nn.contrib_conv3d_winograd_without_weight_transform",
strategy.conv3d_winograd_without_weight_transfrom_strategy)
Expand Down
47 changes: 31 additions & 16 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,16 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
name="conv2d_nhwc.cuda")
N, _, _, _ = get_const_tuple(data.shape)
_, _, CI, CO = get_const_tuple(kernel.shape)
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
(N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
(N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_tensorcore),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore),
name="conv2d_nhwc_tensorcore.cuda",
plevel=20)
if target.target_name == "cuda":
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
(N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
(N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_tensorcore),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore),
name="conv2d_nhwc_tensorcore.cuda",
plevel=20)
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
Expand All @@ -170,7 +171,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
name="dpethwise_nchw.cuda")
name="depthwise_conv2d_nchw.cuda")
elif layout == "NHWC":
assert kernel_layout == "HWOI"
strategy.add_implementation(
Expand Down Expand Up @@ -249,7 +250,7 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
def conv3d_strategy_cuda(attrs, inputs, out_type, target):
"""conv3d cuda strategy"""
strategy = _op.OpStrategy()
_, kernel = inputs
data, kernel = inputs
layout = attrs.data_layout
_, stride_h, stride_w = attrs.get_int_tuple("strides")
_, dilation_h, dilation_w = attrs.get_int_tuple("dilation")
Expand All @@ -268,11 +269,25 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw_winograd),
name="conv3d_ncdhw_winograd.cuda",
plevel=5)
else: # layout == "NDHWC":
strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc),
name="conv3d_ndhwc.cuda",
plevel=10)
else: # layout == "NDHWC":
strategy.add_implementation(
wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc),
name="conv3d_ndhwc.cuda",
plevel=10)
N, _, _, _, _ = get_const_tuple(data.shape)
_, _, _, CI, CO = get_const_tuple(kernel.shape)
if target.target_name == "cuda":
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
(N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
(N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
strategy.add_implementation(
wrap_compute_conv3d(topi.cuda.conv3d_ndhwc_tensorcore),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc_tensorcore),
name="conv3d_ndhwc_tensorcore.cuda",
plevel=20)

if target.target_name == "cuda" and "cudnn" in target.libs:
strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_cudnn, True),
wrap_topi_schedule(topi.cuda.schedule_conv3d_cudnn),
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@
from .rcnn import *
from .sort import *
from .conv2d_nhwc_tensorcore import *
from .conv3d_ndhwc_tensorcore import *
from .dense_tensorcore import *
2 changes: 1 addition & 1 deletion topi/python/topi/cuda/conv2d_nhwc_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp
# convert data type of input feature maps and weights
TransPaddedInput = te.compute(
PaddedInput.shape,
lambda h, w, i, o: PaddedInput[h, w, i, o].astype('float16'))
lambda n, h, w, c: PaddedInput[n, h, w, c].astype('float16'))
TransFilter = te.compute(
Filter.shape, lambda h, w, i, o: Filter[h, w, i, o].astype('float16'))
Output = te.compute(
Expand Down
Loading

0 comments on commit 04388b4

Please sign in to comment.