Skip to content

Commit

Permalink
[dynamo 3.11] changes to LOAD_GLOBAL and function calls (pytorch#94098)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#94098
Approved by: https://github.com/albanD
  • Loading branch information
williamwen42 authored and pytorchmergebot committed Feb 21, 2023
1 parent da98053 commit 055a9e4
Show file tree
Hide file tree
Showing 12 changed files with 341 additions and 117 deletions.
74 changes: 72 additions & 2 deletions torch/_dynamo/bytecode_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,23 @@ def create_jump_absolute(target):
return create_instruction(inst, target=target)


def create_load_global(name, arg, push_null):
"""
`name` is the name of the global to be loaded.
`arg` is the index of `name` in the global name table.
`push_null` specifies whether or not a NULL should be pushed to the stack
before the global (Python 3.11+ only).
Python 3.11 changed the LOAD_GLOBAL instruction in that the first bit of
the arg specifies whether a NULL should be pushed to the stack before the
global. The remaining bits of arg contain the name index. See
`create_call_function` for why this NULL is needed.
"""
if sys.version_info >= (3, 11):
arg = (arg << 1) + push_null
return create_instruction("LOAD_GLOBAL", arg, name)


def create_dup_top():
if sys.version_info >= (3, 11):
return create_instruction("COPY", 1)
Expand Down Expand Up @@ -98,6 +115,40 @@ def create_rot_n(n):
return [create_instruction("ROT_N", n)]


def create_call_function(nargs, push_null):
"""
Creates a sequence of instructions that makes a function call.
`push_null` is used in Python 3.11+ only. It is used in codegen when
a function call is intended to be made with the NULL + fn convention,
and we know that the NULL has not been pushed yet. We will push a
NULL and rotate it to the correct position immediately before making
the function call.
push_null should default to True unless you know you are calling a function
that you codegen'd with a null already pushed, for example,
create_instruction("LOAD_GLOBAL", 1, "math") # pushes a null
create_instruction("LOAD_ATTR", argval="sqrt")
create_instruction("LOAD_CONST", argval=25)
create_call_function(1, False)
"""
if sys.version_info >= (3, 11):
output = []
if push_null:
output.append(create_instruction("PUSH_NULL"))
output.extend(create_rot_n(nargs + 2))
output.append(create_instruction("PRECALL", nargs))
output.append(create_instruction("CALL", nargs))
return output
return [create_instruction("CALL_FUNCTION", nargs)]


def create_call_method(nargs):
if sys.version_info >= (3, 11):
return [create_instruction("PRECALL", nargs), create_instruction("CALL", nargs)]
return [create_instruction("CALL_METHOD", nargs)]


def lnotab_writer(lineno, byteno=0):
"""
Used to create typing.CodeType.co_lnotab
Expand Down Expand Up @@ -276,7 +327,7 @@ def explicit_super(code: types.CodeType, instructions: List[Instruction]):
output.append(inst)
if inst.opname == "LOAD_GLOBAL" and inst.argval == "super":
nexti = instructions[idx + 1]
if nexti.opname == "CALL_FUNCTION" and nexti.arg == 0:
if nexti.opname in ("CALL_FUNCTION", "PRECALL") and nexti.arg == 0:
assert "__class__" in cell_and_free
output.append(
create_instruction(
Expand All @@ -294,6 +345,11 @@ def explicit_super(code: types.CodeType, instructions: List[Instruction]):
output.append(create_instruction("LOAD_FAST", 0, first_var))
nexti.arg = 2
nexti.argval = 2
if nexti.opname == "PRECALL":
# also update the following CALL instruction
call_inst = instructions[idx + 2]
call_inst.arg = 2
call_inst.argval = 2

instructions[:] = output

Expand Down Expand Up @@ -394,11 +450,24 @@ def fix_vars(instructions: List[Instruction], code_options):
varnames = {name: idx for idx, name in enumerate(code_options["co_varnames"])}
names = {name: idx for idx, name in enumerate(code_options["co_names"])}
for i in range(len(instructions)):
if sys.version_info >= (3, 11) and instructions[i].opname == "LOAD_GLOBAL":
# LOAD_GLOBAL is in HAS_NAME, so instructions[i].arg will be overwritten.
# So we must compute push_null earlier.
assert instructions[i].arg is not None
shift = 1
push_null = instructions[i].arg % 2
else:
shift = 0
push_null = 0

if instructions[i].opcode in HAS_LOCAL:
instructions[i].arg = varnames[instructions[i].argval]
elif instructions[i].opcode in HAS_NAME:
instructions[i].arg = names[instructions[i].argval]

if instructions[i].arg is not None:
instructions[i].arg = (instructions[i].arg << shift) + push_null


def transform_code_object(code, transformations, safe=False):
# Python 3.11 changes to code keys are not fully documented.
Expand Down Expand Up @@ -483,7 +552,8 @@ def cleaned_instructions(code, safe=False):
virtualize_jumps(instructions)
strip_extended_args(instructions)
if not safe:
remove_load_call_method(instructions)
if sys.version_info < (3, 11):
remove_load_call_method(instructions)
explicit_super(code, instructions)
return instructions

Expand Down
73 changes: 50 additions & 23 deletions torch/_dynamo/codegen.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import collections
import dataclasses
import re
import sys
import types
from typing import List

import torch.nn

from .bytecode_transformation import (
create_call_function,
create_dup_top,
create_instruction,
create_load_global,
create_rot_n,
Instruction,
)
Expand Down Expand Up @@ -123,10 +126,7 @@ def __call__(self, value, allow_cache=True):

if isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
output.extend(
[
self.create_load_attr("item"),
create_instruction("CALL_FUNCTION", 0),
]
[self.create_load_attr("item")] + create_call_function(0, True)
)
elif isinstance(value, NNModuleVariable):
parts = value.module_key.split(".")
Expand Down Expand Up @@ -161,15 +161,15 @@ def foreach(self, items):
for i in items:
self(i)

def setup_globally_cached(self, name, value):
def setup_globally_cached(self, name, value, push_null):
"""Store value in a new global"""
name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
f_globals = self.tx.f_globals
if name in f_globals:
assert id(f_globals[name]) == id(value)
else:
f_globals[name] = value
return [self.create_load_global(name, add=True)]
return [self.create_load_global(name, push_null, add=True)]

def clear_tos(self):
self.top_of_stack = None
Expand Down Expand Up @@ -213,12 +213,12 @@ def create_store(self, name):
"STORE_FAST", self.code_options["co_varnames"].index(name), name
)

def create_load_global(self, name, add=False):
def create_load_global(self, name, push_null, add=False):
if add:
self.tx.output.update_co_names(name)
assert name in self.code_options["co_names"], f"{name} not in co_names"
return create_instruction(
"LOAD_GLOBAL", self.code_options["co_names"].index(name), name
return create_load_global(
name, self.code_options["co_names"].index(name), push_null
)

def create_load_const(self, value):
Expand Down Expand Up @@ -256,11 +256,18 @@ def create_load_attr(self, name):
def create_load_attrs(self, names):
return [self.create_load_attr(name) for name in names.split(".")]

def load_function_name(self, fn_name, num_on_stack=0):
def load_function_name(self, fn_name, push_null, num_on_stack=0):
"""Load the global fn_name on the stack num_on_stack down"""
return [self.create_load_global(fn_name, add=True)] + self.rot_n(
num_on_stack + 1
output = []
if push_null and sys.version_info >= (3, 11):
output.extend(
[create_instruction("PUSH_NULL")] + self.rot_n(num_on_stack + 1)
)
output.extend(
[self.create_load_global(fn_name, False, add=True)]
+ self.rot_n(num_on_stack + 1)
)
return output

def rot_n(self, n):
try:
Expand All @@ -279,6 +286,16 @@ def rot_n(self, n):
]
)

def pop_null(self):
# POP_TOP doesn't work for null, so we pop nulls by pushing in a
# nop function, calling it (which consumes the null), and popping the result.
assert sys.version_info >= (3, 11)
return (
[self._create_load_const(lambda: None)]
+ create_call_function(0, False)
+ [create_instruction("POP_TOP")]
)

def make_function_with_closure(
self, fn_name: str, code: types.CodeType, num_on_stack=0
):
Expand All @@ -299,42 +316,38 @@ def make_function_with_closure(
output.extend(self.rot_n(num_on_stack + 1))
self.clear_tos()

def create_load_python_module(self, mod):
def create_load_python_module(self, mod, push_null):
"""
Generate a LOAD_GLOBAL instruction to fetch a given python module.
"""
root_globals = self.tx.output.root_globals
name = re.sub(r"^.*[.]", "", mod.__name__)
if root_globals.get(name, None) is mod:
return self.create_load_global(name, add=True)
return self.create_load_global(name, push_null, add=True)
mangled_name = f"___module_{name}_{id(mod)}"
if mangled_name not in root_globals:
self.tx.output.install_global(mangled_name, mod)
return self.create_load_global(mangled_name, add=True)
return self.create_load_global(mangled_name, push_null, add=True)

def make_call_generated_code(self, fn_name: str) -> List[Instruction]:
"""Call the generated code function stored in fn_name"""
self.extend_output(self.load_function_name(fn_name))
self.extend_output(self.load_function_name(fn_name, True))

graphargs = self.tx.output.graphargs
for arg in graphargs:
if arg.is_unspecialized:
self.extend_output(
[
self.create_load_python_module(torch),
self.create_load_python_module(torch, True),
self.create_load_attr("tensor"),
]
)
self.extend_output(arg.load(self))
self.extend_output(
[
create_instruction("CALL_FUNCTION", 1),
]
)
self.extend_output(create_call_function(1, False))
else:
self.extend_output(arg.load(self))

self.append_output(create_instruction("CALL_FUNCTION", len(graphargs)))
self.extend_output(create_call_function(len(graphargs), False))

def load_import_from(self, module_name, object_name):
self.extend_output(
Expand All @@ -345,3 +358,17 @@ def load_import_from(self, module_name, object_name):

def create_begin_finally(self):
return create_instruction("BEGIN_FINALLY")

def create_call_function_kw(self, nargs, kw_names, push_null):
if sys.version_info >= (3, 11):
output = create_call_function(nargs, push_null)
assert output[-2].opname == "PRECALL"
kw_names_inst = create_instruction(
"KW_NAMES", self.get_const_index(self.code_options, kw_names)
)
output.insert(-2, kw_names_inst)
return output
return [
self.create_load_const(kw_names),
create_instruction("CALL_FUNCTION_KW", nargs),
]
21 changes: 13 additions & 8 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@

from . import config, logging as torchdynamo_logging, variables
from .backends.registry import CompiledFn, CompilerFn
from .bytecode_transformation import create_instruction, Instruction, unique_id
from .bytecode_transformation import (
create_call_function,
create_instruction,
Instruction,
unique_id,
)
from .codegen import PyCodegen
from .exc import BackendCompilerFailed, unimplemented
from .guards import GuardBuilder
Expand Down Expand Up @@ -517,18 +522,18 @@ def compile_subgraph(
codegen = PyCodegen(tx, root)
random_calls_instructions.extend(
[
codegen.create_load_global("random", add=True),
codegen.create_load_global("random", True, add=True),
codegen.create_load_attr("setstate"),
codegen.create_load_const(tx.output.initial_random_state),
create_instruction("CALL_FUNCTION", 1),
]
+ create_call_function(1, False),
)
random_calls_instructions.extend(codegen.load_function_name(rand_fn_name))
random_calls_instructions.extend(
[
create_instruction("CALL_FUNCTION", 0),
codegen.create_store(tx.output.random_values_var),
]
codegen.load_function_name(rand_fn_name, True)
)
random_calls_instructions.extend(create_call_function(0, False))
random_calls_instructions.append(
codegen.create_store(tx.output.random_values_var),
)
self.add_output_instructions(random_calls_instructions)

Expand Down
Loading

0 comments on commit 055a9e4

Please sign in to comment.