Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][RELAY][PYTORCH]Conv3d_transpose op support added #5737

Merged
merged 4 commits into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ topi.nn
.. autofunction:: topi.nn.conv2d_hwcn
.. autofunction:: topi.nn.depthwise_conv2d_nchw
.. autofunction:: topi.nn.depthwise_conv2d_nhwc
.. autofunction:: topi.nn.conv3d_ncdhw
.. autofunction:: topi.nn.conv3d_transpose_ncdhw
.. autofunction:: topi.nn.fifo_buffer

topi.image
Expand All @@ -233,6 +235,8 @@ topi.generic

.. autofunction:: topi.generic.schedule_conv2d_nchw
.. autofunction:: topi.generic.schedule_depthwise_conv2d_nchw
.. autofunction:: topi.generic.schedule_conv3d_ncdhw
.. autofunction:: topi.generic.schedule_conv3d_transpose_ncdhw
.. autofunction:: topi.generic.schedule_reduce
.. autofunction:: topi.generic.schedule_broadcast
.. autofunction:: topi.generic.schedule_injective
4 changes: 3 additions & 1 deletion docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ This level enables typical convnet models.

tvm.relay.nn.conv2d
tvm.relay.nn.conv2d_transpose
tvm.relay.nn.conv3d
tvm.relay.nn.conv3d_transpose
tvm.relay.nn.dense
tvm.relay.nn.max_pool2d
tvm.relay.nn.max_pool3d
Expand Down Expand Up @@ -225,4 +227,4 @@ This level supports dialect operators.
:nosignatures:

tvm.relay.qnn.op.requantize
tvm.relay.qnn.op.conv2d
tvm.relay.qnn.op.conv2d
76 changes: 76 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,82 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
}
};

