Skip to content

Commit

Permalink
Add High-level Op Support (apache#5)
Browse files Browse the repository at this point in the history
* high-level-op support

* format

* format

* follow relay convention

* format

* fix
  • Loading branch information
jinhongyii authored and MasterJH5574 committed Nov 19, 2022
1 parent 304048c commit 7425128
Show file tree
Hide file tree
Showing 18 changed files with 730 additions and 107 deletions.
28 changes: 28 additions & 0 deletions include/tvm/relax/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,34 @@ struct AssertOpAttrs : public tvm::AttrsNode<AssertOpAttrs> {
}
};

/*! \brief Attributes used in MaxPool2d operator */
struct MaxPool2dAttrs : public tvm::AttrsNode<MaxPool2dAttrs> {
Array<PrimExpr> kernel_size;
Array<PrimExpr> stride;
Array<PrimExpr> padding;
Array<PrimExpr> dilation;
TVM_DECLARE_ATTRS(MaxPool2dAttrs, "relax.attrs.MaxPool2dAttrs") {
TVM_ATTR_FIELD(kernel_size).describe("The size of the window to take a max over.");
TVM_ATTR_FIELD(stride).describe("The stride of the window.");
TVM_ATTR_FIELD(padding).describe("The padding on the input.");
TVM_ATTR_FIELD(dilation).describe("The stride of elements in the window.");
}
}; // struct MaxPool2dAttrs

/*! \brief Attributes used in Conv2d operator */
struct Conv2dAttrs : public tvm::AttrsNode<Conv2dAttrs> {
Array<PrimExpr> kernel_size;
Array<PrimExpr> stride;
Array<PrimExpr> padding;
Array<PrimExpr> dilation;
TVM_DECLARE_ATTRS(Conv2dAttrs, "relax.attrs.Conv2dAttrs") {
TVM_ATTR_FIELD(kernel_size).describe("The size of the convolving kernel.");
TVM_ATTR_FIELD(stride).describe("The stride of the convolution.");
TVM_ATTR_FIELD(padding).describe("The padding on the input.");
TVM_ATTR_FIELD(dilation).describe("The spacing between kernel elements.");
}
}; // struct Conv2dAttrs

} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_OP_ATTR_TYPES_H_
3 changes: 2 additions & 1 deletion python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

# Operators
from .base import *
from .tensor import *
from .nn import *
from .op_attrs import *
from .tensor import *
from . import builtin
from . import memory
19 changes: 19 additions & 0 deletions python/tvm/relax/op/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Neural network related operators."""
from .nn import *
20 changes: 20 additions & 0 deletions python/tvm/relax/op/nn/_make.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
import tvm._ffi

tvm._ffi._init_api("relax.op.nn", __name__)
48 changes: 48 additions & 0 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from . import _make
from ...expr import Expr


def dense(lhs: Expr, rhs: Expr) -> Expr:
return _make.dense(lhs, rhs)


def conv2d(
lhs: Expr, rhs: Expr, kernel_size, stride=(1, 1), padding=[0, 0], dilation=[1, 1]
) -> Expr:
return _make.conv2d(lhs, rhs, kernel_size, stride, padding, dilation)


def relu(data: Expr) -> Expr:
return _make.relu(data)


def softmax(data: Expr) -> Expr:
return _make.softmax(data)


def flatten(data: Expr) -> Expr:
return _make.flatten(data)


def max_pool2d(data: Expr, kernel_size, stride=None, padding=(0, 0), dilation=(1, 1)) -> Expr:
if stride is None:
stride = kernel_size
return _make.max_pool2d(data, kernel_size, stride, padding, dilation)

5 changes: 3 additions & 2 deletions python/tvm/relax/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin, invalid-name
"""Basic tensor operations."""
import numpy as np # type: ignore
Expand Down Expand Up @@ -83,13 +84,13 @@ def numpy_unique(
Uses numpy.unique to compute unique elements.
"""
# TODO(prakalp): add support for returning a tuple when return_inverse or return_counts is True
# TODO(prakalp) : add support for returning a tuple when return_inverse or return_counts is True
if bool(return_inverse) or bool(return_counts):
raise NotImplementedError("missing support return_inverse or return_counts set to true")
if dim < 0:
dim = None
a_numpy = a.numpy()
# TODO(prakalp): use torch.unique instead of numpy when torch is installed in ci.
# TODO(prakalp) : use torch.unique instead of numpy when torch is installed in ci.
output_sorted_numpy, indices = np.unique(a_numpy, return_index=True)
if sort:
return tvm.nd.array(output_sorted_numpy)
Expand Down
50 changes: 50 additions & 0 deletions src/relax/op/nn/convolution.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "convolution.h"

