Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Support deformable Conv2D NHWC #7075

Merged
merged 4 commits into from
Dec 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,6 @@ def schedule_lrn_cuda(attrs, outs, target):
return topi.cuda.schedule_lrn(outs)


def naive_schedule(_, outs, target):
"""Return the naive default schedule"""
if "gpu" in target.keys:
# For GPU, we at least need thread binding to make a valid schedule.
# So the naive schedule cannot be compiled.
raise RuntimeError(
"Cannot compile for GPU targets if no tuned schedule is found. "
"Please see the warning messages above for more information about the failed workloads."
)
return tvm.te.create_schedule(outs[-1].op)


@conv2d_strategy.register(["cuda", "gpu"])
def conv2d_strategy_cuda(attrs, inputs, out_type, target):
"""conv2d cuda strategy"""
Expand Down
36 changes: 28 additions & 8 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@
logger = logging.getLogger("strategy")


def naive_schedule(_, outs, target):
"""Return the naive default schedule"""
if "gpu" in target.keys:
# For GPU, we at least need thread binding to make a valid schedule.
# So the naive schedule cannot be compiled.
raise RuntimeError(
"Cannot compile for GPU targets if no tuned schedule is found. "
"Please see the warning messages above for more information about the failed workloads."
)
return te.create_schedule(outs[-1].op)


def wrap_topi_schedule(topi_schedule):
"""Wrap TOPI schedule which doesn't use attrs"""

Expand Down Expand Up @@ -357,7 +369,6 @@ def wrap_compute_deformable_conv2d(topi_compute):
"""wrap deformable_conv2d topi compute"""

def _compute_deformable_conv2d(attrs, inputs, out_dtype):
assert attrs.data_layout == "NCHW"
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
Expand All @@ -384,15 +395,24 @@ def _compute_deformable_conv2d(attrs, inputs, out_dtype):
@override_native_generic_func("deformable_conv2d_strategy")
def deformable_conv2d_strategy(attrs, inputs, out_type, target):
"""deformable_conv2d generic strategy"""
logger.warning("deformable_conv2d is not optimized for this platform.")
layout = attrs.data_layout
assert layout == "NCHW"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_deformable_conv2d(topi.nn.deformable_conv2d_nchw),
wrap_topi_schedule(topi.generic.schedule_deformable_conv2d_nchw),
name="deformable_conv2d.generic",
)

if layout == "NCHW":
strategy.add_implementation(
wrap_compute_deformable_conv2d(topi.nn.deformable_conv2d_nchw),
wrap_topi_schedule(topi.generic.schedule_deformable_conv2d_nchw),
name="deformable_conv2d_nchw.generic",
)
elif layout == "NHWC":
# This implementation should never be picked by autotvm
strategy.add_implementation(
wrap_compute_deformable_conv2d(topi.nn.deformable_conv2d_nhwc),
wrap_topi_schedule(naive_schedule),
name="deformable_conv2d_nhwc.generic",
)
else:
raise RuntimeError("Layout %s is not supported in deformable conv2d" % layout)
return strategy


Expand Down
59 changes: 49 additions & 10 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1106,18 +1106,54 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
const auto* weight = types[2].as<TensorTypeNode>();

ICHECK(data);
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");

auto* param = attrs.as<AttrType>();
ICHECK_EQ(param->data_layout, "NCHW") << "data layout not supported.";
ICHECK_EQ(param->kernel_layout, "OIHW") << "kernel_layout not supported.";
ICHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);

const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
if (!trans_in_layout.defined()) {
reporter->GetDiagCtx().Emit(
Diagnostic::Error(reporter->GetSpan())
<< "deformable_conv2d only support input layouts that are convertible from NCHW."
<< " The provided layout is: " << in_layout);
return false;
}

const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
if (!trans_kernel_layout.defined()) {
reporter->GetDiagCtx().Emit(
Diagnostic::Error(reporter->GetSpan())
<< "deformable_conv2d only support kernel layouts that are convertible from OIHW."
<< " The provided layout is: " << kernel_layout);
return false;
}

Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
if (!trans_out_layout.defined()) {
reporter->GetDiagCtx().Emit(
Diagnostic::Error(reporter->GetSpan())
<< "deformable_conv2d only support output layouts that are convertible from NCHW."
<< "The provided layout is: " << out_layout);
return false;
}

Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);

IndexExpr channels, dilated_ksize_y, dilated_ksize_x, ksize_y, ksize_x;

