Skip to content

Commit

Permalink
[IR] Add storage scope to PointerType (#8017)
Browse files Browse the repository at this point in the history
* Add storage scope to PointerType.

* Apply suggestions from code review

Co-authored-by: Siyuan Feng <[email protected]>
  • Loading branch information
csullivan and Hzfengsy authored May 18, 2021
1 parent 365484e commit c510c2b
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 11 deletions.
12 changes: 10 additions & 2 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,15 @@ class PointerTypeNode : public TypeNode {
* \brief The type of the element which the pointer points to.
*/
Type element_type;
/*!
* \brief The storage scope of the pointer
*/
String storage_scope;

void VisitAttrs(AttrVisitor* v) { v->Visit("element_type", &element_type); }
void VisitAttrs(AttrVisitor* v) {
v->Visit("element_type", &element_type);
v->Visit("storage_scope", &storage_scope);
}

bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const {
return equal(element_type, other->element_type);
Expand All @@ -175,8 +182,9 @@ class PointerType : public Type {
/*!
* \brief Constructor
* \param element_type The type of the element which the pointer points to.
* \param storage_scope The storage scope into which the pointer addresses
*/
TVM_DLL explicit PointerType(Type element_type);
TVM_DLL explicit PointerType(Type element_type, String storage_scope = "");

TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode);
};
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,13 @@ class PointerType(Type):
----------
element_type : tvm.ir.Type
The type of pointer's element.
storage_scope : str
The storage scope into which the pointer addresses.
"""

def __init__(self, element_type):
self.__init_handle_by_constructor__(_ffi_api.PointerType, element_type)
def __init__(self, element_type, storage_scope=""):
self.__init_handle_by_constructor__(_ffi_api.PointerType, element_type, storage_scope)


@tvm._ffi.register_object("TypeVar")
Expand Down
13 changes: 9 additions & 4 deletions src/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,26 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << node->dtype;
});

PointerType::PointerType(Type element_type) {
PointerType::PointerType(Type element_type, String storage_scope) {
ObjectPtr<PointerTypeNode> n = make_object<PointerTypeNode>();
n->element_type = std::move(element_type);
n->storage_scope = std::move(storage_scope);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(PointerTypeNode);

TVM_REGISTER_GLOBAL("ir.PointerType").set_body_typed([](Type element_type) {
return PointerType(element_type);
});
TVM_REGISTER_GLOBAL("ir.PointerType")
.set_body_typed([](Type element_type, String storage_scope = "") {
return PointerType(element_type, storage_scope);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PointerTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PointerTypeNode*>(ref.get());
if (!node->storage_scope.empty()) {
p->stream << node->storage_scope << " ";
}
p->Print(node->element_type);
p->stream << '*';
});
Expand Down
2 changes: 1 addition & 1 deletion src/ir/type_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ Type TypeMutator::VisitType_(const PointerTypeNode* op) {
if (element_type.same_as(op->element_type)) {
return GetRef<Type>(op);
} else {
return PointerType(element_type);
return PointerType(element_type, op->storage_scope);
}
}

Expand Down
6 changes: 5 additions & 1 deletion src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,11 @@ Doc TIRTextPrinter::VisitType_(const PrimTypeNode* node) {

Doc TIRTextPrinter::VisitType_(const PointerTypeNode* node) {
Doc doc;
doc << "Pointer(" << Print(node->element_type) << ")";
doc << "Pointer(";
if (!node->storage_scope.empty()) {
doc << node->storage_scope << " ";
}
doc << Print(node->element_type) << ")";
return doc;
}

Expand Down
6 changes: 5 additions & 1 deletion src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,11 @@ Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) {

Doc TVMScriptPrinter::VisitType_(const PointerTypeNode* node) {
Doc doc;
doc << "ty.Ptr[" << Print(node->element_type) << "]";
doc << "ty.Ptr[";
if (!node->storage_scope.empty()) {
doc << node->storage_scope << " ";
}
doc << Print(node->element_type) << "]";
return doc;
}

Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_tir_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,15 @@ def test_stmt_constructor():
assert x.buffer_var == buffer_var
assert x.body == nop

storage_scope = "global.texture"
buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"), storage_scope))
x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop)
assert isinstance(x, tvm.tir.Allocate)
assert x.dtype == "float32"
assert x.buffer_var == buffer_var
assert x.buffer_var.type_annotation.storage_scope == storage_scope
assert x.body == nop

x = tvm.tir.AttrStmt(buffer_var, "xyz", 1, nop)
assert isinstance(x, tvm.tir.AttrStmt)
assert x.node == buffer_var
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_tir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,17 @@ def test_vars():
assert isinstance(ptype.element_type, tvm.ir.PrimType)


def test_scoped_storage_vars():
dtype = "float"
storage_scope = "global.texture"
ptype = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
x = tvm.tir.Var("xyz", ptype)
assert x.dtype == "handle"
assert x.type_annotation == ptype
assert x.type_annotation.storage_scope == storage_scope
assert isinstance(ptype.element_type, tvm.ir.PrimType)


def test_buffer_load_store():
b = tvm.tir.decl_buffer((10,), "float32")
x = tvm.tir.BufferLoad(b, [0])
Expand Down Expand Up @@ -460,6 +471,7 @@ def test_block_blockrealize():
test_intimm_cond()
test_buffer_load_store()
test_vars()
test_scoped_storage_var()
test_prim_func()
test_cast()
test_attr()
Expand Down

0 comments on commit c510c2b

Please sign in to comment.