#include "../tensor/binary.h"
namespace tvm {
namespace relax {

TVM_REGISTER_NODE_TYPE(Conv2dAttrs);

RELAY_REGISTER_OP("relax.nn.conv2d")
.set_num_inputs(2)
.add_argument("e1", "Expr", "The input expression")
.add_argument("e2", "Expr", "The input expression")
.set_attrs_type<Conv2dAttrs>()
.set_attr<FInferShape>("FInferShape", InferShapeConv2d)
.set_attr<FInferType>("FInferType", InferTypeBinaryBroadcast);

Expr MakeConv2d(Expr expr1, Expr expr2, Array<PrimExpr> kernel_size, Array<PrimExpr> stride,
Array<PrimExpr> padding, Array<PrimExpr> dilation) {
static const Op& op = Op::Get("relax.nn.conv2d");
auto attrs = make_object<Conv2dAttrs>();
attrs->kernel_size = kernel_size;
attrs->stride = stride;
attrs->padding = padding;
attrs->dilation = dilation;
return Call(op, {expr1, expr2}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.nn.conv2d").set_body_typed(MakeConv2d);

} // namespace relax
} // namespace tvm
71 changes: 71 additions & 0 deletions src/relax/op/nn/convolution.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_RELAX_OP_NN_CONVOLUTION_H_
#define TVM_RELAX_OP_NN_CONVOLUTION_H_

#include <tvm/relax/expr.h>
#include <tvm/relax/type.h>

#include "../op_common.h"
namespace tvm {
namespace relax {

Optional<Expr> InferShapeConv2d(const Call& call, DiagnosticContext diag_ctx) {
if (call->args.size() != 2) {
diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Conv2d op should have 2 arguments");
}
Expr shape0 = call->args[0]->shape();
Expr shape1 = call->args[1]->shape();
auto* s0 = shape0.as<ShapeExprNode>();
auto* s1 = shape1.as<ShapeExprNode>();
auto* attrs = call->attrs.as<Conv2dAttrs>();
if (s0 && s1) {
std::vector<PrimExpr> output_shape;
size_t ndim0 = s0->values.size();
size_t ndim1 = s1->values.size();
if (ndim0 != 4 || ndim1 != 4) {
LOG(INFO) << ndim0;
LOG(INFO) << ndim1;
diag_ctx.EmitFatal(Diagnostic::Error(call->span)
<< "The 2 arguments of Conv2d must be 4D Tensors");
}
// N
output_shape.push_back(s0->values[0]);
// C
output_shape.push_back(s1->values[0]);
// H
output_shape.push_back((s0->values[2] + 2 * attrs->padding[0] -
attrs->dilation[0] * (attrs->kernel_size[0] - 1) - 1) /
attrs->stride[0] +
1);
// W
output_shape.push_back((s0->values[3] + 2 * attrs->padding[1] -
attrs->dilation[1] * (attrs->kernel_size[1] - 1) - 1) /
attrs->stride[1] +
1);
return ShapeExpr(Array<PrimExpr>{output_shape.begin(), output_shape.end()});
} else {
return NullOpt;
}
}

} // namespace relax
} // namespace tvm
#endif
56 changes: 56 additions & 0 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "nn.h"

namespace tvm {
namespace relax {

RELAY_REGISTER_OP("relax.nn.dense")
.set_num_inputs(2)
.add_argument("e1", "Expr", "The input expression")
.add_argument("e2", "Expr", "The input expression")
.set_attr<FInferShape>("FInferShape", InferShapeDense)
.set_attr<FInferType>("FInferType", InferTypeDense);

Expr MakeDense(Expr expr1, Expr expr2) {
static const Op& op = Op::Get("relax.nn.dense");
return Call(op, {expr1, expr2}, {}, {});
}

TVM_REGISTER_GLOBAL("relax.op.nn.dense").set_body_typed(MakeDense);

RELAX_REGISTER_UNARY_OP("nn.softmax");

RELAX_REGISTER_UNARY_OP("nn.relu");

RELAY_REGISTER_OP("relax.nn.flatten")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor")
.set_attr<FInferShape>("FInferShape", InferShapeFlatten)
.set_attr<FInferType>("FInferType", InferTypeFlatten);

Expr MakeFlatten(Expr data) {
static const Op& op = Op::Get("relax.nn.flatten");
return Call(op, {data}, {}, {});
}
TVM_REGISTER_GLOBAL("relax.op.nn.flatten").set_body_typed(MakeFlatten);

} // namespace relax
} // namespace tvm
Loading

0 comments on commit 7425128

Please sign in to comment.