Skip to content

Commit

Permalink
Allow nest parsing (apache#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jul 4, 2022
1 parent 93d6750 commit 6ca54ed
Show file tree
Hide file tree
Showing 16 changed files with 238 additions and 56 deletions.
2 changes: 1 addition & 1 deletion python/tvm/script/builder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
# pylint: disable=unused-import
"""Namespace for the TVMScript Builder API."""
from .builder import Builder, def_, def_many
from .frame import Frame, IRModuleFrame
from .frame import Frame
8 changes: 0 additions & 8 deletions python/tvm/script/builder/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,3 @@ def __enter__(self) -> "Frame":

def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
_ffi_api.FrameExit(self) # pylint: disable=no-member # type: ignore


@_register_object("script.builder.IRModuleFrame")
class IRModuleFrame(Frame):
def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.IRModuleFrame # pylint: disable=no-member # type: ignore
)
19 changes: 19 additions & 0 deletions python/tvm/script/builder/ir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVMScript IR"""

from .ir import IRModuleFrame, ir_module
20 changes: 20 additions & 0 deletions python/tvm/script/builder/ir/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.script.builder.ir"""
import tvm._ffi

tvm._ffi._init_api("script.builder.ir", __name__) # pylint: disable=protected-access
45 changes: 45 additions & 0 deletions python/tvm/script/builder/ir/ir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVMScript IR"""

import inspect
from typing import Optional, Type, Union

from tvm._ffi import register_object as _register_object
from tvm.ir import IRModule

from ..frame import Frame
from . import _ffi_api


@_register_object("script.builder.ir.IRModuleFrame")
class IRModuleFrame(Frame):
...


def ir_module(f: Optional[Type] = None) -> Union[IRModuleFrame, IRModule]:
if f is not None:
from tvm.script.parse import parse # pylint: disable=import-outside-toplevel

if not inspect.isclass(f):
raise TypeError(f"Expect a class, but got: {f}")

return parse(f)
return _ffi_api.IRModule() # pylint: disable=no-member # type: ignore


setattr(ir_module, "dispatch_token", "ir")
22 changes: 19 additions & 3 deletions python/tvm/script/builder/tir/prim_func_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR Prim Func Frame"""
import inspect
from typing import Any, Callable, Dict, Optional, Union

from tvm._ffi import register_object as _register_object
from tvm.ir import Type
from tvm.tir.buffer import Buffer
from tvm.tir import Buffer, PrimFunc
from tvm.tir.expr import Var

from . import _ffi_api
Expand All @@ -31,12 +32,27 @@ class PrimFuncFrame(TIRFrame):
...


def prim_func(f: Optional[Callable] = None) -> PrimFuncFrame:
def _is_defined_in_class(frames):
if len(frames) > 2:
maybe_class_frame = frames[2]
statement_list = maybe_class_frame[4]
first_statement = statement_list[0]
if first_statement.strip().startswith("class "):
return True
return False


def prim_func(f: Optional[Callable] = None) -> Union[PrimFuncFrame, PrimFunc, Callable]:
if f is not None:
from tvm.script.parse import parse # pylint: disable=import-outside-toplevel

if not inspect.isfunction(f):
raise TypeError(f"Expect a function, but got: {f}")

if _is_defined_in_class(inspect.stack()):
return f
return parse(f)
return _ffi_api.PrimFuncFrame() # pylint: disable=no-member # type: ignore
return _ffi_api.PrimFunc() # pylint: disable=no-member # type: ignore


setattr(prim_func, "dispatch_token", "tir")
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/script/builder/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR For Frame"""
import numpy as np
"""TVMScript TIR statements"""
from typing import List, Union

import numpy as np
from tvm._ffi import register_object as _register_object
from tvm.tir import Buffer, IterVar, PrimExpr, Var, BufferRegion, Stmt, StringImm
from tvm.ir import Type, Range
from tvm.runtime import ndarray as nd, Object
from tvm.runtime import Object
from tvm.runtime import ndarray as nd
from tvm.tir import Buffer, BufferRegion, IterVar, PrimExpr, StringImm, Var

from . import _ffi_api
from .. import _ffi_api as _base_ffi_api
from . import _ffi_api
from .base import TIRFrame


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/parse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
# specific language governing permissions and limitations
# under the Licens.
"""The parser"""
from . import dispatch, doc, parser, tir
from . import dispatch, doc, parser, tir, ir
from .entry import parse
2 changes: 1 addition & 1 deletion python/tvm/script/parse/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""The entry point of TVM parser."""
import inspect
from typing import Any, Dict, Optional, Union
from typing import Any, Union

