Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[next]: new ITIR type inference #1531

Merged
merged 58 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
aa937da
Use type specification for itir.Literal
tehrengruber Apr 15, 2024
095ae13
Small cleanup
tehrengruber Apr 15, 2024
cace41a
Small fix
tehrengruber Apr 15, 2024
b22f929
Fix failing doctests
tehrengruber Apr 15, 2024
288ea3c
Address review comments
tehrengruber Apr 17, 2024
c695935
New type inference first draft
tehrengruber Apr 17, 2024
72b50ff
Merge origin/main
tehrengruber Apr 17, 2024
beabac2
Fix some tests
tehrengruber Apr 18, 2024
ed4bf27
Cleanup
tehrengruber Apr 18, 2024
1c94bf0
Use is_call_to instead of equality comparison with itir.Ref.
tehrengruber Apr 18, 2024
3799b7a
Cleanup
tehrengruber Apr 18, 2024
f75f0e4
Cleanup
tehrengruber Apr 18, 2024
779ba3d
Address reviewer comments
tehrengruber Apr 18, 2024
c7e1e93
Cleanup
tehrengruber Apr 18, 2024
172982f
Merge origin_tehrengruber/fix_im_ref_comp
tehrengruber Apr 18, 2024
79bee51
Cleanup
tehrengruber Apr 18, 2024
dd5e601
Cleanup
tehrengruber Apr 18, 2024
1b803b7
Merge origin_tehrengruber/fix_im_ref_comp
tehrengruber Apr 18, 2024
f41e032
Merge origin_tehrengruber/fix_im_ref_comp
tehrengruber Apr 18, 2024
520c805
Merge origin/main
tehrengruber Apr 18, 2024
129f6b3
Add test for neighbor / sparse input field
tehrengruber Apr 18, 2024
a99ca9a
Multiple fixes
tehrengruber Apr 19, 2024
896582e
Multiple fixes
tehrengruber Apr 19, 2024
c8fb2d9
Cleanup
tehrengruber Apr 19, 2024
cb0ebbb
Cleanup
tehrengruber Apr 19, 2024
873c60e
Fix dace
tehrengruber Apr 19, 2024
d42fa12
Formatting
tehrengruber Apr 19, 2024
8c632de
Formatting
tehrengruber Apr 19, 2024
1d043d1
Don't use temporaries in dace
tehrengruber Apr 19, 2024
b792a53
Try fixing test failures by removing types
tehrengruber Apr 22, 2024
f0680df
Fix pretty print
tehrengruber Apr 22, 2024
3603adb
Fix lowering tests
tehrengruber Apr 22, 2024
4517d2f
Fix missing fixture import in tests
tehrengruber Apr 22, 2024
dbf1001
Fix failing tests
tehrengruber Apr 22, 2024
6793d58
Fix format
tehrengruber Apr 22, 2024
cdf0a7f
Fix format
tehrengruber Apr 22, 2024
17644dc
Fix doctests
tehrengruber Apr 22, 2024
3b8303f
Fix cpp tests
tehrengruber Apr 22, 2024
0d361a0
Merge origin/main
tehrengruber Apr 22, 2024
9535e72
Add documentation
tehrengruber Apr 24, 2024
2340907
Merge origin/main
tehrengruber Apr 24, 2024
d9a772c
Remove debug code
tehrengruber Apr 25, 2024
5db67fc
Fix temporary codegen in roundtrip
tehrengruber Apr 25, 2024
2d28181
Formatting
tehrengruber Apr 25, 2024
09006e9
Cleanup
tehrengruber Apr 25, 2024
7ebb257
Add support for new itir.Program format
tehrengruber Apr 25, 2024
b23454d
Small fixes
tehrengruber May 2, 2024
953658a
Fix list_get
tehrengruber May 2, 2024
2dba9cf
Fix typing and SetAt type check
tehrengruber May 2, 2024
233c369
Fix program
tehrengruber May 2, 2024
f9dd78b
Fix program
tehrengruber May 2, 2024
18ab7a4
Fix gtfn_module test
tehrengruber May 2, 2024
5947934
Address review comments
tehrengruber May 7, 2024
fb6fac5
Cleanup
tehrengruber Jun 10, 2024
7971fff
Small fix
tehrengruber Jun 10, 2024
876d3df
Fix doctest
tehrengruber Jun 11, 2024
88c4f0d
Address reviewer comments
tehrengruber Jun 21, 2024
6bcc38a
Merge remote-tracking branch 'origin/main' into itir_type_inference
tehrengruber Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/gt4py/next/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
module in question is a submodule, defines `__all__` and exports many public API objects.
"""

from . import common, ffront, iterator, program_processors, type_inference
from . import common, ffront, iterator, program_processors
from .common import (
Dimension,
DimensionKind,
Expand Down Expand Up @@ -62,7 +62,6 @@
"ffront",
"iterator",
"program_processors",
"type_inference",
# from common
"Dimension",
"DimensionKind",
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,10 +831,10 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> foast.Call:
)

return_type = type_info.apply_to_primitive_constituents(
value.type,
lambda primitive_type: with_altered_scalar_kind(
primitive_type, getattr(ts.ScalarKind, new_type.id.upper())
),
value.type,
)
assert isinstance(return_type, (ts.TupleType, ts.ScalarType, ts.FieldType))

Expand Down
8 changes: 1 addition & 7 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator):
>>> lowered.id
SymbolName('fieldop')
>>> lowered.params # doctest: +ELLIPSIS
[Sym(id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))]
[Sym(id=SymbolName('inp'))]
"""

uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator)
Expand Down Expand Up @@ -228,12 +228,6 @@ def visit_Assign(
)

def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym:
# TODO(tehrengruber): extend to more types
if isinstance(node.type, ts.FieldType):
kind = "Iterator"
dtype = node.type.dtype.kind.name.lower()
is_list = type_info.is_local_field(node.type)
return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list))
return im.sym(node.id)

def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef:
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/ffront/lowering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def fun(primitive_type: ts.TypeSpec, path: tuple[int, ...]) -> itir.Expr:

return im.let(param, expr)(
type_info.apply_to_primitive_constituents(
arg_type, fun, with_path_arg=True, tuple_constructor=im.make_tuple
fun, arg_type, with_path_arg=True, tuple_constructor=im.make_tuple
)
)

Expand Down Expand Up @@ -96,7 +96,7 @@ def fun(_: Any, path: tuple[int, ...]) -> itir.FunCall:
lift_args.append(arg_expr)

stencil_expr = type_info.apply_to_primitive_constituents(
arg_type, fun, with_path_arg=True, tuple_constructor=im.make_tuple
fun, arg_type, with_path_arg=True, tuple_constructor=im.make_tuple
)
return im.let(param, expr)(im.lift(im.lambda_(*lift_params)(stencil_expr))(*lift_args))

Expand Down
25 changes: 13 additions & 12 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,14 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]:
)
assert all(field_dims == fields_dims[0] for field_dims in fields_dims)
for dim_idx in range(len(fields_dims[0])):
size_params.append(itir.Sym(id=_size_arg_from_field(param.id, dim_idx)))
size_params.append(
itir.Sym(
id=_size_arg_from_field(param.id, dim_idx),
type=ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
),
)
)

return size_params

Expand Down Expand Up @@ -390,11 +397,11 @@ def _compute_field_slice(node: past.Subscript) -> list[past.Slice]:
raise AssertionError(
"Unexpected 'out' argument, must be tuple of slices or slice expression."
)
node_dims_ls = cast(ts.FieldType, node.type).dims
assert isinstance(node_dims_ls, list)
if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims_ls):
node_dims = cast(ts.FieldType, node.type).dims
assert isinstance(node_dims, list)
if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims):
raise ValueError(
f"Too many indices for field '{out_field_name}': field is {len(node_dims_ls)}"
f"Too many indices for field '{out_field_name}': field is {len(node_dims)}"
f"-dimensional, but {len(out_field_slice_)} were indexed."
)
return out_field_slice_
Expand Down Expand Up @@ -466,13 +473,7 @@ def visit_Name(self, node: past.Name, **kwargs: Any) -> itir.SymRef:
return itir.SymRef(id=node.id)

