From b171cf1db9ecc01964d2f43675be7d612c6bd6f5 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 16 Jan 2020 20:18:57 -0800 Subject: [PATCH] [REFACTOR] Polish runtime (#4729) - Remove operator bool from base object ref macro - Raitionale: operator bool can be dangerous for sub-classes that also overloads other operators(e.g. ==). - If bool is still needed, use explicit operator bool. - Use absolute include when necessary - Move type related util to data_type - Isolate stackvm code from compiler --- include/tvm/ir.h | 20 ++++- include/tvm/ir/expr.h | 59 ++------------ include/tvm/runtime/c_backend_api.h | 2 +- include/tvm/runtime/data_type.h | 19 ++++- include/tvm/runtime/device_api.h | 4 +- include/tvm/runtime/memory.h | 2 +- include/tvm/runtime/object.h | 3 - include/tvm/runtime/serializer.h | 4 +- include/tvm/runtime/util.h | 79 ------------------- src/codegen/stackvm/codegen_stackvm.cc | 26 +++++- src/pass/hoist_if_then_else.cc | 4 +- src/relay/op/type_relations.cc | 21 ++--- src/relay/pass/partial_eval.cc | 2 +- src/runtime/contrib/cblas/cblas.cc | 2 +- src/runtime/contrib/cblas/gemm_common.h | 2 +- src/runtime/contrib/cublas/cublas.cc | 2 +- src/runtime/contrib/cudnn/conv_forward.cc | 2 +- src/runtime/contrib/miopen/conv_forward.cc | 2 +- src/runtime/contrib/mps/mps_utils.h | 2 +- src/runtime/contrib/nnpack/convolution.cc | 2 +- src/runtime/contrib/nnpack/fully_connected.cc | 6 +- src/runtime/contrib/nnpack/nnpack_utils.h | 2 +- src/runtime/contrib/random/random.cc | 6 +- src/runtime/contrib/rocblas/rocblas.cc | 6 +- src/runtime/contrib/sort/sort.cc | 1 - src/runtime/stackvm/stackvm.cc | 49 ++++++------ src/runtime/stackvm/stackvm.h | 20 +++++ 27 files changed, 144 insertions(+), 205 deletions(-) delete mode 100644 include/tvm/runtime/util.h diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 4e36332e004e..ff4b47ffca12 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -30,7 +30,6 @@ #include #include #include "expr.h" -#include "runtime/util.h" namespace tvm { namespace ir { @@ -1677,6 +1676,25 @@ constexpr const char* tvm_fill_fragment = "tvm_fill_fragment"; */ constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; +/*! \brief The kind of structure field info used in intrinsic */ +enum TVMStructFieldKind : int { + // array head address + kArrAddr, + kArrData, + kArrShape, + kArrStrides, + kArrNDim, + kArrTypeCode, + kArrTypeBits, + kArrTypeLanes, + kArrByteOffset, + kArrDeviceId, + kArrDeviceType, + kArrKindBound_, + // TVMValue field + kTVMValueContent, + kTVMValueKindBound_ +}; } // namespace intrinsic /*! diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index ddb5f80ca2f1..87122e802db5 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -49,15 +49,7 @@ class BaseExprNode : public Object { */ class BaseExpr : public ObjectRef { public: - /*! \brief Cosntructor */ - BaseExpr() {} - /*! - * \brief Cosntructor from object ptr. - * \param ptr The object pointer. - */ - explicit BaseExpr(ObjectPtr ptr) : ObjectRef(ptr) {} - /*! \brief The container type. */ - using ContainerType = BaseExprNode; + TVM_DEFINE_OBJECT_REF_METHODS(BaseExpr, ObjectRef, BaseExprNode); }; /*! @@ -100,13 +92,6 @@ class PrimExprNode : public BaseExprNode { */ class PrimExpr : public BaseExpr { public: - /*! \brief Cosntructor */ - PrimExpr() {} - /*! - * \brief Cosntructor from object ptr. - * \param ptr The object pointer. - */ - explicit PrimExpr(ObjectPtr ptr) : BaseExpr(ptr) {} /*! * \brief construct from integer. * \param value The value to be constructed. @@ -127,8 +112,8 @@ class PrimExpr : public BaseExpr { DataType dtype() const { return static_cast(get())->dtype; } - /*! \brief The container type. */ - using ContainerType = PrimExprNode; + + TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode); }; /*! @@ -156,29 +141,14 @@ class IntImmNode : public PrimExprNode { */ class IntImm : public PrimExpr { public: - /*! - * \brief Constructor - */ - IntImm() {} - /*! - * \brief constructor from node. - */ - explicit IntImm(ObjectPtr node) : PrimExpr(node) {} /*! * \brief Constructor. * \param dtype The data type of the value. * \param value The internal value. */ TVM_DLL IntImm(DataType dtype, int64_t value); - /*! - * \brief Get pointer to the internal value. - * \return the content of the integer. - */ - const IntImmNode* operator->() const { - return static_cast(get()); - } - /*! \brief type indicate the container type */ - using ContainerType = IntImmNode; + + TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode); }; /*! @@ -206,29 +176,14 @@ class FloatImmNode : public PrimExprNode { */ class FloatImm : public PrimExpr { public: - /*! - * \brief Constructor - */ - FloatImm() {} - /*! - * \brief constructor from node. - */ - explicit FloatImm(ObjectPtr node) : PrimExpr(node) {} /*! * \brief Constructor. * \param dtype The data type of the value. * \param value The internal value. */ TVM_DLL FloatImm(DataType dtype, double value); - /*! - * \brief Get pointer to the container. - * \return The pointer. - */ - const FloatImmNode* operator->() const { - return static_cast(get()); - } - /*! \brief type indicate the container type */ - using ContainerType = FloatImmNode; + + TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode); }; /*! diff --git a/include/tvm/runtime/c_backend_api.h b/include/tvm/runtime/c_backend_api.h index ffd13eccdac5..abfc792d574f 100644 --- a/include/tvm/runtime/c_backend_api.h +++ b/include/tvm/runtime/c_backend_api.h @@ -28,7 +28,7 @@ #ifndef TVM_RUNTIME_C_BACKEND_API_H_ #define TVM_RUNTIME_C_BACKEND_API_H_ -#include "c_runtime_api.h" +#include #ifdef __cplusplus extern "C" { diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index c91c2cf82452..cb58e9741d1f 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -28,7 +28,6 @@ #include #include - namespace tvm { namespace runtime { /*! @@ -233,6 +232,24 @@ inline int GetVectorBytes(DataType dtype) { return data_bits / 8; } +/*! + * \brief Check whether type matches the given spec. + * \param t The type + * \param code The type code. + * \param bits The number of bits to be matched. + * \param lanes The number of lanes in the type. + */ +inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) { + return t.code == code && t.bits == bits && t.lanes == lanes; +} +/*! + * \brief Check whether two types are equal . + * \param lhs The left operand. + * \param rhs The right operand. + */ +inline bool TypeEqual(DLDataType lhs, DLDataType rhs) { + return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes; +} } // namespace runtime using DataType = runtime::DataType; diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 7212aad5aa20..00508a11b042 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -24,9 +24,9 @@ #ifndef TVM_RUNTIME_DEVICE_API_H_ #define TVM_RUNTIME_DEVICE_API_H_ +#include +#include #include -#include "packed_func.h" -#include "c_runtime_api.h" namespace tvm { namespace runtime { diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h index 10e8be35ef74..121dbdde37a6 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/memory.h @@ -23,10 +23,10 @@ #ifndef TVM_RUNTIME_MEMORY_H_ #define TVM_RUNTIME_MEMORY_H_ +#include #include #include #include -#include "object.h" namespace tvm { namespace runtime { diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index a2e9188fcd2b..8ef9cb449d1b 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -29,7 +29,6 @@ #include #include - /*! * \brief Whether or not use atomic reference counter. * If the reference counter is not atomic, @@ -715,7 +714,6 @@ struct ObjectEqual { const ObjectName* operator->() const { \ return static_cast(data_.get()); \ } \ - operator bool() const { return data_ != nullptr; } \ using ContainerType = ObjectName; /* @@ -734,7 +732,6 @@ struct ObjectEqual { ObjectName* operator->() const { \ return static_cast(data_.get()); \ } \ - operator bool() const { return data_ != nullptr; } \ using ContainerType = ObjectName; /*! diff --git a/include/tvm/runtime/serializer.h b/include/tvm/runtime/serializer.h index ca968c4b58f4..37bb95f54655 100644 --- a/include/tvm/runtime/serializer.h +++ b/include/tvm/runtime/serializer.h @@ -27,8 +27,8 @@ #include #include -#include "c_runtime_api.h" -#include "ndarray.h" +#include +#include namespace dmlc { namespace serializer { diff --git a/include/tvm/runtime/util.h b/include/tvm/runtime/util.h deleted file mode 100644 index 8e213dd146b8..000000000000 --- a/include/tvm/runtime/util.h +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/util.h - * \brief Useful runtime util. - */ -#ifndef TVM_RUNTIME_UTIL_H_ -#define TVM_RUNTIME_UTIL_H_ - -#include "c_runtime_api.h" - -namespace tvm { -namespace runtime { - -/*! - * \brief Check whether type matches the given spec. - * \param t The type - * \param code The type code. - * \param bits The number of bits to be matched. - * \param lanes The number of lanes in the type. - */ -inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) { - return t.code == code && t.bits == bits && t.lanes == lanes; -} -/*! - * \brief Check whether two types are equal . - * \param lhs The left operand. - * \param rhs The right operand. - */ -inline bool TypeEqual(DLDataType lhs, DLDataType rhs) { - return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes; -} -} // namespace runtime -} // namespace tvm -// Forward declare the intrinsic id we need -// in structure fetch to enable stackvm in runtime -namespace tvm { -namespace ir { -namespace intrinsic { -/*! \brief The kind of structure field info used in intrinsic */ -enum TVMStructFieldKind : int { - // array head address - kArrAddr, - kArrData, - kArrShape, - kArrStrides, - kArrNDim, - kArrTypeCode, - kArrTypeBits, - kArrTypeLanes, - kArrByteOffset, - kArrDeviceId, - kArrDeviceType, - kArrKindBound_, - // TVMValue field - kTVMValueContent, - kTVMValueKindBound_ -}; -} // namespace intrinsic -} // namespace ir -} // namespace tvm -#endif // TVM_RUNTIME_UTIL_H_ diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index f4b8fbe7ff1d..8253007fff72 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -32,6 +32,28 @@ namespace codegen { using namespace ir; +// map struct field kind to runtime variants +// We keep two separate enums to ensure runtime/compiler isolation. +StackVM::StructFieldKind MapFieldKind(int64_t kind) { + auto val = static_cast(kind); + switch (val) { + case intrinsic::kArrData: return StackVM::kArrData; + case intrinsic::kArrShape: return StackVM::kArrShape; + case intrinsic::kArrAddr: return StackVM::kArrAddr; + case intrinsic::kArrStrides: return StackVM::kArrStrides; + case intrinsic::kArrNDim: return StackVM::kArrNDim; + case intrinsic::kArrTypeCode: return StackVM::kArrTypeCode; + case intrinsic::kArrTypeBits: return StackVM::kArrTypeBits; + case intrinsic::kArrTypeLanes: return StackVM::kArrTypeLanes; + case intrinsic::kArrByteOffset: return StackVM::kArrByteOffset; + case intrinsic::kArrDeviceId: return StackVM::kArrDeviceId; + case intrinsic::kArrDeviceType: return StackVM::kArrDeviceType; + case intrinsic::kTVMValueContent: return StackVM::kTVMValueContent; + default: LOG(FATAL) << "Do not know how to map field " << kind; + } + return StackVM::kArrData; +} + StackVM CodeGenStackVM::Compile(LoweredFunc f) { for (size_t i = 0; i < f->args.size(); ++i) { Var v = f->args[i]; @@ -163,7 +185,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { vm_.code.push_back(code); code.v_int = index->value; vm_.code.push_back(code); - code.v_int = kind; + code.v_int = MapFieldKind(kind); vm_.code.push_back(code); } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) { CHECK_GE(op->args.size(), 5U); @@ -431,7 +453,7 @@ void CodeGenStackVM::VisitStmt_(const EvaluateNode *ev) { vm_.code.push_back(code); code.v_int = index->value; vm_.code.push_back(code); - code.v_int = op->args[2].as()->value; + code.v_int = MapFieldKind(op->args[2].as()->value); vm_.code.push_back(code); } else { this->Push(ev->value); diff --git a/src/pass/hoist_if_then_else.cc b/src/pass/hoist_if_then_else.cc index 6d4df47a730c..78d743ac8669 100644 --- a/src/pass/hoist_if_then_else.cc +++ b/src/pass/hoist_if_then_else.cc @@ -189,7 +189,7 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { then_for = IRTransform(for_stmt, nullptr, replace_then_case, {PrimExpr("IfThenElse")}); - if (if_stmt.as()->else_case) { + if (if_stmt.as()->else_case.defined()) { else_for = IRTransform(for_stmt, nullptr, replace_else_case, {PrimExpr("IfThenElse")}); } @@ -221,7 +221,7 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { for2if_map_[for_stmt.get()].push_back(head); const IfThenElseNode* if_node = head.as(); tracker.push(if_node->then_case); - if (if_node->else_case) { + if (if_node->else_case.defined()) { tracker.push(if_node->else_case); } diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index fbaf66507ef1..a1e2a6630f05 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -31,14 +31,6 @@ namespace tvm { namespace relay { -TensorType ToTensorType(const Type& t) { - if (const auto* tt_node = t.as()) { - return GetRef(tt_node); - } else { - return TensorType(nullptr); - } -} - bool IdentityRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -115,11 +107,11 @@ bool BroadcastRel(const Array& types, CHECK_EQ(types.size(), 3); // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] // << ",Out:" << types[2] << std::endl; - if (auto t0 = ToTensorType(types[0])) { - if (auto t1 = ToTensorType(types[1])) { + if (auto* t0 = types[0].as()) { + if (auto* t1 = types[1].as()) { CHECK_EQ(t0->dtype, t1->dtype); reporter->Assign(types[2], - ConcreteBroadcast(t0, t1, t0->dtype)); + ConcreteBroadcast(GetRef(t0), GetRef(t1), t0->dtype)); return true; } } @@ -133,10 +125,11 @@ bool BroadcastCompRel(const Array& types, CHECK_EQ(types.size(), 3); // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] // << ",Out:" << types[2] << std::endl; - if (auto t0 = ToTensorType(types[0])) { - if (auto t1 = ToTensorType(types[1])) { + if (auto* t0 = types[0].as()) { + if (auto* t1 = types[1].as()) { CHECK_EQ(t0->dtype, t1->dtype); - reporter->Assign(types[2], ConcreteBroadcast(t0, t1, ::tvm::DataType::Bool())); + reporter->Assign(types[2], + ConcreteBroadcast(GetRef(t0), GetRef(t1), DataType::Bool())); return true; } } diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 4c343bd30330..c7935c49dfaf 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -749,7 +749,7 @@ class PartialEvaluator : public ExprFunctor PStatic r = VisitExpr(op->ref, ll); if (r->pstatic.defined()) { PStatic ret = store_.Lookup(r->pstatic.as()); - if (ret) { + if (ret.defined()) { return ret; } } diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc index ef9f5d6564de..d4959be64cf1 100644 --- a/src/runtime/contrib/cblas/cblas.cc +++ b/src/runtime/contrib/cblas/cblas.cc @@ -22,7 +22,7 @@ */ #include #include -#include +#include #include "gemm_common.h" extern "C" { diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index e35e4311730c..b73ababbbade 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -24,7 +24,7 @@ #pragma once #include -#include +#include #include namespace tvm { diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 2cb677729654..5424f4cdcddf 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -21,7 +21,7 @@ * \file Use external cblas library call. */ #include -#include +#include #include #include "../cblas/gemm_common.h" #include "cublas_utils.h" diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index b9609b9d1047..95811332bbfa 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -21,7 +21,7 @@ * \file Use external cudnn utils function */ #include -#include +#include #include #include "cudnn_utils.h" diff --git a/src/runtime/contrib/miopen/conv_forward.cc b/src/runtime/contrib/miopen/conv_forward.cc index 5094cef60f92..d4575484320b 100644 --- a/src/runtime/contrib/miopen/conv_forward.cc +++ b/src/runtime/contrib/miopen/conv_forward.cc @@ -21,7 +21,7 @@ * \file Use external miopen utils function */ #include -#include +#include #include #include "miopen_utils.h" diff --git a/src/runtime/contrib/mps/mps_utils.h b/src/runtime/contrib/mps/mps_utils.h index 728646c537b9..f1fff95c1df3 100644 --- a/src/runtime/contrib/mps/mps_utils.h +++ b/src/runtime/contrib/mps/mps_utils.h @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include "../../metal/metal_common.h" diff --git a/src/runtime/contrib/nnpack/convolution.cc b/src/runtime/contrib/nnpack/convolution.cc index 8934693973b4..79ea19175d65 100644 --- a/src/runtime/contrib/nnpack/convolution.cc +++ b/src/runtime/contrib/nnpack/convolution.cc @@ -22,7 +22,7 @@ */ #include #include -#include +#include #include #include #include "nnpack_utils.h" diff --git a/src/runtime/contrib/nnpack/fully_connected.cc b/src/runtime/contrib/nnpack/fully_connected.cc index b0d72fe5744b..5f111efac4df 100644 --- a/src/runtime/contrib/nnpack/fully_connected.cc +++ b/src/runtime/contrib/nnpack/fully_connected.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,7 +21,7 @@ * \file Use external nnpack library call. */ #include -#include +#include #include #include #include "nnpack_utils.h" diff --git a/src/runtime/contrib/nnpack/nnpack_utils.h b/src/runtime/contrib/nnpack/nnpack_utils.h index 551cff2957b2..4ba586fe08ac 100644 --- a/src/runtime/contrib/nnpack/nnpack_utils.h +++ b/src/runtime/contrib/nnpack/nnpack_utils.h @@ -23,7 +23,7 @@ #ifndef TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_ #define TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_ #include -#include +#include #include #include #include diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index 3da2e16a44b3..46a14e61f937 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,7 +21,7 @@ * \file External random functions for tensor. */ #include -#include +#include #include #include #include diff --git a/src/runtime/contrib/rocblas/rocblas.cc b/src/runtime/contrib/rocblas/rocblas.cc index 813f4c691523..dda4ee30fde5 100644 --- a/src/runtime/contrib/rocblas/rocblas.cc +++ b/src/runtime/contrib/rocblas/rocblas.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,7 +21,7 @@ * \file Use external rocblas library call. */ #include -#include +#include #include #include "rocblas.h" diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 68f70c15b4d6..0c9c57533dbe 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -22,7 +22,6 @@ */ #include -#include #include #include #include diff --git a/src/runtime/stackvm/stackvm.cc b/src/runtime/stackvm/stackvm.cc index 06b154e8d4df..0f17f9e4b4a2 100644 --- a/src/runtime/stackvm/stackvm.cc +++ b/src/runtime/stackvm/stackvm.cc @@ -22,7 +22,6 @@ * \file stackvm.cc */ #include -#include #include #include #include "stackvm.h" @@ -392,50 +391,49 @@ void StackVM::Run(State* s) const { } // intrinsics case TVM_STRUCT_GET: { - using namespace ir; int index = code[pc + 1].v_int; int kind = code[pc + 2].v_int; DLTensor* arr = static_cast(stack[sp].v_handle); switch (kind) { - case intrinsic::kArrData: { + case StackVM::kArrData: { stack[sp].v_handle = arr[index].data; break; } - case intrinsic::kArrShape: { + case StackVM::kArrShape: { stack[sp].v_handle = arr[index].shape; break; } - case intrinsic::kArrStrides: { + case StackVM::kArrStrides: { stack[sp].v_handle = arr[index].strides; break; } - case intrinsic::kArrNDim: { + case StackVM::kArrNDim: { stack[sp].v_int64 = arr[index].ndim; break; } - case intrinsic::kArrTypeCode: { + case StackVM::kArrTypeCode: { stack[sp].v_int64 = static_cast( arr[index].dtype.code); break; } - case intrinsic::kArrTypeBits: { + case StackVM::kArrTypeBits: { stack[sp].v_int64 = static_cast( arr[index].dtype.bits); break; } - case intrinsic::kArrTypeLanes: { + case StackVM::kArrTypeLanes: { stack[sp].v_int64 = static_cast( arr[index].dtype.lanes); break; } - case intrinsic::kArrByteOffset: { + case StackVM::kArrByteOffset: { stack[sp].v_int64 = static_cast( arr[index].byte_offset); break; } - case intrinsic::kArrDeviceId: { + case StackVM::kArrDeviceId: { stack[sp].v_int64 = arr[index].ctx.device_id; break; } - case intrinsic::kArrDeviceType: { + case StackVM::kArrDeviceType: { stack[sp].v_int64 = static_cast( arr[index].ctx.device_type); break; } - case intrinsic::kArrAddr: { + case StackVM::kArrAddr: { stack[sp].v_handle = arr + index; break; } - case intrinsic::kTVMValueContent: { + case StackVM::kTVMValueContent: { stack[sp] = static_cast(stack[sp].v_handle)[index]; break; } default: LOG(FATAL) << "unhandled get " << kind; @@ -444,51 +442,50 @@ void StackVM::Run(State* s) const { break; } case TVM_STRUCT_SET: { - using namespace ir; int index = code[pc + 1].v_int; int kind = code[pc + 2].v_int; DLTensor* arr = static_cast(stack[sp - 1].v_handle); switch (kind) { - case intrinsic::kArrData: { + case StackVM::kArrData: { arr[index].data = stack[sp].v_handle; break; } - case intrinsic::kArrShape: { + case StackVM::kArrShape: { arr[index].shape = static_cast(stack[sp].v_handle); break; } - case intrinsic::kArrStrides: { + case StackVM::kArrStrides: { arr[index].strides = static_cast(stack[sp].v_handle); break; } - case intrinsic::kArrNDim: { + case StackVM::kArrNDim: { arr[index].ndim = static_cast(stack[sp].v_int64); break; } - case intrinsic::kArrTypeCode: { + case StackVM::kArrTypeCode: { arr[index].dtype.code = static_cast(stack[sp].v_int64); break; } - case intrinsic::kArrTypeBits: { + case StackVM::kArrTypeBits: { arr[index].dtype.bits = static_cast(stack[sp].v_int64); break; } - case intrinsic::kArrTypeLanes: { + case StackVM::kArrTypeLanes: { arr[index].dtype.lanes = static_cast(stack[sp].v_int64); break; } - case intrinsic::kArrByteOffset: { + case StackVM::kArrByteOffset: { arr[index].byte_offset = static_cast(stack[sp].v_int64); break; } - case intrinsic::kArrDeviceId: { + case StackVM::kArrDeviceId: { arr[index].ctx.device_id = static_cast(stack[sp].v_int64); break; } - case intrinsic::kArrDeviceType: { + case StackVM::kArrDeviceType: { arr[index].ctx.device_type = static_cast(stack[sp].v_int64); break; } - case intrinsic::kTVMValueContent: { + case StackVM::kTVMValueContent: { static_cast(stack[sp - 1].v_handle)[index] = stack[sp]; break; } default: LOG(FATAL) << "unhandled tvm_struct_set " << kind; diff --git a/src/runtime/stackvm/stackvm.h b/src/runtime/stackvm/stackvm.h index 6ed9647e2da9..f36e171cdf3e 100644 --- a/src/runtime/stackvm/stackvm.h +++ b/src/runtime/stackvm/stackvm.h @@ -38,6 +38,7 @@ namespace tvm { namespace runtime { using runtime::operator<<; + /*! * \brief A simple stack-based virtual machine program. */ @@ -283,6 +284,25 @@ class StackVM { */ TVM_STRUCT_SET }; + /*! \brief The kind of structure field info */ + enum StructFieldKind : int { + // array head address + kArrAddr, + kArrData, + kArrShape, + kArrStrides, + kArrNDim, + kArrTypeCode, + kArrTypeBits, + kArrTypeLanes, + kArrByteOffset, + kArrDeviceId, + kArrDeviceType, + kArrKindBound_, + // TVMValue field + kTVMValueContent, + kTVMValueKindBound_ + }; /*! \brief The code structure */ union Code { OpCode op_code;