Skip to content

Commit

Permalink
refactor[next]: new ITIR type inference (#1531)
Browse files Browse the repository at this point in the history
New type inference algorithm on ITIR unifying the type system with the one used in the frontend. Types are stored directly in the ITIR nodes. This replaces the constraint based type inference giving significant performance and usability improvements. Types of builtins are expressing using simple to write `TypeSynthesizer` of the form:
```python
@_register_builtin_type_synthesizer
def power(base: ts.ScalarType, exponent: ts.ScalarType) -> ts.ScalarType:
    return base
```
  • Loading branch information
tehrengruber committed Jun 27, 2024
1 parent b8f7f72 commit 3dfbf3f
Show file tree
Hide file tree
Showing 45 changed files with 1,915 additions and 2,999 deletions.
3 changes: 1 addition & 2 deletions src/gt4py/next/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
module in question is a submodule, defines `__all__` and exports many public API objects.
"""

from . import common, ffront, iterator, program_processors, type_inference
from . import common, ffront, iterator, program_processors
from .common import (
Dimension,
DimensionKind,
Expand Down Expand Up @@ -62,7 +62,6 @@
"ffront",
"iterator",
"program_processors",
"type_inference",
# from common
"Dimension",
"DimensionKind",
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,10 +834,10 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> foast.Call:
)

return_type = type_info.apply_to_primitive_constituents(
value.type,
lambda primitive_type: with_altered_scalar_kind(
primitive_type, getattr(ts.ScalarKind, new_type.id.upper())
),
value.type,
)
assert isinstance(return_type, (ts.TupleType, ts.ScalarType, ts.FieldType))

Expand Down
8 changes: 1 addition & 7 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator):
>>> lowered.id
SymbolName('fieldop')
>>> lowered.params # doctest: +ELLIPSIS
[Sym(id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))]
[Sym(id=SymbolName('inp'))]
"""

uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator)
Expand Down Expand Up @@ -233,12 +233,6 @@ def visit_Assign(
)

def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym:
# TODO(tehrengruber): extend to more types
if isinstance(node.type, ts.FieldType):
kind = "Iterator"
dtype = node.type.dtype.kind.name.lower()
is_list = type_info.is_local_field(node.type)
return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list))
return im.sym(node.id)

def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef:
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/ffront/lowering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def fun(primitive_type: ts.TypeSpec, path: tuple[int, ...]) -> itir.Expr:

return im.let(param, expr)(
type_info.apply_to_primitive_constituents(
arg_type, fun, with_path_arg=True, tuple_constructor=im.make_tuple
fun, arg_type, with_path_arg=True, tuple_constructor=im.make_tuple
)
)

Expand Down Expand Up @@ -96,7 +96,7 @@ def fun(_: Any, path: tuple[int, ...]) -> itir.FunCall:
lift_args.append(arg_expr)

stencil_expr = type_info.apply_to_primitive_constituents(
arg_type, fun, with_path_arg=True, tuple_constructor=im.make_tuple
fun, arg_type, with_path_arg=True, tuple_constructor=im.make_tuple
)
return im.let(param, expr)(im.lift(im.lambda_(*lift_params)(stencil_expr))(*lift_args))

Expand Down
25 changes: 13 additions & 12 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,14 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]:
)
assert all(field_dims == fields_dims[0] for field_dims in fields_dims)
for dim_idx in range(len(fields_dims[0])):
size_params.append(itir.Sym(id=_size_arg_from_field(param.id, dim_idx)))
size_params.append(
itir.Sym(
id=_size_arg_from_field(param.id, dim_idx),
type=ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
),
)
)

return size_params

Expand Down Expand Up @@ -390,11 +397,11 @@ def _compute_field_slice(node: past.Subscript) -> list[past.Slice]:
raise AssertionError(
"Unexpected 'out' argument, must be tuple of slices or slice expression."
)
node_dims_ls = cast(ts.FieldType, node.type).dims
assert isinstance(node_dims_ls, list)
if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims_ls):
node_dims = cast(ts.FieldType, node.type).dims
assert isinstance(node_dims, list)
if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims):
raise ValueError(
f"Too many indices for field '{out_field_name}': field is {len(node_dims_ls)}"
f"Too many indices for field '{out_field_name}': field is {len(node_dims)}"
f"-dimensional, but {len(out_field_slice_)} were indexed."
)
return out_field_slice_
Expand Down Expand Up @@ -466,13 +473,7 @@ def visit_Name(self, node: past.Name, **kwargs: Any) -> itir.SymRef:
return itir.SymRef(id=node.id)

