Skip to content

Commit

Permalink
[Parser][Printer] update parser and printer for match_shape (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
altanh authored and junrushao committed Feb 5, 2023
1 parent c3aa9c2 commit 3edf15c
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 64 deletions.
101 changes: 72 additions & 29 deletions python/tvm/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,26 @@ def transform_function(self, func: ast.Function, is_global: bool = False) -> rx.
params, new_body, ret_type, name=func_name, span=self.to_tvm_span(func.span)
)

def is_match_shape(self, stmt: ast.Stmt) -> bool:
"""Returns whether or not the given statement is a MatchShape binding.
Parameters
----------
stmt : ast.Stmt
The statement to be parsed.
Returns
-------
bool
Whether or not the statement is a MatchShape binding.
"""
call_op = None
if isinstance(stmt, ast.UnassignedCall):
call_op = self.transform_expr(stmt.call.func_name)
elif isinstance(stmt, ast.Assign) and isinstance(stmt.rhs, ast.Call):
call_op = self.transform_expr(stmt.rhs.func_name)
return call_op == SpecialOp.MATCH_SHAPE

def parse_binding(self, stmt: ast.Stmt, is_dataflow: bool = False) -> rx.Binding:
"""Parses the input synr statement to the corresponding Relax binding.
Expand All @@ -495,42 +515,62 @@ def parse_binding(self, stmt: ast.Stmt, is_dataflow: bool = False) -> rx.Binding
The parsed Relax binding
"""
assert isinstance(stmt, (ast.Assign, ast.UnassignedCall))
if isinstance(stmt, ast.Assign):
return self.parse_var_binding(stmt, is_dataflow=is_dataflow)
if self.is_match_shape(stmt):
return self.parse_shape_binding(stmt, is_dataflow=is_dataflow)
else:
return self.parse_shape_binding(stmt)
assert isinstance(stmt, ast.Assign)
return self.parse_var_binding(stmt, is_dataflow=is_dataflow)

def parse_shape_binding(self, stmt: ast.UnassignedCall) -> rx.MatchShape:
def parse_shape_binding(self, stmt: ast.Stmt, is_dataflow: bool = False) -> rx.MatchShape:
"""Parses the input synr statement to a Relax shape binding.
Parameters
----------
stmt : ast.UnassignedCall
stmt : ast.Stmt
The input synr statement
is_dataflow : bool, optional
Whether or not the bound variable (if any) is a dataflow variable, by default False
Returns
-------
rx.MatchShape
The parsed Relax shape binding
"""
call: synr.ast.Call = stmt.call
var: ast.Var = None
call: ast.Call = None

if isinstance(stmt, ast.UnassignedCall):
# case where only dimension variables are bound, e.g. `match_shape(x.shape, (n, m))`
call = stmt.call
else:
# case where the statement also binds a Relax variable to the value being matched
assert isinstance(stmt, ast.Assign)
if not isinstance(stmt.lhs, ast.Var):
self.report_error(
"the left hand side of a binding must be a variable", stmt.lhs.span
)
var = stmt.lhs
call = stmt.rhs

op = self.transform_expr(call.func_name)
if op != SpecialOp.MATCH_SHAPE:
self.report_error("the results of calls must be bound or used", stmt.span)
if len(stmt.call.params) != 2:
self.report_error(op.value + " takes exactly two arguments", stmt.span)

lhs = stmt.call.params[0]
rhs = stmt.call.params[1]
assert op == SpecialOp.MATCH_SHAPE
if len(call.params) != 2:
self.report_error(op.value + " takes exactly two arguments", call.span)

rhs_expr = self.transform_expr(rhs)
if not isinstance(lhs, ast.Tuple):
self.report_error(
"the pattern (lhs) of " + op.value + " must be a tuple",
lhs.span,
)
lhs_expr = self.parse_shape(lhs, bind_free_vars=True)
return rx.MatchShape(lhs_expr, rhs_expr, self.to_tvm_span(stmt.span))
value, pattern = call.params

value = self.transform_expr(value)
if not isinstance(pattern, ast.Tuple):
self.report_error(f"the pattern of a {op.value} call must be a tuple", pattern.span)
pattern = self.parse_shape(pattern, bind_free_vars=True)

if var is not None:
# TODO(@altanh): keep or discard annotation?
ty, shape = self.transform_type(stmt.ty, bind_free_vars=False)
var = self.decl_var(var.id.name, ty, shape, var.span, is_dataflow=is_dataflow)

return rx.MatchShape(value, pattern, var, self.to_tvm_span(stmt.span))