def visit_Symbol(self, node: past.Symbol, **kwargs: Any) -> itir.Sym:
# TODO(tehrengruber): extend to more types
if isinstance(node.type, ts.FieldType):
kind = "Iterator"
dtype = node.type.dtype.kind.name.lower()
is_list = type_info.is_local_field(node.type)
return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list))
return itir.Sym(id=node.id)
return itir.Sym(id=node.id, type=node.type)

def visit_BinOp(self, node: past.BinOp, **kwargs: Any) -> itir.FunCall:
return itir.FunCall(
Expand Down
10 changes: 5 additions & 5 deletions src/gt4py/next/ffront/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def promote_el(type_el: ts.TypeSpec) -> ts.TypeSpec:
return ts.FieldType(dims=[], dtype=type_el)
return type_el

return type_info.apply_to_primitive_constituents(type_, promote_el)
return type_info.apply_to_primitive_constituents(promote_el, type_)


def promote_zero_dims(
Expand Down Expand Up @@ -69,11 +69,11 @@ def _as_field(arg_el: ts.TypeSpec, path: tuple[int, ...]) -> ts.TypeSpec:
raise ValueError(f"'{arg_el}' is not compatible with '{param_el}'.")
return arg_el

return type_info.apply_to_primitive_constituents(arg, _as_field, with_path_arg=True)
return type_info.apply_to_primitive_constituents(_as_field, arg, with_path_arg=True)

new_args = [*args]
for i, (param, arg) in enumerate(
zip(function_type.pos_only_args + list(function_type.pos_or_kw_args.values()), args)
zip(list(function_type.pos_only_args) + list(function_type.pos_or_kw_args.values()), args)
egparedes marked this conversation as resolved.
Show resolved Hide resolved
):
new_args[i] = promote_arg(param, arg)
new_kwargs = {**kwargs}
Expand Down Expand Up @@ -192,7 +192,7 @@ def _as_field(dtype: ts.TypeSpec, path: tuple[int, ...]) -> ts.FieldType:
# TODO: we want some generic field type here, but our type system does not support it yet.
return ts.FieldType(dims=[common.Dimension("...")], dtype=dtype)

res = type_info.apply_to_primitive_constituents(param, _as_field, with_path_arg=True)
res = type_info.apply_to_primitive_constituents(_as_field, param, with_path_arg=True)
assert isinstance(res, (ts.FieldType, ts.TupleType))
return res

Expand Down Expand Up @@ -309,5 +309,5 @@ def return_type_scanop(
[callable_type.axis],
)
return type_info.apply_to_primitive_constituents(
carry_dtype, lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg))
lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg)), carry_dtype
)
36 changes: 8 additions & 28 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

import typing
from typing import Any, ClassVar, List, Optional, Union
from typing import ClassVar, List, Optional, Union

import gt4py.eve as eve
from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels
Expand All @@ -32,6 +31,9 @@
class Node(eve.Node):
location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False)

# TODO(tehrengruber): include in comparison if value is not None
egparedes marked this conversation as resolved.
Show resolved Hide resolved
type: Optional[ts.TypeSpec] = eve.field(default=None, repr=False, compare=False)

def __str__(self) -> str:
from gt4py.next.iterator.pretty_printer import pformat

Expand All @@ -48,24 +50,6 @@ def __hash__(self) -> int:

class Sym(Node): # helper
id: Coerced[SymbolName]
# TODO(tehrengruber): Revisit. Using strings is a workaround to avoid coupling with the
# type inference.
kind: typing.Literal["Iterator", "Value", None] = None
dtype: Optional[tuple[str, bool]] = (
None # format: name of primitive type, boolean indicating if it is a list
)