def visit_Symbol(self, node: past.Symbol, **kwargs: Any) -> itir.Sym:
# TODO(tehrengruber): extend to more types
if isinstance(node.type, ts.FieldType):
kind = "Iterator"
dtype = node.type.dtype.kind.name.lower()
is_list = type_info.is_local_field(node.type)
return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list))
return itir.Sym(id=node.id)
return itir.Sym(id=node.id, type=node.type)

def visit_BinOp(self, node: past.BinOp, **kwargs: Any) -> itir.FunCall:
return itir.FunCall(
Expand Down
14 changes: 9 additions & 5 deletions src/gt4py/next/ffront/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def promote_el(type_el: ts.TypeSpec) -> ts.TypeSpec:
return ts.FieldType(dims=[], dtype=type_el)
return type_el

return type_info.apply_to_primitive_constituents(type_, promote_el)
return type_info.apply_to_primitive_constituents(promote_el, type_)


def promote_zero_dims(
Expand Down Expand Up @@ -69,11 +69,15 @@ def _as_field(arg_el: ts.TypeSpec, path: tuple[int, ...]) -> ts.TypeSpec:
raise ValueError(f"'{arg_el}' is not compatible with '{param_el}'.")
return arg_el

return type_info.apply_to_primitive_constituents(arg, _as_field, with_path_arg=True)
return type_info.apply_to_primitive_constituents(_as_field, arg, with_path_arg=True)

new_args = [*args]
for i, (param, arg) in enumerate(
zip(function_type.pos_only_args + list(function_type.pos_or_kw_args.values()), args)
zip(
list(function_type.pos_only_args) + list(function_type.pos_or_kw_args.values()),
args,
strict=True,
)
):
new_args[i] = promote_arg(param, arg)
new_kwargs = {**kwargs}
Expand Down Expand Up @@ -192,7 +196,7 @@ def _as_field(dtype: ts.TypeSpec, path: tuple[int, ...]) -> ts.FieldType:
# TODO: we want some generic field type here, but our type system does not support it yet.
return ts.FieldType(dims=[common.Dimension("...")], dtype=dtype)

res = type_info.apply_to_primitive_constituents(param, _as_field, with_path_arg=True)
res = type_info.apply_to_primitive_constituents(_as_field, param, with_path_arg=True)
assert isinstance(res, (ts.FieldType, ts.TupleType))
return res

Expand Down Expand Up @@ -309,5 +313,5 @@ def return_type_scanop(
[callable_type.axis],
)
return type_info.apply_to_primitive_constituents(
carry_dtype, lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg))
lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg)), carry_dtype
)
36 changes: 8 additions & 28 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

import typing
from typing import Any, ClassVar, List, Optional, Union
from typing import ClassVar, List, Optional, Union

import gt4py.eve as eve
from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels
Expand All @@ -32,6 +31,9 @@
class Node(eve.Node):
location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False)

# TODO(tehrengruber): include in comparison if value is not None
type: Optional[ts.TypeSpec] = eve.field(default=None, repr=False, compare=False)

def __str__(self) -> str:
from gt4py.next.iterator.pretty_printer import pformat

Expand All @@ -48,24 +50,6 @@ def __hash__(self) -> int:

class Sym(Node): # helper
id: Coerced[SymbolName]
# TODO(tehrengruber): Revisit. Using strings is a workaround to avoid coupling with the
# type inference.
kind: typing.Literal["Iterator", "Value", None] = None
dtype: Optional[tuple[str, bool]] = (
None # format: name of primitive type, boolean indicating if it is a list
)

@datamodels.validator("kind")
def _kind_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str):
if value and value not in ["Iterator", "Value"]:
raise ValueError(f"Invalid kind '{value}', must be one of 'Iterator', 'Value'.")

@datamodels.validator("dtype")
def _dtype_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str):
if value and value[0] not in TYPEBUILTINS:
raise ValueError(
f"Invalid dtype '{value}', must be one of '{', '.join(TYPEBUILTINS)}'."
)


@noninstantiable
Expand Down Expand Up @@ -177,18 +161,14 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib
FLOATING_POINT_BUILTINS = {"float32", "float64"}
TYPEBUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"}