// infer weight shape if kernel_size and channels are defiend
if (param->kernel_size.defined() && param->channels.defined()) {
ICHECK_EQ(param->kernel_size.size(), 2);
ICHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape({param->channels, indexdiv(data->shape[1], param->groups),
Array<IndexExpr> wshape({param->channels, indexdiv(dshape_nchw[1], param->groups),
param->kernel_size[0], param->kernel_size[1]});

wshape = trans_kernel_layout.BackwardShape(wshape);
channels = param->channels;
ksize_y = param->kernel_size[0];
ksize_x = param->kernel_size[1];
Expand All @@ -1128,7 +1164,8 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = weight->shape;
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);

if (param->kernel_size.defined()) {
ICHECK_EQ(param->kernel_size.size(), 2);
// check the size
Expand All @@ -1142,8 +1179,8 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
<< "DeformableConv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << wshape;
}
if (!data->shape[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
ICHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1]));
if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1]));
}
channels = wshape[0];
ksize_y = wshape[2];
Expand All @@ -1152,22 +1189,24 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
}
// dilation
Array<IndexExpr> oshape({data->shape[0], channels, 0, 0});
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});

IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1);
oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1);
oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1);
oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1);
DataType out_dtype = param->out_dtype;

// infer offset shape
Array<IndexExpr> offset_shape(
{data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups, oshape[2], oshape[3]});
{dshape_nchw[0], 2 * ksize_y * ksize_x * param->deformable_groups, oshape[2], oshape[3]});
offset_shape = trans_in_layout.BackwardShape(offset_shape);
reporter->Assign(types[1], TensorType(offset_shape, data->dtype));
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}

oshape = trans_out_layout.BackwardShape(oshape);
reporter->Assign(types[3], TensorType(oshape, out_dtype));
return true;
}
Expand Down
49 changes: 33 additions & 16 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,39 +787,57 @@ def verify_yolo_reorg(shape, stride):

@tvm.testing.uses_gpu
def test_deformable_conv2d():
def test_infer_type(batch, in_channel, size, out_channel, deformable_groups, groups):
data_shape = (batch, in_channel, size, size)
def test_infer_type(batch, in_channel, size, out_channel, deformable_groups, groups, layout):
kernel_size = (3, 3)
if layout == "NCHW":
kernel_layout = "OIHW"
data_shape = (batch, in_channel, size, size)
weight_shape = (out_channel, in_channel // groups, kernel_size[0], kernel_size[1])
out_shape = (batch, out_channel, size, size)
offset_shape = (
batch,
2 * kernel_size[0] * kernel_size[1] * deformable_groups,
out_shape[2],
out_shape[3],
)
else:
kernel_layout = "HWIO"
data_shape = (batch, size, size, in_channel)
weight_shape = (kernel_size[0], kernel_size[1], in_channel // groups, out_channel)
out_shape = (batch, size, size, out_channel)
offset_shape = (
batch,
out_shape[1],
out_shape[2],
2 * kernel_size[0] * kernel_size[1] * deformable_groups,
)

data = relay.var("data", shape=data_shape)
offset = relay.var("offset")
kernel = relay.var("kernel")
kernel_size = (3, 3)
y = relay.nn.deformable_conv2d(
data,
offset,
kernel,
strides=(1, 1),
padding=(1, 1),
dilation=(1, 1),
data_layout=layout,
kernel_layout=kernel_layout,
kernel_size=kernel_size,
deformable_groups=deformable_groups,
groups=groups,
channels=out_channel,
)
weight_shape = (out_channel, in_channel // groups, kernel_size[0], kernel_size[1])
out_shape = (batch, out_channel, size, size)
offset_shape = (
batch,
2 * kernel_size[0] * kernel_size[1] * deformable_groups,
out_shape[2],
out_shape[3],
)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(out_shape)
assert yy.checked_type == relay.TensorType(out_shape), yy.checked_type
assert yy.args[1].checked_type == relay.TensorType(offset_shape), yy.args[1].checked_type
assert yy.args[2].checked_type == relay.TensorType(weight_shape)
assert yy.args[2].checked_type == relay.TensorType(weight_shape), yy.args[2].checked_type

test_infer_type(1, 4, 16, 4, 4, 1)
test_infer_type(2, 4, 16, 4, 1, 2)
test_infer_type(1, 4, 16, 4, 4, 1, "NCHW")
test_infer_type(2, 4, 16, 4, 1, 2, "NCHW")
test_infer_type(1, 4, 16, 4, 4, 1, "NHWC")
test_infer_type(2, 4, 16, 4, 1, 2, "NHWC")

def test_run(batch, in_channel, size, out_channel, deformable_groups, groups):
kernel_size = (3, 3)
Expand Down Expand Up @@ -1216,4 +1234,3 @@ def verify_batch_to_space_nd(dshape, block_shape, crops):
test_affine_grid()
test_grid_sample()
test_space_to_batch_nd()
test_batch_to_space_nd()