Skip to content

Commit

Permalink
[ConvertLayout] Support QNN ops. apache#5066
Browse files Browse the repository at this point in the history
  • Loading branch information
shoubhik committed May 11, 2020
1 parent 1394d04 commit a1d12cc
Show file tree
Hide file tree
Showing 19 changed files with 324 additions and 102 deletions.
10 changes: 1 addition & 9 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,22 +268,14 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
"""

from tvm import relay
data_layout = attrs['data_layout']
kernel_layout = attrs['kernel_layout']
data, weight = inputs
assert desired_layout == 'NCHW', \
"Currently only transformation to NCHW layout is supported."
if desired_layout == 'NCHW':
new_attrs = dict(attrs)
new_attrs['data_layout'] = desired_layout
new_attrs['kernel_layout'] = 'OIHW'

if data_layout == 'NHWC' and kernel_layout == 'HWIO':
# Convert (NHWC, HWIO) to (NCHW, OIHW)
return relay.nn.conv2d(data, weight, **new_attrs)
if data_layout == 'NHWC' and kernel_layout == 'HWOI':
# Convert (NHWC, HWOI) to (NCHW, OIHW). Depthwise conv2d.
return relay.nn.conv2d(data, weight, **new_attrs)
return relay.nn.conv2d(data, weight, **new_attrs)
return None

reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/qnn/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
from __future__ import absolute_import as _abs
from .qnn import *
from .op import register_qnn_legalize
from . import legalizations
from . import legalizations, layout_conversions
51 changes: 51 additions & 0 deletions python/tvm/relay/qnn/op/layout_conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Convert layout related registration"""
from __future__ import absolute_import

from tvm.relay.op import op as reg


