Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUNTIME] Remove Extension VTable in favor of Unified Object system. #4578

Merged
merged 1 commit into from
Dec 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions apps/extension/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ TVM_ROOT=$(shell cd ../..; pwd)
PKG_CFLAGS = -std=c++11 -O2 -fPIC\
-I${TVM_ROOT}/include\
-I${TVM_ROOT}/3rdparty/dmlc-core/include\
-I${TVM_ROOT}/3rdparty/dlpack/include\
-I${TVM_ROOT}/3rdparty/HalideIR/src
-I${TVM_ROOT}/3rdparty/dlpack/include

PKG_LDFLAGS =-L${TVM_ROOT}/build
UNAME_S := $(shell uname -s)
Expand Down
16 changes: 2 additions & 14 deletions apps/extension/python/tvm_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,28 +38,16 @@ def load_lib():
ivec_create = tvm.get_global_func("tvm_ext.ivec_create")
ivec_get = tvm.get_global_func("tvm_ext.ivec_get")

class IntVec(object):
@tvm.register_object("tvm_ext.IntVector")
class IntVec(tvm.Object):
"""Example for using extension class in c++ """
_tvm_tcode = 17

def __init__(self, handle):
self.handle = handle

def __del__(self):
# You can also call your own customized
# deleter if you can free it via your own FFI.
tvm.nd.free_extension_handle(self.handle, self.__class__._tvm_tcode)

@property
def _tvm_handle(self):
return self.handle.value

def __getitem__(self, idx):
return ivec_get(self, idx)

# Register IntVec extension on python side.
tvm.register_extension(IntVec, IntVec)


nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
Expand Down
44 changes: 31 additions & 13 deletions apps/extension/src/tvm_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -30,17 +30,12 @@
#include <tvm/runtime/device_api.h>

namespace tvm_ext {
using IntVector = std::vector<int>;
class NDSubClass;
} // namespace tvm_ext

namespace tvm {
namespace runtime {
template<>
struct extension_type_info<tvm_ext::IntVector> {
static const int code = 17;
};
template<>
struct array_type_info<tvm_ext::NDSubClass> {
static const int code = 1;
};
Expand Down Expand Up @@ -104,24 +99,47 @@ class NDSubClass : public tvm::runtime::NDArray {
return self->addtional_info_;
}
};


/*!
* \brief Introduce additional extension data structures
* by sub-classing TVM's object system.
*/
class IntVectorObj : public Object {
public:
std::vector<int> vec;

static constexpr const char* _type_key = "tvm_ext.IntVector";
TVM_DECLARE_FINAL_OBJECT_INFO(IntVectorObj, Object);
};

/*!
* \brief Int vector reference class.
*/
class IntVector : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(IntVector, ObjectRef, IntVectorObj);
};

TVM_REGISTER_OBJECT_TYPE(IntVectorObj);

} // namespace tvm_ext

namespace tvm_ext {

TVM_REGISTER_EXT_TYPE(IntVector);

TVM_REGISTER_GLOBAL("tvm_ext.ivec_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
IntVector vec;
auto n = tvm::runtime::make_object<IntVectorObj>();
for (int i = 0; i < args.size(); ++i) {
vec.push_back(args[i].operator int());
n->vec.push_back(args[i].operator int());
}
*rv = vec;
*rv = IntVector(n);
});

TVM_REGISTER_GLOBAL("tvm_ext.ivec_get")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = args[0].AsExtension<IntVector>()[args[1].operator int()];
IntVector p = args[0];
*rv = p->vec[args[1].operator int()];
});


Expand Down
8 changes: 0 additions & 8 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,6 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
int query_imports,
TVMFunctionHandle *out);

/*!
* \brief Free front-end extension type resource.
* \param handle The extension handle.
* \param type_code The type of of the extension type.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMExtTypeFree(void* handle, int type_code);

/*!
* \brief Free the Module
* \param mod The module to be freed.
Expand Down
102 changes: 5 additions & 97 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,6 @@ inline std::string TVMType2String(TVMType t);
#define TVM_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \

/*!
* \brief Type traits to mark if a class is tvm extension type.
*
Expand All @@ -404,34 +403,6 @@ struct extension_type_info {
static const int code = 0;
};

/*!
* \brief Runtime function table about extension type.
*/
class ExtTypeVTable {
public:
/*! \brief function to be called to delete a handle */
void (*destroy)(void* handle);
/*! \brief function to be called when clone a handle */
void* (*clone)(void* handle);
/*!
* \brief Register type
* \tparam T The type to be register.
* \return The registered vtable.
*/
template <typename T>
static inline ExtTypeVTable* Register_();
/*!
* \brief Get a vtable based on type code.
* \param type_code The type code
* \return The registered vtable.
*/
TVM_DLL static ExtTypeVTable* Get(int type_code);

private:
// Internal registration function.
TVM_DLL static ExtTypeVTable* RegisterInternal(int type_code, const ExtTypeVTable& vt);
};

