Skip to content

Commit

Permalink
Handle int/float arguments for cpp codegen in inductor (#95533)
Browse files Browse the repository at this point in the history
This is a little questionable because we don't actually know what the dtype of the sympy expression is, and it's not clear we can rely on the assumptions.

Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: pytorch/pytorch#95533
Approved by: https://github.com/ngimel, https://github.com/jansel
  • Loading branch information
ezyang authored and cyyever committed Mar 5, 2023
1 parent 5cde15e commit f3227cc
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 12 deletions.
10 changes: 10 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5539,6 +5539,16 @@ def fn(x, y):
[torch.randn((4, 2)), torch.randn((4))],
)

@torch._dynamo.config.patch(dynamic_shapes=True)
def test_int_input_dynamic_shapes(self):
@torch.compile(dynamic=True)
def fn(x, i):
y = x * i
return y

# Constant must not get matched as constant
self.common(fn, [torch.randn(3, 1, 1, 1, 1), 9132])

@unittest.skipIf(HAS_CUDA, "test in_out_ptr for CppKernel")
def test_in_out_buffer(self):
def fn(x, y):
Expand Down
4 changes: 3 additions & 1 deletion torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,7 +1700,9 @@ def aot_wrapper_dedupe(
ok = True

for i, a in enumerate(flat_args):
if a not in args_set:
if not isinstance(a, torch.Tensor):
leaf_flat_args.append(a)
elif a not in args_set:
args_set.add(a)
leaf_flat_args.append(a)
elif not fw_metadata.input_info[i].mutates_data and not fw_metadata.input_info[i].mutates_metadata:
Expand Down
13 changes: 10 additions & 3 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import sympy
from sympy.printing.printer import Printer

import torch

from .. import metrics
from ..utils import (
DeferredLineBase,
Expand Down Expand Up @@ -305,9 +307,14 @@ def cpp_argdefs(self):

# TODO(jansel): replace this with data from scheduler
buffer_types = {x.get_name(): x.get_dtype() for x in V.graph.buffers}
buffer_types.update(
{name: val.get_dtype() for name, val in V.graph.graph_inputs.items()}
)
for name, val in V.graph.graph_inputs.items():
if isinstance(val, sympy.Expr):
if val.is_integer:
buffer_types[name] = torch.int64
else:
buffer_types[name] = torch.float64
else:
buffer_types[name] = val.get_dtype()
buffer_types.update(
{name: val.dtype for name, val in V.graph.constants.items()}
)
Expand Down
18 changes: 13 additions & 5 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from itertools import count
from typing import Any, Dict, List

import sympy

from torch._dynamo.utils import dynamo_timed

from .. import codecache, config, ir
Expand Down Expand Up @@ -572,6 +574,9 @@ def add_fake_input(name, shape, stride, device, dtype):
f"device='{device}', dtype={dtype})"
)

def add_expr_input(name, val):
output.writeline(f"{name} = {val}")

output.writelines(["", "", 'if __name__ == "__main__":'])
with output.indent():
output.splice(
Expand All @@ -588,11 +593,14 @@ def add_fake_input(name, shape, stride, device, dtype):
)

for name, value in V.graph.graph_inputs.items():
shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
add_fake_input(
name, shape, stride, value.get_device(), value.get_dtype()
)
if isinstance(value, sympy.Expr): # Don't need to add symbolic
add_expr_input(name, V.graph.sizevars.size_hint(value))
else:
shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
add_fake_input(
name, shape, stride, value.get_device(), value.get_dtype()
)

output.writeline(
f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))"
Expand Down
3 changes: 2 additions & 1 deletion torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def is_aligned(storage_offset, dtype):
check_inputs = [
i
for i in range(len(inputs))
if (
if isinstance(inputs[i], torch.Tensor)
and (
i not in static_input_idxs
or not is_aligned(inputs[i].storage_offset(), inputs[i].dtype)
)
Expand Down
8 changes: 8 additions & 0 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
magic_methods,
method_to_operator,
ShapeEnv,
SymTypes,
)
from torch.utils._mode_utils import no_dispatch

Expand Down Expand Up @@ -278,6 +279,10 @@ def constant_name(self, name: str, device_override: torch.device):

def placeholder(self, target: str, args, kwargs):
example: torch.Tensor = super().placeholder(target, args, kwargs)
if isinstance(example, SymTypes):
expr = example.node.expr
self.graph_inputs[target] = expr
return expr
# todo(chilli): We can remove the last check once we turn buffers into
# static shape tensors. That's a hack to workaround Inductor believing
# the buffer should be static but us passing in a fake tensor with
Expand Down Expand Up @@ -384,6 +389,9 @@ def output(self, target, args, kwargs):
), result
self.graph_outputs = [ir.ExternKernel.realize_input(x) for x in result]
for name, value in self.graph_inputs.items():
assert isinstance(value, (TensorBox, sympy.Expr))
if not isinstance(value, TensorBox):
continue
value.realize()
assert isinstance(value, TensorBox)
value = value.data
Expand Down
18 changes: 16 additions & 2 deletions torch/_inductor/sizevars.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,21 @@ def strideof(name):
# Assign all symbolic shapes needed to local variables
needed = set(self.var_to_val.keys()) - set(self.replacements.keys())

for name, value in graph_inputs.items():
def is_expr(x):
return isinstance(x[1], sympy.Expr)

graph_inputs_expr = list(filter(is_expr, graph_inputs.items()))
graph_inputs_tensors = list(
filter(lambda x: not is_expr(x), graph_inputs.items())
)

for name, shape in graph_inputs_expr:
shape = self.simplify(shape)
if shape in needed:
needed.remove(shape)
code.writeline(f"{self.declare}{shape} = {name}{self.ending}")

for name, value in graph_inputs_tensors:
shapes = value.get_size()
for dim, shape in enumerate(shapes):
shape = self.simplify(shape)
Expand All @@ -470,7 +484,7 @@ def strideof(name):
f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
)

for name, value in graph_inputs.items():
for name, value in graph_inputs_tensors:
shapes = value.get_stride()
for dim, shape in enumerate(shapes):
shape = self.simplify(shape)
Expand Down

0 comments on commit f3227cc

Please sign in to comment.