Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Prepare TraceShift pass for GTIR #1592

Merged
merged 8 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,13 @@ class SimpleTemporaryExtractionHeuristics:

closure: ir.StencilClosure

@functools.cached_property
def closure_shifts(
self,
) -> dict[int, set[tuple[ir.OffsetLiteral, ...]]]:
return trace_shifts.TraceShifts.apply(self.closure, inputs_only=False) # type: ignore[return-value] # TODO fix weird `apply` overloads
def __post_init__(self) -> None:
trace_shifts.trace_stencil(
self.closure.stencil, num_args=len(self.closure.inputs), save_to_annex=True
)

def __call__(self, expr: ir.Expr) -> bool:
shifts = self.closure_shifts[id(expr)]
shifts = expr.annex.recorded_shifts
if len(shifts) > 1:
return True
return False
Expand Down Expand Up @@ -564,8 +563,9 @@ def update_domains(

closures.append(closure)

local_shifts = trace_shifts.TraceShifts.apply(closure)
for param, shift_chains in local_shifts.items():
local_shifts = trace_shifts.trace_stencil(closure.stencil, num_args=len(closure.inputs))
for param_sym, shift_chains in zip(closure.inputs, local_shifts):
param = param_sym.id
assert isinstance(param, str)
consumed_domains: list[SymbolicDomain] = (
[SymbolicDomain.from_expr(domains[param])] if param in domains else []
Expand Down
21 changes: 3 additions & 18 deletions src/gt4py/next/iterator/transforms/infer_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from gt4py.next.common import Dimension
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.transforms import trace_shifts
from gt4py.next.iterator.transforms.global_tmps import AUTO_DOMAIN, SymbolicDomain, domain_union
from gt4py.next.iterator.transforms.trace_shifts import TraceShifts


def _merge_domains(
Expand All @@ -28,31 +28,16 @@ def _merge_domains(
return new_domains


# FIXME[#1582](tehrengruber): Use new TraceShift API when #1592 is merged.
def trace_shifts(
stencil: itir.Expr, input_ids: list[str], domain: itir.Expr
) -> dict[str, set[tuple[itir.OffsetLiteral, ...]]]:
node = itir.StencilClosure(
stencil=stencil,
inputs=[im.ref(id_) for id_ in input_ids],
output=im.ref("__dummy"),
domain=domain,
)
return TraceShifts.apply(node, inputs_only=True) # type: ignore[return-value] # ensured by inputs_only=True


def extract_shifts_and_translate_domains(
stencil: itir.Expr,
input_ids: list[str],
target_domain: SymbolicDomain,
offset_provider: Dict[str, Dimension],
accessed_domains: Dict[str, SymbolicDomain],
):
shifts_results = trace_shifts(stencil, input_ids, SymbolicDomain.as_expr(target_domain))

for in_field_id in input_ids:
shifts_list = shifts_results[in_field_id]
shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids))

for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True):
new_domains = [
SymbolicDomain.translate(target_domain, shift, offset_provider) for shift in shifts_list
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from gt4py.eve import utils as eve_utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms import trace_shifts
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda
from gt4py.next.iterator.transforms.inline_lifts import InlineLifts
from gt4py.next.iterator.transforms.trace_shifts import TraceShifts, copy_recorded_shifts


def is_center_derefed_only(node: itir.Node) -> bool:
Expand Down Expand Up @@ -58,7 +58,7 @@ def apply(cls, node: itir.FencilDefinition, uids: Optional[eve_utils.UIDGenerato
def visit_StencilClosure(self, node: itir.StencilClosure, **kwargs):
# TODO(tehrengruber): move the analysis out of this pass and just make it a requirement
# such that we don't need to run in multiple times if multiple passes use it.
TraceShifts.apply(node, save_to_annex=True)
trace_shifts.trace_stencil(node.stencil, num_args=len(node.inputs), save_to_annex=True)
return self.generic_visit(node, **kwargs)

def visit_FunCall(self, node: itir.FunCall, **kwargs):
Expand All @@ -74,7 +74,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs):
eligible_params[i] = True
bound_arg_name = self.uids.sequential_id(prefix="_icdlv")
capture_lift = im.promote_to_const_iterator(bound_arg_name)
copy_recorded_shifts(from_=param, to=capture_lift)
trace_shifts.copy_recorded_shifts(from_=param, to=capture_lift)
new_args.append(capture_lift)
# since we deref an applied lift here we can (but don't need to) immediately
# inline
Expand Down
109 changes: 60 additions & 49 deletions src/gt4py/next/iterator/transforms/trace_shifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import dataclasses
import sys
from collections.abc import Callable
from typing import Any, Final, Iterable, Literal
from typing import Any, Final, Iterable, Literal, Optional

from gt4py import eve
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift


Expand Down Expand Up @@ -76,7 +77,7 @@ def __call__(self, inp: ir.Expr | ir.Sym, offsets: tuple[ir.OffsetLiteral, ...])


# for performance reasons (`isinstance` is slow otherwise) we don't use abc here
class IteratorTracer:
class Tracer:
def deref(self):
raise NotImplementedError()

Expand All @@ -85,13 +86,13 @@ def shift(self, offsets: tuple[ir.OffsetLiteral, ...]):


@dataclasses.dataclass(frozen=True)
class IteratorArgTracer(IteratorTracer):
class ArgTracer(Tracer):
arg: ir.Expr | ir.Sym
shift_recorder: ShiftRecorder | ForwardingShiftRecorder
offsets: tuple[ir.OffsetLiteral, ...] = ()

def shift(self, offsets: tuple[ir.OffsetLiteral, ...]):
return IteratorArgTracer(
return ArgTracer(
arg=self.arg, shift_recorder=self.shift_recorder, offsets=self.offsets + tuple(offsets)
)

Expand All @@ -103,8 +104,8 @@ def deref(self):
# This class is only needed because we currently allow conditionals on iterators. Since this is
# not supported in the C++ backend it can likely be removed again in the future.
@dataclasses.dataclass(frozen=True)
class CombinedTracer(IteratorTracer):
its: tuple[IteratorTracer, ...]
class CombinedTracer(Tracer):
its: tuple[Tracer, ...]

def shift(self, offsets: tuple[ir.OffsetLiteral, ...]):
return CombinedTracer(tuple(_shift(*offsets)(it) for it in self.its))
Expand Down Expand Up @@ -142,16 +143,16 @@ def _shift(*offsets):
)

def apply(arg):
assert isinstance(arg, IteratorTracer)
assert isinstance(arg, Tracer)
return arg.shift(offsets)

return apply


@dataclasses.dataclass(frozen=True)
class AppliedLift(IteratorTracer):
class AppliedLift(Tracer):
stencil: Callable
its: tuple[IteratorTracer, ...]
its: tuple[Tracer, ...]

def shift(self, offsets):
return AppliedLift(self.stencil, tuple(_shift(*offsets)(it) for it in self.its))
Expand All @@ -162,7 +163,7 @@ def deref(self):

def _lift(f):
def apply(*its):
if not all(isinstance(it, IteratorTracer) for it in its):
if not all(isinstance(it, Tracer) for it in its):
raise AssertionError("All arguments must be iterators.")
return AppliedLift(f, its)

Expand All @@ -189,20 +190,18 @@ def apply(*args):


def _primitive_constituents(
val: Literal[Sentinel.VALUE] | IteratorTracer | tuple,
) -> Iterable[Literal[Sentinel.VALUE] | IteratorTracer]:
if val is Sentinel.VALUE or isinstance(val, IteratorTracer):
val: Literal[Sentinel.VALUE] | Tracer | tuple,
) -> Iterable[Literal[Sentinel.VALUE] | Tracer]:
if val is Sentinel.VALUE or isinstance(val, Tracer):
yield val
elif isinstance(val, tuple):
for el in val:
if isinstance(el, tuple):
yield from _primitive_constituents(el)
elif el is Sentinel.VALUE or isinstance(el, IteratorTracer):
elif el is Sentinel.VALUE or isinstance(el, Tracer):
yield el
else:
raise AssertionError(
"Expected a `Sentinel.VALUE`, `IteratorTracer` or tuple thereof."
)
raise AssertionError("Expected a `Sentinel.VALUE`, `Tracer` or tuple thereof.")
else:
raise ValueError()

Expand All @@ -225,9 +224,7 @@ def _if(cond: Literal[Sentinel.VALUE], true_branch, false_branch):
result.append(_if(Sentinel.VALUE, el_true_branch, el_false_branch))
return tuple(result)

is_iterator_arg = tuple(
isinstance(arg, IteratorTracer) for arg in (cond, true_branch, false_branch)
)
is_iterator_arg = tuple(isinstance(arg, Tracer) for arg in (cond, true_branch, false_branch))
if is_iterator_arg == (False, True, True):
return CombinedTracer((true_branch, false_branch))
assert is_iterator_arg == (False, False, False) and all(
Expand All @@ -247,7 +244,15 @@ def _tuple_get(index, tuple_val):
return Sentinel.VALUE


def _as_fieldop(stencil, domain=None):
def applied_as_fieldop(*args):
return stencil(*args)

return applied_as_fieldop


_START_CTX: Final = {
"as_fieldop": _as_fieldop,
"deref": _deref,
"can_deref": _can_deref,
"shift": _shift,
Expand Down Expand Up @@ -291,11 +296,11 @@ def visit_FunCall(self, node: ir.FunCall, *, ctx: dict[str, Any]) -> Any:

def visit(self, node, **kwargs):
result = super().visit(node, **kwargs)
if isinstance(result, IteratorTracer):
if isinstance(result, Tracer):
assert isinstance(node, (ir.Sym, ir.Expr))

self.shift_recorder.register_node(node)
result = IteratorArgTracer(
result = ArgTracer(
arg=node, shift_recorder=ForwardingShiftRecorder(result, self.shift_recorder)
)
return result
Expand All @@ -304,10 +309,10 @@ def visit_Lambda(self, node: ir.Lambda, *, ctx: dict[str, Any]) -> Callable:
def fun(*args):
new_args = []
for param, arg in zip(node.params, args, strict=True):
if isinstance(arg, IteratorTracer):
if isinstance(arg, Tracer):
self.shift_recorder.register_node(param)
new_args.append(
IteratorArgTracer(
ArgTracer(
arg=param,
shift_recorder=ForwardingShiftRecorder(arg, self.shift_recorder),
)
Expand All @@ -321,46 +326,49 @@ def fun(*args):

return fun

def visit_StencilClosure(self, node: ir.StencilClosure):
tracers = []
for inp in node.inputs:
self.shift_recorder.register_node(inp)
tracers.append(IteratorArgTracer(arg=inp, shift_recorder=self.shift_recorder))

result = self.visit(node.stencil, ctx=_START_CTX)(*tracers)
assert all(el is Sentinel.VALUE for el in _primitive_constituents(result))
return node

@classmethod
def apply(
cls, node: ir.StencilClosure | ir.FencilDefinition, *, inputs_only=True, save_to_annex=False
) -> (
dict[int, set[tuple[ir.OffsetLiteral, ...]]] | dict[str, set[tuple[ir.OffsetLiteral, ...]]]
def trace_stencil(
cls, stencil: ir.Expr, *, num_args: Optional[int] = None, save_to_annex: bool = False
):
# If we get a lambda we can deduce the number of arguments.
if isinstance(stencil, ir.Lambda):
assert num_args is None or num_args == len(stencil.params)
num_args = len(stencil.params)
Comment on lines +334 to +336
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's special with lambdas? what else can it be?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For lambdas we can deduce the number of arguments and the num_args parameter is optional. This is essentially a convenience feature for testing. I've added a comment.

if not isinstance(num_args, int):
raise ValueError("Stencil must be an 'itir.Lambda' or `num_args` is given.")
assert isinstance(num_args, int)

args = [im.ref(f"__arg{i}") for i in range(num_args)]

old_recursionlimit = sys.getrecursionlimit()
sys.setrecursionlimit(100000000)

instance = cls()
instance.visit(node)

# initialize shift recorder & context with all built-ins and the iterator argument tracers
ctx: dict[str, Any] = {**_START_CTX}
for arg in args:
instance.shift_recorder.register_node(arg)
ctx[arg.id] = ArgTracer(arg=arg, shift_recorder=instance.shift_recorder)

# actually trace stencil
instance.visit(im.call(stencil)(*args), ctx=ctx)

sys.setrecursionlimit(old_recursionlimit)

recorded_shifts = instance.shift_recorder.recorded_shifts

param_shifts = []
for arg in args:
param_shifts.append(recorded_shifts[id(arg)])

if save_to_annex:
_save_to_annex(node, recorded_shifts)
_save_to_annex(stencil, recorded_shifts)

if __debug__:
ValidateRecordedShiftsAnnex().visit(node)
return param_shifts

if inputs_only:
assert isinstance(node, ir.StencilClosure)
inputs_shifts = {}
for inp in node.inputs:
inputs_shifts[str(inp.id)] = recorded_shifts[id(inp)]
return inputs_shifts

return recorded_shifts
trace_stencil = TraceShifts.trace_stencil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it also a classmethod?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To allow subclassing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you just make this up or do you have a use-case in mind?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To stick to your wording: I made this up. I can change it if you like, I can also make up unlikely cases where this might be useful.



def _save_to_annex(
Expand All @@ -369,3 +377,6 @@ def _save_to_annex(
for child_node in node.pre_walk_values():
if id(child_node) in recorded_shifts:
child_node.annex.recorded_shifts = recorded_shifts[id(child_node)]

if __debug__:
ValidateRecordedShiftsAnnex().visit(node)
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,15 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG:
self.itir, offset_provider=offset_provider
)
for closure in itir_tmp.closures: # type: ignore[union-attr]
shifts = itir_transforms.trace_shifts.TraceShifts.apply(closure)
for k, v in shifts.items():
if not isinstance(k, str):
params_shifts = itir_transforms.trace_shifts.trace_stencil(
closure.stencil, num_args=len(closure.inputs)
)
for param, shifts in zip(closure.inputs, params_shifts):
if not isinstance(param.id, str):
continue
if k not in sdfg.gt4py_program_input_fields:
if param.id not in sdfg.gt4py_program_input_fields:
continue
sdfg.offset_providers_per_input_field.setdefault(k, []).extend(list(v))
sdfg.offset_providers_per_input_field.setdefault(param.id, []).extend(list(shifts))

return sdfg

Expand Down
Loading
Loading