Skip to content

Commit

Permalink
feat[next]: Refactor CSE pass to support ITIR.Program (#1646)
Browse files Browse the repository at this point in the history
Extends the common subexpression elimination to support the new
itir.Program node and pushes the intermediate Fencil -> Program
conversion upwards the pass manager. The CSE pass now uses the type
inference such that only field expressions or composites thereof are
collected in field-view context (i.e. outside of as_fieldop).

This PR was initially meant to be merged into the temporary GTIR branch
and reviewed by @egparedes here: #1570. The only change since then is to
make dace tests pass (see commit 160a616).

---------

Co-authored-by: edopao <[email protected]>
Co-authored-by: Enrique González Paredes <[email protected]>
  • Loading branch information
3 people authored Sep 17, 2024
1 parent c8822c0 commit 9328c50
Show file tree
Hide file tree
Showing 11 changed files with 256 additions and 110 deletions.
152 changes: 106 additions & 46 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import dataclasses
import functools
import math
import operator
import typing
from typing import Callable, Iterable, TypeVar, Union, cast

from gt4py.eve import (
NodeTranslator,
Expand All @@ -20,32 +22,36 @@
VisitorWithSymbolTableTrait,
)
from gt4py.eve.utils import UIDGenerator
from gt4py.next.iterator import ir
from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda
from gt4py.next.iterator.type_system import inference as itir_type_inference
from gt4py.next.type_system import type_info, type_specifications as ts


@dataclasses.dataclass
class _NodeReplacer(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type",)

expr_map: dict[int, ir.SymRef]
expr_map: dict[int, itir.SymRef]

def visit_Expr(self, node: ir.Node) -> ir.Node:
def visit_Expr(self, node: itir.Node) -> itir.Node:
if id(node) in self.expr_map:
return self.expr_map[id(node)]
return self.generic_visit(node)

def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
node = typing.cast(ir.FunCall, self.visit_Expr(node))
def visit_FunCall(self, node: itir.FunCall) -> itir.Node:
node = cast(itir.FunCall, self.visit_Expr(node))
# If we encounter an expression like:
# (λ(_cs_1) → (λ(a) → a+a)(_cs_1))(outer_expr)
# (non-recursively) inline the lambda to obtain:
# (λ(_cs_1) → _cs_1+_cs_1)(outer_expr)
# This allows identifying more common subexpressions later on
if isinstance(node, ir.FunCall) and isinstance(node.fun, ir.Lambda):
if isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda):
eligible_params = []
for arg in node.args:
eligible_params.append(isinstance(arg, ir.SymRef) and arg.id.startswith("_cs"))
eligible_params.append(isinstance(arg, itir.SymRef) and arg.id.startswith("_cs"))
if any(eligible_params):
# note: the inline is opcount preserving anyway so avoid the additional
# effort in the inliner by disabling opcount preservation.
Expand All @@ -55,18 +61,18 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
return node


def _is_collectable_expr(node: ir.Node) -> bool:
if isinstance(node, ir.FunCall):
def _is_collectable_expr(node: itir.Node) -> bool:
if isinstance(node, itir.FunCall):
# do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be
# visited, to ensure symbol dependencies are recognized correctly.
# do also not collect reduce nodes if they are left in the it at this point, this may lead to
# conceptual problems (other parts of the tool chain rely on the arguments being present directly
# on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend
# backend (single pass eager depth first visit approach)
if isinstance(node.fun, ir.SymRef) and node.fun.id in ["lift", "shift", "reduce"]:
if isinstance(node.fun, itir.SymRef) and node.fun.id in ["lift", "shift", "reduce"]:
return False
return True
elif isinstance(node, ir.Lambda):
elif isinstance(node, itir.Lambda):
return True

return False
Expand All @@ -87,7 +93,7 @@ class SubexpressionData:
class State:
#: A dictionary mapping a node to a list of node ids with equal hash and some additional
#: information. See `SubexpressionData` for more information.
subexprs: dict[ir.Node, "CollectSubexpressions.SubexpressionData"] = dataclasses.field(
subexprs: dict[itir.Node, CollectSubexpressions.SubexpressionData] = dataclasses.field(
default_factory=dict
)
# TODO(tehrengruber): Revisit if this makes sense or if we can just recompute the collected
Expand All @@ -97,7 +103,7 @@ class State:
#: The ids of all nodes declaring a symbol which are referenced (using a `SymRef`)
used_symbol_ids: set[int] = dataclasses.field(default_factory=set)

def remove_subexprs(self, nodes: typing.Iterable[ir.Node]) -> None:
def remove_subexprs(self, nodes: Iterable[itir.Node]) -> None:
node_ids_to_remove: set[int] = set()
for node in nodes:
subexpr_data = self.subexprs.pop(node, None)
Expand All @@ -109,22 +115,22 @@ def remove_subexprs(self, nodes: typing.Iterable[ir.Node]) -> None:
collected_child_node_ids -= node_ids_to_remove

@classmethod
def apply(cls, node: ir.Node) -> dict[ir.Node, list[tuple[int, set[int]]]]:
def apply(cls, node: itir.Node) -> dict[itir.Node, list[tuple[int, set[int]]]]:
state = cls.State()
obj = cls()
obj.visit(node, state=state, depth=-1)
# Return subexpression such that the nodes closer to the root come first and skip the root
# node itself.
subexprs_sorted: list[tuple[ir.Node, "CollectSubexpressions.SubexpressionData"]] = sorted(
subexprs_sorted: list[tuple[itir.Node, CollectSubexpressions.SubexpressionData]] = sorted(
state.subexprs.items(), key=lambda el: el[1].max_depth
)
return {k: v.subexprs for k, v in subexprs_sorted if k is not node}

def generic_visit(self, *args, **kwargs):
def generic_visit(self, node, **kwargs):
depth = kwargs.pop("depth")
return super().generic_visit(*args, depth=depth + 1, **kwargs)
return super().generic_visit(node, depth=depth + 1, **kwargs)

def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here.
def visit(self, node: itir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here.
if not isinstance(node, SymbolTableTrait) and not _is_collectable_expr(node):
return super().visit(node, **kwargs)

Expand All @@ -136,7 +142,7 @@ def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # su
# Special handling of `if_(condition, true_branch, false_branch)` like expressions that
# avoids extracting subexpressions unless they are used in either the condition or both
# branches.
if isinstance(node, ir.FunCall) and node.fun == ir.SymRef(id="if_"):
if isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="if_"):
assert len(node.args) == 3
# collect subexpressions for all arguments to the `if_`
arg_states = [self.State() for _ in node.args]
Expand All @@ -152,7 +158,7 @@ def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # su
arg_state.remove_subexprs(arg_state.subexprs.keys() - eligible_subexprs)

# merge the states of the three arguments
subexprs: dict[ir.Node, CollectSubexpressions.SubexpressionData] = {}
subexprs: dict[itir.Node, CollectSubexpressions.SubexpressionData] = {}
for state in arg_states:
for subexpr, data in state.subexprs.items():
merged_data = subexprs.setdefault(subexpr, self.SubexpressionData())
Expand Down Expand Up @@ -198,19 +204,19 @@ def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # su
parent_state.collected_child_node_ids.update(collected_child_node_ids)

def visit_SymRef(
self, node: ir.SymRef, *, symtable: dict[str, ir.Node], state: State, **kwargs
self, node: itir.SymRef, *, symtable: dict[str, itir.Node], state: State, **kwargs
) -> None:
if node.id in symtable: # root symbol otherwise
state.used_symbol_ids.add(id(symtable[node.id]))


def extract_subexpression(
node: ir.Expr,
predicate: typing.Callable[[ir.Expr, int], bool],
node: itir.Expr,
predicate: Callable[[itir.Expr, int], bool],
uid_generator: UIDGenerator,
once_only: bool = False,
deepest_expr_first: bool = False,
) -> tuple[ir.Expr, typing.Union[dict[ir.Sym, ir.Expr], None], bool]:
) -> tuple[itir.Expr, Union[dict[itir.Sym, itir.Expr], None], bool]:
"""
Given an expression extract all subexprs and return a new expr with the subexprs replaced.
Expand Down Expand Up @@ -307,20 +313,20 @@ def extract_subexpression(
)

ignored_children = False
extracted = dict[ir.Sym, ir.Expr]()
extracted = dict[itir.Sym, itir.Expr]()

# collect expressions
subexprs = CollectSubexpressions.apply(node)

# collect multiple occurrences and map them to fresh symbols
expr_map = dict[int, ir.SymRef]()
expr_map = dict[int, itir.SymRef]()
ignored_ids = set()
for expr, subexpr_entry in (
subexprs.items() if not deepest_expr_first else reversed(subexprs.items())
):
# just to make mypy happy when calling the predicate. Every subnode and hence subexpression
# is an expr anyway.
assert isinstance(expr, ir.Expr)
assert isinstance(expr, itir.Expr)

if not predicate(expr, len(subexpr_entry)):
continue
Expand All @@ -340,8 +346,8 @@ def extract_subexpression(
continue

expr_id = uid_generator.sequential_id()
extracted[ir.Sym(id=expr_id)] = expr
expr_ref = ir.SymRef(id=expr_id)
extracted[itir.Sym(id=expr_id)] = expr
expr_ref = itir.SymRef(id=expr_id)
for id_ in eligible_ids:
expr_map[id_] = expr_ref

Expand All @@ -354,17 +360,26 @@ def extract_subexpression(
return _NodeReplacer(expr_map).visit(node), extracted, ignored_children


ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.FencilDefinition | itir.Expr)


@dataclasses.dataclass(frozen=True)
class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator):
"""
Perform common subexpression elimination.
Examples:
>>> x = ir.SymRef(id="x")
>>> plus = lambda a, b: ir.FunCall(fun=ir.SymRef(id=("plus")), args=[a, b])
>>> x = itir.SymRef(id="x")
>>> plus = lambda a, b: itir.FunCall(fun=itir.SymRef(id=("plus")), args=[a, b])
>>> expr = plus(plus(x, x), plus(x, x))
>>> print(CommonSubexpressionElimination().visit(expr))
>>> print(CommonSubexpressionElimination.apply(expr, is_local_view=True))
(λ(_cs_1) → _cs_1 + _cs_1)(x + x)
The pass visits the tree top-down starting from the root node, e.g. an itir.Program.
For each node we extract (eligible) subexpressions occuring more than once using
:ref:`extract_subexpression`. In field-view context we only extract expression when they are
fields (or composites thereof), in local view everything is eligible. Since the visit is
top-down, extracted expressions always land up in the outermost scope they can appear in.
"""

# we use one UID generator per instance such that the generated ids are
Expand All @@ -375,23 +390,68 @@ class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator):

collect_all: bool = dataclasses.field(default=False)

def visit_FunCall(self, node: ir.FunCall):
if isinstance(node.fun, ir.SymRef) and node.fun.id in [
"cartesian_domain",
"unstructured_domain",
]:
return node
@classmethod
def apply(
cls,
node: ProgramOrExpr,
is_local_view: bool | None = None,
offset_provider: common.OffsetProvider | None = None,
) -> ProgramOrExpr:
is_program = isinstance(node, (itir.Program, itir.FencilDefinition))
if is_program:
assert is_local_view is None
is_local_view = False
else:
assert (
is_local_view is not None
), "The expression's context must be specified using `is_local_view`."

new_expr, extracted, ignored_children = extract_subexpression(
node, lambda subexpr, num_occurences: num_occurences > 1, self.uids
offset_provider = offset_provider or {}
node = itir_type_inference.infer(
node, offset_provider=offset_provider, allow_undeclared_symbols=not is_program
)
return cls().visit(node, is_local_view=is_local_view)

def generic_visit(self, node, **kwargs):
if cpm.is_call_to("as_fieldop", node):
assert not kwargs.get("is_local_view")
is_local_view = cpm.is_call_to("as_fieldop", node) or kwargs.get("is_local_view")

return super().generic_visit(node, **(kwargs | {"is_local_view": is_local_view}))

def visit_FunCall(self, node: itir.FunCall, **kwargs):
is_local_view = kwargs["is_local_view"]

if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")):
return node

def predicate(subexpr: itir.Expr, num_occurences: int):
# note: be careful here with the syntatic context: the expression might be in local
# view, even though the syntactic context `node` is in field view.
# note: what is extracted is sketched in the docstring above. keep it updated.
if num_occurences > 1:
if is_local_view:
return True
else:
# only extract fields outside of `as_fieldop`
# `as_fieldop(...)(field_expr, field_expr)`
# -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)`
assert isinstance(subexpr.type, ts.TypeSpec)
if all(
isinstance(stype, ts.FieldType)
for stype in type_info.primitive_constituents(subexpr.type)
):
return True
return False

new_expr, extracted, ignored_children = extract_subexpression(node, predicate, self.uids)

if not extracted:
return self.generic_visit(node)
return self.generic_visit(node, **kwargs)

# apply remapping
result = ir.FunCall(
fun=ir.Lambda(params=list(extracted.keys()), expr=new_expr),
result = itir.FunCall(
fun=itir.Lambda(params=list(extracted.keys()), expr=new_expr),
args=list(extracted.values()),
)

Expand All @@ -401,6 +461,6 @@ def visit_FunCall(self, node: ir.FunCall):
# inside of subexpressions directly. This would require a different order of replacement
# (from lower to higher level).
if ignored_children:
return self.visit(result)
return self.visit(result, **kwargs)

return self.generic_visit(result)
return self.generic_visit(result, **kwargs)
12 changes: 9 additions & 3 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from gt4py.eve import utils as eve_utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.transforms import fencil_to_program
from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet
from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple
from gt4py.next.iterator.transforms.constant_folding import ConstantFolding
Expand Down Expand Up @@ -78,7 +79,7 @@ def apply_common_transforms(
Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]]
] = None,
symbolic_domain_sizes: Optional[dict[str, str]] = None,
) -> itir.FencilDefinition | FencilWithTemporaries | itir.Program:
) -> itir.Program:
if isinstance(ir, itir.Program):
# TODO(havogt): during refactoring to GTIR, we bypass transformations in case we already translated to itir.Program
# (currently the case when using the roundtrip backend)
Expand Down Expand Up @@ -174,6 +175,11 @@ def apply_common_transforms(
ir = FuseMaps().visit(ir)
ir = CollapseListGet().visit(ir)

assert isinstance(ir, (itir.FencilDefinition, FencilWithTemporaries))
ir = fencil_to_program.FencilToProgram().apply(
ir
) # FIXME[#1582](havogt): should be removed after refactoring to combined IR

if unroll_reduce:
for _ in range(10):
unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider)
Expand All @@ -191,12 +197,12 @@ def apply_common_transforms(
ir = ScanEtaReduction().visit(ir)

if common_subexpression_elimination:
ir = CommonSubexpressionElimination().visit(ir)
ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) # type: ignore[type-var] # always an itir.Program
ir = MergeLet().visit(ir)

ir = InlineLambdas.apply(
ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args
)

assert isinstance(ir, (itir.FencilDefinition, FencilWithTemporaries))
assert isinstance(ir, itir.Program)
return ir
31 changes: 31 additions & 0 deletions src/gt4py/next/iterator/transforms/program_to_fencil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm


def program_to_fencil(node: itir.Program) -> itir.FencilDefinition:
assert not node.declarations
closures = []
for stmt in node.body:
assert isinstance(stmt, itir.SetAt)
assert isinstance(stmt.expr, itir.FunCall) and cpm.is_call_to(stmt.expr.fun, "as_fieldop")
stencil, domain = stmt.expr.fun.args
inputs = stmt.expr.args
assert all(isinstance(inp, itir.SymRef) for inp in inputs)
closures.append(
itir.StencilClosure(domain=domain, stencil=stencil, output=stmt.target, inputs=inputs)
)

return itir.FencilDefinition(
id=node.id,
function_definitions=node.function_definitions,
params=node.params,
closures=closures,
)
Loading

0 comments on commit 9328c50

Please sign in to comment.