-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor to expose MakeOp functions to C++ (#6047)
* Initial Refactor * add templated nn Make* functions * fix build typo * inline functions, fix unit tests
- Loading branch information
Matthew Brookhart
authored
Jul 14, 2020
1 parent
e4a0aa5
commit bfe83eb
Showing
14 changed files
with
347 additions
and
271 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* | ||
* \file tvm/relay/op/make_op.h | ||
* \brief Header of internal operator functions | ||
* to assist in creating ops in C++ | ||
*/ | ||
#ifndef TVM_RELAY_OP_MAKE_OP_H_ | ||
#define TVM_RELAY_OP_MAKE_OP_H_ | ||
|
||
#include <tvm/relay/expr.h> | ||
#include <tvm/relay/op.h> | ||
|
||
// Include Templated Make Functions | ||
#include "nn/convolution_make.h" | ||
#include "nn/pooling.h" | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
Expr MakeBroadCastTo(Expr data, Expr shape); | ||
|
||
Expr MakeCast(Expr data, DataType dtype); | ||
|
||
Expr MakeClip(Expr a, double a_min, double a_max); | ||
|
||
Expr MakeConcatenate(Expr data, int axis); | ||
|
||
Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype); | ||
|
||
Expr MakeExpandDims(Expr data, int axis, int num_newaxis); | ||
|
||
Expr MakeFull(Expr fill_value, Expr shape, DataType dtype); | ||
|
||
Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout); | ||
|
||
Expr MakeOnes(Expr shape, DataType dtype); | ||
|
||
Expr MakePad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value, String pad_mode); | ||
|
||
Expr MakeReduce(Expr data, Array<Integer> axis, bool keepdims, bool exclude, String op_name); | ||
|
||
Expr MakeRepeat(Expr data, int repeats, int axis); | ||
|
||
Expr MakeReshape(Expr data, Array<Integer> newshape); | ||
|
||
Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis); | ||
|
||
Expr MakeSqueeze(Expr data, Array<Integer> axis); | ||
|
||
Expr MakeStack(Expr data, int axis); | ||
|
||
Expr MakeStridedSlice(Expr data, Expr begin, Expr end, Expr strides, String slice_mode); | ||
|
||
Expr MakeTile(Expr data, Array<Integer> reps); | ||
|
||
Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype); | ||
|
||
Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude); | ||
|
||
Expr MakeZeros(Expr shape, DataType dtype); | ||
|
||
} // namespace relay | ||
} // namespace tvm | ||
#endif // TVM_RELAY_OP_MAKE_OP_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file src/relay/op/nn/make_convolution.h | ||
* \brief utilities for creating convolution ops | ||
*/ | ||
#ifndef TVM_RELAY_OP_NN_CONVOLUTION_MAKE_H_ | ||
#define TVM_RELAY_OP_NN_CONVOLUTION_MAKE_H_ | ||
|
||
#include <tvm/relay/attrs/nn.h> | ||
#include <tvm/relay/op.h> | ||
|
||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
template <typename T> | ||
inline Expr MakeConv(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding, | ||
Array<IndexExpr> dilation, int groups, IndexExpr channels, | ||
Array<IndexExpr> kernel_size, std::string data_layout, | ||
std::string kernel_layout, std::string out_layout, DataType out_dtype, | ||
std::string op_name) { | ||
auto attrs = make_object<T>(); | ||
attrs->strides = std::move(strides); | ||
attrs->padding = std::move(padding); | ||
attrs->dilation = std::move(dilation); | ||
attrs->groups = groups; | ||
attrs->channels = std::move(channels); | ||
attrs->kernel_size = std::move(kernel_size); | ||
attrs->data_layout = std::move(data_layout); | ||
attrs->kernel_layout = std::move(kernel_layout); | ||
attrs->out_layout = std::move(out_layout); | ||
attrs->out_dtype = std::move(out_dtype); | ||
const Op& op = Op::Get(op_name); | ||
return Call(op, {data, weight}, Attrs(attrs), {}); | ||
} | ||
|
||
template <typename T> | ||
inline Expr MakeConvWinograd(Expr data, Expr weight, int tile_size, Array<IndexExpr> strides, | ||
Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups, | ||
IndexExpr channels, Array<IndexExpr> kernel_size, | ||
std::string data_layout, std::string kernel_layout, | ||
std::string out_layout, DataType out_dtype, std::string op_name) { | ||
auto attrs = make_object<T>(); | ||
attrs->tile_size = tile_size; | ||
attrs->strides = std::move(strides); | ||
attrs->padding = std::move(padding); | ||
attrs->dilation = std::move(dilation); | ||
attrs->groups = groups; | ||
attrs->channels = std::move(channels); | ||
attrs->kernel_size = std::move(kernel_size); | ||
attrs->data_layout = std::move(data_layout); | ||
attrs->kernel_layout = std::move(kernel_layout); | ||
attrs->out_layout = std::move(out_layout); | ||
attrs->out_dtype = std::move(out_dtype); | ||
const Op& op = Op::Get(op_name); | ||
return Call(op, {data, weight}, Attrs(attrs), {}); | ||
} | ||
|
||
template <typename T> | ||
inline Expr MakeConvGemm(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding, | ||
Array<IndexExpr> dilation, int groups, IndexExpr channels, | ||
Array<IndexExpr> kernel_size, std::string data_layout, | ||
std::string kernel_layout, std::string out_layout, DataType out_dtype, | ||
std::string op_name) { | ||
auto attrs = make_object<T>(); | ||
attrs->strides = std::move(strides); | ||
attrs->padding = std::move(padding); | ||
attrs->dilation = std::move(dilation); | ||
attrs->groups = groups; | ||
attrs->channels = std::move(channels); | ||
attrs->kernel_size = std::move(kernel_size); | ||
attrs->data_layout = std::move(data_layout); | ||
attrs->kernel_layout = std::move(kernel_layout); | ||
attrs->out_layout = std::move(out_layout); | ||
attrs->out_dtype = std::move(out_dtype); | ||
const Op& op = Op::Get(op_name); | ||
return Call(op, {data, weight}, Attrs(attrs), {}); | ||
} | ||
|
||
template <typename T> | ||
inline Expr MakeConvTranspose(Expr data, Expr weight, Array<IndexExpr> strides, | ||
Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups, | ||
IndexExpr channels, Array<IndexExpr> kernel_size, | ||
std::string data_layout, std::string kernel_layout, | ||
std::string out_layout, Array<IndexExpr> output_padding, | ||
DataType out_dtype, std::string op_name) { | ||
auto attrs = make_object<T>(); | ||
attrs->strides = std::move(strides); | ||
attrs->padding = std::move(padding); | ||
attrs->dilation = std::move(dilation); | ||
attrs->groups = groups; | ||
attrs->channels = std::move(channels); | ||
attrs->kernel_size = std::move(kernel_size); | ||
attrs->data_layout = std::move(data_layout); | ||
attrs->kernel_layout = std::move(kernel_layout); | ||
attrs->out_layout = std::move(out_layout); | ||
attrs->output_padding = std::move(output_padding); | ||
attrs->out_dtype = std::move(out_dtype); | ||
const Op& op = Op::Get(op_name); | ||
return Call(op, {data, weight}, Attrs(attrs), {}); | ||
} | ||
|
||
template <typename T> | ||
inline Expr MakeDeformableConv(Expr data, Expr offset, Expr weight, Array<IndexExpr> strides, | ||
Array<IndexExpr> padding, Array<IndexExpr> dilation, | ||
int deformable_groups, int groups, int channels, | ||
Array<IndexExpr> kernel_size, std::string data_layout, | ||
std::string kernel_layout, std::string out_layout, | ||
DataType out_dtype, std::string op_name) { | ||
auto attrs = make_object<T>(); | ||
attrs->strides = strides; | ||
attrs->padding = padding; | ||
attrs->dilation = dilation; | ||
attrs->deformable_groups = deformable_groups; | ||
attrs->groups = groups; | ||
attrs->channels = channels; | ||
attrs->kernel_size = kernel_size; | ||
attrs->data_layout = data_layout; | ||
attrs->kernel_layout = kernel_layout; | ||
attrs->out_layout = out_layout; | ||
attrs->out_dtype = out_dtype; | ||
const Op& op = Op::Get(op_name); | ||
return Call(op, {data, offset, weight}, Attrs{attrs}, {}); | ||
} | ||
|
||
} // namespace relay | ||
} // namespace tvm | ||
#endif // TVM_RELAY_OP_NN_CONVOLUTION_MAKE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
|
||
#include <vector> | ||
|
||
#include "../make_op.h" | ||
#include "../op_common.h" | ||
|
||
namespace tvm { | ||
|
Oops, something went wrong.