Skip to content

Commit

Permalink
merge fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
maheshambule committed Mar 13, 2020
1 parent 30b6d07 commit 636694e
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 198 deletions.
142 changes: 0 additions & 142 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,148 +431,6 @@ bool Dilation2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}

template <typename AttrType>
bool Dilation2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCHW("NCHW");
static const Layout kOIHW("IHW");

const AttrType* param = attrs.as<AttrType>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);

const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Dilation2D only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;

const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Dilation2D only support kernel layouts that are convertible from OIHW."
<< " But got " << kernel_layout;

Layout out_layout(param->data_layout);
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Dilation2D only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;

Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);

IndexExpr channels, dilated_ksize_y, dilated_ksize_x;

// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
channels = wshape[0];

dilated_ksize_y = 1 + (wshape[1] - 1) * param->rates[0];
dilated_ksize_x = 1 + (wshape[2] - 1) * param->rates[1];

// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
if (!dshape_nchw[2].as<tir::AnyNode>()) {
oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y,
param->strides[0]) + 1);
} else {
oshape.Set(2, dshape_nchw[2]);
}

if (!dshape_nchw[3].as<tir::AnyNode>()) {
oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x,
param->strides[1]) + 1);
} else {
oshape.Set(3, dshape_nchw[3]);
}

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}

template <typename AttrType>
bool Dilation2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCHW("NCHW");
static const Layout kOIHW("IHW");

const AttrType* param = attrs.as<AttrType>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);

const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Dilation2D only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;

const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Dilation2D only support kernel layouts that are convertible from OIHW."
<< " But got " << kernel_layout;

Layout out_layout(param->data_layout);
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Dilation2D only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;

Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);

IndexExpr channels, dilated_ksize_y, dilated_ksize_x;

// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
channels = wshape[0];

dilated_ksize_y = 1 + (wshape[1] - 1) * param->rates[0];
dilated_ksize_x = 1 + (wshape[2] - 1) * param->rates[1];

// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
if (!dshape_nchw[2].as<tir::AnyNode>()) {
oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y,
param->strides[0]) + 1);
} else {
oshape.Set(2, dshape_nchw[2]);
}

if (!dshape_nchw[3].as<tir::AnyNode>()) {
oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x,
param->strides[1]) + 1);
} else {
oshape.Set(3, dshape_nchw[3]);
}

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_NN_CONVOLUTION_H_
4 changes: 4 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2985,6 +2985,9 @@ def test_forward_add_n():
_test_forward_add_n(in5)


#######################################################################
# Dilation2D
# ----------------------
def _test_dilation2d(tensor_in_sizes, filter_in_sizes,
strides, dilations, padding):
""" One iteration of dilation2d with given shapes and attributes """
Expand Down Expand Up @@ -3029,6 +3032,7 @@ def test_forward_dilation():
_test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 2, 2, 1], "SAME")
_test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 1, 2, 1], "VALID")


# #######################################################################
# Main
# ----
Expand Down
56 changes: 0 additions & 56 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,41 +652,13 @@ def schedule_batch_matmul(outs):

def schedule_dilation2d_nchw(outs):
"""Schedule for dilation2d
Parameters
----------
outs : Array of Tensor
The computation graph description of dilation2d
in the format of an array of tensors.
Returns
-------
sch : Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)

def schedule_dilation2d_nhwc(outs):
"""Schedule for dilation2d
Parameters
----------
outs : Array of Tensor
The computation graph description of dilation2d
in the format of an array of tensors.
Returns
-------
sch : Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)

def schedule_dilation2d_nchw(outs):
"""Schedule for dilation2d
Parameters
----------
outs : Array of Tensor
The computation graph description of dilation2d
in the format of an array of tensors.
Returns
-------
sch : Schedule
Expand All @@ -697,41 +669,13 @@ def schedule_dilation2d_nchw(outs):

def schedule_dilation2d_nhwc(outs):
"""Schedule for dilation2d
Parameters
----------
outs : Array of Tensor
The computation graph description of dilation2d
in the format of an array of tensors.
Returns
-------
sch : Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)

def schedule_dilation2d_nchw(outs):
"""Schedule for dilation2d
Parameters
----------
outs : Array of Tensor
The computation graph description of dilation2d
in the format of an array of tensors.
Returns
-------
sch : Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)

def schedule_dilation2d_nhwc(outs):
"""Schedule for dilation2d
Parameters
----------
outs : Array of Tensor
The computation graph description of dilation2d
in the format of an array of tensors.
Returns
-------
sch : Schedule
Expand Down

0 comments on commit 636694e

Please sign in to comment.