@datamodels.validator("kind")
def _kind_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str):
if value and value not in ["Iterator", "Value"]:
raise ValueError(f"Invalid kind '{value}', must be one of 'Iterator', 'Value'.")

@datamodels.validator("dtype")
def _dtype_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str):
if value and value[0] not in TYPEBUILTINS:
raise ValueError(
f"Invalid dtype '{value}', must be one of '{', '.join(TYPEBUILTINS)}'."
)


@noninstantiable
Expand Down Expand Up @@ -177,18 +161,14 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib
FLOATING_POINT_BUILTINS = {"float32", "float64"}
TYPEBUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"}

GRAMMAR_BUILTINS = {
BUILTINS = {
"tuple_get",
"cast_",
"cartesian_domain",
"unstructured_domain",
"make_tuple",
"tuple_get",
"shift",
"neighbors",
"cast_",
}

BUILTINS = {
*GRAMMAR_BUILTINS,
"named_range",
"list_get",
"map_",
Expand Down Expand Up @@ -232,7 +212,7 @@ class SetAt(Stmt): # from JAX array.at[...].set()
class Temporary(Node):
id: Coerced[eve.SymbolName]
domain: Optional[Expr] = None
dtype: Optional[Any] = None # TODO
dtype: Optional[ts.ScalarType | ts.TupleType] = None


class Program(Node, ValidatedSymbolTableTrait):
Expand Down
16 changes: 9 additions & 7 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,25 @@
from gt4py.next.type_system import type_specifications as ts, type_translation


def sym(sym_or_name: Union[str, itir.Sym]) -> itir.Sym:
def sym(sym_or_name: Union[str, itir.Sym], type_=None) -> itir.Sym:
"""
Convert to Sym if necessary.

Examples
--------
>>> sym("a")
Sym(id=SymbolName('a'), kind=None, dtype=None)
Sym(id=SymbolName('a'))
egparedes marked this conversation as resolved.
Show resolved Hide resolved

>>> sym(itir.Sym(id="b"))
Sym(id=SymbolName('b'), kind=None, dtype=None)
Sym(id=SymbolName('b'))
"""
if isinstance(sym_or_name, itir.Sym):
assert not type_
return sym_or_name
return itir.Sym(id=sym_or_name)
return itir.Sym(id=sym_or_name, type=ensure_type(type_))


def ref(ref_or_name: Union[str, itir.SymRef]) -> itir.SymRef:
def ref(ref_or_name: Union[str, itir.SymRef], type_=None) -> itir.SymRef:
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
"""
Convert to SymRef if necessary.

Expand All @@ -50,8 +51,9 @@ def ref(ref_or_name: Union[str, itir.SymRef]) -> itir.SymRef:
SymRef(id=SymbolRef('b'))
"""
if isinstance(ref_or_name, itir.SymRef):
assert not type_
return ref_or_name
return itir.SymRef(id=ref_or_name)
return itir.SymRef(id=ref_or_name, type=ensure_type(type_))


def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> itir.Expr:
Expand Down Expand Up @@ -108,7 +110,7 @@ class lambda_:
Examples
--------
>>> lambda_("a")(deref("a")) # doctest: +ELLIPSIS
Lambda(params=[Sym(id=SymbolName('a'), kind=None, dtype=None)], expr=FunCall(fun=SymRef(id=SymbolRef('deref')), args=[SymRef(id=SymbolRef('a'))]))
Lambda(params=[Sym(id=SymbolName('a'))], expr=FunCall(fun=SymRef(id=SymbolRef('deref')), args=[SymRef(id=SymbolRef('a'))]))
"""

def __init__(self, *args):
Expand Down
9 changes: 8 additions & 1 deletion src/gt4py/next/iterator/pretty_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.type_system import type_specifications as ts


GRAMMAR = """
Expand All @@ -31,6 +32,7 @@

SYM: CNAME
SYM_REF: CNAME
TYPE_LITERAL: CNAME
INT_LITERAL: SIGNED_INT
FLOAT_LITERAL: SIGNED_FLOAT
OFFSET_LITERAL: ( INT_LITERAL | CNAME ) "ₒ"
Expand Down Expand Up @@ -84,7 +86,7 @@

named_range: AXIS_NAME ":" "[" prec0 "," prec0 ")"
function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";"
declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" prec0 ")" ";"
declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";"
stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";"
set_at: prec0 "@" prec0 "←" prec1 ";"
fencil_definition: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( stencil_closure )+ "}"
Expand All @@ -111,6 +113,11 @@ def INT_LITERAL(self, value: lark_lexer.Token) -> ir.Literal:
def FLOAT_LITERAL(self, value: lark_lexer.Token) -> ir.Literal:
return im.literal(value.value, "float64")

def TYPE_LITERAL(self, value: lark_lexer.Token) -> ts.TypeSpec:
if hasattr(ts.ScalarKind, value.upper()):
return ts.ScalarType(kind=getattr(ts.ScalarKind, value.upper()))
raise NotImplementedError(f"Type {value} not supported.")

def OFFSET_LITERAL(self, value: lark_lexer.Token) -> ir.OffsetLiteral:
v: Union[int, str] = value.value[:-1]
try:
Expand Down
38 changes: 12 additions & 26 deletions src/gt4py/next/iterator/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
SymRef,
)
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.type_system import type_info, type_specifications, type_translation
from gt4py.next.type_system import type_specifications as ts, type_translation


