Skip to content

Commit

Permalink
add document
Browse files Browse the repository at this point in the history
lint

lint

save

save

add more case

save

error

lint

lint

commit

do

lint

save

fix lint

wrap it back as func

lint

save

remove dead comment

fix style

fix lint

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <[email protected]>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <[email protected]>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <[email protected]>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <[email protected]>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <[email protected]>

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <[email protected]>

address review feedback

pe now handle freevar. as a result preserving function is now trivial.

test

add basic test, implement pretty printing for generic function

test

lint

fix segfault

save

save

do

test

fix another error

address comment

commit

save

address review feedback

add test for invalidate, fix error in lookup

rename cont to boduy

fix error and add regression test

Update src/relay/pass/partial_eval.cc

Co-Authored-By: MarisaKirisame <[email protected]>

fix error, add test case

fix lint

remove extra line

fix some error

pe

commit

save

save

save

save

save (pe/dce broken)

[DOCKER] Pin flatbuffers checkout to the last release tag (apache#2823). (apache#2879)

[Relay][Text Format] Reverse CallNode Print Order (apache#2882)

[NNPACK] Modernize test (apache#2868)

[Relay] Add list update to prelude (apache#2866)

Add missing sgx includes (apache#2878)

Fix setting up hints for getaddrinfo (apache#2872)

[ARITH] RewriteSimplifier: improved cmp simplification (apache#2851)

do (apache#2883)

[RELAY][Frontend][TF] decompile tf control flow (apache#2830)

* decompile tf control flow

* Add docs

* remove import relay

* move tests under tensorflow frontend

* minor fix

Enhance upsample operator to adapt onnx opset version 9 (apache#2840)

Use version invariant rustfmt (apache#2886)

[Relay][Op] Add group conv2d dispatch to topi function (apache#2870)

* [Relay][Op] Add group conv2d dispatch to topi function

* Rerun tests

[Apps] [howto_deploy] fix cxx-flags order and build directory (apache#2888)

fix prelu, now can use on 2d input and add one test (apache#2875)

Add dense schedules to __init__ for cpu (apache#2855)

* Add dense schedules to __init__ for cpu

* Add documentation for topi::shape

* Add additional imports to topi CPU __init__.

[TESTS] Improve script robustness (apache#2893)

A number of test scripts use the '|| exit 1' idiom.  This has two
issues, first process exit codes are defined to be in the range 0-255.
Second, more importantly, the idiom is fragile because it requires
that every possible failure point be explicitly coded.  This patch
removes the idiom in favour of "set -e" as used in the docker scripts
as a more robust mechanism to ensure that script failures are always
caught and propagated by default.

[Relay] Fix name of bias in testing.mlp (apache#2892)

winograd_nnpack (apache#2721)

[Relay] Fix Relay ARM CPU depthwise spatial pack schedule alter op layout issue. (apache#2861)

* Fix Relay ARM CPU spatial pack depthwise alter op layout issue.

* Update tune_relay_arm.py

[TESTS] Import script robustness (set -u) (apache#2896)

Adopt the "set -u" idiom from the docker scripts as a mechanism to
improve future robustness.

[DOCKER] Upgrade ci-cpu to latest v0.50 (apache#2901)

Allow linking against MKLML (apache#2902)

[COMMUNITY] ASF mentors (apache#2906)

[Relay] Allow converting keras.layers.Sequential (apache#2842)

* Allow converting keras.layers.Sequential

* Use existing new_var function

* Only update expr when missing

* Add test

[Relay] clean up hd, change tl (apache#2917)

Turn on USE_SORT by default (apache#2916)

[TEST] Cache test data (apache#2921)

Unified error handling in NNVM and Relay frontends (apache#2828)

add support for mxnet smooth_l1 (apache#2905)

[Relay] Add support for TupleGetItem in op fusion (apache#2914)

[Relay, TOPI]  Deformable conv2d (apache#2908)

* [Relay, TOPI] Add deformable conv2d

* Moved to op level2

* Fix lint

* Moved to level2 & bug fix

* Update comments

* Disabled flaky test of conv2d

TVM debugresult dump to Chrome Tracing (apache#2922)

[Relay] add test for second order ad (apache#2754)

* do second order

* add comment

* better name

* use tvm assert all close

* refire ci

Revert "[Relay] add test for second order ad (apache#2754)" (apache#2926)

This reverts commit f5ca991.

[Tutorial] Cache the test data in tutorial (apache#2923)

[AUTOTVM] Refactor measure build func (apache#2927)

Fix intersect of modular set (apache#2904)

Fix comment bugs and code style

[Relay, OpFusion] Fix handling TupleGetItem for nested tuples (apache#2929)

Consistent result of DetectLinearEquation() when an empy vars is passed (apache#2860)

[FRONTEND][ONNX] Some bug fixes and Shape operator fixed for relay. (apache#2850)

* [FRONTEND][ONNX] Some bug fixes and Shape operator fixed for relay.

* 	* test cases

* 	* ci error

Outdated renaming for flatten in ONNX converter (apache#2843)

[FRONTEND][TENSORFLOW] bug fix for tensorflow official slim models. (apache#2864)

* [FRONTEND][TENSORFLOW] bug fix for tensorflow official slim models.

* 	* review comments

Fix vcvtph2ps codegen (apache#2925)

Port changes

More fixes

save

save

Changes to schedules and mxnet importer
  • Loading branch information
MarisaKirisame committed Apr 9, 2019
1 parent 46f0b67 commit b96f221
Show file tree
Hide file tree
Showing 159 changed files with 6,396 additions and 1,324 deletions.
13 changes: 12 additions & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
TVM Contributors
================
TVM adopts the Apache style model and governs by merit. We believe that it is important to create an inclusive community where everyone can use,
TVM adopts the Apache way and governs by merit. We believe that it is important to create an inclusive community where everyone can use,
contribute to, and influence the direction of the project. We actively invite contributors who have earned the merit to be part of the development community.

See the [community structure document](http://docs.tvm.ai/contribute/community.html) for the explanation of community structure and contribution guidelines.

## Mentors

TVM is now part of the Apache Incubator.
We are fortunate to have the following mentors.

- Markus Weimer @markusweimer
- Sebastian Schelter @sscdotopen
- Byung-Gon Chun @bgchun
- Henry Saputra @hsaputra
- Timothy Chen @tnachen
- Furkan KAMACI @kamaci

## Committers

Expand Down
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
//
ci_lint = "tvmai/ci-lint:v0.50"
ci_gpu = "tvmai/ci-gpu:v0.51"
ci_cpu = "tvmai/ci-cpu:v0.41"
ci_cpu = "tvmai/ci-cpu:v0.50"
ci_i386 = "tvmai/ci-i386:v0.50"

// tvm libraries
Expand Down
2 changes: 1 addition & 1 deletion apps/howto_deploy/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ lib/cpp_deploy_pack: cpp_deploy.cc lib/test_addone_sys.o lib/libtvm_runtime_pack
# Deploy using pre-built libtvm_runtime.so
lib/cpp_deploy_normal: cpp_deploy.cc lib/test_addone_sys.o
@mkdir -p $(@D)
$(CXX) $(PKG_CFLAGS) -o $@ $^ $(PKG_LDFLAGS) -ltvm_runtime
$(CXX) $(PKG_CFLAGS) -o $@ $^ -ltvm_runtime $(PKG_LDFLAGS)
4 changes: 2 additions & 2 deletions apps/howto_deploy/run_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ echo "Build the libraries.."
mkdir -p lib
make
echo "Run the example"
export LD_LIBRARY_PATH=../../lib:${LD_LIBRARY_PATH}
export DYLD_LIBRARY_PATH=../../lib:${DYLD_LIBRARY_PATH}
export LD_LIBRARY_PATH=../../build:${LD_LIBRARY_PATH}
export DYLD_LIBRARY_PATH=../../build:${DYLD_LIBRARY_PATH}

echo "Run the deployment with all in one packed library..."
lib/cpp_deploy_pack
Expand Down
2 changes: 1 addition & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ set(USE_MPS OFF)
set(USE_ROCBLAS OFF)

# Whether use contrib sort
set(USE_SORT OFF)
set(USE_SORT ON)

# Build ANTLR parser for Relay text format
set(USE_ANTLR OFF)
2 changes: 2 additions & 0 deletions cmake/modules/SGX.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,6 @@ if(NOT USE_SGX STREQUAL "OFF")
-L${USE_SGX}/lib64 -l${_urts_lib}
-L${RUST_SGX_SDK}/sgx_ustdc -lsgx_ustdc)
list(APPEND RUNTIME_SRCS ${RUNTIME_SGX_SRCS})

include_directories(${RUST_SGX_SDK}/edl ${RUST_SGX_SDK}/common)
endif()
2 changes: 1 addition & 1 deletion cmake/modules/contrib/BLAS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ elseif(USE_BLAS STREQUAL "mkl")
if(NOT IS_DIRECTORY ${USE_MKL_PATH})
set(USE_MKL_PATH /opt/intel/mkl)
endif()
find_library(BLAS_LIBRARY mkl_rt ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
include_directories(${USE_MKL_PATH}/include)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY})
list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC})
Expand Down
8 changes: 3 additions & 5 deletions docker/install/ubuntu_install_rust.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@ apt-get update && apt-get install -y --no-install-recommends curl
export RUSTUP_HOME=/opt/rust
export CARGO_HOME=/opt/rust
# this rustc is one supported by the installed version of rust-sgx-sdk
curl -s -S -L https://sh.rustup.rs -sSf | sh -s -- -y --no-modify-path --default-toolchain nightly-2019-01-28
curl -s -S -L https://sh.rustup.rs -sSf | sh -s -- -y --no-modify-path --default-toolchain nightly-2019-03-24
. $CARGO_HOME/env
rustup component add rust-src
cargo install sccache
cargo install rustfmt-nightly --version 1.0.1 --force
cargo install xargo
rustup component add rustfmt
cargo install sccache --no-default-features

# make rust usable by all users
chmod -R a+w /opt/rust
2 changes: 1 addition & 1 deletion docker/install/ubuntu_install_tflite.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set -u
set -o pipefail

# Download, build and install flatbuffers
git clone --depth=1 --recursive https://github.com/google/flatbuffers.git
git clone --branch=v1.10.0 --depth=1 --recursive https://github.com/google/flatbuffers.git
cd flatbuffers
cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release
make install -j8
Expand Down
79 changes: 79 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,24 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
}
};

/*! \brief Attributes used in winograd weight transformation operators */
struct Conv2DWinogradNNPACKWeightTransformAttrs
: public tvm::AttrsNode<Conv2DWinogradNNPACKWeightTransformAttrs> {
int convolution_algorithm;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DWinogradNNPACKWeightTransformAttrs,
"relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs") {
TVM_ATTR_FIELD(convolution_algorithm)
.describe(
"The convolution algorithm for Winograd NNPACK. "
"E.g. tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8 for WT_8x8, "
"tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8_FP16 for WT_8x8_FP16");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

/*! \brief Attributes used in softmax operators */
struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
Expand Down Expand Up @@ -438,6 +456,67 @@ struct L2NormalizeAttrs : public tvm::AttrsNode<L2NormalizeAttrs> {
}
};


/*! \brief Attributes for DeformableConv2D operator */
struct DeformableConv2DAttrs : public tvm::AttrsNode<DeformableConv2DAttrs> {
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int deformable_groups;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(DeformableConv2DAttrs, "relay.attrs.DeformableConv2DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(deformable_groups).set_default(1)
.describe("Controls the connections between inputs and offsets."
"Input channels are partitioned into multiple deformable groups. Offsets"
"are shared across input channels in the same deformable group.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
.describe("The number of output channels in the convolution."
" If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(data_layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_H_
41 changes: 40 additions & 1 deletion include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,26 @@ class VarNode : public ExprNode {

RELAY_DEFINE_NODE_REF(Var, VarNode, Expr);

/*! \brief Hash Var by it's id.
* Different VarNode might has same vid, and they are considered to be the same var in such case.
* Use VarHash to hash Var by id.
*/
struct VarHash {
size_t operator()(const Var& v) const {
return v->vid.hash();
}
};

/*! \brief Compare Var by it's id.
* Different VarNode might has same vid, and they are considered to be the same var in such case.
* Use VarEqual to compare Var by id.
*/
struct VarEqual {
bool operator()(const Var& l, const Var& r) const {
return l->vid.get() == r->vid.get();
}
};

/*!
* \brief Global variable that leaves in the top-level module.
* This is used to enable recursive calls between function.
Expand Down Expand Up @@ -503,7 +523,7 @@ RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr);
* rewriting pass such as layout or type transformation.
*
* Subclass TempExprNode allows us to pattern match on
* specific kind TempExpr and use them for expression rewriting.
* specific kind of TempExpr and use them for expression rewriting.
*
* TempExpr should only be used within a pass,
*/
Expand All @@ -521,6 +541,25 @@ class TempExprNode : public ExprNode {

RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr);

class Annotate;
class AnnotateNode : public ExprNode {
public:
Expr expr;
NodeRef annotation;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("expr", &expr);
v->Visit("annotation", &annotation);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static Annotate make(Expr expr, NodeRef annotation);

static constexpr const char* _type_key = "relay.AnnotateNode";
TVM_DECLARE_NODE_TYPE_INFO(AnnotateNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Annotate, AnnotateNode, Expr);

// implementataions
inline const Type& ExprNode::checked_type() const {
CHECK(checked_type_.defined()) << "internal error: the type checker has "
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
CHECK(n.defined());
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
Expand All @@ -97,6 +98,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const AnnotateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
Expand All @@ -121,6 +123,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode);
RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode);
RELAY_EXPR_FUNCTOR_DISPATCH(AnnotateNode);
return vtable;
}
};
Expand Down Expand Up @@ -151,6 +154,7 @@ class ExprVisitor
void VisitExpr_(const RefWriteNode* op) override;
void VisitExpr_(const ConstructorNode* op) override;
void VisitExpr_(const MatchNode* op) override;
void VisitExpr_(const AnnotateNode* op) override;
virtual void VisitType(const Type& t);
virtual void VisitClause(const Clause& c);
virtual void VisitPattern(const Pattern& c);
Expand Down Expand Up @@ -193,6 +197,7 @@ class ExprMutator
Expr VisitExpr_(const RefWriteNode* op) override;
Expr VisitExpr_(const ConstructorNode* op) override;
Expr VisitExpr_(const MatchNode* op) override;
Expr VisitExpr_(const AnnotateNode* op) override;

/*!
* \brief Used to visit the types inside of expressions.
Expand Down
34 changes: 26 additions & 8 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
#include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/type.h>

#include <tvm/relay/adt.h>
#include <string>
#include <vector>

Expand Down Expand Up @@ -326,6 +326,17 @@ TVM_DLL bool WellFormed(const Expr& expr);
*/
TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);

/*! \brief Get all bound variables from pattern pat.
*
* Bound variables are all variables that got bound by the pat.
* They only have meaning inside that expr, and can only be used in it.
*
* \param pat the Pattern.
*
* \return List of bound vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);

/*! \brief Get free type parameters from expression expr.
*
* Free variables are variables that are not bound by a
Expand Down Expand Up @@ -413,12 +424,13 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);

/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let bindings which are not referenced, and branches that will
* not be entered.
* It will remove let bindings which are not referenced,
* and inline let bindings that are only used once.
*
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of
* the expression does not depend on a. Another example is `if (true) then 1
* else 2` will be optimized into 1.
* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1.
*
* \param e the expression to optimize.
*
Expand Down Expand Up @@ -527,7 +539,7 @@ struct StructuralHash {
*
* \return expression in A-Normal Form
*/
Expr ToANormalForm(const Expr& e, const Module& mod);
TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);

/*! \brief Remove let binding and directly share via pointer instead.
*
Expand All @@ -538,8 +550,14 @@ Expr ToANormalForm(const Expr& e, const Module& mod);
*
* \return the expression in graph normal form.
*/
Expr ToGraphNormalForm(const Expr& e);
TVM_DLL Expr ToGraphNormalForm(const Expr& e);

/*! \brief Aggressive constant propagation/constant folding/inlining.
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*/
Expr PartialEval(const Expr& e, const Module& mod);
} // namespace relay
} // namespace tvm

Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
* \return The result of the call
*/
virtual R VisitPattern(const Pattern& n, Args... args) {
CHECK(n.defined());
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
Expand Down
20 changes: 20 additions & 0 deletions nnvm/include/nnvm/top/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,26 @@ struct WinogradWeightTransformParam : public dmlc::Parameter<WinogradWeightTrans
static const constexpr int kWeight = 0;
};

struct WinogradNNPACKWeightTransformParam
: public dmlc::Parameter<WinogradNNPACKWeightTransformParam> {
int convolution_algorithm;
int out_dtype;

DMLC_DECLARE_PARAMETER(WinogradNNPACKWeightTransformParam) {
DMLC_DECLARE_FIELD(convolution_algorithm)
.describe(
"The convolution algorithm for Winograd NNPACK. "
"E.g. tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8 for WT_8x8, "
"tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8_FP16 for WT_8x8_FP16");
DMLC_DECLARE_DTYPE_FIELD(out_dtype)
.add_enum("same", -1)
.set_default(-1)
.describe("Output data type, set to explicit type under mixed precision setting");
}

static const constexpr int kWeight = 0;
};

struct WinogradConv2DParam : public dmlc::Parameter<WinogradConv2DParam> {
int channels;
TShape kernel_size;
Expand Down
Loading

0 comments on commit b96f221

Please sign in to comment.