From 7c42f6bf02c2583f73eccefb351f17b61bed1dfb Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Tue, 16 Jul 2024 06:02:05 -0400 Subject: [PATCH] [INTERPRETER] Refactor function rewriter (#4325) 1. Use the builtin `ast.increment_lineno` function to make it more robust 2. Clean up function rewrite logic 3. Resolve global variable reference issues 4. Enable line info tests --- .github/workflows/integration-tests.yml | 3 +- .github/workflows/integration-tests.yml.in | 3 +- python/test/unit/language/test_line_info.py | 52 +++++--- python/triton/runtime/interpreter.py | 125 ++++++++++++-------- 4 files changed, 119 insertions(+), 64 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 66fef041464a..6bb0351982bd 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -256,7 +256,8 @@ jobs: run: | cd python/test/unit python3 -m pytest -s -n 16 -m interpreter language/test_core.py language/test_standard.py \ - language/test_random.py language/test_block_pointer.py language/test_subprocess.py runtime/test_autotuner::test_kwargs[False]\ + language/test_random.py language/test_block_pointer.py language/test_subprocess.py language/test_line_info.py \ + runtime/test_autotuner::test_kwargs[False]\ ../../tutorials/06-fused-attention.py::test_op --device cpu - name: Run C++ unittests run: | diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index 5dfea720992c..1bead68e7933 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -292,7 +292,8 @@ jobs: run: | cd python/test/unit python3 -m pytest -s -n 16 -m interpreter language/test_core.py language/test_standard.py \ - language/test_random.py language/test_block_pointer.py language/test_subprocess.py runtime/test_autotuner::test_kwargs[False]\ + language/test_random.py language/test_block_pointer.py language/test_subprocess.py language/test_line_info.py \ + runtime/test_autotuner::test_kwargs[False]\ ../../tutorials/06-fused-attention.py::test_op --device cpu - &run-cpp-unittests-step diff --git a/python/test/unit/language/test_line_info.py b/python/test/unit/language/test_line_info.py index b00e10d4b83a..ebd1a7235f36 100644 --- a/python/test/unit/language/test_line_info.py +++ b/python/test/unit/language/test_line_info.py @@ -67,6 +67,14 @@ def kernel_dot_combine(x): tl.device_print("", d) +# Call another jit function (cdiv) not in this file +@triton.jit +def kernel_cdiv(x): + c = tl.full((32, 32), 4, dtype=tl.int8) + d = tl.cdiv(c, 4) + tl.device_print("", d) + + def get_disassembler_command_and_debug_line_format(): """Gets backend specific disassembler information. @@ -125,11 +133,19 @@ def check_file_lines(file_lines, file_name, lineno, should_contain=True): return not should_contain -func_types = ["single", "call", "call_noinline", "autotune", "dot_combine"] +func_types = ["single", "call", "call_noinline", "autotune", "dot_combine", "cdiv"] + + +def is_interpreter(): + import os + return os.environ.get('TRITON_INTERPRET', '0') == '1' @pytest.mark.parametrize("func", func_types) def test_line_info(func: str): + if is_interpreter(): + pytest.skip("interpreter does not support warmup compilation") + try: obj_kind, command, anchor, separator = get_disassembler_command_and_debug_line_format() except BaseException: @@ -147,6 +163,8 @@ def test_line_info(func: str): kernel_info = kernel_autotune.warmup(torch.float32, torch.float32, SIZE=shape[0], grid=(1,))[0] elif func == "dot_combine": kernel_info = kernel_dot_combine.warmup(20, grid=(1,)) + elif func == "cdiv": + kernel_info = kernel_cdiv.warmup(20, grid=(1,)) file_lines = extract_file_lines(command, anchor, separator, kernel_info.asm[obj_kind]) if func == "single": @@ -168,11 +186,8 @@ def test_line_info(func: str): elif func == "dot_combine": assert (check_file_lines(file_lines, "test_line_info.py", 65)) assert (check_file_lines(file_lines, "test_line_info.py", 66, should_contain=False)) - - -def is_interpreter(): - import os - return os.environ.get('TRITON_INTERPRET', '0') == '1' + elif func == "cdiv": + assert (check_file_lines(file_lines, "test_line_info.py", 75)) @pytest.mark.interpreter @@ -182,21 +197,30 @@ def test_line_info_interpreter(func: str): pytest.skip("interpreter is not enabled") kernel = None - expected_offset = 0 + expected_def_lineno = 0 if func == "single": kernel = kernel_single - expected_offset = 12 + expected_def_lineno = 12 elif func == "call": kernel = kernel_call - expected_offset = 25 + expected_def_lineno = 25 elif func == "call_noinline": kernel = kernel_call_noinline - expected_offset = 41 + expected_def_lineno = 41 elif func == "autotune": kernel = kernel_autotune.fn - expected_offset = 52 + expected_def_lineno = 52 elif func == "dot_combine": kernel = kernel_dot_combine - expected_offset = 62 - kernel._rewrite_ast() - assert kernel.ast_transformer.offset == expected_offset + expected_def_lineno = 62 + elif func == "cdiv": + kernel = kernel_cdiv + expected_def_lineno = 72 + kernel.rewrite() + assert kernel.rewriter.def_file_lineno == expected_def_lineno + if func == "autotune": + assert kernel.rewriter.last_decorator_lineno == 7 + assert kernel.rewriter.def_lineno == 8 + else: + assert kernel.rewriter.last_decorator_lineno == 1 + assert kernel.rewriter.def_lineno == 2 diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index b59352d02132..192248503952 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -990,7 +990,7 @@ def _set_attr(input, values, name): def _patch_lang(fn): lang = [value for _, value in fn.__globals__.items() if value in [tl, tl.core]] - assert len(lang) == 1, "triton.language must be visible from within jit'd function" + assert len(lang) >= 1, "triton.language must be visible from within jit'd function" _patch_builtin(lang[0], interpreter_builder) _patch_builtin(lang[0].tensor, interpreter_builder) if lang[0] == tl: @@ -1099,9 +1099,6 @@ def __call__(self, *args_dev, **kwargs): class ASTTransformer(ast.NodeTransformer): - def __init__(self) -> None: - self.offset = 0 - def visit_Assign(self, node): names = [] for target in node.targets: @@ -1119,82 +1116,114 @@ def visit_Assign(self, node): ast.Constant(value=False)], keywords=[]) return node - def generic_visit(self, node): - # Adjust the begin line number of the node - if hasattr(node, 'lineno') and node.lineno is not None: - node.lineno += self.offset - if hasattr(node, 'end_lineno') and node.end_lineno is not None: - node.end_lineno += self.offset - return super().generic_visit(node) - -class InterpretedFunction: - rewritted_fn = {} +class FunctionRewriter: ast_transformer = ASTTransformer() - def __init__(self, fn, **kwargs) -> None: + def __init__(self, fn, **kwargs): self.fn = fn - - def run(*args, **kwargs): - grid = kwargs["grid"] - fn = self._rewrite_ast() - return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs) - - self.run = run self.kwargs = kwargs - signature = inspect.signature(fn) - self.arg_names = [v.name for v in signature.parameters.values()] - - def _rewrite_ast(self): - if self.fn in self.rewritted_fn: - return self.rewritted_fn[self.fn] + self.filename: str = "" + # Absolute line number in the file + self.def_file_lineno: int = 0 + # Relative line numbers from the beginning of the function + self.last_decorator_lineno: int = 0 + self.def_lineno: int = 0 + + def rewrite_ast(self): # If exception is raise, it means the function does not have source code available, # e.g., dynamically generated functions, we cannot rewrite it so just return the original function try: - lines, lineno = inspect.getsourcelines(self.fn) + lines, _ = inspect.getsourcelines(self.fn) except Exception: - self.rewritted_fn[self.fn] = self.fn return self.fn - from .jit import get_jit_fn_file_line, JITFunction - filename, lineno = get_jit_fn_file_line(JITFunction(self.fn)) + # truncate lines before @triton.jit, which is the last decorator # @triton.autotune(...) # ... # @triton.jit <- this line is the last decorator, which must be a triton.jit # - # def foo(...): - last_decorator_line = 0 + # def foo(...): <- this line is the function definition + self.filename, self.def_file_lineno = self._get_jit_fn_file_line() + self.last_decorator_lineno, self.def_lineno = self._find_decorator_and_def(lines) + src = self._prepare_source(lines) + transformed_ast = self._transform_ast(src) + return self._compile_and_exec(transformed_ast) + + def _get_jit_fn_file_line(self): + from .jit import get_jit_fn_file_line, JITFunction + return get_jit_fn_file_line(JITFunction(self.fn)) + + def _find_decorator_and_def(self, lines): + last_decorator_lineno = 0 + def_lineno = 0 + # Line numbers start from 1 for i, line in enumerate(lines): if line.strip().startswith("@"): - last_decorator_line = i - lines = lines[last_decorator_line:] + last_decorator_lineno = i + 1 + if line.strip().startswith("def "): + def_lineno = i + 1 + return last_decorator_lineno, def_lineno + + def _prepare_source(self, lines): + lines = lines[self.last_decorator_lineno - 1:] src = ''.join(lines) - src = textwrap.dedent(src) + return textwrap.dedent(src) + + def _transform_ast(self, src): parsed_ast = ast.parse(src) - self.ast_transformer.offset = lineno transformed_ast = self.ast_transformer.visit(parsed_ast) - transformed_ast = ast.fix_missing_locations(transformed_ast) - compiled_code = compile(transformed_ast, filename=filename, mode='exec') + ast.fix_missing_locations(transformed_ast) + # Default line numbers start from 1, so the difference should -1 + inc_lineno = (self.def_file_lineno - 1) - (self.def_lineno - self.last_decorator_lineno) + ast.increment_lineno(transformed_ast, inc_lineno) + return transformed_ast + + def _compile_and_exec(self, transformed_ast): + compiled_code = compile(transformed_ast, filename=self.filename, mode='exec') local_namespace = {**self.kwargs} - if self.fn.__name__ in local_namespace: - raise ValueError(f"Function name {self.fn.__name__} is reserved") - exec(compiled_code, globals(), local_namespace) - fn = local_namespace[self.fn.__name__].fn - self.rewritted_fn[self.fn] = fn - return fn + # Overwrite globals using the current global namespace + fn_globals = self.fn.__globals__ + for key, value in globals().items(): + fn_globals[key] = value + exec(compiled_code, fn_globals, local_namespace) + return local_namespace[self.fn.__name__].fn + + +class InterpretedFunction: + # Cache all rewritten functions + rewritten_fn = {} + + def __init__(self, fn, **kwargs) -> None: + self.fn = fn + self.rewriter = FunctionRewriter(fn, **kwargs) + + def run(*args, **kwargs): + grid = kwargs["grid"] + fn = self.rewrite() + return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs) + + self.run = run + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + + def rewrite(self): + if self.fn not in self.rewritten_fn: + self.rewritten_fn[self.fn] = self.rewriter.rewrite_ast() + return self.rewritten_fn[self.fn] @property def __name__(self): return self.fn.__name__ def __getitem__(self, grid): - fn = self._rewrite_ast() + fn = self.rewrite() return GridExecutor(fn, self.arg_names, grid) def __call__(self, *args, **kwargs): # This is a device function call _patch_lang(self.fn) - fn = self._rewrite_ast() + fn = self.rewrite() try: return fn(*args, **kwargs) except Exception as e: