Skip to content

Commit

Permalink
refactor[next]: Fencil to itir.Program for gtfn (#1524)
Browse files Browse the repository at this point in the history
First PR for preparing itir for a combined field view + iterator
representation.

Adds new nodes:
- itir.Program (to replace itir.FencilDefinition)
- itir.SetAt (to replace itir.StencilClosure)
and a new builtin `as_fieldop`.

The semantic of `SetAt` is that the `expr` is directly computed into the
`target` field. `as_fieldop` aka map, takes an itir stencil (a function
from iterators to values) and promotes it to a field_operator, i.e. a
function from fields to fields.

The idea for `itir.Program` is to align and ultimately merge with
`past.Program`.

In this PR, the transition from Fencil to Program is implemented only
for GTFN with a separate pass before lowering to gtfn.

Additionally the pretty_printer/parser is extended with the new
nodes/builtin.

---------

Co-authored-by: Till Ehrengruber <[email protected]>
  • Loading branch information
havogt and tehrengruber authored Apr 16, 2024
1 parent d5d59d2 commit a603bfe
Show file tree
Hide file tree
Showing 16 changed files with 414 additions and 138 deletions.
39 changes: 38 additions & 1 deletion src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import typing
from typing import ClassVar, List, Optional, Union
from typing import Any, ClassVar, List, Optional, Union

import gt4py.eve as eve
from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels
Expand All @@ -22,6 +22,11 @@
from gt4py.eve.utils import noninstantiable


# TODO(havogt):
# After completion of refactoring to GTIR, FencilDefinition and StencilClosure should be removed everywhere.
# During transition, we lower to FencilDefinitions and apply a transformation to GTIR-style afterwards.


@noninstantiable
class Node(eve.Node):
location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False)
Expand Down Expand Up @@ -202,6 +207,13 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib
*TYPEBUILTINS,
}

# only used in `Program`` not `FencilDefinition`
# TODO(havogt): restructure after refactoring to GTIR
GTIR_BUILTINS = {
*BUILTINS,
"as_fieldop", # `as_fieldop(stencil)` creates field_operator from stencil
}


class FencilDefinition(Node, ValidatedSymbolTableTrait):
id: Coerced[SymbolName]
Expand All @@ -212,6 +224,31 @@ class FencilDefinition(Node, ValidatedSymbolTableTrait):
_NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS]


class Stmt(Node): ...


class SetAt(Stmt): # from JAX array.at[...].set()
expr: Expr # only `as_fieldop(stencil)(inp0, ...)` in first refactoring
domain: Expr
target: Expr # `make_tuple` or SymRef


class Temporary(Node):
id: Coerced[eve.SymbolName]
domain: Optional[Expr] = None
dtype: Optional[Any] = None # TODO


class Program(Node, ValidatedSymbolTableTrait):
id: Coerced[SymbolName]
function_definitions: List[FunctionDefinition]
params: List[Sym]
declarations: List[Temporary]
body: List[Stmt]

_NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in GTIR_BUILTINS]