GRAMMAR_BUILTINS = {
BUILTINS = {
"tuple_get",
"cast_",
"cartesian_domain",
"unstructured_domain",
"make_tuple",
"tuple_get",
"shift",
"neighbors",
"cast_",
}

BUILTINS = {
*GRAMMAR_BUILTINS,
"named_range",
"list_get",
"map_",
Expand Down Expand Up @@ -232,7 +212,7 @@ class SetAt(Stmt): # from JAX array.at[...].set()
class Temporary(Node):
id: Coerced[eve.SymbolName]
domain: Optional[Expr] = None
dtype: Optional[Any] = None # TODO
dtype: Optional[ts.ScalarType | ts.TupleType] = None


class Program(Node, ValidatedSymbolTableTrait):
Expand Down
26 changes: 19 additions & 7 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,31 @@
from gt4py.next.type_system import type_specifications as ts, type_translation


def sym(sym_or_name: Union[str, itir.Sym]) -> itir.Sym:
def sym(sym_or_name: Union[str, itir.Sym], type_: str | ts.TypeSpec | None = None) -> itir.Sym:
"""
Convert to Sym if necessary.
Examples
--------
>>> sym("a")
Sym(id=SymbolName('a'), kind=None, dtype=None)
Sym(id=SymbolName('a'))
>>> sym(itir.Sym(id="b"))
Sym(id=SymbolName('b'), kind=None, dtype=None)
Sym(id=SymbolName('b'))
>>> a = sym("a", "float32")
>>> a.id, a.type
(SymbolName('a'), ScalarType(kind=<ScalarKind.FLOAT32: 1032>, shape=None))
"""
if isinstance(sym_or_name, itir.Sym):
assert not type_
return sym_or_name
return itir.Sym(id=sym_or_name)
return itir.Sym(id=sym_or_name, type=ensure_type(type_))


def ref(ref_or_name: Union[str, itir.SymRef]) -> itir.SymRef:
def ref(
ref_or_name: Union[str, itir.SymRef], type_: str | ts.TypeSpec | None = None
) -> itir.SymRef:
"""
Convert to SymRef if necessary.
Expand All @@ -48,10 +55,15 @@ def ref(ref_or_name: Union[str, itir.SymRef]) -> itir.SymRef:
>>> ref(itir.SymRef(id="b"))
SymRef(id=SymbolRef('b'))
>>> a = ref("a", "float32")
>>> a.id, a.type
(SymbolRef('a'), ScalarType(kind=<ScalarKind.FLOAT32: 1032>, shape=None))
"""
if isinstance(ref_or_name, itir.SymRef):
assert not type_
return ref_or_name
return itir.SymRef(id=ref_or_name)
return itir.SymRef(id=ref_or_name, type=ensure_type(type_))


def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> itir.Expr:
Expand Down Expand Up @@ -108,7 +120,7 @@ class lambda_:
Examples
--------
>>> lambda_("a")(deref("a")) # doctest: +ELLIPSIS
Lambda(params=[Sym(id=SymbolName('a'), kind=None, dtype=None)], expr=FunCall(fun=SymRef(id=SymbolRef('deref')), args=[SymRef(id=SymbolRef('a'))]))
Lambda(params=[Sym(id=SymbolName('a'))], expr=FunCall(fun=SymRef(id=SymbolRef('deref')), args=[SymRef(id=SymbolRef('a'))]))
"""

def __init__(self, *args):
Expand Down
9 changes: 8 additions & 1 deletion src/gt4py/next/iterator/pretty_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.type_system import type_specifications as ts


GRAMMAR = """
Expand All @@ -31,6 +32,7 @@
SYM: CNAME
SYM_REF: CNAME
TYPE_LITERAL: CNAME
INT_LITERAL: SIGNED_INT
FLOAT_LITERAL: SIGNED_FLOAT
OFFSET_LITERAL: ( INT_LITERAL | CNAME ) "ₒ"
Expand Down Expand Up @@ -84,7 +86,7 @@
named_range: AXIS_NAME ":" "[" prec0 "," prec0 ")"
function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";"
declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" prec0 ")" ";"
declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";"
stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";"
set_at: prec0 "@" prec0 "←" prec1 ";"
fencil_definition: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( stencil_closure )+ "}"
Expand All @@ -111,6 +113,11 @@ def INT_LITERAL(self, value: lark_lexer.Token) -> ir.Literal:
def FLOAT_LITERAL(self, value: lark_lexer.Token) -> ir.Literal:
return im.literal(value.value, "float64")

def TYPE_LITERAL(self, value: lark_lexer.Token) -> ts.TypeSpec:
if hasattr(ts.ScalarKind, value.upper()):
return ts.ScalarType(kind=getattr(ts.ScalarKind, value.upper()))
raise NotImplementedError(f"Type {value} not supported.")

def OFFSET_LITERAL(self, value: lark_lexer.Token) -> ir.OffsetLiteral:
v: Union[int, str] = value.value[:-1]
try:
Expand Down
Loading

0 comments on commit 3dfbf3f

Please sign in to comment.