Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Script][TensorIR] update block syntax #9286

Merged
merged 1 commit into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/install/ubuntu_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ pip3 install \
pytest-xdist \
requests \
scipy \
synr==0.4.1 \
synr==0.5.0 \
six \
tornado
14 changes: 8 additions & 6 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,10 @@ class LinkedParam : public ObjectRef {
* def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
* A = T.match_buffer(a, (m, n), "float32")
* B = T.match_buffer(b, (m, n), "float32")
*
* with T.block([m, n], "") as [vi, vj]:
* B[vi, vj] = A[vi, vj]
* for i, j in T.grid(m, n):
* with T.block():
* vi, vj = T.axis.remap("SS", [i, j])
* B[vi, vj] = A[vi, vj]
* \endcode
*
* Then we can make it specialized with given shapes or buffers.
Expand All @@ -218,9 +219,10 @@ class LinkedParam : public ObjectRef {
* def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
* A = T.match_buffer(a, (16, 16), "float32")
* B = T.match_buffer(b, (16, 16), "float32")
*
* with T.block([16, 16], "") as [vi, vj]:
* B[vi, vj] = A[vi, vj]
* for i, j in T.grid(16, 16):
* with T.block():
* vi, vj = T.axis.remap("SS", [i, j])
* B[vi, vj] = A[vi, vj]
* \endcode
*/
PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map);
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1078,9 +1078,9 @@ class MatchBufferRegion : public ObjectRef {
* \note Block's body is parameterized by iter vars.
* \code
*
* with T.block([extent0, extent1, ...], name) as [v0, v1, ...]:
* T.bind(v0, value0)
* T.bind(v1, value1)
* with T.block(name):
* v0 = T.axis.S(domain, value0)
* v1 = T.axis.R(domain, value1)
* ...
* T.reads([buffer0[start:end, ...], ...])
* T.writes([buffer1[start:end, ...], ...])
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ TVM_DLL Pass ConvertBlocksToOpaque();
* \code
*
* for i in range(0, 16):
* with T.block([]):
* with T.block():
* B = T.alloc_buffer(16, 16)
* for j in range(0, 16):
* B[i, j] = A[i, j] + 1
Expand All @@ -404,7 +404,7 @@ TVM_DLL Pass ConvertBlocksToOpaque();
* \code
*
* for i in range(0, 16):
* with T.block([]):
* with T.block():
* B = T.alloc_buffer(1, 16)
* for j in range(0, 16):
* B[0, j] = A[i, j] + 1
Expand Down
2 changes: 1 addition & 1 deletion python/gen_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@
("sphinx_autodoc_annotation", None),
("sphinx_gallery", None),
("sphinx_rtd_theme", None),
("synr", "==0.4.1"),
("synr", "==0.5.0"),
("tensorflow", None),
("tensorflow-estimator", None),
("tflite", None),
Expand Down
29 changes: 15 additions & 14 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@

import tvm
from tvm.ir import Span
from tvm.ir.expr import Range
from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
from tvm.runtime import Object
from tvm.tir.expr import IterVar
from .tir.node import BufferSlice


Expand All @@ -41,10 +43,10 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
C = T.match_buffer(a, (16, 16), "float32")

for i, j, k in T.grid(16, 16, 16):
with T.block([16, 16, T.reduce_axis(16)], "matmul") as [vi, vj, vk]:
T.bind(vi, i)
T.bind(vj, j)
T.bind(vk, k) # iter_bindings = {vj: i, vj: j, vk: k}
with T.block("matmul"):
vi = T.axis.S(16, i)
vj = T.axis.S(16, j)
vk = T.axis.R(16, k) # iter_bindings = {vj: i, vj: j, vk: k}

T.where(True) # predicate of the block_realize

Expand Down Expand Up @@ -72,8 +74,10 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
"""List[Buffer]: list of T.alloc_buffer statements in the block signature"""
match_buffers: List[MatchBufferRegion] = []
"""List[MatchBufferRegion]: list of T.match_buffer statements in the block signature"""
iter_bindings: Mapping[Var, PrimExpr] = {}
"""Mapping[Var, PrimExpr]: map of block iter var to its values"""
iter_values: List[PrimExpr] = []
"""List[PrimExpr]: list of binding values for iter vars"""
iter_vars: List[IterVar] = []
"""List[PrimExpr]: list of iter vars in the block"""
reads: Optional[List[BufferSlice]] = None
"""Optional[List[BufferSlice]]:
list of T.reads statements in the block signature, None for not-visited"""
Expand All @@ -91,7 +95,8 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
def __init__(self):
self.alloc_buffers = []
self.match_buffers = []
self.iter_bindings = {}
self.iter_values = []
self.iter_vars = []
self.reads = None
self.writes = None
self.annotations = None
Expand All @@ -112,8 +117,8 @@ class ContextMaintainer:
"""List[List[synr.ast.Node]]: The ast nodes insides the current scope"""
block_info_stack: List[BlockInfo] = []
"""List[BlockInfo]: The block info for the current block scope"""
loop_stack: List[List[Var]] = []
"""List[List[Var]]: List of loop vars inside the current block scope"""
loop_stack: Dict[Var, Range] = {}
"""Dict[Var, Range]: The dict from loop var to its domain outside the block"""
symbols: List[Dict[str, Union[Var, Buffer]]] = []
"""List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""

Expand All @@ -137,7 +142,7 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No
# scope context
self.node_stack = []
self.block_info_stack = []
self.loop_stack = []
self.loop_stack = {}
self.symbols = []
# function context
self.func_params = []
Expand Down Expand Up @@ -183,8 +188,6 @@ def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
The synr AST nodes in new scope
"""
self.enter_scope(nodes)
# Create a new loop stack for the new block
self.loop_stack.append([])
# Create a new BlockInfo for the new block
self.block_info_stack.append(BlockInfo())

Expand All @@ -196,8 +199,6 @@ def exit_scope(self):
def exit_block_scope(self):
"""Pop the inner most block scope, the function will call `exit_scope` implicitly"""
self.exit_scope()
# Pop loop stack
self.loop_stack.pop()
# Pop block_info
self.block_info_stack.pop()

Expand Down
41 changes: 26 additions & 15 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,13 @@ def A():
"""
if len(node.assignments) == 1:
if not (
isinstance(node.assignments[0].lhs, ast.Var)
and node.assignments[0].lhs.id.name == "__tvm_meta__"
len(node.assignments[0].lhs) == 1
and isinstance(node.assignments[0].lhs[0], ast.Var)
and node.assignments[0].lhs[0].id.name == "__tvm_meta__"
):
self.report_error(
"The only top level assignments allowed are `__tvm_meta__ = ...`",
node.assignments[0].lhs.span,
node.assignments[0].span,
)
self.init_meta(
MetaUnparser().do_transform(node.assignments[0].rhs, self._diagnostic_context)
Expand Down Expand Up @@ -526,18 +527,19 @@ def transform_Assign(self, node):
return self.parse_body(node)
else:
value = self.transform(node.rhs)
if not isinstance(node.lhs, ast.Var):
if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
# This is a little confusing because it only is true when
# we have taken this branch. We might need to clarify what
# exectly is allowed in Assignments in tvmscript.
self.report_error(
"Left hand side of assignment must be an unqualified variable",
node.lhs.span,
node.span,
)
ast_var = node.lhs[0]
var = tvm.te.var(
node.lhs.id.name,
self.parse_type(node.ty, node.lhs),
span=tvm_span_from_synr(node.lhs.span),
ast_var.id.name,
self.parse_type(node.ty, ast_var),
span=tvm_span_from_synr(ast_var.span),
)
self.context.update_symbol(var.name, var, node)
body = self.parse_body(node)
Expand Down Expand Up @@ -596,7 +598,7 @@ def transform_For(self, node):
For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment)
By now 1 pattern of For is supported:
1. for scope handler
for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/tir.range()/
for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/range()/
T.grid()/T.thread_binding()
"""

Expand Down Expand Up @@ -892,9 +894,20 @@ def transform_Attr(self, node):
namespace.
"""

if isinstance(node.object, ast.Var):
if self.match_tir_namespace(node.object.id.name):
func_name = "tir." + node.field.name
def get_full_attr_name(node: ast.Attr) -> str:
reverse_field_names = [node.field.name]
while isinstance(node.object, ast.Attr):
node = node.object
reverse_field_names.append(node.field.name)
if isinstance(node.object, ast.Var):
reverse_field_names.append(node.object.id.name)
return ".".join(reversed(reverse_field_names))

if isinstance(node.object, (ast.Var, ast.Attr)):
full_attr_name = get_full_attr_name(node)
attr_object, fields = full_attr_name.split(".", maxsplit=1)
if self.match_tir_namespace(attr_object):
func_name = "tir." + fields
res = Registry.lookup(func_name)
if res is not None:
return res
Expand All @@ -903,9 +916,7 @@ def transform_Attr(self, node):
except TVMError as e:
# Check if we got an attribute error
if e.args[0].find("AttributeError"):
self.report_error(
f"Unregistered function `tir.{node.field.name}`.", node.field.span
)
self.report_error(f"Unregistered function `tir.{fields}`.", node.span)
else:
raise e

Expand Down
Loading