Skip to content

Commit

Permalink
[Relay] Support deformable Conv2D NHWC (#7075)
Browse files Browse the repository at this point in the history
* [Relay] Support deformable conv2D NHWC

* add test case

* fix lint

* lint
  • Loading branch information
comaniac committed Dec 10, 2020
1 parent ec60a50 commit fcead9f
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 46 deletions.
12 changes: 0 additions & 12 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,6 @@ def schedule_lrn_cuda(attrs, outs, target):
return topi.cuda.schedule_lrn(outs)


def naive_schedule(_, outs, target):
"""Return the naive default schedule"""
if "gpu" in target.keys:
# For GPU, we at least need thread binding to make a valid schedule.
# So the naive schedule cannot be compiled.
raise RuntimeError(
"Cannot compile for GPU targets if no tuned schedule is found. "
"Please see the warning messages above for more information about the failed workloads."
)
return tvm.te.create_schedule(outs[-1].op)


@conv2d_strategy.register(["cuda", "gpu"])
def conv2d_strategy_cuda(attrs, inputs, out_type, target):
"""conv2d cuda strategy"""
Expand Down
36 changes: 28 additions & 8 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@
logger = logging.getLogger("strategy")


def naive_schedule(_, outs, target):
"""Return the naive default schedule"""
if "gpu" in target.keys:
# For GPU, we at least need thread binding to make a valid schedule.
# So the naive schedule cannot be compiled.
raise RuntimeError(
"Cannot compile for GPU targets if no tuned schedule is found. "
"Please see the warning messages above for more information about the failed workloads."
)
return te.create_schedule(outs[-1].op)


def wrap_topi_schedule(topi_schedule):
"""Wrap TOPI schedule which doesn't use attrs"""

Expand Down Expand Up @@ -357,7 +369,6 @@ def wrap_compute_deformable_conv2d(topi_compute):
"""wrap deformable_conv2d topi compute"""

def _compute_deformable_conv2d(attrs, inputs, out_dtype):
assert attrs.data_layout == "NCHW"
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
Expand All @@ -384,15 +395,24 @@ def _compute_deformable_conv2d(attrs, inputs, out_dtype):
@override_native_generic_func("deformable_conv2d_strategy")
def deformable_conv2d_strategy(attrs, inputs, out_type, target):
"""deformable_conv2d generic strategy"""
logger.warning("deformable_conv2d is not optimized for this platform.")
layout = attrs.data_layout
assert layout == "NCHW"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_deformable_conv2d(topi.nn.deformable_conv2d_nchw),
wrap_topi_schedule(topi.generic.schedule_deformable_conv2d_nchw),
name="deformable_conv2d.generic",
)

if layout == "NCHW":
strategy.add_implementation(
wrap_compute_deformable_conv2d(topi.nn.deformable_conv2d_nchw),
wrap_topi_schedule(topi.generic.schedule_deformable_conv2d_nchw),
name="deformable_conv2d_nchw.generic",
)
elif layout == "NHWC":
# This implementation should never be picked by autotvm
strategy.add_implementation(
wrap_compute_deformable_conv2d(topi.nn.deformable_conv2d_nhwc),
wrap_topi_schedule(naive_schedule),
name="deformable_conv2d_nhwc.generic",
)
else:
raise RuntimeError("Layout %s is not supported in deformable conv2d" % layout)
return strategy


Expand Down
59 changes: 49 additions & 10 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1106,18 +1106,54 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
const auto* weight = types[2].as<TensorTypeNode>();

ICHECK(data);
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");

auto* param = attrs.as<AttrType>();
ICHECK_EQ(param->data_layout, "NCHW") << "data layout not supported.";
ICHECK_EQ(param->kernel_layout, "OIHW") << "kernel_layout not supported.";
ICHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);

const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
if (!trans_in_layout.defined()) {
reporter->GetDiagCtx().Emit(
Diagnostic::Error(reporter->GetSpan())
<< "deformable_conv2d only support input layouts that are convertible from NCHW."
<< " The provided layout is: " << in_layout);
return false;
}

const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
if (!trans_kernel_layout.defined()) {
reporter->GetDiagCtx().Emit(
Diagnostic::Error(reporter->GetSpan())
<< "deformable_conv2d only support kernel layouts that are convertible from OIHW."
<< " The provided layout is: " << kernel_layout);
return false;
}

Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
if (!trans_out_layout.defined()) {
reporter->GetDiagCtx().Emit(
Diagnostic::Error(reporter->GetSpan())
<< "deformable_conv2d only support output layouts that are convertible from NCHW."
<< "The provided layout is: " << out_layout);
return false;
}

Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);

IndexExpr channels, dilated_ksize_y, dilated_ksize_x, ksize_y, ksize_x;