# TODO(fthaler): just use hashable types in nodes (tuples instead of lists)
Sym.__hash__ = Node.__hash__ # type: ignore[method-assign]
Expr.__hash__ = Node.__hash__ # type: ignore[method-assign]
Expand Down
42 changes: 42 additions & 0 deletions src/gt4py/next/iterator/pretty_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
GRAMMAR = """
start: fencil_definition
| function_definition
| declaration
| stencil_closure
| set_at
| program
| prec0
SYM: CNAME
Expand Down Expand Up @@ -64,6 +67,7 @@
| "·" prec7 -> deref
| "¬" prec7 -> bool_not
| "↑" prec7 -> lift
| "⇑" prec7 -> as_fieldop
?prec8: prec9
| prec8 "[" prec0 "]" -> tuple_get
Expand All @@ -80,8 +84,11 @@
named_range: AXIS_NAME ":" "[" prec0 "," prec0 ")"
function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";"
declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" prec0 ")" ";"
stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";"
set_at: prec0 "@" prec0 "←" prec1 ";"
fencil_definition: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( stencil_closure )+ "}"
program: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( declaration )* ( set_at )+ "}"
%import common (CNAME, SIGNED_FLOAT, SIGNED_INT, WS)
%ignore WS
Expand Down Expand Up @@ -167,6 +174,9 @@ def deref(self, arg: ir.Expr) -> ir.FunCall:
def lift(self, arg: ir.Expr) -> ir.FunCall:
return ir.FunCall(fun=ir.SymRef(id="lift"), args=[arg])

def as_fieldop(self, arg: ir.Expr) -> ir.FunCall:
return ir.FunCall(fun=ir.SymRef(id="as_fieldop"), args=[arg])

def astype(self, arg: ir.Expr) -> ir.FunCall:
return ir.FunCall(fun=ir.SymRef(id="cast_"), args=[arg])

Expand Down Expand Up @@ -202,6 +212,15 @@ def stencil_closure(self, *args: ir.Expr) -> ir.StencilClosure:
output, stencil, *inputs, domain = args
return ir.StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs)

def declaration(self, *args: ir.Expr) -> ir.Temporary:
tid, domain, dtype = args
return ir.Temporary(id=tid, domain=domain, dtype=dtype)

def set_at(self, *args: ir.Expr) -> ir.SetAt:
target, domain, expr = args
return ir.SetAt(expr=expr, domain=domain, target=target)

# TODO(havogt): remove after refactoring.
def fencil_definition(self, fid: str, *args: ir.Node) -> ir.FencilDefinition:
params = []
function_definitions = []
Expand All @@ -218,6 +237,29 @@ def fencil_definition(self, fid: str, *args: ir.Node) -> ir.FencilDefinition:
id=fid, function_definitions=function_definitions, params=params, closures=closures
)

def program(self, fid: str, *args: ir.Node) -> ir.Program:
params = []
function_definitions = []
body = []
declarations = []
for arg in args:
if isinstance(arg, ir.Sym):
params.append(arg)
elif isinstance(arg, ir.FunctionDefinition):
function_definitions.append(arg)
elif isinstance(arg, ir.Temporary):
declarations.append(arg)
else:
assert isinstance(arg, ir.SetAt)
body.append(arg)
return ir.Program(
id=fid,
function_definitions=function_definitions,
params=params,
body=body,
declarations=declarations,
)

def start(self, arg: ir.Node) -> ir.Node:
return arg

Expand Down
57 changes: 56 additions & 1 deletion src/gt4py/next/iterator/pretty_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
}

# replacements for builtin unary operations
UNARY_OPS: Final = {"deref": "·", "lift": "↑", "not_": "¬"}
UNARY_OPS: Final = {"deref": "·", "lift": "↑", "not_": "¬", "as_fieldop": "⇑"}

# operator precedence
PRECEDENCE: Final = {
Expand All @@ -63,6 +63,7 @@
"deref": 7,
"not_": 7,
"lift": 7,
"as_fieldop": 7,
"tuple_get": 8,
"__call__": 8,
}
Expand Down Expand Up @@ -272,6 +273,35 @@ def visit_StencilClosure(self, node: ir.StencilClosure, *, prec: int) -> list[st
)
return self._optimum(h, v)

def visit_Temporary(self, node: ir.Temporary, *, prec: int) -> list[str]:
start, end = [node.id + " = temporary("], [");"]
args = []
if node.domain is not None:
args.append(self._hmerge(["domain="], self.visit(node.domain, prec=0)))
if node.dtype is not None:
args.append(self._hmerge(["dtype="], [str(node.dtype)]))
hargs = self._hmerge(*self._hinterleave(args, ", "))
vargs = self._vmerge(*self._hinterleave(args, ","))
oargs = self._optimum(hargs, vargs)
h = self._hmerge(start, oargs, end)
v = self._vmerge(start, self._indent(oargs), end)
return self._optimum(h, v)

def visit_SetAt(self, node: ir.SetAt, *, prec: int) -> list[str]:
expr = self.visit(node.expr, prec=0)
domain = self.visit(node.domain, prec=0)
target = self.visit(node.target, prec=0)

head = self._hmerge(target, [" @ "], domain)
foot = self._hmerge([" ← "], expr, [";"])

h = self._hmerge(head, foot)
v = self._vmerge(
head,
self._indent(self._indent(foot)),
)
return self._optimum(h, v)

def visit_FencilDefinition(self, node: ir.FencilDefinition, *, prec: int) -> list[str]:
assert prec == 0
function_definitions = self.visit(node.function_definitions, prec=0)
Expand All @@ -291,6 +321,31 @@ def visit_FencilDefinition(self, node: ir.FencilDefinition, *, prec: int) -> lis
params, self._indent(function_definitions), self._indent(closures), ["}"]
)

def visit_Program(self, node: ir.Program, *, prec: int) -> list[str]:
assert prec == 0
function_definitions = self.visit(node.function_definitions, prec=0)
body = self.visit(node.body, prec=0)
declarations = self.visit(node.declarations, prec=0)
params = self.visit(node.params, prec=0)

hparams = self._hmerge([node.id + "("], *self._hinterleave(params, ", "), [") {"])
vparams = self._vmerge(
[node.id + "("], *self._hinterleave(params, ",", indent=True), [") {"]
)
params = self._optimum(hparams, vparams)

function_definitions = self._vmerge(*function_definitions)
declarations = self._vmerge(*declarations)
body = self._vmerge(*body)

return self._vmerge(
params,
self._indent(function_definitions),
self._indent(declarations),
self._indent(body),
["}"],
)

@classmethod
def apply(cls, node: ir.Node, indent: int, width: int) -> str:
return "\n".join(cls(indent=indent, width=width).visit(node, prec=0))
Expand Down
46 changes: 46 additions & 0 deletions src/gt4py/next/iterator/transforms/fencil_to_program.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py import eve
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms import global_tmps


class FencilToProgram(eve.NodeTranslator):
@classmethod
def apply(cls, node: itir.FencilDefinition | global_tmps.FencilWithTemporaries) -> itir.Program:
return cls().visit(node)

def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt:
as_fieldop = im.call(im.call("as_fieldop")(node.stencil, node.domain))(*node.inputs)
return itir.SetAt(expr=as_fieldop, domain=node.domain, target=node.output)

def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program:
return itir.Program(
id=node.id,
function_definitions=node.function_definitions,
params=node.params,
declarations=[],
body=self.visit(node.closures),
)

def visit_FencilWithTemporaries(self, node: global_tmps.FencilWithTemporaries) -> itir.Program:
return itir.Program(
id=node.fencil.id,
function_definitions=node.fencil.function_definitions,
params=node.params,
declarations=node.tmps,
body=self.visit(node.fencil.closures),
)
38 changes: 8 additions & 30 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
from collections.abc import Mapping
from typing import Any, Callable, Final, Iterable, Literal, Optional, Sequence

import gt4py.eve as eve
import gt4py.next as gtx
from gt4py.eve import Coerced, NodeTranslator, PreserveLocationVisitor
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.eve.traits import SymbolTableTrait
from gt4py.eve.utils import UIDGenerator
from gt4py.next import common
Expand Down Expand Up @@ -54,40 +53,19 @@
# Iterator IR extension nodes


class Temporary(ir.Node):
"""Iterator IR extension: declaration of a temporary buffer."""

id: Coerced[eve.SymbolName]
domain: Optional[ir.Expr] = None
dtype: Optional[Any] = None


class FencilWithTemporaries(ir.Node, SymbolTableTrait):
class FencilWithTemporaries(
ir.Node, SymbolTableTrait
): # TODO(havogt): remove and use new `itir.Program` instead.
"""Iterator IR extension: declaration of a fencil with temporary buffers."""

fencil: ir.FencilDefinition
params: list[ir.Sym]
tmps: list[Temporary]
tmps: list[ir.Temporary]


# Extensions for `PrettyPrinter` for easier debugging


def pformat_Temporary(printer: PrettyPrinter, node: Temporary, *, prec: int) -> list[str]:
start, end = [node.id + " = temporary("], [");"]
args = []
if node.domain is not None:
args.append(printer._hmerge(["domain="], printer.visit(node.domain, prec=0)))
if node.dtype is not None:
args.append(printer._hmerge(["dtype="], [str(node.dtype)]))
hargs = printer._hmerge(*printer._hinterleave(args, ", "))
vargs = printer._vmerge(*printer._hinterleave(args, ","))
oargs = printer._optimum(hargs, vargs)
h = printer._hmerge(start, oargs, end)
v = printer._vmerge(start, printer._indent(oargs), end)
return printer._optimum(h, v)


def pformat_FencilWithTemporaries(
printer: PrettyPrinter, node: FencilWithTemporaries, *, prec: int
) -> list[str]:
Expand Down Expand Up @@ -117,7 +95,6 @@ def pformat_FencilWithTemporaries(
return printer._vmerge(params, printer._indent(body), ["}"])


PrettyPrinter.visit_Temporary = pformat_Temporary # type: ignore
PrettyPrinter.visit_FencilWithTemporaries = pformat_FencilWithTemporaries # type: ignore


Expand Down Expand Up @@ -367,7 +344,7 @@ def always_extract_heuristics(_):
location=node.location,
),
params=node.params,
tmps=[Temporary(id=tmp.id) for tmp in tmps],
tmps=[ir.Temporary(id=tmp.id) for tmp in tmps],
)


Expand Down Expand Up @@ -638,7 +615,8 @@ def convert_type(dtype):
fencil=node.fencil,
params=node.params,
tmps=[
Temporary(id=tmp.id, domain=domains[tmp.id], dtype=types[tmp.id]) for tmp in node.tmps
ir.Temporary(id=tmp.id, domain=domains[tmp.id], dtype=types[tmp.id])
for tmp in node.tmps
],
)

Expand Down
5 changes: 3 additions & 2 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination
from gt4py.next.iterator.transforms.eta_reduction import EtaReduction
from gt4py.next.iterator.transforms.fuse_maps import FuseMaps
from gt4py.next.iterator.transforms.global_tmps import CreateGlobalTmps
from gt4py.next.iterator.transforms.global_tmps import CreateGlobalTmps, FencilWithTemporaries
from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars
from gt4py.next.iterator.transforms.inline_fundefs import InlineFundefs, PruneUnreferencedFundefs
from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan
Expand Down Expand Up @@ -88,7 +88,7 @@ def apply_common_transforms(
Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]]
] = None,
symbolic_domain_sizes: Optional[dict[str, str]] = None,
):
) -> itir.FencilDefinition | FencilWithTemporaries:
icdlv_uids = eve_utils.UIDGenerator()

if lift_mode is None:
Expand Down Expand Up @@ -203,4 +203,5 @@ def apply_common_transforms(
ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args
)

assert isinstance(ir, (itir.FencilDefinition, FencilWithTemporaries))
return ir
Loading

0 comments on commit a603bfe

Please sign in to comment.