TRACING = "tracing"
Expand Down Expand Up @@ -252,7 +252,7 @@ def _contains_tuple_dtype_field(arg):
return isinstance(arg, common.Field) and any(dim is None for dim in arg.domain.dims)


def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]:
def _make_fencil_params(fun, args) -> list[Sym]:
params: list[Sym] = []
param_infos = list(inspect.signature(fun).parameters.values())

Expand All @@ -277,40 +277,26 @@ def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]:
"Only 'POSITIONAL_OR_KEYWORD' or 'VAR_POSITIONAL' parameters are supported."
)

kind, dtype = None, None
if use_arg_types:
# TODO(tehrengruber): Fields of tuples are not supported yet. Just ignore them for now.
if not _contains_tuple_dtype_field(arg):
arg_type = type_translation.from_value(arg)
# TODO(tehrengruber): Support more types.
if isinstance(arg_type, type_specifications.FieldType):
kind = "Iterator"
dtype = (
arg_type.dtype.kind.name.lower(), # actual dtype
type_info.is_local_field(arg_type), # is list
)

params.append(Sym(id=param_name, kind=kind, dtype=dtype))
arg_type = None
if isinstance(arg, ts.TypeSpec):
arg_type = arg
else:
arg_type = type_translation.from_value(arg)

params.append(Sym(id=param_name, type=arg_type))
return params


def trace_fencil_definition(
fun: typing.Callable, args: typing.Iterable, *, use_arg_types=True
) -> FencilDefinition:
def trace_fencil_definition(fun: typing.Callable, args: typing.Iterable) -> FencilDefinition:
"""
Transform fencil given as a callable into `itir.FencilDefinition` using tracing.

Arguments:
fun: The fencil / callable to trace.
args: A list of arguments, e.g. fields, scalars or composites thereof. If `use_arg_types`
is `False` may also be dummy values.

Keyword arguments:
use_arg_types: Deduce type of the arguments and add them to the fencil parameter nodes
(i.e. `itir.Sym`s).
args: A list of arguments, e.g. fields, scalars, composites thereof, or directly a type.
"""
with TracerContext() as _:
params = _make_fencil_params(fun, args, use_arg_types=use_arg_types)
params = _make_fencil_params(fun, args)
trace_function_call(fun, args=(_s(param.id) for param in params))

return FencilDefinition(
Expand Down
Loading
Loading