diff --git a/apps/extension/Makefile b/apps/extension/Makefile index 3a1f8a2160ee..41e9bf621cb6 100644 --- a/apps/extension/Makefile +++ b/apps/extension/Makefile @@ -6,7 +6,7 @@ PKG_CFLAGS = -std=c++11 -O2 -fPIC\ -I${TVM_ROOT}/3rdparty/dlpack/include\ -I${TVM_ROOT}/3rdparty/HalideIR/src -PKG_LDFLAGS =-L${TVM_ROOT}/lib +PKG_LDFLAGS =-L${TVM_ROOT}/build UNAME_S := $(shell uname -s) ifeq ($(UNAME_S), Darwin) diff --git a/apps/extension/python/tvm_ext/__init__.py b/apps/extension/python/tvm_ext/__init__.py index 25286f67b4f5..78b407ae9aa1 100644 --- a/apps/extension/python/tvm_ext/__init__.py +++ b/apps/extension/python/tvm_ext/__init__.py @@ -31,7 +31,7 @@ def __init__(self, handle): def __del__(self): # You can also call your own customized # deleter if you can free it via your own FFI. - tvm.nd.free_extension_handle(self.handle, 17) + tvm.nd.free_extension_handle(self.handle, self.__class__._tvm_tcode) @property def _tvm_handle(self): @@ -42,3 +42,30 @@ def __getitem__(self, idx): # Register IntVec extension on python side. tvm.register_extension(IntVec, IntVec) + + +nd_create = tvm.get_global_func("tvm_ext.nd_create") +nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two") +nd_get_addtional_info = tvm.get_global_func("tvm_ext.nd_get_addtional_info") + +class NDSubClass(tvm.nd.NDArrayBase): + """Example for subclassing TVM's NDArray infrastructure. + + By inheriting TMV's NDArray, external libraries could + leverage TVM's FFI without any modification. + """ + # Should be consistent with the type-trait set in the backend + _array_type_code = 1 + + @staticmethod + def create(addtional_info): + return nd_create(addtional_info) + + @property + def addtional_info(self): + return nd_get_addtional_info(self) + + def __add__(self, other): + return nd_add_two(self, other) + +tvm.register_extension(NDSubClass, NDSubClass) diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc index 362ac62dea3d..97e0ada25a2e 100644 --- a/apps/extension/src/tvm_ext.cc +++ b/apps/extension/src/tvm_ext.cc @@ -7,24 +7,87 @@ #include #include #include +#include #include +#include namespace tvm_ext { using IntVector = std::vector; +class NDSubClass; } // namespace tvm_ext namespace tvm { namespace runtime { template<> -struct extension_class_info { +struct extension_type_info { static const int code = 17; }; +template<> +struct array_type_info { + static const int code = 1; +}; } // namespace tvm } // namespace runtime using namespace tvm; using namespace tvm::runtime; +namespace tvm_ext { +/*! + * \brief A subclass of TVM's NDArray. + * + * To use this extension, an external library should + * + * 1) Inherit TVM's NDArray and NDArray container, + * and define the trait `array_type_info` for this class. + * + * 2) Define a constructor in the inherited class that accepts + * a pointer to TVM's Container, which is nullable. + * + * 3) On Python frontend, inherit `tvm.nd.NDArrayBase`, + * define the class attribute `_array_type_code` consistent to + * the C++ type trait, and register the subclass using `tvm.register_extension`. + */ +class NDSubClass : public tvm::runtime::NDArray { + public: + class SubContainer : public NDArray::Container { + public: + SubContainer(int addtional_info) : + addtional_info_(addtional_info) { + array_type_code_ = array_type_info::code; + } + static bool Is(NDArray::Container *container) { + SubContainer *c = static_cast(container); + return c->array_type_code_ == array_type_info::code; + } + int addtional_info_{0}; + }; + NDSubClass(NDArray::Container *container) { + if (container == nullptr) { + data_ = nullptr; + return; + } + CHECK(SubContainer::Is(container)); + container->IncRef(); + data_ = container; + } + ~NDSubClass() { + this->reset(); + } + NDSubClass AddWith(const NDSubClass &other) const { + SubContainer *a = static_cast(data_); + SubContainer *b = static_cast(other.data_); + CHECK(a != nullptr && b != nullptr); + return NDSubClass(new SubContainer(a->addtional_info_ + b->addtional_info_)); + } + int get_additional_info() const { + SubContainer *self = static_cast(data_); + CHECK(self != nullptr); + return self->addtional_info_; + } +}; +} // namespace tvm_ext + namespace tvm_ext { TVM_REGISTER_EXT_TYPE(IntVector); @@ -64,6 +127,26 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = (*tvm::runtime::Registry::Get("device_api.cpu"))(); }); + +TVM_REGISTER_GLOBAL("tvm_ext.nd_create") +.set_body([](TVMArgs args, TVMRetValue *rv) { + int addtional_info = args[0]; + *rv = NDSubClass(new NDSubClass::SubContainer(addtional_info)); +}); + +TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two") +.set_body([](TVMArgs args, TVMRetValue *rv) { + NDSubClass a = args[0]; + NDSubClass b = args[1]; + *rv = a.AddWith(b); +}); + +TVM_REGISTER_GLOBAL("tvm_ext.nd_get_addtional_info") +.set_body([](TVMArgs args, TVMRetValue *rv) { + NDSubClass a = args[0]; + *rv = a.get_additional_info(); +}); + } // namespace tvm_ext // External function exposed to runtime. diff --git a/apps/extension/tests/test_ext.py b/apps/extension/tests/test_ext.py index def30803135e..a6246d6be2e1 100644 --- a/apps/extension/tests/test_ext.py +++ b/apps/extension/tests/test_ext.py @@ -32,6 +32,7 @@ def test_sym_add(): c = tvm_ext.sym_add(a, b) assert c.a == a and c.b == b + def test_ext_vec(): ivec = tvm_ext.ivec_create(1, 2, 3) assert(isinstance(ivec, tvm_ext.IntVec)) @@ -44,6 +45,7 @@ def ivec_cb(v2): tvm.convert(ivec_cb)(ivec) + def test_extract_ext(): fdict = tvm.extract_ext_funcs(tvm_ext._LIB.TVMExtDeclare) assert fdict["mul"](3, 4) == 12 @@ -68,7 +70,21 @@ def check_llvm(): check_llvm() +def test_nd_subclass(): + a = tvm_ext.NDSubClass.create(addtional_info=3) + b = tvm_ext.NDSubClass.create(addtional_info=5) + c = a + b + d = a + a + e = b + b + assert(a.addtional_info == 3) + assert(b.addtional_info == 5) + assert(c.addtional_info == 8) + assert(d.addtional_info == 6) + assert(e.addtional_info == 10) + + if __name__ == "__main__": + test_nd_subclass() test_extern_call() test_ext_dev() test_ext_vec() diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index e2a447e4235c..2b9674301607 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -178,10 +178,30 @@ class NDArray { Container* data_{nullptr}; // enable internal functions friend struct Internal; + friend class TVMPODValue_; + friend class TVMArgValue; friend class TVMRetValue; friend class TVMArgsSetter; }; +/*! + * \brief The type trait indicates subclass of TVM's NDArray. + * For irrelavant classes, code = -1. + * For TVM NDArray itself, code = 0. + * All subclasses of NDArray should override code > 0. + */ +template +struct array_type_info { + /*! \brief the value of the traits */ + static const int code = -1; +}; + +// Overrides the type trait for tvm's NDArray. +template<> +struct array_type_info { + static const int code = 0; +}; + /*! * \brief Save a DLTensor to stream * \param strm The outpu stream @@ -196,7 +216,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor); * the pointer to the NDArrayContainer can be directly * interpreted as a DLTensor* * - * \note: do not use this function directly, use NDArray. + * \note do not use this function directly, use NDArray. */ class NDArray::Container { public: @@ -228,16 +248,19 @@ class NDArray::Container { protected: friend class NDArray; + friend class TVMPODValue_; + friend class TVMArgValue; + friend class TVMRetValue; friend class RPCWrappedFunc; /*! * \brief Type flag used to indicate subclass. * Default value 0 means normal NDArray::Conatainer. * * We can extend a more specialized NDArray::Container - * and use the array_type_index_ to indicate + * and use the array_type_code_ to indicate * the specific array subclass. */ - uint32_t array_type_index_{0}; + int32_t array_type_code_{0}; /*! \brief The internal reference counter */ std::atomic ref_counter_{0}; /*! diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index a3b4a1696bf0..1398da0d748b 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -362,7 +362,7 @@ inline std::string TVMType2String(TVMType t); * \tparam T the typename */ template -struct extension_class_info { +struct extension_type_info { static const int code = 0; }; @@ -455,6 +455,15 @@ class TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); return value_.v_ctx; } + template::value>::type> + TNDArray AsNDArray() const { + if (type_code_ == kNull) return TNDArray(nullptr); + auto *container = static_cast(value_.v_handle); + CHECK_EQ(container->array_type_code_, array_type_info::code); + return TNDArray(container); + } template const TExtension& AsExtension() const { CHECK_LT(type_code_, kExtEnd); @@ -561,7 +570,7 @@ class TVMArgValue : public TVMPODValue_ { inline TNodeRef AsNodeRef() const; template::value>::type> + std::is_class::value>::type> inline operator T() const; template::code != 0>::type> + extension_type_info::code != 0>::type> TVMRetValue& operator=(const T& other) { this->SwitchToClass( - extension_class_info::code, other); + extension_type_info::code, other); return *this; } /*! @@ -1094,7 +1103,7 @@ class TVMArgsSetter { // extension template::code != 0>::type> + extension_type_info::code != 0>::type> inline void operator()(size_t i, const T& value) const; // NodeRef related extenstions: in tvm/packed_func_ext.h inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*) @@ -1212,40 +1221,53 @@ inline R TypedPackedFunc::operator()(Args... args) const { // extension and node type handling namespace detail { -template +template struct TVMValueCast { static T Apply(const TSrc* self) { + static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions"); return self->template AsNodeRef(); } }; template -struct TVMValueCast { +struct TVMValueCast { static T Apply(const TSrc* self) { return self->template AsExtension(); } }; + +template +struct TVMValueCast { + static T Apply(const TSrc* self) { + return self->template AsNDArray(); + } +}; + } // namespace detail template inline TVMArgValue::operator T() const { return detail:: - TVMValueCast::code != 0> + TVMValueCast::code != 0), + (array_type_info::code > 0)> ::Apply(this); } template inline TVMRetValue::operator T() const { return detail:: - TVMValueCast::code != 0> + TVMValueCast::code != 0), + (array_type_info::code > 0)> ::Apply(this); } template inline void TVMArgsSetter::operator()(size_t i, const T& value) const { - static_assert(extension_class_info::code != 0, + static_assert(extension_type_info::code != 0, "Need to have extesion code"); - type_codes_[i] = extension_class_info::code; + type_codes_[i] = extension_type_info::code; values_[i].v_handle = const_cast(&value); } @@ -1262,9 +1284,9 @@ struct ExtTypeInfo { template inline ExtTypeVTable* ExtTypeVTable::Register_() { - const int code = extension_class_info::code; + const int code = extension_type_info::code; static_assert(code != 0, - "require extension_class_info traits to be declared with non-zero code"); + "require extension_type_info traits to be declared with non-zero code"); ExtTypeVTable vt; vt.clone = ExtTypeInfo::clone; vt.destroy = ExtTypeInfo::destroy; diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index 9466056a1282..a53a76f4df2e 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -133,7 +133,7 @@ class Registry { /*! * \brief Macro to register extension type. * This must be registered in a cc file - * after the trait extension_class_info is defined. + * after the trait extension_type_info is defined. */ #define TVM_REGISTER_EXT_TYPE(T) \ TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \ diff --git a/nnvm/include/nnvm/compiler/packed_func_ext.h b/nnvm/include/nnvm/compiler/packed_func_ext.h index e289fd4efa59..a79574fa0879 100644 --- a/nnvm/include/nnvm/compiler/packed_func_ext.h +++ b/nnvm/include/nnvm/compiler/packed_func_ext.h @@ -40,17 +40,17 @@ namespace tvm { namespace runtime { template<> -struct extension_class_info { +struct extension_type_info { static const int code = 16; }; template<> -struct extension_class_info { +struct extension_type_info { static const int code = 17; }; template<> -struct extension_class_info { +struct extension_type_info { static const int code = 18; }; diff --git a/nnvm/src/compiler/packed_func_ext.cc b/nnvm/src/compiler/packed_func_ext.cc index 1a19feabfe8a..8530a5556b64 100644 --- a/nnvm/src/compiler/packed_func_ext.cc +++ b/nnvm/src/compiler/packed_func_ext.cc @@ -76,8 +76,8 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout") if (ret.type_code() == TVMTypeCode::kNull) { return false; } - CHECK_EQ(ret.type_code(), tvm::runtime::extension_class_info::code) - << " expected " << "Symbol (code = " << tvm::runtime::extension_class_info::code + CHECK_EQ(ret.type_code(), tvm::runtime::extension_type_info::code) + << " expected " << "Symbol (code = " << tvm::runtime::extension_type_info::code << ") but get code = " << ret.type_code(); *ret_symbol = *(static_cast(ret.value().v_handle)); return true; diff --git a/python/tvm/_ffi/_ctypes/function.py b/python/tvm/_ffi/_ctypes/function.py index 3c2a7a5f8c9b..5c176f819105 100644 --- a/python/tvm/_ffi/_ctypes/function.py +++ b/python/tvm/_ffi/_ctypes/function.py @@ -223,13 +223,13 @@ def _handle_return_func(x): _node.__init_by_constructor__ = __init_handle_by_constructor__ RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module -RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False) +RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True) C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func( _handle_return_func, TypeCode.FUNC_HANDLE) C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func( _return_module, TypeCode.MODULE_HANDLE) -C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True) -C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False) +C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True, False) +C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True) _CLASS_MODULE = None _CLASS_FUNCTION = None diff --git a/python/tvm/_ffi/_ctypes/ndarray.py b/python/tvm/_ffi/_ctypes/ndarray.py index 8b88e7dc98ea..37a18cbe4051 100644 --- a/python/tvm/_ffi/_ctypes/ndarray.py +++ b/python/tvm/_ffi/_ctypes/ndarray.py @@ -4,7 +4,7 @@ import ctypes from ..base import _LIB, check_call, c_str -from ..runtime_ctypes import TVMArrayHandle +from ..runtime_ctypes import TVMArrayHandle, TVMNDArrayContainerHandle from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle @@ -28,7 +28,7 @@ def _from_dlpack(dltensor): check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle))) ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor) ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0)) - return _make_array(handle, False) + return _make_array(handle, False, False) raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once") @@ -77,9 +77,15 @@ def to_dlpack(self): return ctypes.pythonapi.PyCapsule_New(handle, _c_str_dltensor, _c_dlpack_deleter) -def _make_array(handle, is_view): +def _make_array(handle, is_view, is_container): + global _TVM_ND_CLS handle = ctypes.cast(handle, TVMArrayHandle) - return _CLASS_NDARRAY(handle, is_view) + fcreate = _CLASS_NDARRAY + if is_container and _TVM_ND_CLS: + array_type_info = ctypes.cast(handle, TVMNDArrayContainerHandle).array_type_info.value + if array_type_info > 0: + fcreate = _TVM_ND_CLS[array_type_info] + return fcreate(handle, is_view) _TVM_COMPATS = () @@ -91,6 +97,11 @@ def _reg_extension(cls, fcreate): RETURN_SWITCH[cls._tvm_tcode] = fret C_TO_PY_ARG_SWITCH[cls._tvm_tcode] = _wrap_arg_func(fret, cls._tvm_tcode) +_TVM_ND_CLS = {} + +def _reg_ndarray(cls, fcreate): + global _TVM_ND_CLS + _TVM_ND_CLS[cls._array_type_code] = fcreate _CLASS_NDARRAY = None diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index ac5532835c47..feb2fffebd23 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -2,7 +2,7 @@ from ..base import TVMError from libcpp.vector cimport vector from cpython.version cimport PY_MAJOR_VERSION from cpython cimport pycapsule -from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t +from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t import ctypes cdef enum TVMTypeCode: @@ -61,6 +61,14 @@ ctypedef void* TVMRetValueHandle ctypedef void* TVMFunctionHandle ctypedef void* NodeHandle +ctypedef struct TVMNDArrayContainer: + DLTensor dl_tensor + void* manager_ctx + void (*deleter)(DLManagedTensor* self) + int32_t array_type_info + +ctypedef TVMNDArrayContainer* TVMNDArrayContainerHandle + ctypedef int (*TVMPackedCFunc)( TVMValue* args, int* type_codes, diff --git a/python/tvm/_ffi/_cython/function.pxi b/python/tvm/_ffi/_cython/function.pxi index dcbf4c665e66..9995aea6357a 100644 --- a/python/tvm/_ffi/_cython/function.pxi +++ b/python/tvm/_ffi/_cython/function.pxi @@ -33,7 +33,7 @@ cdef int tvm_callback(TVMValue* args, if tcode != kArrayHandle: pyargs.append(make_ret(value, tcode)) else: - pyargs.append(c_make_array(value.v_handle, True)) + pyargs.append(c_make_array(value.v_handle, True, False)) try: rv = local_pyfunc(*pyargs) except Exception: @@ -175,7 +175,7 @@ cdef inline object make_ret(TVMValue value, int tcode): elif tcode == kFloat: return value.v_float64 elif tcode == kNDArrayContainer: - return c_make_array(value.v_handle, False) + return c_make_array(value.v_handle, False, True) elif tcode == kStr: return py_str(value.v_str) elif tcode == kBytes: diff --git a/python/tvm/_ffi/_cython/ndarray.pxi b/python/tvm/_ffi/_cython/ndarray.pxi index 0a507affec1c..4cd6709a0118 100644 --- a/python/tvm/_ffi/_cython/ndarray.pxi +++ b/python/tvm/_ffi/_cython/ndarray.pxi @@ -20,7 +20,7 @@ def _from_dlpack(object dltensor): # set name and destructor to be empty pycapsule.PyCapsule_SetDestructor(dltensor, NULL) pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor) - return c_make_array(chandle, 0) + return c_make_array(chandle, False, False) raise ValueError("Expect a dltensor field, pycapsule.PyCapsule can only be consumed once") @@ -73,8 +73,15 @@ cdef class NDArrayBase: return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter) -cdef c_make_array(void* chandle, is_view): - ret = _CLASS_NDARRAY(None, is_view) +cdef c_make_array(void* chandle, is_view, is_container): + global _TVM_ND_CLS + cdef int32_t array_type_info + fcreate = _CLASS_NDARRAY + if is_container and len(_TVM_ND_CLS) > 0: + array_type_info = (chandle).array_type_info + if array_type_info > 0: + fcreate = _TVM_ND_CLS[array_type_info] + ret = fcreate(None, is_view) (ret).chandle = chandle return ret @@ -89,11 +96,16 @@ def _reg_extension(cls, fcreate): if fcreate: _TVM_EXT_RET[cls._tvm_tcode] = fcreate +cdef _TVM_ND_CLS = {} -def _make_array(handle, is_view): +def _reg_ndarray(cls, fcreate): + global _TVM_ND_CLS + _TVM_ND_CLS[cls._array_type_code] = fcreate + +def _make_array(handle, is_view, is_container): cdef unsigned long long ptr ptr = ctypes.cast(handle, ctypes.c_void_p).value - return c_make_array(ptr, is_view) + return c_make_array(ptr, is_view, is_container) cdef object _CLASS_NDARRAY = None diff --git a/python/tvm/_ffi/ndarray.py b/python/tvm/_ffi/ndarray.py index e49c3b62f473..3c5b170bdca7 100644 --- a/python/tvm/_ffi/ndarray.py +++ b/python/tvm/_ffi/ndarray.py @@ -17,15 +17,18 @@ if _FFI_MODE == "ctypes": raise ImportError() if sys.version_info >= (3, 0): - from ._cy3.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack + from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack from ._cy3.core import NDArrayBase as _NDArrayBase + from ._cy3.core import _reg_extension, _reg_ndarray else: - from ._cy2.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack + from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack from ._cy2.core import NDArrayBase as _NDArrayBase + from ._cy2.core import _reg_extension, _reg_ndarray except IMPORT_EXCEPT: # pylint: disable=wrong-import-position - from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack + from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack from ._ctypes.ndarray import NDArrayBase as _NDArrayBase + from ._ctypes.ndarray import _reg_extension, _reg_ndarray def context(dev_type, dev_id=0): @@ -111,7 +114,7 @@ def empty(shape, dtype="float32", ctx=context(1, 0)): ctx.device_type, ctx.device_id, ctypes.byref(handle))) - return _make_array(handle, False) + return _make_array(handle, False, False) def from_dlpack(dltensor): @@ -295,6 +298,7 @@ def free_extension_handle(handle, type_code): """ check_call(_LIB.TVMExtTypeFree(handle, ctypes.c_int(type_code))) + def register_extension(cls, fcreate=None): """Register a extension class to TVM. @@ -306,21 +310,26 @@ def register_extension(cls, fcreate=None): cls : class The class object to be registered as extension. + fcreate : function, optional + The creation function to create a class object given handle value. + Note ---- - The registered class is requires one property: _tvm_handle and a class attribute _tvm_tcode. + The registered class is requires one property: _tvm_handle. + + If the registered class is a subclass of NDArray, + it is required to have a class attribute _array_type_code. + Otherwise, it is required to have a class attribute _tvm_tcode. - ```_tvm_handle``` returns integer represents the address of the handle. - - ```_tvm_tcode``` gives integer represents type code of the class. + - ```_tvm_tcode``` or ```_array_type_code``` gives integer represents type + code of the class. Returns ------- cls : class The class being registered. - fcreate : function, optional - The creation function to create a class object given handle value. - Example ------- The following code registers user defined class @@ -339,7 +348,13 @@ def __init__(self): def _tvm_handle(self): return self.handle.value """ - if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN: - raise ValueError("Cannot register create when extension tcode is same as buildin") - _reg_extension(cls, fcreate) + if issubclass(cls, _NDArrayBase): + assert fcreate is not None + assert hasattr(cls, "_array_type_code") + _reg_ndarray(cls, fcreate) + else: + assert hasattr(cls, "_tvm_tcode") + if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN: + raise ValueError("Cannot register create when extension tcode is same as buildin") + _reg_extension(cls, fcreate) return cls diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index ef5316b5e267..e1b78735a97d 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -240,3 +240,12 @@ class TVMArray(ctypes.Structure): ("byte_offset", ctypes.c_uint64)] TVMArrayHandle = ctypes.POINTER(TVMArray) + +class TVMNDArrayContainer(ctypes.Structure): + """TVM NDArray::Container""" + _fields_ = [("dl_tensor", TVMArray), + ("manager_ctx", ctypes.c_void_p), + ("deleter", ctypes.c_void_p), + ("array_type_info", ctypes.c_int32)] + +TVMNDArrayContainerHandle = ctypes.POINTER(TVMNDArrayContainer) diff --git a/python/tvm/ndarray.py b/python/tvm/ndarray.py index 448e5f6d8bdb..8691d7a03900 100644 --- a/python/tvm/ndarray.py +++ b/python/tvm/ndarray.py @@ -15,7 +15,7 @@ class NDArray(NDArrayBase): """Lightweight NDArray class of TVM runtime. - Strictly this is only an Array Container(a buffer object) + Strictly this is only an Array Container (a buffer object) No arthimetic operations are defined. All operations are performed by TVM functions. diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index abe26fabe9ea..83c0ba602927 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -168,7 +168,7 @@ namespace tvm { namespace runtime { template<> -struct extension_class_info { +struct extension_type_info { static const int code = kExtBegin + 1; }; } // runtime