from ..builder import Builder
from . import doc
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/script/parse/ir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from . import ir
28 changes: 28 additions & 0 deletions python/tvm/script/parse/ir/ir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from ...builder import Frame
from ...builder import ir as I
from .. import dispatch, doc
from ..parser import Parser


@dispatch.register(token="ir", type_name="ClassDef")
def visit_class_def(self: Parser, node: doc.ClassDef) -> None:
with self.var_table.with_frame():
with I.ir_module():
with self.with_dispatch_token("ir"):
self.visit_body(node.body)
40 changes: 26 additions & 14 deletions python/tvm/script/parse/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,6 @@ def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod:
return lambda self, node: self.generic_visit(node)


def _handle_function(self: "Parser", node: doc.FunctionDef) -> None:
if not node.decorator_list:
self.report_error(node, "Function must be decorated")
# TODO: only the last decorator is parsed
decorator = self.eval_expr(node.decorator_list[-1])
if hasattr(decorator, "dispatch_token"):
token = decorator.dispatch_token
func = dispatch.get(token=token, type_name="FunctionDef", default=None)
if func is not None:
func(self, node)
return
self.report_error(node, "The parser does not understand the decorator")


class Parser(doc.NodeVisitor):
"""The TVMScript parser"""

Expand Down Expand Up @@ -91,6 +77,9 @@ def report_error(self, node: doc.AST, msg: str) -> None: # pylint: disable=no-s
def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name
_handle_function(self, node)

def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name
_handle_class(self, node)

def visit_body(self, node: List[doc.stmt]) -> Any:
for stmt in node:
self.visit(stmt)
Expand All @@ -106,3 +95,26 @@ def visit_With(self, node: doc.With) -> Any: # pylint: disable=invalid-name

def visit_Assign(self, node: doc.Assign) -> Any: # pylint: disable=invalid-name
_dispatch(self, "Assign")(self, node)


def _handle_function(self: Parser, node: doc.FunctionDef) -> None:
if not node.decorator_list:
self.report_error(node, "Function must be decorated")
# TODO: only the last decorator is parsed
decorator = self.eval_expr(node.decorator_list[-1])
if hasattr(decorator, "dispatch_token"):
token = decorator.dispatch_token
func = dispatch.get(token=token, type_name="FunctionDef", default=None)
if func is not None:
func(self, node)
return
self.report_error(node, "The parser does not understand the decorator")


def _handle_class(self: Parser, node: doc.ClassDef) -> None:
# TODO: assume IRModule
func = dispatch.get(token="ir", type_name="ClassDef", default=None)
if func is not None:
func(self, node)
return
self.report_error(node, "The parser does not understand the decorator")
5 changes: 3 additions & 2 deletions src/script/builder/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ namespace script {
namespace builder {
namespace ir {

IRModuleFrame::IRModuleFrame() {
IRModuleFrame IRModule() {
ObjectPtr<IRModuleFrameNode> n = make_object<IRModuleFrameNode>();
n->global_vars.clear();
n->functions.clear();
data_ = std::move(n);
return IRModuleFrame(n);
}

void IRModuleFrameNode::ExitWithScope() {
Expand All @@ -45,6 +45,7 @@ void IRModuleFrameNode::ExitWithScope() {
}

TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);
TVM_REGISTER_GLOBAL("script.builder.ir.IRModule").set_body_typed(IRModule);

} // namespace ir
} // namespace builder
Expand Down
3 changes: 1 addition & 2 deletions src/script/builder/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,10 @@ class IRModuleFrameNode : public FrameNode {

class IRModuleFrame : public Frame {
public:
IRModuleFrame();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, Frame, IRModuleFrameNode);
};

IRModuleFrame ir_module();
IRModuleFrame IRModule();

} // namespace ir
} // namespace builder
Expand Down
2 changes: 1 addition & 1 deletion src/script/builder/tir/prim_func_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ void PreflattenedBuffer(tvm::tir::Buffer postflattened_buffer, Array<PrimExpr> s
};

TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
TVM_REGISTER_GLOBAL("script.builder.tir.PrimFuncFrame").set_body_typed(PrimFunc);
TVM_REGISTER_GLOBAL("script.builder.tir.PrimFunc").set_body_typed(PrimFunc);
TVM_REGISTER_GLOBAL("script.builder.tir.Arg")
.set_body_typed([](String name, ObjectRef obj) -> ObjectRef {
using namespace tvm::tir;
Expand Down
Loading

0 comments on commit 6ca54ed

Please sign in to comment.