Skip to content

Commit

Permalink
update block syntax (apache#9286)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy authored and ylc committed Jan 7, 2022
1 parent 043f355 commit 6d78aba
Show file tree
Hide file tree
Showing 56 changed files with 2,967 additions and 2,094 deletions.
2 changes: 1 addition & 1 deletion docker/install/ubuntu_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ pip3 install \
pytest-xdist \
requests \
scipy \
synr==0.4.1 \
synr==0.5.0 \
six \
tornado
14 changes: 8 additions & 6 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,10 @@ class LinkedParam : public ObjectRef {
* def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
* A = T.match_buffer(a, (m, n), "float32")
* B = T.match_buffer(b, (m, n), "float32")
*
* with T.block([m, n], "") as [vi, vj]:
* B[vi, vj] = A[vi, vj]
* for i, j in T.grid(m, n):
* with T.block():
* vi, vj = T.axis.remap("SS", [i, j])
* B[vi, vj] = A[vi, vj]
* \endcode
*
* Then we can make it specialized with given shapes or buffers.
Expand All @@ -218,9 +219,10 @@ class LinkedParam : public ObjectRef {
* def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
* A = T.match_buffer(a, (16, 16), "float32")
* B = T.match_buffer(b, (16, 16), "float32")
*
* with T.block([16, 16], "") as [vi, vj]:
* B[vi, vj] = A[vi, vj]
* for i, j in T.grid(16, 16):
* with T.block():
* vi, vj = T.axis.remap("SS", [i, j])
* B[vi, vj] = A[vi, vj]
* \endcode
*/
PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map);
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1078,9 +1078,9 @@ class MatchBufferRegion : public ObjectRef {
* \note Block's body is parameterized by iter vars.
* \code
*
* with T.block([extent0, extent1, ...], name) as [v0, v1, ...]:
* T.bind(v0, value0)
* T.bind(v1, value1)
* with T.block(name):
* v0 = T.axis.S(domain, value0)
* v1 = T.axis.R(domain, value1)
* ...
* T.reads([buffer0[start:end, ...], ...])
* T.writes([buffer1[start:end, ...], ...])
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ TVM_DLL Pass ConvertBlocksToOpaque();
* \code
*
* for i in range(0, 16):
* with T.block([]):
* with T.block():
* B = T.alloc_buffer(16, 16)
* for j in range(0, 16):
* B[i, j] = A[i, j] + 1
Expand All @@ -404,7 +404,7 @@ TVM_DLL Pass ConvertBlocksToOpaque();
* \code
*
* for i in range(0, 16):
* with T.block([]):
* with T.block():
* B = T.alloc_buffer(1, 16)
* for j in range(0, 16):
* B[0, j] = A[i, j] + 1
Expand Down
2 changes: 1 addition & 1 deletion python/gen_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@
("sphinx_autodoc_annotation", None),
("sphinx_gallery", None),
("sphinx_rtd_theme", None),
("synr", "==0.4.1"),
("synr", "==0.5.0"),
("tensorflow", None),
("tensorflow-estimator", None),
("tflite", None),
Expand Down
29 changes: 15 additions & 14 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@

import tvm
from tvm.ir import Span
from tvm.ir.expr import Range
from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
from tvm.runtime import Object
from tvm.tir.expr import IterVar
from .tir.node import BufferSlice


Expand All @@ -41,10 +43,10 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
C = T.match_buffer(a, (16, 16), "float32")
for i, j, k in T.grid(16, 16, 16):
with T.block([16, 16, T.reduce_axis(16)], "matmul") as [vi, vj, vk]:
T.bind(vi, i)
T.bind(vj, j)
T.bind(vk, k) # iter_bindings = {vj: i, vj: j, vk: k}
with T.block("matmul"):
vi = T.axis.S(16, i)
vj = T.axis.S(16, j)
vk = T.axis.R(16, k) # iter_bindings = {vj: i, vj: j, vk: k}
T.where(True) # predicate of the block_realize
Expand Down Expand Up @@ -72,8 +74,10 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
"""List[Buffer]: list of T.alloc_buffer statements in the block signature"""
match_buffers: List[MatchBufferRegion] = []
"""List[MatchBufferRegion]: list of T.match_buffer statements in the block signature"""
iter_bindings: Mapping[Var, PrimExpr] = {}
"""Mapping[Var, PrimExpr]: map of block iter var to its values"""
iter_values: List[PrimExpr] = []
"""List[PrimExpr]: list of binding values for iter vars"""
iter_vars: List[IterVar] = []
"""List[PrimExpr]: list of iter vars in the block"""
reads: Optional[List[BufferSlice]] = None
"""Optional[List[BufferSlice]]:
list of T.reads statements in the block signature, None for not-visited"""
Expand All @@ -91,7 +95,8 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
def __init__(self):
self.alloc_buffers = []
self.match_buffers = []
self.iter_bindings = {}
self.iter_values = []
self.iter_vars = []
self.reads = None
self.writes = None
self.annotations = None
Expand All @@ -112,8 +117,8 @@ class ContextMaintainer:
"""List[List[synr.ast.Node]]: The ast nodes insides the current scope"""
block_info_stack: List[BlockInfo] = []
"""List[BlockInfo]: The block info for the current block scope"""
loop_stack: List[List[Var]] = []
"""List[List[Var]]: List of loop vars inside the current block scope"""
loop_stack: Dict[Var, Range] = {}
"""Dict[Var, Range]: The dict from loop var to its domain outside the block"""
symbols: List[Dict[str, Union[Var, Buffer]]] = []
"""List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""

Expand All @@ -137,7 +142,7 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No
# scope context
self.node_stack = []
self.block_info_stack = []
self.loop_stack = []
self.loop_stack = {}
self.symbols = []
# function context
self.func_params = []
Expand Down Expand Up @@ -183,8 +188,6 @@ def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
The synr AST nodes in new scope
"""
self.enter_scope(nodes)
# Create a new loop stack for the new block
self.loop_stack.append([])
# Create a new BlockInfo for the new block
self.block_info_stack.append(BlockInfo())

Expand All @@ -196,8 +199,6 @@ def exit_scope(self):
def exit_block_scope(self):
"""Pop the inner most block scope, the function will call `exit_scope` implicitly"""
self.exit_scope()
# Pop loop stack
self.loop_stack.pop()
# Pop block_info
self.block_info_stack.pop()

Expand Down
41 changes: 26 additions & 15 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,13 @@ def A():
"""
if len(node.assignments) == 1:
if not (
isinstance(node.assignments[0].lhs, ast.Var)
and node.assignments[0].lhs.id.name == "__tvm_meta__"
len(node.assignments[0].lhs) == 1
and isinstance(node.assignments[0].lhs[0], ast.Var)
and node.assignments[0].lhs[0].id.name == "__tvm_meta__"
):
self.report_error(
"The only top level assignments allowed are `__tvm_meta__ = ...`",
node.assignments[0].lhs.span,
node.assignments[0].span,
)
self.init_meta(
MetaUnparser().do_transform(node.assignments[0].rhs, self._diagnostic_context)
Expand Down Expand Up @@ -526,18 +527,19 @@ def transform_Assign(self, node):
return self.parse_body(node)
else:
value = self.transform(node.rhs)
if not isinstance(node.lhs, ast.Var):
if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
# This is a little confusing because it only is true when
# we have taken this branch. We might need to clarify what
# exectly is allowed in Assignments in tvmscript.
self.report_error(
"Left hand side of assignment must be an unqualified variable",
node.lhs.span,
node.span,
)
ast_var = node.lhs[0]
var = tvm.te.var(
node.lhs.id.name,
self.parse_type(node.ty, node.lhs),
span=tvm_span_from_synr(node.lhs.span),
ast_var.id.name,
self.parse_type(node.ty, ast_var),
span=tvm_span_from_synr(ast_var.span),
)
self.context.update_symbol(var.name, var, node)
body = self.parse_body(node)
Expand Down Expand Up @@ -596,7 +598,7 @@ def transform_For(self, node):
For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment)
By now 1 pattern of For is supported:
1. for scope handler
for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/tir.range()/
for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/range()/
T.grid()/T.thread_binding()
"""

Expand Down Expand Up @@ -892,9 +894,20 @@ def transform_Attr(self, node):
namespace.
"""

if isinstance(node.object, ast.Var):
if self.match_tir_namespace(node.object.id.name):
func_name = "tir." + node.field.name
def get_full_attr_name(node: ast.Attr) -> str:
reverse_field_names = [node.field.name]
while isinstance(node.object, ast.Attr):
node = node.object
reverse_field_names.append(node.field.name)
if isinstance(node.object, ast.Var):
reverse_field_names.append(node.object.id.name)
return ".".join(reversed(reverse_field_names))

if isinstance(node.object, (ast.Var, ast.Attr)):
full_attr_name = get_full_attr_name(node)
attr_object, fields = full_attr_name.split(".", maxsplit=1)
if self.match_tir_namespace(attr_object):
func_name = "tir." + fields
res = Registry.lookup(func_name)
if res is not None:
return res
Expand All @@ -903,9 +916,7 @@ def transform_Attr(self, node):
except TVMError as e:
# Check if we got an attribute error
if e.args[0].find("AttributeError"):
self.report_error(
f"Unregistered function `tir.{node.field.name}`.", node.field.span
)
self.report_error(f"Unregistered function `tir.{fields}`.", node.span)
else:
raise e

Expand Down
Loading

0 comments on commit 6d78aba

Please sign in to comment.