// infer weight shape if kernel_size and channels are defiend
if (param->kernel_size.defined() && param->channels.defined()) {
ICHECK_EQ(param->kernel_size.size(), 2);
ICHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape({param->channels, indexdiv(data->shape[1], param->groups),
Array<IndexExpr> wshape({param->channels, indexdiv(dshape_nchw[1], param->groups),
param->kernel_size[0], param->kernel_size[1]});

wshape = trans_kernel_layout.BackwardShape(wshape);
channels = param->channels;
ksize_y = param->kernel_size[0];
ksize_x = param->kernel_size[1];
Expand All @@ -1128,7 +1164,8 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = weight->shape;
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);

if (param->kernel_size.defined()) {
ICHECK_EQ(param->kernel_size.size(), 2);
// check the size
Expand All @@ -1142,8 +1179,8 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
<< "DeformableConv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << wshape;
}
if (!data->shape[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
ICHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1]));
if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1]));
}
channels = wshape[0];
ksize_y = wshape[2];
Expand All @@ -1152,22 +1189,24 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
}
// dilation
Array<IndexExpr> oshape({data->shape[0], channels, 0, 0});
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});

IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1);
oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1);
oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1);
oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1);
DataType out_dtype = param->out_dtype;

// infer offset shape
Array<IndexExpr> offset_shape(
{data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups, oshape[2], oshape[3]});
{dshape_nchw[0], 2 * ksize_y * ksize_x * param->deformable_groups, oshape[2], oshape[3]});
offset_shape = trans_in_layout.BackwardShape(offset_shape);
reporter->Assign(types[1], TensorType(offset_shape, data->dtype));
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}

oshape = trans_out_layout.BackwardShape(oshape);
reporter->Assign(types[3], TensorType(oshape, out_dtype));
return true;
}
Expand Down
49 changes: 33 additions & 16 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,39 +787,57 @@ def verify_yolo_reorg(shape, stride):

@tvm.testing.uses_gpu
def test_deformable_conv2d():
def test_infer_type(batch, in_channel, size, out_channel, deformable_groups, groups):
data_shape = (batch, in_channel, size, size)
def test_infer_type(batch, in_channel, size, out_channel, deformable_groups, groups, layout):
kernel_size = (3, 3)
if layout == "NCHW":
kernel_layout = "OIHW"
data_shape = (batch, in_channel, size, size)
weight_shape = (out_channel, in_channel // groups, kernel_size[0], kernel_size[1])
out_shape = (batch, out_channel, size, size)
offset_shape = (
batch,
2 * kernel_size[0] * kernel_size[1] * deformable_groups,
out_shape[2],
out_shape[3],
)
else:
kernel_layout = "HWIO"
data_shape = (batch, size, size, in_channel)
weight_shape = (kernel_size[0], kernel_size[1], in_channel // groups, out_channel)
out_shape = (batch, size, size, out_channel)
offset_shape = (
batch,
out_shape[1],
out_shape[2],
2 * kernel_size[0] * kernel_size[1] * deformable_groups,
)

data = relay.var("data", shape=data_shape)
offset = relay.var("offset")
kernel = relay.var("kernel")
kernel_size = (3, 3)
y = relay.nn.deformable_conv2d(
data,
offset,
kernel,
strides=(1, 1),
padding=(1, 1),
dilation=(1, 1),
data_layout=layout,
kernel_layout=kernel_layout,
kernel_size=kernel_size,
deformable_groups=deformable_groups,
groups=groups,
channels=out_channel,
)
weight_shape = (out_channel, in_channel // groups, kernel_size[0], kernel_size[1])
out_shape = (batch, out_channel, size, size)
offset_shape = (
batch,
2 * kernel_size[0] * kernel_size[1] * deformable_groups,
out_shape[2],
out_shape[3],
)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(out_shape)
assert yy.checked_type == relay.TensorType(out_shape), yy.checked_type
assert yy.args[1].checked_type == relay.TensorType(offset_shape), yy.args[1].checked_type
assert yy.args[2].checked_type == relay.TensorType(weight_shape)
assert yy.args[2].checked_type == relay.TensorType(weight_shape), yy.args[2].checked_type

test_infer_type(1, 4, 16, 4, 4, 1)
test_infer_type(2, 4, 16, 4, 1, 2)
test_infer_type(1, 4, 16, 4, 4, 1, "NCHW")
test_infer_type(2, 4, 16, 4, 1, 2, "NCHW")
test_infer_type(1, 4, 16, 4, 4, 1, "NHWC")
test_infer_type(2, 4, 16, 4, 1, 2, "NHWC")

def test_run(batch, in_channel, size, out_channel, deformable_groups, groups):
kernel_size = (3, 3)
Expand Down Expand Up @@ -1216,4 +1234,3 @@ def verify_batch_to_space_nd(dshape, block_shape, crops):
test_affine_grid()
test_grid_sample()
test_space_to_batch_nd()
test_batch_to_space_nd()

0 comments on commit fcead9f

Please sign in to comment.