def parse_var_binding(self, stmt: ast.Assign, is_dataflow=False) -> rx.VarBinding:
"""Parses the input synr assignment to a Relax variable binding.
Expand All @@ -540,12 +580,12 @@ def parse_var_binding(self, stmt: ast.Assign, is_dataflow=False) -> rx.VarBindin
stmt : ast.Assign
The input synr assignment
is_dataflow : bool, optional
Whether or not the binding is in a dataflow block, by default False
Whether or not the bound variable is a dataflow variable, by default False
Returns
-------
rx.VarBinding
The prased Relax variable binding
The parsed Relax variable binding
"""
if not isinstance(stmt.lhs, ast.Var):
self.report_error(
Expand Down Expand Up @@ -644,8 +684,10 @@ def transform_stmt(self, stmt: ast.Stmt) -> Union[rx.Expr, rx.Binding, rx.Datafl
return self.transform_expr(stmt.value)

elif isinstance(stmt, ast.UnassignedCall):
# FIXME: when we add ref support, ref_write can be unassigned
return self.parse_shape_binding(stmt)
if self.transform_expr(stmt.call.func_name) == SpecialOp.MATCH_SHAPE:
return self.parse_shape_binding(stmt)
else:
self.report_error("the results of normal function calls must be bound", stmt.span)

elif isinstance(stmt, ast.With):
if not isinstance(stmt.rhs, ast.Call):
Expand Down Expand Up @@ -727,19 +769,20 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock:
"only bindings are supported in dataflow blocks",
binding_stmt.span,
)
is_match_shape = isinstance(binding_stmt, ast.UnassignedCall)
is_dataflow = not is_match_shape and (
binding_stmt.lhs.id.name not in output_var_names
is_match_shape = self.is_match_shape(binding_stmt)
is_dataflow = (
isinstance(binding_stmt, ast.Assign)
and binding_stmt.lhs.id.name not in output_var_names
)
binding = self.parse_binding(binding_stmt, is_dataflow=is_dataflow)
bindings.append(binding)
if not is_dataflow:
if is_match_shape:
for var in binding.pattern:
output_vars.append(var)
else:
if binding.var is not None:
output_vars.append(binding.var)
unbound_output_vars.pop(binding_stmt.lhs.id.name)
unbound_output_vars.pop(binding.var.name_hint)

# check that the output variables are all bound locally
for unbound_var in unbound_output_vars.values():
Expand Down
51 changes: 27 additions & 24 deletions src/relay/printer/relax_script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class RelaxScriptPrinter : public relax::IRFunctor<Doc(const ObjectRef&)>,
Doc PrintIfStmt(const relax::Var& var, const relay::If& ite);
Doc PrintFunctionDef(const Doc& name, const relax::Function& func);

Doc PrintVarAnnotation(const relax::Var& var);
Doc PrintTensorAnnotation(const relax::DynTensorType& ty, const Optional<ObjectRef>& shape);

Doc VisitType_(const relax::ShapeTypeNode* node) override;
Expand Down Expand Up @@ -238,9 +239,12 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::ShapeExprNode* op) {

Doc RelaxScriptPrinter::VisitNode_(const relax::MatchShapeNode* op) {
Doc doc;
if (op->var.defined()) {
doc << Print(op->var) << PrintVarAnnotation(op->var) << " = ";
}
doc << "relax.match_shape(";
// TODO(@altanh): maybe op->pattern should just be a ShapeExpr?
doc << Print(relax::ShapeExpr(op->pattern)) << ", " << Print(op->value);
doc << Print(op->value) << ", " << Print(relax::ShapeExpr(op->pattern));
doc << ")";
return doc;
}
Expand All @@ -260,16 +264,7 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::VarBindingNode* op) {
return tir::AsTVMScriptDoc(mod, false, prim_func_ref);
} else {
Doc doc;
doc << Print(op->var);
if (op->var->type_annotation.defined()) {
doc << ": ";
if (const relax::DynTensorTypeNode* tty =
op->var->type_annotation.as<relax::DynTensorTypeNode>()) {
doc << PrintTensorAnnotation(GetRef<DynTensorType>(tty), op->var->shape_);
} else {
doc << Print(op->var->type_annotation);
}
}
doc << Print(op->var) << PrintVarAnnotation(op->var);
doc << " = " << Print(op->value);
return doc;
}
Expand All @@ -289,10 +284,14 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::DataflowBlockNode* op) {
std::vector<Doc> return_vars;
for (const relax::Binding& binding : op->bindings) {
body << Print(binding) << Doc::NewLine();
Var var;
if (const relax::VarBindingNode* var_binding = binding.as<relax::VarBindingNode>()) {
if (!var_binding->var.as<relax::DataflowVarNode>()) {
return_vars.push_back(Print(var_binding->var));
}
var = var_binding->var;
} else if (const relax::MatchShapeNode* shape_binding = binding.as<relax::MatchShapeNode>()) {
var = shape_binding->var;
}
if (var.defined() && !var.as<relax::DataflowVarNode>()) {
return_vars.push_back(Print(var));
}
}
ICHECK(!return_vars.empty()) << "dataflow blocks should have at least one output variable";
Expand Down Expand Up @@ -444,16 +443,7 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function&
for (size_t i = 0; i < func->params.size(); ++i) {
relax::Var var = func->params[i];
Doc param;
param << Print(var);
if (var->type_annotation.defined()) {
param << ": ";
if (const relax::DynTensorTypeNode* tty =
var->type_annotation.as<relax::DynTensorTypeNode>()) {
param << PrintTensorAnnotation(GetRef<DynTensorType>(tty), var->shape_);
} else {
param << Print(var->type_annotation);
}
}
param << Print(var) << PrintVarAnnotation(var);
params.push_back(param);
}

Expand All @@ -471,6 +461,19 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function&
return doc;
}