@reg.register_convert_op_layout("qnn.conv2d")
def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layout):
"""Convert Layout pass registration for QNN conv2d op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layout : str
The desired layout
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
# pylint: disable=import-outside-toplevel
from tvm import relay
assert desired_layout == 'NCHW', \
"Currently only transformation to NCHW layout is supported."
if desired_layout == 'NCHW':
new_attrs = dict(attrs)
new_attrs['data_layout'] = desired_layout
new_attrs['kernel_layout'] = 'OIHW'
return relay.qnn.op.conv2d(*inputs, **new_attrs)
return None
2 changes: 1 addition & 1 deletion src/relay/op/nn/bitserial.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ template <typename T>
Array<Array<Layout>> BinaryConv2DInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
const Array<tvm::relay::Type>& old_in_types) {
const T* params = attrs.as<T>();

// We always make other operators to fit the layouts of convolution layers
Expand Down
14 changes: 0 additions & 14 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,6 @@ namespace relay {
// relay.nn.conv2d
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);

template<typename T>
Array<Array<Layout> > Conv2DInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const T* params = attrs.as<T>();

// We always make other operators to fit the layouts of convolution layers
// So this inference ignores all inputs
return Array<Array<Layout> >{{params->data_layout, params->kernel_layout},
{params->out_layout == "" ?
params->data_layout : params->out_layout}};
}

// Positional relay function to create conv2d operator
// used by frontend FFI.
Expand Down
16 changes: 16 additions & 0 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,22 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}

template<typename T>
Array<Array<Layout> > Conv2DInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type> &old_in_types) {
const T* params = attrs.as<T>();

// We always make other operators to fit the layouts of convolution layers
// So this inference ignores all inputs
return Array<Array<Layout> >{{params->data_layout, params->kernel_layout},
{params->out_layout == "" ?
params->data_layout : params->out_layout}};
}


} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_NN_CONVOLUTION_H_
12 changes: 9 additions & 3 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,10 @@ Array<Array<Layout> > PReluInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const Array<tvm::relay::Type> &old_in_types) {

CHECK_EQ(old_in_layouts.size(), 2U);
CHECK_EQ(old_in_shapes.size(), 2U);
CHECK_EQ(old_in_types.size(), 2U);
Layout data_layout = old_in_layouts[0];
if (new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 2U);
Expand Down Expand Up @@ -619,9 +619,15 @@ TVM_REGISTER_NODE_TYPE(BatchNormAttrs);
Array<Array<Layout>> BatchNormInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
const Array<tvm::relay::Type>& old_in_types) {
BatchNormAttrs* param = const_cast<BatchNormAttrs*>(attrs.as<BatchNormAttrs>());

Array<Array<IndexExpr>> old_in_shapes;
for (auto old_in_t : old_in_types) {
CHECK(old_in_t.as<TensorTypeNode>());
old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
}

size_t axis =
param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis);

Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Array<Array<Layout> > PadInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const Array<tvm::relay::Type> &old_in_types) {
// NOTE: Discard "const" qualifier here.
PadAttrs *params = const_cast<PadAttrs*>(attrs.as<PadAttrs>());

Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Array<Array<Layout> > Pool2DInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const Array<tvm::relay::Type> &old_in_types) {
// NOTE: Discard "const" qualifier here.
T *params = const_cast<T*>(attrs.as<T>());

Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/upsampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Array<Array<Layout> > UpsamplingInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const Array<tvm::relay::Type> &old_in_types) {
// NOTE: Discard "const" qualifier here.
T *params = const_cast<T*>(attrs.as<T>());

Expand Down
7 changes: 6 additions & 1 deletion src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,16 @@ Array<Integer> GetExcludeAxes(size_t indim,
Array<Array<Layout>> ReduceInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
ReduceAttrs* params = const_cast<ReduceAttrs*>(attrs.as<ReduceAttrs>());

// Get the reduce axes.
Array<Array<IndexExpr>> old_in_shapes;
for (auto old_in_t : old_in_types) {
CHECK(old_in_t.as<TensorTypeNode>());
old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
}
uint32_t indim = old_in_shapes[0].size();
auto r_axes = GetReduceAxes(indim, params->axis, params->exclude);

Expand Down
55 changes: 6 additions & 49 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,54 +278,6 @@ Array<Tensor> ConcatenateCompute(const Attrs& attrs,
return { topi::concatenate(inputs, param->axis) };
}

Array<Array<Layout>> ConcatenateLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
ConcatenateAttrs* param = const_cast<ConcatenateAttrs*>(attrs.as<ConcatenateAttrs>());

size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() :
static_cast<size_t>(param->axis);

Layout ret;
bool is_new_layout_selected = false;
if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
// If all the new input layouts are same, the new in layout gets selected. For axis, the new
// axis in the new layout is identified. The param->axis is then modified on the fly to conform
// to the new input layout.
const auto& concate_dim = old_in_layouts[0][axis];
bool all_input_layouts_same = true;
for (auto new_layout : new_in_layouts) {
if (!new_layout.Equals(new_in_layouts[0])) {
all_input_layouts_same = false;
}
}
if (all_input_layouts_same) {
auto new_index = new_in_layouts[0].IndexOf(concate_dim);
ret = new_in_layouts[0];
param->axis = new_index;
is_new_layout_selected = true;
}
}

if (!is_new_layout_selected) {
// this function is called on the original correct relay ir
for (size_t i = 0; i < old_in_layouts.size(); ++i) {
if (old_in_layouts[i].defined()) {
ret = old_in_layouts[i];
break;
}
}

if (ret.ndim() <= axis || !ret[axis].IsPrimal()) {
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
}
}

return Array<Array<Layout> > {Array<Layout>(old_in_layouts.size(), ret), {ret}};
}

Expr MakeConcatenate(Expr data,
int axis) {
auto attrs = make_node<ConcatenateAttrs>();
Expand Down Expand Up @@ -1933,9 +1885,14 @@ Array<Array<Layout> > StridedSliceInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
const Array<tvm::relay::Type>& old_in_types) {
CHECK(old_in_layouts.defined());
CHECK_EQ(old_in_layouts.size(), 1);
Array<Array<IndexExpr>> old_in_shapes;
for (auto old_in_t : old_in_types) {
CHECK(old_in_t.as<TensorTypeNode>());
old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
}
CHECK(old_in_shapes.defined());
CHECK_EQ(old_in_shapes.size(), 1);

Expand Down
59 changes: 59 additions & 0 deletions src/relay/op/tensor/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_
#define TVM_RELAY_OP_TENSOR_TRANSFORM_H_

#include <tvm/relay/attrs/transform.h>
#include <tvm/data_layout.h>
#include <vector>
#include <algorithm>
#include <limits>
Expand Down Expand Up @@ -123,6 +125,63 @@ bool ConcatenateRel(const Array<Type>& types,
return true;
}

static inline Array<Array<Layout>> ConcatenateLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type> &old_in_types) {
ConcatenateAttrs* param = const_cast<ConcatenateAttrs*>(attrs.as<ConcatenateAttrs>());

Array<Array<IndexExpr>> old_in_shapes;
CHECK_EQ(old_in_types.size(), 1);
for (auto old_in_tuple_t : old_in_types) {
CHECK(old_in_tuple_t.as<TupleTypeNode>());
for (auto old_in_t : old_in_tuple_t.as<TupleTypeNode>()->fields) {
old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
}
}

size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() :
static_cast<size_t>(param->axis);

Layout ret;
bool is_new_layout_selected = false;
if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
// If all the new input layouts are same, the new in layout gets selected. For axis, the new
// axis in the new layout is identified. The param->axis is then modified on the fly to conform
// to the new input layout.
const auto& concate_dim = old_in_layouts[0][axis];
bool all_input_layouts_same = true;
for (auto new_layout : new_in_layouts) {
if (!new_layout.Equals(new_in_layouts[0])) {
all_input_layouts_same = false;
}
}
if (all_input_layouts_same) {
auto new_index = new_in_layouts[0].IndexOf(concate_dim);
ret = new_in_layouts[0];
param->axis = new_index;
is_new_layout_selected = true;
}
}

if (!is_new_layout_selected) {
// this function is called on the original correct relay ir
for (size_t i = 0; i < old_in_layouts.size(); ++i) {
if (old_in_layouts[i].defined()) {
ret = old_in_layouts[i];
break;
}
}

if (ret.ndim() <= axis || !ret[axis].IsPrimal()) {
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
}
}

return Array<Array<Layout> > {Array<Layout>(old_in_layouts.size(), ret), {ret}};
}

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_TENSOR_TRANSFORM_H_
Loading

0 comments on commit a1d12cc

Please sign in to comment.