Skip to content

Commit

Permalink
[INTERPRETER] Refactor function rewriter (#4325)
Browse files Browse the repository at this point in the history
1. Use the builtin `ast.increment_lineno` function to make it more
robust
2. Clean up function rewrite logic
3. Resolve global variable reference issues
4. Enable line info tests
  • Loading branch information
Jokeren authored Jul 16, 2024
1 parent 79297ec commit 7c42f6b
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 64 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ jobs:
run: |
cd python/test/unit
python3 -m pytest -s -n 16 -m interpreter language/test_core.py language/test_standard.py \
language/test_random.py language/test_block_pointer.py language/test_subprocess.py runtime/test_autotuner::test_kwargs[False]\
language/test_random.py language/test_block_pointer.py language/test_subprocess.py language/test_line_info.py \
runtime/test_autotuner::test_kwargs[False]\
../../tutorials/06-fused-attention.py::test_op --device cpu
- name: Run C++ unittests
run: |
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/integration-tests.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ jobs:
run: |
cd python/test/unit
python3 -m pytest -s -n 16 -m interpreter language/test_core.py language/test_standard.py \
language/test_random.py language/test_block_pointer.py language/test_subprocess.py runtime/test_autotuner::test_kwargs[False]\
language/test_random.py language/test_block_pointer.py language/test_subprocess.py language/test_line_info.py \
runtime/test_autotuner::test_kwargs[False]\
../../tutorials/06-fused-attention.py::test_op --device cpu

- &run-cpp-unittests-step
Expand Down
52 changes: 38 additions & 14 deletions python/test/unit/language/test_line_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def kernel_dot_combine(x):
tl.device_print("", d)


# Call another jit function (cdiv) not in this file
@triton.jit
def kernel_cdiv(x):
c = tl.full((32, 32), 4, dtype=tl.int8)
d = tl.cdiv(c, 4)
tl.device_print("", d)


def get_disassembler_command_and_debug_line_format():
"""Gets backend specific disassembler information.
Expand Down Expand Up @@ -125,11 +133,19 @@ def check_file_lines(file_lines, file_name, lineno, should_contain=True):
return not should_contain


func_types = ["single", "call", "call_noinline", "autotune", "dot_combine"]
func_types = ["single", "call", "call_noinline", "autotune", "dot_combine", "cdiv"]


def is_interpreter():
import os
return os.environ.get('TRITON_INTERPRET', '0') == '1'


@pytest.mark.parametrize("func", func_types)
def test_line_info(func: str):
if is_interpreter():
pytest.skip("interpreter does not support warmup compilation")

try:
obj_kind, command, anchor, separator = get_disassembler_command_and_debug_line_format()
except BaseException:
Expand All @@ -147,6 +163,8 @@ def test_line_info(func: str):
kernel_info = kernel_autotune.warmup(torch.float32, torch.float32, SIZE=shape[0], grid=(1,))[0]
elif func == "dot_combine":
kernel_info = kernel_dot_combine.warmup(20, grid=(1,))
elif func == "cdiv":
kernel_info = kernel_cdiv.warmup(20, grid=(1,))

file_lines = extract_file_lines(command, anchor, separator, kernel_info.asm[obj_kind])
if func == "single":
Expand All @@ -168,11 +186,8 @@ def test_line_info(func: str):
elif func == "dot_combine":
assert (check_file_lines(file_lines, "test_line_info.py", 65))
assert (check_file_lines(file_lines, "test_line_info.py", 66, should_contain=False))


def is_interpreter():
import os
return os.environ.get('TRITON_INTERPRET', '0') == '1'
elif func == "cdiv":
assert (check_file_lines(file_lines, "test_line_info.py", 75))


@pytest.mark.interpreter
Expand All @@ -182,21 +197,30 @@ def test_line_info_interpreter(func: str):
pytest.skip("interpreter is not enabled")

kernel = None
expected_offset = 0
expected_def_lineno = 0
if func == "single":
kernel = kernel_single
expected_offset = 12
expected_def_lineno = 12
elif func == "call":
kernel = kernel_call
expected_offset = 25
expected_def_lineno = 25
elif func == "call_noinline":
kernel = kernel_call_noinline
expected_offset = 41
expected_def_lineno = 41
elif func == "autotune":
kernel = kernel_autotune.fn
expected_offset = 52
expected_def_lineno = 52
elif func == "dot_combine":
kernel = kernel_dot_combine
expected_offset = 62
kernel._rewrite_ast()
assert kernel.ast_transformer.offset == expected_offset
expected_def_lineno = 62
elif func == "cdiv":
kernel = kernel_cdiv
expected_def_lineno = 72
kernel.rewrite()
assert kernel.rewriter.def_file_lineno == expected_def_lineno
if func == "autotune":
assert kernel.rewriter.last_decorator_lineno == 7
assert kernel.rewriter.def_lineno == 8
else:
assert kernel.rewriter.last_decorator_lineno == 1
assert kernel.rewriter.def_lineno == 2
125 changes: 77 additions & 48 deletions python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ def _set_attr(input, values, name):

def _patch_lang(fn):
lang = [value for _, value in fn.__globals__.items() if value in [tl, tl.core]]
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
assert len(lang) >= 1, "triton.language must be visible from within jit'd function"
_patch_builtin(lang[0], interpreter_builder)
_patch_builtin(lang[0].tensor, interpreter_builder)
if lang[0] == tl:
Expand Down Expand Up @@ -1099,9 +1099,6 @@ def __call__(self, *args_dev, **kwargs):

class ASTTransformer(ast.NodeTransformer):

def __init__(self) -> None:
self.offset = 0

def visit_Assign(self, node):
names = []
for target in node.targets:
Expand All @@ -1119,82 +1116,114 @@ def visit_Assign(self, node):
ast.Constant(value=False)], keywords=[])
return node

def generic_visit(self, node):
# Adjust the begin line number of the node
if hasattr(node, 'lineno') and node.lineno is not None:
node.lineno += self.offset
if hasattr(node, 'end_lineno') and node.end_lineno is not None:
node.end_lineno += self.offset
return super().generic_visit(node)


class InterpretedFunction:
rewritted_fn = {}
class FunctionRewriter:
ast_transformer = ASTTransformer()

def __init__(self, fn, **kwargs) -> None:
def __init__(self, fn, **kwargs):
self.fn = fn

def run(*args, **kwargs):
grid = kwargs["grid"]
fn = self._rewrite_ast()
return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs)

self.run = run
self.kwargs = kwargs
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]

def _rewrite_ast(self):
if self.fn in self.rewritted_fn:
return self.rewritted_fn[self.fn]
self.filename: str = ""
# Absolute line number in the file
self.def_file_lineno: int = 0
# Relative line numbers from the beginning of the function
self.last_decorator_lineno: int = 0
self.def_lineno: int = 0

def rewrite_ast(self):
# If exception is raise, it means the function does not have source code available,
# e.g., dynamically generated functions, we cannot rewrite it so just return the original function
try:
lines, lineno = inspect.getsourcelines(self.fn)
lines, _ = inspect.getsourcelines(self.fn)
except Exception:
self.rewritted_fn[self.fn] = self.fn
return self.fn
from .jit import get_jit_fn_file_line, JITFunction
filename, lineno = get_jit_fn_file_line(JITFunction(self.fn))

# truncate lines before @triton.jit, which is the last decorator
# @triton.autotune(...)
# ...
# @triton.jit <- this line is the last decorator, which must be a triton.jit
#
# def foo(...):
last_decorator_line = 0
# def foo(...): <- this line is the function definition
self.filename, self.def_file_lineno = self._get_jit_fn_file_line()
self.last_decorator_lineno, self.def_lineno = self._find_decorator_and_def(lines)
src = self._prepare_source(lines)
transformed_ast = self._transform_ast(src)
return self._compile_and_exec(transformed_ast)

def _get_jit_fn_file_line(self):
from .jit import get_jit_fn_file_line, JITFunction
return get_jit_fn_file_line(JITFunction(self.fn))

def _find_decorator_and_def(self, lines):
last_decorator_lineno = 0
def_lineno = 0
# Line numbers start from 1
for i, line in enumerate(lines):
if line.strip().startswith("@"):
last_decorator_line = i
lines = lines[last_decorator_line:]
last_decorator_lineno = i + 1
if line.strip().startswith("def "):
def_lineno = i + 1
return last_decorator_lineno, def_lineno

def _prepare_source(self, lines):
lines = lines[self.last_decorator_lineno - 1:]
src = ''.join(lines)
src = textwrap.dedent(src)
return textwrap.dedent(src)

def _transform_ast(self, src):
parsed_ast = ast.parse(src)
self.ast_transformer.offset = lineno
transformed_ast = self.ast_transformer.visit(parsed_ast)
transformed_ast = ast.fix_missing_locations(transformed_ast)
compiled_code = compile(transformed_ast, filename=filename, mode='exec')
ast.fix_missing_locations(transformed_ast)
# Default line numbers start from 1, so the difference should -1
inc_lineno = (self.def_file_lineno - 1) - (self.def_lineno - self.last_decorator_lineno)
ast.increment_lineno(transformed_ast, inc_lineno)
return transformed_ast

def _compile_and_exec(self, transformed_ast):
compiled_code = compile(transformed_ast, filename=self.filename, mode='exec')
local_namespace = {**self.kwargs}
if self.fn.__name__ in local_namespace:
raise ValueError(f"Function name {self.fn.__name__} is reserved")
exec(compiled_code, globals(), local_namespace)
fn = local_namespace[self.fn.__name__].fn
self.rewritted_fn[self.fn] = fn
return fn
# Overwrite globals using the current global namespace
fn_globals = self.fn.__globals__
for key, value in globals().items():
fn_globals[key] = value
exec(compiled_code, fn_globals, local_namespace)
return local_namespace[self.fn.__name__].fn


class InterpretedFunction:
# Cache all rewritten functions
rewritten_fn = {}

def __init__(self, fn, **kwargs) -> None:
self.fn = fn
self.rewriter = FunctionRewriter(fn, **kwargs)

def run(*args, **kwargs):
grid = kwargs["grid"]
fn = self.rewrite()
return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs)

self.run = run
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]

def rewrite(self):
if self.fn not in self.rewritten_fn:
self.rewritten_fn[self.fn] = self.rewriter.rewrite_ast()
return self.rewritten_fn[self.fn]

@property
def __name__(self):
return self.fn.__name__

def __getitem__(self, grid):
fn = self._rewrite_ast()
fn = self.rewrite()
return GridExecutor(fn, self.arg_names, grid)

def __call__(self, *args, **kwargs):
# This is a device function call
_patch_lang(self.fn)
fn = self._rewrite_ast()
fn = self.rewrite()
try:
return fn(*args, **kwargs)
except Exception as e:
Expand Down

0 comments on commit 7c42f6b

Please sign in to comment.