/*!
* \brief Internal base class to
* handle conversion to POD values.
Expand Down Expand Up @@ -518,11 +489,6 @@ class TVMPODValue_ {
CHECK_EQ(container->array_type_code_, array_type_info<TNDArray>::code);
return TNDArray(container);
}
template<typename TExtension>
const TExtension& AsExtension() const {
CHECK_LT(type_code_, kExtEnd);
return static_cast<TExtension*>(value_.v_handle)[0];
}
template<typename TObjectRef,
typename = typename std::enable_if<
std::is_class<TObjectRef>::value>::type>
Expand Down Expand Up @@ -867,20 +833,8 @@ class TVMRetValue : public TVMPODValue_ {
break;
}
default: {
if (other.type_code() < kExtBegin) {
SwitchToPOD(other.type_code());
value_ = other.value_;
} else {
#if TVM_RUNTIME_HEADER_ONLY
LOG(FATAL) << "Header only mode do not support ext type";
#else
this->Clear();
type_code_ = other.type_code();
value_.v_handle =
(*(ExtTypeVTable::Get(other.type_code())->clone))(
other.value().v_handle);
#endif
}
SwitchToPOD(other.type_code());
value_ = other.value_;
break;
}
}
Expand Down Expand Up @@ -931,13 +885,6 @@ class TVMRetValue : public TVMPODValue_ {
break;
}
}
if (type_code_ > kExtBegin) {
#if TVM_RUNTIME_HEADER_ONLY
LOG(FATAL) << "Header only mode do not support ext type";
#else
(*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle);
#endif
}
type_code_ = kNull;
}
};
Expand Down Expand Up @@ -1317,23 +1264,16 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {

// extension and node type handling
namespace detail {
template<typename T, typename TSrc, bool is_ext, bool is_nd>
template<typename T, typename TSrc, bool is_nd>
struct TVMValueCast {
static T Apply(const TSrc* self) {
static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions");
static_assert(!is_nd, "The default case accepts only non-extensions");
return self->template AsObjectRef<T>();
}
};

template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, true, false> {
static T Apply(const TSrc* self) {
return self->template AsExtension<T>();
}
};

template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, false, true> {
struct TVMValueCast<T, TSrc, true> {
static T Apply(const TSrc* self) {
return self->template AsNDArray<T>();
}
Expand All @@ -1345,7 +1285,6 @@ template<typename T, typename>
inline TVMArgValue::operator T() const {
return detail::
TVMValueCast<T, TVMArgValue,
(extension_type_info<T>::code != 0),
(array_type_info<T>::code > 0)>
::Apply(this);
}
Expand All @@ -1354,19 +1293,10 @@ template<typename T, typename>
inline TVMRetValue::operator T() const {
return detail::
TVMValueCast<T, TVMRetValue,
(extension_type_info<T>::code != 0),
(array_type_info<T>::code > 0)>
::Apply(this);
}

template<typename T, typename>
inline void TVMArgsSetter::operator()(size_t i, const T& value) const {
static_assert(extension_type_info<T>::code != 0,
"Need to have extesion code");
type_codes_[i] = extension_type_info<T>::code;
values_[i].v_handle = const_cast<T*>(&value);
}

// PackedFunc support
inline TVMRetValue& TVMRetValue::operator=(const DataType& t) {
return this->operator=(t.operator DLDataType());
Expand All @@ -1385,28 +1315,6 @@ inline void TVMArgsSetter::operator()(
this->operator()(i, t.operator DLDataType());
}

// extension type handling
template<typename T>
struct ExtTypeInfo {
static void destroy(void* handle) {
delete static_cast<T*>(handle);
}
static void* clone(void* handle) {
return new T(*static_cast<T*>(handle));
}
};

template<typename T>
inline ExtTypeVTable* ExtTypeVTable::Register_() {
const int code = extension_type_info<T>::code;
static_assert(code != 0,
"require extension_type_info traits to be declared with non-zero code");
ExtTypeVTable vt;
vt.clone = ExtTypeInfo<T>::clone;
vt.destroy = ExtTypeInfo<T>::destroy;
return ExtTypeVTable::RegisterInternal(code, vt);
}

inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
return (*this)->GetFunction(name, query_imports);
}
Expand Down
9 changes: 0 additions & 9 deletions include/tvm/runtime/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,15 +311,6 @@ class Registry {
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::Registry::Register(OpName)

/*!
* \brief Macro to register extension type.
* This must be registered in a cc file
* after the trait extension_type_info is defined.
*/
#define TVM_REGISTER_EXT_TYPE(T) \
TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::ExtTypeVTable::Register_<T>()

} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_REGISTRY_H_
14 changes: 0 additions & 14 deletions python/tvm/_ffi/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,20 +299,6 @@ def copyto(self, target):
raise ValueError("Unsupported target type %s" % str(type(target)))


def free_extension_handle(handle, type_code):
"""Free c++ extension type handle

Parameters
----------
handle : ctypes.c_void_p
The handle to the extension type.

type_code : int
The tyoe code
"""
check_call(_LIB.TVMExtTypeFree(handle, ctypes.c_int(type_code)))


def register_extension(cls, fcreate=None):
"""Register a extension class to TVM.

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray
from ._ffi.ndarray import register_extension, free_extension_handle
from ._ffi.ndarray import register_extension

class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
Expand Down
Loading