/*! \brief Attributes used in transposed convolution operator */
struct Conv3DTransposeAttrs : public tvm::AttrsNode<Conv3DTransposeAttrs> {
IndexExpr channels;
Array<IndexExpr> kernel_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> output_padding;
Array<IndexExpr> dilation;
int groups;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv3DTransposeAttrs, "relay.attrs.Conv3DTransposeAttrs") {
TVM_ATTR_FIELD(channels)
.set_default(NullValue<IndexExpr>())
.describe(
"The dimensionality of the output space"
"i.e. the number of output channels in the convolution.");
TVM_ATTR_FIELD(kernel_size)
.describe("The dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(strides)
.set_default(Array<IndexExpr>({1, 1, 1}))
.describe("The strides of the convolution.");
TVM_ATTR_FIELD(output_padding)
.set_default(Array<IndexExpr>({0, 0, 0}))
.describe(
"Zero-padding added to one side of the output."
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"three int : front, bottom, right will use same padding as back, top, left"
"six int : padding width in the order of (front, top, left, back, bottom, right)");
TVM_ATTR_FIELD(padding)
.set_default(Array<IndexExpr>({0, 0, 0}))
.describe(
"If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"three int : front, bottom, right will use same padding as back, top, left"
"six int : padding width in the order of (front, top, left, back, bottom, right)");
TVM_ATTR_FIELD(dilation)
.set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1).describe(
"Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(data_layout)
.set_default("NCDHW")
.describe(
"Dimension ordering of data. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Convolution is applied on the 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout)
.set_default("OIDHW")
.describe(
"Dimension ordering of data and weight. Can be 'OIDHW', 'OIDHW16o16i', etc."
"'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout)
.set_default("")
.describe(
"Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

/*! \brief Attributes used in 3d winograd convolution operators */
struct Conv3DWinogradAttrs : public tvm::AttrsNode<Conv3DWinogradAttrs> {
int tile_size;
Expand Down
21 changes: 14 additions & 7 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,17 +750,24 @@ def _impl(inputs, input_types):
if isinstance(dilation, _expr.Expr):
dilation = _infer_shape(dilation)

data_layout = "NCHW"
kernel_layout = "OIHW"
conv_op = _op.nn.conv2d

if use_transpose:
assert len(kernel_size) == 2, "ConvTranspose 3D not supported"
conv_op = _op.nn.conv2d_transpose
if len(kernel_size) == 3:
conv_op = _op.nn.conv3d_transpose
else:
conv_op = _op.nn.conv2d_transpose
else:
if len(kernel_size) == 3:
conv_op = _op.nn.conv3d
else:
conv_op = _op.nn.conv2d

if len(kernel_size) == 3:
conv_op = _op.nn.conv3d
data_layout = "NCDHW"
kernel_layout = "OIDHW"
else:
data_layout = "NCHW"
kernel_layout = "OIHW"


conv_out = conv_op(data,
weight,
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,31 @@ def legalize_conv2d_transpose(attrs, inputs, types):
return topi.nn.conv2d_transpose_legalize(attrs, inputs, types)


# conv3d_transpose
reg.register_strategy("nn.conv3d_transpose", strategy.conv3d_transpose_strategy)
reg.register_pattern("nn.conv3d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_legalize("nn.conv3d_transpose")
def legalize_conv3d_transpose(attrs, inputs, types):
"""Legalize conv3d_transpose op.

Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current Transposed convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types

Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.nn.conv3d_transpose_legalize(attrs, inputs, types)


# conv3d
reg.register_strategy("nn.conv3d", strategy.conv3d_strategy)
reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE)
Expand Down
70 changes: 70 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,76 @@ def contrib_conv3d_winograd_without_weight_transform(data,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)

def conv3d_transpose(data,
weight,
strides=(1, 1, 1),
padding=(0, 0, 0),
dilation=(1, 1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCDHW",
kernel_layout="OIDHW",
out_layout="",
output_padding=(0, 0, 0),
out_dtype=""):
r"""3D transpose convolution.

Parameters
----------
data : tvm.relay.Expr
The input data to the operator.

weight : tvm.relay.Expr
The weight expressions.

strides : Optional[Tuple[int]]
The strides of convolution.

padding : Optional[int, Tuple[int]]
The padding of convolution on both sides of inputs before convolution.

dilation : Optional[int, Tuple[int]]
Specifies the dilation rate to be used for dilated convolution.

groups : Optional[int]
Number of groups for grouped convolution.

channels : Optional[int]
Number of output channels of this convolution.

kernel_size : Optional[int, Tuple[int]]
The spatial of the convolution kernel.

data_layout : Optional[str]
Layout of the input.

kernel_layout : Optional[str]
Layout of the weight.

out_layout : Optional[str]
Layout of the output, by default, out_layout is the same as data_layout

out_dtype : Optional[str]
Specifies the output data type for mixed precision conv3d.

Returns
-------
result : tvm.relay.Expr
The computed result.
"""

if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(strides, int):
strides = (strides, strides, strides)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
padding = get_pad_tuple3d(padding)

return _make.conv3d_transpose(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, output_padding, out_dtype)

def conv2d_transpose(data,
weight,
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ class BinaryDenseAttrs(Attrs):
class Conv2DTransposeAttrs(Attrs):
"""Attributes used in Transposed Conv2D operators"""

@tvm._ffi.register_object("relay.attrs.Conv3DTransposeAttrs")
class Conv3DTransposeAttrs(Attrs):
"""Attributes used in Transposed Conv3D operators"""

@tvm._ffi.register_object("relay.attrs.DilateAttrs")
class DilateAttrs(Attrs):
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,24 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
name="conv2d_transpose_nchw.cuda")
return strategy


@conv3d_transpose_strategy.register(["cuda", "gpu"])
def conv3d_transpose_strategy_cuda(attrs, inputs, out_type, target):
"""conv3d_transpose cuda strategy"""
layout = attrs.data_layout
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
assert layout == "NCDHW", "only support ncdhw for now"
assert dilation == (1, 1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv3d_transpose(topi.cuda.conv3d_transpose_ncdhw),
wrap_topi_schedule(topi.cuda.schedule_conv3d_transpose_ncdhw),
name="conv3d_transpose_ncdhw.cuda")
return strategy


@conv3d_strategy.register(["cuda", "gpu"])
def conv3d_strategy_cuda(attrs, inputs, out_type, target):
"""conv3d cuda strategy"""
Expand Down
38 changes: 38 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,44 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target):
name="conv2d_transpose_nchw.generic")
return strategy


# conv3d_transpose
def wrap_compute_conv3d_transpose(topi_compute):
"""wrap conv3d_transpose topi compute"""
def compute_conv3d_transpose(attrs, inputs, out_dtype):
"""Compute definition of conv3d_transpose"""
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)
out = topi_compute(
inputs[0], inputs[1], strides, padding, out_dtype)
output_padding = get_const_tuple(attrs.output_padding)
out = topi.nn.pad(out,
[0, 0, 0, 0, 0],
[0, 0, output_padding[0], output_padding[1], output_padding[2]])
return [out]
return compute_conv3d_transpose


@override_native_generic_func("conv3d_transpose_strategy")
def conv3d_transpose_strategy(attrs, inputs, out_type, target):
"""conv3d_transpose generic strategy"""
logger.warning("conv3d_transpose is not optimized for this platform.")
layout = attrs.data_layout
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
assert layout == "NCDHW", "only support ncdhw for now"
assert dilation == (1, 1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv3d_transpose(topi.nn.conv3d_transpose_ncdhw),
wrap_topi_schedule(topi.generic.schedule_conv3d_transpose_ncdhw),
name="conv3d_transpose_ncdhw.generic")
return strategy

# conv3d
def wrap_compute_conv3d(topi_compute, need_layout=False):
"""wrap conv3d topi compute"""
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,24 @@ def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target):
name="conv2d_transpose_nchw.x86")
return strategy


@conv3d_transpose_strategy.register("cpu")
def conv3d_transpose_strategy_cpu(attrs, inputs, out_type, target):
"""conv3d_transpose x86 strategy"""
layout = attrs.data_layout
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
assert layout == "NCDHW", "only support ncdhw for now"
assert dilation == (1, 1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv3d_transpose(topi.x86.conv3d_transpose_ncdhw),
wrap_topi_schedule(topi.x86.schedule_conv3d_transpose_ncdhw),
name="conv3d_transpose_ncdhw.x86")
return strategy


@conv3d_strategy.register("cpu")
def conv3d_strategy_cpu(attrs, inputs, out_type, target):
"""conv3d generic strategy"""
Expand Down
Loading