From 5e3f2510fa763567a93a9dba6fe98365c42739e4 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 23 Dec 2019 16:37:12 -0800 Subject: [PATCH] [RUNTIME] Remove Extension VTable in favor of Unified Object system. Before the unified object protocol, we support pass additional extension objects around by declaring a type as an extension type. The old extension mechanism requires the types to register their constructor and deleter to a VTable and does not enjoy the benefit of the self-contained deletion property of the new Object system. This PR upgrades the extension example to make use of the new object system and removed the old Extension VTable. Note that the register_extension funtion in the python side continues to work when the passed argument does not require explicit container copy/deletion, which covers the current usecases of the extension mechanism. --- apps/extension/Makefile | 3 +- apps/extension/python/tvm_ext/__init__.py | 16 +--- apps/extension/src/tvm_ext.cc | 44 +++++++--- include/tvm/runtime/c_runtime_api.h | 8 -- include/tvm/runtime/packed_func.h | 102 ++-------------------- include/tvm/runtime/registry.h | 9 -- python/tvm/_ffi/ndarray.py | 14 --- python/tvm/ndarray.py | 2 +- src/runtime/registry.cc | 33 +------ tests/cpp/packed_func_test.cc | 50 ----------- 10 files changed, 42 insertions(+), 239 deletions(-) diff --git a/apps/extension/Makefile b/apps/extension/Makefile index 14c71d92ca20..1680a003e06f 100644 --- a/apps/extension/Makefile +++ b/apps/extension/Makefile @@ -20,8 +20,7 @@ TVM_ROOT=$(shell cd ../..; pwd) PKG_CFLAGS = -std=c++11 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${TVM_ROOT}/3rdparty/dmlc-core/include\ - -I${TVM_ROOT}/3rdparty/dlpack/include\ - -I${TVM_ROOT}/3rdparty/HalideIR/src + -I${TVM_ROOT}/3rdparty/dlpack/include PKG_LDFLAGS =-L${TVM_ROOT}/build UNAME_S := $(shell uname -s) diff --git a/apps/extension/python/tvm_ext/__init__.py b/apps/extension/python/tvm_ext/__init__.py index 38d511eeb617..7404a717f778 100644 --- a/apps/extension/python/tvm_ext/__init__.py +++ b/apps/extension/python/tvm_ext/__init__.py @@ -38,18 +38,9 @@ def load_lib(): ivec_create = tvm.get_global_func("tvm_ext.ivec_create") ivec_get = tvm.get_global_func("tvm_ext.ivec_get") -class IntVec(object): +@tvm.register_object("tvm_ext.IntVector") +class IntVec(tvm.Object): """Example for using extension class in c++ """ - _tvm_tcode = 17 - - def __init__(self, handle): - self.handle = handle - - def __del__(self): - # You can also call your own customized - # deleter if you can free it via your own FFI. - tvm.nd.free_extension_handle(self.handle, self.__class__._tvm_tcode) - @property def _tvm_handle(self): return self.handle.value @@ -57,9 +48,6 @@ def _tvm_handle(self): def __getitem__(self, idx): return ivec_get(self, idx) -# Register IntVec extension on python side. -tvm.register_extension(IntVec, IntVec) - nd_create = tvm.get_global_func("tvm_ext.nd_create") nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two") diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc index 8655fa7d0c30..788c28da18d3 100644 --- a/apps/extension/src/tvm_ext.cc +++ b/apps/extension/src/tvm_ext.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,17 +30,12 @@ #include namespace tvm_ext { -using IntVector = std::vector; class NDSubClass; } // namespace tvm_ext namespace tvm { namespace runtime { template<> -struct extension_type_info { - static const int code = 17; -}; -template<> struct array_type_info { static const int code = 1; }; @@ -104,24 +99,47 @@ class NDSubClass : public tvm::runtime::NDArray { return self->addtional_info_; } }; + + +/*! + * \brief Introduce additional extension data structures + * by sub-classing TVM's object system. + */ +class IntVectorObj : public Object { + public: + std::vector vec; + + static constexpr const char* _type_key = "tvm_ext.IntVector"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntVectorObj, Object); +}; + +/*! + * \brief Int vector reference class. + */ +class IntVector : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(IntVector, ObjectRef, IntVectorObj); +}; + +TVM_REGISTER_OBJECT_TYPE(IntVectorObj); + } // namespace tvm_ext namespace tvm_ext { -TVM_REGISTER_EXT_TYPE(IntVector); - TVM_REGISTER_GLOBAL("tvm_ext.ivec_create") .set_body([](TVMArgs args, TVMRetValue *rv) { - IntVector vec; + auto n = tvm::runtime::make_object(); for (int i = 0; i < args.size(); ++i) { - vec.push_back(args[i].operator int()); + n->vec.push_back(args[i].operator int()); } - *rv = vec; + *rv = IntVector(n); }); TVM_REGISTER_GLOBAL("tvm_ext.ivec_get") .set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = args[0].AsExtension()[args[1].operator int()]; + IntVector p = args[0]; + *rv = p->vec[args[1].operator int()]; }); diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 5053326058bc..dda2a98dac22 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -234,14 +234,6 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, int query_imports, TVMFunctionHandle *out); -/*! - * \brief Free front-end extension type resource. - * \param handle The extension handle. - * \param type_code The type of of the extension type. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMExtTypeFree(void* handle, int type_code); - /*! * \brief Free the Module * \param mod The module to be freed. diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 1d7db66ec570..27dcb4130b4c 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -387,7 +387,6 @@ inline std::string TVMType2String(TVMType t); #define TVM_CHECK_TYPE_CODE(CODE, T) \ CHECK_EQ(CODE, T) << " expected " \ << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ - /*! * \brief Type traits to mark if a class is tvm extension type. * @@ -404,34 +403,6 @@ struct extension_type_info { static const int code = 0; }; -/*! - * \brief Runtime function table about extension type. - */ -class ExtTypeVTable { - public: - /*! \brief function to be called to delete a handle */ - void (*destroy)(void* handle); - /*! \brief function to be called when clone a handle */ - void* (*clone)(void* handle); - /*! - * \brief Register type - * \tparam T The type to be register. - * \return The registered vtable. - */ - template - static inline ExtTypeVTable* Register_(); - /*! - * \brief Get a vtable based on type code. - * \param type_code The type code - * \return The registered vtable. - */ - TVM_DLL static ExtTypeVTable* Get(int type_code); - - private: - // Internal registration function. - TVM_DLL static ExtTypeVTable* RegisterInternal(int type_code, const ExtTypeVTable& vt); -}; - /*! * \brief Internal base class to * handle conversion to POD values. @@ -518,11 +489,6 @@ class TVMPODValue_ { CHECK_EQ(container->array_type_code_, array_type_info::code); return TNDArray(container); } - template - const TExtension& AsExtension() const { - CHECK_LT(type_code_, kExtEnd); - return static_cast(value_.v_handle)[0]; - } template::value>::type> @@ -867,20 +833,8 @@ class TVMRetValue : public TVMPODValue_ { break; } default: { - if (other.type_code() < kExtBegin) { - SwitchToPOD(other.type_code()); - value_ = other.value_; - } else { -#if TVM_RUNTIME_HEADER_ONLY - LOG(FATAL) << "Header only mode do not support ext type"; -#else - this->Clear(); - type_code_ = other.type_code(); - value_.v_handle = - (*(ExtTypeVTable::Get(other.type_code())->clone))( - other.value().v_handle); -#endif - } + SwitchToPOD(other.type_code()); + value_ = other.value_; break; } } @@ -931,13 +885,6 @@ class TVMRetValue : public TVMPODValue_ { break; } } - if (type_code_ > kExtBegin) { -#if TVM_RUNTIME_HEADER_ONLY - LOG(FATAL) << "Header only mode do not support ext type"; -#else - (*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle); -#endif - } type_code_ = kNull; } }; @@ -1317,23 +1264,16 @@ inline R TypedPackedFunc::operator()(Args... args) const { // extension and node type handling namespace detail { -template +template struct TVMValueCast { static T Apply(const TSrc* self) { - static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions"); + static_assert(!is_nd, "The default case accepts only non-extensions"); return self->template AsObjectRef(); } }; template -struct TVMValueCast { - static T Apply(const TSrc* self) { - return self->template AsExtension(); - } -}; - -template -struct TVMValueCast { +struct TVMValueCast { static T Apply(const TSrc* self) { return self->template AsNDArray(); } @@ -1345,7 +1285,6 @@ template inline TVMArgValue::operator T() const { return detail:: TVMValueCast::code != 0), (array_type_info::code > 0)> ::Apply(this); } @@ -1354,19 +1293,10 @@ template inline TVMRetValue::operator T() const { return detail:: TVMValueCast::code != 0), (array_type_info::code > 0)> ::Apply(this); } -template -inline void TVMArgsSetter::operator()(size_t i, const T& value) const { - static_assert(extension_type_info::code != 0, - "Need to have extesion code"); - type_codes_[i] = extension_type_info::code; - values_[i].v_handle = const_cast(&value); -} - // PackedFunc support inline TVMRetValue& TVMRetValue::operator=(const DataType& t) { return this->operator=(t.operator DLDataType()); @@ -1385,28 +1315,6 @@ inline void TVMArgsSetter::operator()( this->operator()(i, t.operator DLDataType()); } -// extension type handling -template -struct ExtTypeInfo { - static void destroy(void* handle) { - delete static_cast(handle); - } - static void* clone(void* handle) { - return new T(*static_cast(handle)); - } -}; - -template -inline ExtTypeVTable* ExtTypeVTable::Register_() { - const int code = extension_type_info::code; - static_assert(code != 0, - "require extension_type_info traits to be declared with non-zero code"); - ExtTypeVTable vt; - vt.clone = ExtTypeInfo::clone; - vt.destroy = ExtTypeInfo::destroy; - return ExtTypeVTable::RegisterInternal(code, vt); -} - inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) { return (*this)->GetFunction(name, query_imports); } diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index d668984f50e2..3500a7e4e398 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -311,15 +311,6 @@ class Registry { TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \ ::tvm::runtime::Registry::Register(OpName) -/*! - * \brief Macro to register extension type. - * This must be registered in a cc file - * after the trait extension_type_info is defined. - */ -#define TVM_REGISTER_EXT_TYPE(T) \ - TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \ - ::tvm::runtime::ExtTypeVTable::Register_() - } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_REGISTRY_H_ diff --git a/python/tvm/_ffi/ndarray.py b/python/tvm/_ffi/ndarray.py index 56bf4a00080c..1773d916722b 100644 --- a/python/tvm/_ffi/ndarray.py +++ b/python/tvm/_ffi/ndarray.py @@ -299,20 +299,6 @@ def copyto(self, target): raise ValueError("Unsupported target type %s" % str(type(target))) -def free_extension_handle(handle, type_code): - """Free c++ extension type handle - - Parameters - ---------- - handle : ctypes.c_void_p - The handle to the extension type. - - type_code : int - The tyoe code - """ - check_call(_LIB.TVMExtTypeFree(handle, ctypes.c_int(type_code))) - - def register_extension(cls, fcreate=None): """Register a extension class to TVM. diff --git a/python/tvm/ndarray.py b/python/tvm/ndarray.py index 9a00f78eb77f..2a7a532e660e 100644 --- a/python/tvm/ndarray.py +++ b/python/tvm/ndarray.py @@ -26,7 +26,7 @@ from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase from ._ffi.ndarray import context, empty, from_dlpack from ._ffi.ndarray import _set_class_ndarray -from ._ffi.ndarray import register_extension, free_extension_handle +from ._ffi.ndarray import register_extension class NDArray(NDArrayBase): """Lightweight NDArray class of TVM runtime. diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index ce6a281a6ead..4717d89e33c1 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -40,15 +40,10 @@ struct Registry::Manager { // and the resource can become invalid because of indeterminstic order of destruction. // The resources will only be recycled during program exit. std::unordered_map fmap; - // vtable for extension type - std::array ext_vtable; // mutex std::mutex mutex; Manager() { - for (auto& x : ext_vtable) { - x.destroy = nullptr; - } } static Manager* Global() { @@ -109,24 +104,6 @@ std::vector Registry::ListNames() { return keys; } -ExtTypeVTable* ExtTypeVTable::Get(int type_code) { - CHECK(type_code > kExtBegin && type_code < kExtEnd); - Registry::Manager* m = Registry::Manager::Global(); - ExtTypeVTable* vt = &(m->ext_vtable[type_code]); - CHECK(vt->destroy != nullptr) - << "Extension type not registered"; - return vt; -} - -ExtTypeVTable* ExtTypeVTable::RegisterInternal( - int type_code, const ExtTypeVTable& vt) { - CHECK(type_code > kExtBegin && type_code < kExtEnd); - Registry::Manager* m = Registry::Manager::Global(); - std::lock_guard lock(m->mutex); - ExtTypeVTable* pvt = &(m->ext_vtable[type_code]); - pvt[0] = vt; - return pvt; -} } // namespace runtime } // namespace tvm @@ -141,12 +118,6 @@ struct TVMFuncThreadLocalEntry { /*! \brief Thread local store that can be used to hold return values. */ typedef dmlc::ThreadLocalStore TVMFuncThreadLocalStore; -int TVMExtTypeFree(void* handle, int type_code) { - API_BEGIN(); - tvm::runtime::ExtTypeVTable::Get(type_code)->destroy(handle); - API_END(); -} - int TVMFuncRegisterGlobal( const char* name, TVMFunctionHandle f, int override) { API_BEGIN(); diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index e5f3b0d72277..f6f3b8f90e37 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -178,56 +178,6 @@ TEST(TypedPackedFunc, HighOrder) { CHECK_EQ(f1(3), 4); } -// new namespoace -namespace test { -// register int vector as extension type -using IntVector = std::vector; -} // namespace test - -namespace tvm { -namespace runtime { - -template<> -struct extension_type_info { - static const int code = kExtBegin + 1; -}; -} // runtime -} // tvm - -// do registration, this need to be in cc file -TVM_REGISTER_EXT_TYPE(test::IntVector); - -TEST(PackedFunc, ExtensionType) { - using namespace tvm; - using namespace tvm::runtime; - // note: class are copy by value. - test::IntVector vec{1, 2, 4}; - - auto copy_vec = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - // copy by value - const test::IntVector& v = args[0].AsExtension(); - CHECK(&v == &vec); - test::IntVector v2 = args[0]; - CHECK_EQ(v2.size(), 3U); - CHECK_EQ(v[2], 4); - // return copy by value - *rv = v2; - }); - - auto pass_vec = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - // copy by value - *rv = args[0]; - }); - - test::IntVector vret1 = copy_vec(vec); - test::IntVector vret2 = pass_vec(copy_vec(vec)); - CHECK_EQ(vret1.size(), 3U); - CHECK_EQ(vret2.size(), 3U); - CHECK_EQ(vret1[2], 4); - CHECK_EQ(vret2[2], 4); -} - - int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe";