Skip to content

Commit

Permalink
change rates to dilations
Browse files Browse the repository at this point in the history
  • Loading branch information
maheshambule committed Mar 15, 2020
1 parent 423e5c3 commit b8eb6b9
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 38 deletions.
6 changes: 3 additions & 3 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
struct Dilation2DAttrs : public tvm::AttrsNode<Dilation2DAttrs> {
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> rates;
Array<IndexExpr> dilations;
std::string data_layout;
std::string kernel_layout;
DataType out_dtype;
Expand All @@ -174,8 +174,8 @@ struct Dilation2DAttrs : public tvm::AttrsNode<Dilation2DAttrs> {
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(rates).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use. [rate_height, rate_width]");
TVM_ATTR_FIELD(dilations).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use. [dilation_height, dilation_width]");
TVM_ATTR_FIELD(data_layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,9 @@ def _impl(inputs, attr, params):

if attr['data_format'] in ['NHWC', 'NCHW']:
if 'rates' in attr:
attr['rates'] = (attr['rates'][1], attr['rates'][2])
attr['dilations'] = attr['rates']
if 'dilations' in attr:
attr['rates'] = (attr['dilations'][1], attr['dilations'][2])
attr['dilations'] = (attr['dilations'][1], attr['dilations'][2])
attr['strides'] = (attr['strides'][1], attr['strides'][2])
else:
msg = 'Value {} in attribute "data_format" of operator Dilation2D is ' \
Expand All @@ -454,8 +454,8 @@ def _impl(inputs, attr, params):
in_h = input_shape[2]
in_w = input_shape[3]

dilation_h = attr['rates'][0]
dilation_w = attr['rates'][1]
dilation_h = attr['dilations'][0]
dilation_w = attr['dilations'][1]
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
Expand Down Expand Up @@ -484,7 +484,7 @@ def _impl(inputs, attr, params):
attr['kernel_layout'] = 'HWI' if attr['data_format'] == 'NHWC' else 'IHW'
out = AttrCvt(
op_name='dilation2d',
ignores=['explicit_paddings', 'dilations'],
ignores=['explicit_paddings', 'rates'],
transforms={
'data_format': 'data_layout',
})([inputs[0], inputs[1]], attr)
Expand Down
14 changes: 11 additions & 3 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2377,7 +2377,7 @@ def dilation2d(data,
weight,
strides=(1, 1),
padding=(0, 0),
rates=(1, 1),
dilations=(1, 1),
data_layout="NCHW",
kernel_layout="IHW",
out_dtype=""):
Expand All @@ -2404,24 +2404,32 @@ def dilation2d(data,
----------
data : tvm.relay.Expr
The input data to the operator.
weight : tvm.relay.Expr
The weight expressions.
strides : Optional[Tuple[int]]
The strides of convolution.
padding : Optional[Tuple[int]]
The padding of convolution on both sides of inputs before convolution.
rates : Optional[Tuple[int]]
dilations : Optional[Tuple[int]]
Specifies the dilation rate to be used for dilated convolution.
data_layout : Optional[str]
Layout of the input.
kernel_layout : Optional[str]
Layout of the weight.
out_dtype : Optional[str]
Specifies the output data type.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.dilation2d(data, weight, strides, padding, rates, data_layout,
return _make.dilation2d(data, weight, strides, padding, dilations, data_layout,
kernel_layout, out_dtype)
8 changes: 4 additions & 4 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,12 @@ def wrap_compute_dilation2d(topi_compute, need_data_layout=False, need_out_layou
def _compute_dilation2d(attrs, inputs, out_type):
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
rates = get_const_tuple(attrs.rates)
dilations = get_const_tuple(attrs.dilations)
data_layout = attrs.get_str("data_layout")
out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)
args = [inputs[0], inputs[1], strides, padding, rates]
args = [inputs[0], inputs[1], strides, padding, dilations]
if need_data_layout:
args.append(data_layout)
args.append(out_dtype)
Expand All @@ -467,12 +467,12 @@ def dilation2d_strategy(attrs, inputs, out_type, target):
"""dilation2d_strategy generic strategy"""
logger.warning("dilation2d_strategy is not optimized for this platform.")
strategy = _op.OpStrategy()
rates = get_const_tuple(attrs.rates)
dilations = get_const_tuple(attrs.dilations)
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout

assert layout in ["NCHW", "NHWC"]
(dilation_h, dilation_w) = rates
(dilation_h, dilation_w) = dilations
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")

Expand Down
4 changes: 2 additions & 2 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1063,14 +1063,14 @@ Expr MakeDilation2D(Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> rates,
Array<IndexExpr> dilations,
std::string data_layout,
std::string kernel_layout,
DataType out_dtype) {
auto attrs = make_object<Dilation2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->rates = std::move(rates);
attrs->dilations = std::move(dilations);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_dtype = std::move(out_dtype);
Expand Down
4 changes: 2 additions & 2 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ bool Dilation2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
channels = wshape[0];

dilated_ksize_y = 1 + (wshape[1] - 1) * param->rates[0];
dilated_ksize_x = 1 + (wshape[2] - 1) * param->rates[1];
dilated_ksize_y = 1 + (wshape[1] - 1) * param->dilations[0];
dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilations[1];

// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
Expand Down
4 changes: 2 additions & 2 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2986,7 +2986,7 @@ def test_forward_add_n():


def _test_dilation2d(tensor_in_sizes, filter_in_sizes,
strides, rates, padding):
strides, dilations, padding):
""" One iteration of dilation2d with given shapes and attributes """

total_size_1 = np.prod(tensor_in_sizes)
Expand All @@ -3004,7 +3004,7 @@ def _test_dilation2d(tensor_in_sizes, filter_in_sizes,
nn_ops.dilation2d(in_data,
in_filter,
strides=strides,
rates=rates,
rates=dilations,
padding=padding)

compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
Expand Down
10 changes: 5 additions & 5 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,7 +1228,7 @@ def test_dilation2d_infer_type():
y = relay.nn.dilation2d(x, w,
# kernel_size=(3, 3),
strides=[1, 1, 1, 1],
rates=[1, 1, 1, 1],
dilations=[1, 1, 1, 1],
padding=[0, 0, 0, 0])
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(
Expand All @@ -1240,7 +1240,7 @@ def run_test_dilation2d(indata, kernel, out,
dtype='float32',
strides=[1, 1],
padding=[0, 0],
rates=[1, 1],
dilations=[1, 1],
except_targets=['cuda'],
**attrs):

Expand All @@ -1254,7 +1254,7 @@ def run_test_dilation2d(indata, kernel, out,
w = relay.var("w", shape=kshape, dtype=dtype)
y = relay.nn.dilation2d(x, w,
strides=strides,
rates=rates,
dilations=dilations,
padding=padding,
**attrs)
func = relay.Function([x, w], y)
Expand Down Expand Up @@ -1313,8 +1313,8 @@ def _convert_data(indata, kernel, out, layout=None):
image = [[[[.1], [.2], [.3]], [[.4], [.5], [.6]], [[.7], [.8], [.9]]]]
kernel = [[[.4], [.3]], [[.1], [.2]]]
out = [[[[.7], [.8], [.6]], [[1.0], [1.1], [.9]], [[.8], [.9], [.9]]]]
run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), padding=[1, 1], rates=[2, 2])
run_test_dilation2d(*_convert_data(image, kernel, out), padding=[1, 1], rates=[2, 2],
run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), padding=[1, 1], dilations=[2, 2])
run_test_dilation2d(*_convert_data(image, kernel, out), padding=[1, 1], dilations=[2, 2],
data_layout='NHWC', kernel_layout='HWI')

image = [[[[.1], [.2], [.3], [.4]], [[.5], [.6], [.7], [.8]],
Expand Down
38 changes: 26 additions & 12 deletions topi/python/topi/nn/dilation2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,29 @@
from .util import get_pad_tuple


def dilation2d_nchw(input, filter, stride, padding, dilation, out_dtype=None):
def dilation2d_nchw(input, filter, stride, padding, dilations, out_dtype=None):
"""Dilation2D operator in NCHW layout.
Parameters
----------
input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
filter : tvm.Tensor
3-D with shape [ in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size
dilation: int or a list/tuple of two ints
dilations: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
out_dtype : Optional[str]
Specifies the output data type.
Returns
-------
Output : tvm.Tensor
Expand All @@ -48,16 +55,16 @@ def dilation2d_nchw(input, filter, stride, padding, dilation, out_dtype=None):
if out_dtype is None:
out_dtype = input.dtype
assert isinstance(stride, int) or len(stride) == 2
assert isinstance(dilation, int) or len(dilation) == 2
assert isinstance(dilations, int) or len(dilations) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
if isinstance(dilations, int):
dilation_h = dilation_w = dilations
else:
dilation_h, dilation_w = dilation
dilation_h, dilation_w = dilations

batch, in_channel, in_height, in_width = input.shape
channel, kernel_h, kernel_w = filter.shape
Expand Down Expand Up @@ -88,22 +95,29 @@ def dilation2d_nchw(input, filter, stride, padding, dilation, out_dtype=None):
axis=[ry, rx]), tag="dilation2d_nchw")


def dilation2d_nhwc(input, filter, stride, padding, dilation, out_dtype=None):
def dilation2d_nhwc(input, filter, stride, padding, dilations, out_dtype=None):
"""Dilation2D operator in NHWC layout.
Parameters
----------
input : tvm.Tensor
4-D with shape [batch, in_height, in_width, in_channel]
filter : tvm.Tensor
3-D with shape [filter_height, filter_width, in_channel]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int
Padding size
dilation: int or a list/tuple of two ints
dilations: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
out_dtype : Optional[str]
Specifies the output data type.
Returns
-------
Output : tvm.Tensor
Expand All @@ -112,16 +126,16 @@ def dilation2d_nhwc(input, filter, stride, padding, dilation, out_dtype=None):
if out_dtype is None:
out_dtype = input.dtype
assert isinstance(stride, int) or len(stride) == 2
assert isinstance(dilation, int) or len(dilation) == 2
assert isinstance(dilations, int) or len(dilations) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride

if isinstance(dilation, int):
dilation_h = dilation_w = dilation
if isinstance(dilations, int):
dilation_h = dilation_w = dilations
else:
dilation_h, dilation_w = dilation
dilation_h, dilation_w = dilations

batch, in_height, in_width, in_channel = input.shape
kernel_h, kernel_w, channel = filter.shape
Expand Down

0 comments on commit b8eb6b9

Please sign in to comment.