Doc RelaxScriptPrinter::PrintVarAnnotation(const relax::Var& var) {
Doc doc;
if (var->type_annotation.defined()) {
doc << ": ";
if (const relax::DynTensorTypeNode* tty = var->type_annotation.as<relax::DynTensorTypeNode>()) {
doc << PrintTensorAnnotation(GetRef<DynTensorType>(tty), var->shape_);
} else {
doc << Print(var->type_annotation);
}
}
return doc;
}

Doc RelaxScriptPrinter::PrintTensorAnnotation(const relax::DynTensorType& ty,
const Optional<ObjectRef>& shape) {
Doc doc;
Expand Down
28 changes: 23 additions & 5 deletions tests/python/relax/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor:
def test_match_shape():
@rx.script
def foo(x: Tensor[_, "float32"]):
relax.match_shape((n, m), x.shape)
relax.match_shape(x.shape, (n, m))
y: Tensor[(n, m), "float32"] = add(x, x)
return x

Expand Down Expand Up @@ -289,13 +289,31 @@ def test_dataflow_match_shape():
@rx.script
def foo(x: Tensor[_, _]):
with relax.dataflow():
y = add(x, x)
x2: Tensor[(n, m), _] = relax.match_shape(x, (n, m))
y = add(x2, x2)
z = multiply(y, x)
relax.match_shape((n, m), z.shape)
relax.match_shape(z.shape, (n, m))
w: Tensor[(n, m), _] = subtract(z, x)
relax.output(y, w)
relax.output(y, w, x2)
t: Tensor[(n, m), _] = divide(y, w)
return t
q: Tensor[(n, m), _] = add(t, x2)
return q

f = rx_func(foo)
x = f.params[0]
df_block = f.body.blocks[0]
x2_bind = df_block.bindings[0]
z_shape_bind = df_block.bindings[3]
q_bind = f.body.blocks[1].bindings[1]

assert x2_bind.var.name_hint == "x2"
check_tensor_var(x2_bind.var, ("n", "m"), "")
check_shape(x2_bind.pattern, ("n", "m"))
assert x2_bind.value == x

check_shape(z_shape_bind.pattern, ("n", "m"))

assert q_bind.value.args[1] == x2_bind.var


@pytest.mark.xfail
Expand Down
13 changes: 7 additions & 6 deletions tests/python/relax/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@ def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor:
def test_match_shape():
@rx.script
def foo(x: Tensor[_, "float32"]):
relax.match_shape((n, m), x.shape)
relax.match_shape(x.shape, (n, m))
y: Tensor[(n, m), "float32"] = add(x, x)
return x

check_roundtrip(foo)



def test_if():
@rx.script
def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]):
Expand Down Expand Up @@ -94,13 +93,15 @@ def test_dataflow_match_shape():
@rx.script
def foo(x: Tensor[_, _]):
with relax.dataflow():
y = add(x, x)
x2: Tensor[(n, m), _] = relax.match_shape(x, (n, m))
y = add(x2, x2)
z = multiply(y, x)
relax.match_shape((n, m), z.shape)
relax.match_shape(z.shape, (n, m))
w: Tensor[(n, m), _] = subtract(z, x)
relax.output(y, w)
relax.output(y, w, x2)
t: Tensor[(n, m), _] = divide(y, w)
return t
q: Tensor[(n, m), _] = add(t, x2)
return q

check_roundtrip(foo)

Expand Down

0 comments on commit 3edf15c

Please sign in to comment.