Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle int/float arguments for cpp codegen in inductor #95533

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5527,6 +5527,16 @@ def fn(x, y):
[torch.randn((4, 2)), torch.randn((4))],
)

@torch._dynamo.config.patch(dynamic_shapes=True)
def test_int_input_dynamic_shapes(self):
@torch.compile(dynamic=True)
def fn(x, i):
y = x * i
return y

# Constant must not get matched as constant
self.common(fn, [torch.randn(3, 1, 1, 1, 1), 9132])

@unittest.skipIf(HAS_CUDA, "test in_out_ptr for CppKernel")
def test_in_out_buffer(self):
def fn(x, y):
Expand Down
4 changes: 3 additions & 1 deletion torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,7 +1700,9 @@ def aot_wrapper_dedupe(
ok = True

for i, a in enumerate(flat_args):
if a not in args_set:
if not isinstance(a, torch.Tensor):
leaf_flat_args.append(a)
elif a not in args_set:
args_set.add(a)
leaf_flat_args.append(a)
elif not fw_metadata.input_info[i].mutates_data and not fw_metadata.input_info[i].mutates_metadata:
Expand Down
13 changes: 10 additions & 3 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import sympy
from sympy.printing.printer import Printer

import torch

from .. import metrics
from ..utils import (
DeferredLineBase,
Expand Down Expand Up @@ -305,9 +307,14 @@ def cpp_argdefs(self):

# TODO(jansel): replace this with data from scheduler
buffer_types = {x.get_name(): x.get_dtype() for x in V.graph.buffers}
buffer_types.update(
{name: val.get_dtype() for name, val in V.graph.graph_inputs.items()}
)
for name, val in V.graph.graph_inputs.items():
if isinstance(val, sympy.Expr):
if val.is_integer:
buffer_types[name] = torch.int64
else:
buffer_types[name] = torch.float64
else:
buffer_types[name] = val.get_dtype()
buffer_types.update(
{name: val.dtype for name, val in V.graph.constants.items()}
)
Expand Down
18 changes: 13 additions & 5 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from itertools import count
from typing import Any, Dict, List

import sympy

from torch._dynamo.utils import dynamo_timed

from .. import codecache, config, ir
Expand Down Expand Up @@ -572,6 +574,9 @@ def add_fake_input(name, shape, stride, device, dtype):
f"device='{device}', dtype={dtype})"
)

def add_expr_input(name, val):
output.writeline(f"{name} = {val}")

output.writelines(["", "", 'if __name__ == "__main__":'])
with output.indent():
output.splice(
Expand All @@ -588,11 +593,14 @@ def add_fake_input(name, shape, stride, device, dtype):
)

for name, value in V.graph.graph_inputs.items():
shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
add_fake_input(
name, shape, stride, value.get_device(), value.get_dtype()
)
if isinstance(value, sympy.Expr): # Don't need to add symbolic
add_expr_input(name, V.graph.sizevars.size_hint(value))
else:
shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
add_fake_input(
name, shape, stride, value.get_device(), value.get_dtype()
)

output.writeline(
f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))"
Expand Down
3 changes: 2 additions & 1 deletion torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ def is_aligned(storage_offset, dtype):
check_inputs = [
i
for i in range(len(inputs))
if (
if isinstance(inputs[i], torch.Tensor)
and (
i not in static_input_idxs
or not is_aligned(inputs[i].storage_offset(), inputs[i].dtype)
)
Expand Down
8 changes: 8 additions & 0 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
magic_methods,
method_to_operator,
ShapeEnv,
SymTypes,
)
from torch.utils._mode_utils import no_dispatch

Expand Down Expand Up @@ -278,6 +279,10 @@ def constant_name(self, name: str, device_override: torch.device):

def placeholder(self, target: str, args, kwargs):
example: torch.Tensor = super().placeholder(target, args, kwargs)
if isinstance(example, SymTypes):
expr = example.node.expr
self.graph_inputs[target] = expr
return expr
# todo(chilli): We can remove the last check once we turn buffers into
# static shape tensors. That's a hack to workaround Inductor believing
# the buffer should be static but us passing in a fake tensor with
Expand Down Expand Up @@ -384,6 +389,9 @@ def output(self, target, args, kwargs):
), result
self.graph_outputs = [ir.ExternKernel.realize_input(x) for x in result]
for name, value in self.graph_inputs.items():
assert isinstance(value, (TensorBox, sympy.Expr))
if not isinstance(value, TensorBox):
continue
value.realize()
assert isinstance(value, TensorBox)
value = value.data
Expand Down
18 changes: 16 additions & 2 deletions torch/_inductor/sizevars.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,21 @@ def strideof(name):
# Assign all symbolic shapes needed to local variables
needed = set(self.var_to_val.keys()) - set(self.replacements.keys())

for name, value in graph_inputs.items():
def is_expr(x):
return isinstance(x[1], sympy.Expr)

graph_inputs_expr = list(filter(is_expr, graph_inputs.items()))
graph_inputs_tensors = list(
filter(lambda x: not is_expr(x), graph_inputs.items())
)

for name, shape in graph_inputs_expr:
shape = self.simplify(shape)
if shape in needed:
needed.remove(shape)
code.writeline(f"{self.declare}{shape} = {name}{self.ending}")

for name, value in graph_inputs_tensors:
shapes = value.get_size()
for dim, shape in enumerate(shapes):
shape = self.simplify(shape)
Expand All @@ -467,7 +481,7 @@ def strideof(name):
f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
)

for name, value in graph_inputs.items():
for name, value in graph_inputs_tensors:
shapes = value.get_stride()
for dim, shape in enumerate(shapes):
shape = self.simplify(shape)
Expand Down