Skip to content

Commit

Permalink
[RUNTIME][OBJECT] Introduce static slots for common objects.
Browse files Browse the repository at this point in the history
The _type_child_slots can be used to enable quick type checking optimization
by checking the whether the type index is within the bound.

This PR enables these static slots:

- Introduce a static assert to avoid the scenario when a developer forget to
  _type_child_slots when the field is set for the type's parent.
- Revamp and assign static type index to common runtime objects
- Add a DumpTypeTable call to allow developer monitor the current situation
  of type table and offers suggestions for the slots(ideally the slots equals
  the number of children so there is no overflow.
  • Loading branch information
tqchen committed Apr 23, 2020
1 parent 3ab3751 commit 4f2661f
Show file tree
Hide file tree
Showing 21 changed files with 93 additions and 32 deletions.
7 changes: 5 additions & 2 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ namespace tvm {
*/
class BaseExprNode : public Object {
public:
static constexpr const char* _type_key = "Expr";
static constexpr const char* _type_key = "BaseExpr";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 58;
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};

Expand Down Expand Up @@ -88,6 +89,7 @@ class PrimExprNode : public BaseExprNode {
DataType dtype;

static constexpr const char* _type_key = "PrimExpr";
static constexpr const uint32_t _type_child_slots = 34;
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
};

Expand Down Expand Up @@ -161,7 +163,8 @@ class RelayExprNode : public BaseExprNode {
template<typename TTypeNode>
inline const TTypeNode* type_as() const;

static constexpr const char* _type_key = "relay.Expr";
static constexpr const char* _type_key = "RelayExpr";
static constexpr const uint32_t _type_child_slots = 22;
TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class BaseFuncNode : public RelayExprNode {
}

static constexpr const char* _type_key = "BaseFunc";
static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir/tensor_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace tvm {
class BaseTensorTypeNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.BaseTensorType";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
};

Expand Down
2 changes: 2 additions & 0 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class TypeNode : public Object {
static constexpr const char* _type_key = "Type";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 14;
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};

Expand Down Expand Up @@ -391,6 +392,7 @@ inline bool IsVoidType(const Type& type) {
class TypeConstraintNode : public TypeNode {
public:
static constexpr const char* _type_key = "TypeConstraint";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ class TempExprNode : public ExprNode {
static constexpr const char* _type_key = "relay.TempExpr";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
static constexpr const uint32_t _type_child_slots = 0;
TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode);
};

Expand Down
6 changes: 3 additions & 3 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ class ADTObj : public Object, public InplaceArrayBase<ADTObj, ObjectRef> {
uint32_t size;
// The fields of the structure follows directly in memory.

static constexpr const uint32_t _type_index = TypeIndex::kVMADT;
static constexpr const char* _type_key = "vm.ADT";
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeADT;
static constexpr const char* _type_key = "runtime.ADT";
TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object);

private:
Expand Down Expand Up @@ -314,7 +314,7 @@ class StringObj : public Object {
/*! \brief The length of the string object. */
uint64_t size;

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString;
static constexpr const char* _type_key = "runtime.String";
TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object);

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,10 @@ class NDArray::Container :
using Object::IncRef;

// Information for object protocol.
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeNDArray;
static constexpr const uint32_t _type_child_slots = 0;
static constexpr const uint32_t _type_child_slots_can_overflow = true;
static constexpr const char* _type_key = "NDArray";
static constexpr const char* _type_key = "runtime.NDArray";
TVM_DECLARE_BASE_OBJECT_INFO(NDArray::Container, Object);

protected:
Expand Down
43 changes: 31 additions & 12 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,31 @@
namespace tvm {
namespace runtime {

/*! \brief list of the type index. */
enum TypeIndex {
/*! \brief Root object type. */
kRoot = 0,
kClosure = 1,
kVMADT = 2,
kRuntimeModule = 3,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
};
/*!
* \brief Namespace for the list of type index.
* \note Use struct so that we have to use TypeIndex::ENumName to refer to
* the constant, but still able to use enum.
*/
struct TypeIndex {
enum {
/*! \brief Root object type. */
kRoot = 0,
// Standard static index assignments,
// Frontends can take benefit of these constants.
/*! \brief runtime::Module. */
kRuntimeModule = 1,
/*! \brief runtime::NDArray. */
kRuntimeNDArray = 2,
/*! \brief runtime::String. */
kRuntimeString = 3,
// static assignments that may subject to change.
kRuntimeClosure,
kRuntimeADT,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
};
}; // namespace TypeIndex

/*!
* \brief base class of all object containers.
Expand Down Expand Up @@ -198,7 +212,7 @@ class Object {
using RefCounterType = int32_t;
#endif

static constexpr const char* _type_key = "Object";
static constexpr const char* _type_key = "runtime.Object";

static uint32_t _GetOrAllocRuntimeTypeIndex() {
return TypeIndex::kRoot;
Expand Down Expand Up @@ -675,6 +689,10 @@ struct ObjectEqual {
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
static_assert(!ParentType::_type_final, "ParentObj maked as final"); \
static uint32_t RuntimeTypeIndex() { \
static_assert(TypeName::_type_child_slots == 0 || \
ParentType::_type_child_slots == 0 || \
TypeName::_type_child_slots < ParentType::_type_child_slots, \
"Need to set _type_child_slots when parent specifies it."); \
if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
return TypeName::_type_index; \
} \
Expand All @@ -690,6 +708,7 @@ struct ObjectEqual {
return tidx; \
} \


/*!
* \brief helper macro to declare type information in a final class.
* \param TypeName The name of the current type.
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,8 @@ struct unpack_call_dispatcher<void, 0, index, F> {

template<typename R, int nargs, typename F>
inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) {
CHECK_EQ(nargs, args.size())
<< "Expect " << nargs << " arguments but get " << args.size();
unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
}

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ namespace vm {
*/
class ClosureObj : public Object {
public:
static constexpr const uint32_t _type_index = TypeIndex::kClosure;
static constexpr const char* _type_key = "Closure";
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeClosure;
static constexpr const char* _type_key = "runtime.Closure";
TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class StmtNode : public Object {
static constexpr const char* _type_key = "Stmt";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 15;
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class VarNode : public PrimExprNode {
}

static constexpr const char* _type_key = "tir.Var";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
};

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __enter__(self):
return self

def __exit__(self, ptype, value, trace):
_quantize._ExitQConfigScope(self)
_quantize._ExitQConfigScope()

def __setattr__(self, name, value):
if name in QConfig._node_defaults:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/runtime/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def getitem_helper(obj, elem_getter, length, idx):
return elem_getter(obj, idx)


@tvm._ffi.register_object("vm.ADT")
@tvm._ffi.register_object("runtime.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from tvm._ffi._ctypes.ndarray import NDArrayBase


@tvm._ffi.register_object
@tvm._ffi.register_object("runtime.NDArray")
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
Expand Down
1 change: 1 addition & 0 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class CanonicalExprNode : public PrimExprNode {
}

static constexpr const char* _type_key = "arith.CanonicalExpr";
static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode);
};

Expand Down
2 changes: 2 additions & 0 deletions src/runtime/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,5 +188,7 @@ TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
.set_body_typed([](Module mod, std::string name, std::string fmt) {
mod->SaveToFile(name, fmt);
});

TVM_REGISTER_OBJECT_TYPE(ModuleNode);
} // namespace runtime
} // namespace tvm
34 changes: 30 additions & 4 deletions src/runtime/object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ class TypeContext {
return it->second;
}
// try to allocate from parent's type table.
CHECK_LT(parent_tindex, type_table_.size());
CHECK_LT(parent_tindex, type_table_.size())
<< " skey= " << skey << "static_index=" << static_tindex;
TypeInfo& pinfo = type_table_[parent_tindex];
CHECK_EQ(pinfo.index, parent_tindex);

Expand All @@ -108,7 +109,7 @@ class TypeContext {
<< " between " << type_table_[allocated_tindex].name
<< " and "
<< skey;
} else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) {
} else if (pinfo.allocated_slots + num_slots <= pinfo.num_slots) {
// allocate the slot from parent's reserved pool
allocated_tindex = parent_tindex + pinfo.allocated_slots;
// update parent's state
Expand All @@ -119,8 +120,8 @@ class TypeContext {
// allocate new entries.
allocated_tindex = type_counter_;
type_counter_ += num_slots;
CHECK_LE(type_table_.size(), allocated_tindex);
type_table_.resize(allocated_tindex + 1, TypeInfo());
CHECK_LE(type_table_.size(), type_counter_);
type_table_.resize(type_counter_, TypeInfo());
}
CHECK_GT(allocated_tindex, parent_tindex);
// initialize the slot.
Expand Down Expand Up @@ -161,6 +162,25 @@ class TypeContext {
return it->second;
}

void Dump(int min_children_count) {
std::vector<int> num_children(type_table_.size(), 0);
// reverse accumulation so we can get total counts in a bottom-up manner.
for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
if (it->index != 0) {
num_children[it->parent_index] += num_children[it->index] + 1;
}
}

for (const auto& info : type_table_) {
if (info.index != 0 && num_children[info.index] >= min_children_count) {
std::cerr <<'[' << info.index << "] "<< info.name
<< "\tparent=" << type_table_[info.parent_index].name
<< "\tnum_child_slots=" << info.num_slots - 1
<< "\tnum_children=" << num_children[info.index] << std::endl;
}
}
}

static TypeContext* Global() {
static TypeContext inst;
return &inst;
Expand All @@ -169,6 +189,7 @@ class TypeContext {
private:
TypeContext() {
type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo());
type_table_[0].name = "runtime.Object";
}
// mutex to avoid registration from multiple threads.
std::mutex mutex_;
Expand Down Expand Up @@ -208,6 +229,11 @@ TVM_REGISTER_GLOBAL("runtime.ObjectHash")
.set_body_typed([](ObjectRef obj) {
return static_cast<int64_t>(ObjectHash()(obj));
});

TVM_REGISTER_GLOBAL("runtime.DumpTypeTable")
.set_body_typed([](int min_child_count) {
TypeContext::Global()->Dump(min_child_count);
});
} // namespace runtime
} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion src/target/stackvm/codegen_stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ void CodeGenStackVM::VisitExpr_(const LetNode* op) {
this->Push(op->body);
}

runtime::Module BuildStackVM(const IRModule& mod) {
runtime::Module BuildStackVM(const IRModule& mod, const std::string& target) {
std::unordered_map<std::string, StackVM> fmap;
std::string entry_func;

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/object_protocol_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class ObjA : public ObjBase {
class ObjB : public ObjBase {
public:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_child_slots = 0;
static constexpr const char* _type_key = "test.ObjB";
TVM_DECLARE_BASE_OBJECT_INFO(ObjB, ObjBase);
};
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_te_schedule_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_attach_path():

def test_fix_pt():
body = tvm.te.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op, body)
fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op)
assert(fxpt[s_scan.spatial_axis_[0]].value != 0)

def test_scan_fix_point():
Expand All @@ -57,7 +57,7 @@ def test_scan0():
lambda t, i, j: x[t, j, i] + s_state[t-1, i, j], name="update")
s_scan = tvm.te.scan(s_init, s_update, s_state)
body = tvm.te.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op, body)
fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 1)

Expand All @@ -66,7 +66,7 @@ def test_scan1():
lambda t, i, j: x[t, j, i] + s_state[t-1, j, i], name="update")
s_scan = tvm.te.scan(s_init, s_update, s_state)
body = tvm.te.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op, body)
fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op)
assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)

Expand Down

0 comments on commit 4f2661f

Please sign in to comment.