From aa937da9a33d98f7c42d178e31be5ad8b25326a5 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 15 Apr 2024 07:43:12 +0200 Subject: [PATCH 01/52] Use type specification for itir.Literal --- src/gt4py/next/ffront/past_to_itir.py | 4 +- src/gt4py/next/iterator/ir.py | 8 +--- src/gt4py/next/iterator/ir_utils/ir_makers.py | 10 ++++- src/gt4py/next/iterator/pretty_parser.py | 4 +- .../iterator/transforms/collapse_tuple.py | 3 +- .../next/iterator/transforms/inline_lifts.py | 2 +- src/gt4py/next/iterator/type_inference.py | 8 ++-- .../codegens/gtfn/itir_to_gtfn_ir.py | 7 ++-- .../runners/dace_iterator/itir_to_sdfg.py | 4 +- src/gt4py/next/type_system/type_info.py | 21 +++++++++- .../ffront_tests/test_past_to_itir.py | 19 +++++++-- .../iterator_tests/test_pretty_parser.py | 11 ++--- .../iterator_tests/test_pretty_printer.py | 15 +++---- .../iterator_tests/test_type_inference.py | 40 +++++++++---------- .../test_collapse_list_get.py | 11 +++-- .../transforms_tests/test_global_tmps.py | 16 ++++---- .../transforms_tests/test_inline_into_scan.py | 5 ++- .../test_scan_eta_reduction.py | 5 ++- .../test_simple_inline_heuristic.py | 5 ++- .../transforms_tests/test_unroll_reduce.py | 9 +++-- .../gtfn_tests/test_gtfn_module.py | 7 ++-- 21 files changed, 126 insertions(+), 88 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index a7e9751c4e..fb5c1a6882 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -328,7 +328,7 @@ def _construct_itir_domain_arg( else: lower = self._visit_slice_bound( slices[dim_i].lower if slices else None, - itir.Literal(value="0", type=itir.INTEGER_INDEX_BUILTIN), + im.literal("0", itir.INTEGER_INDEX_BUILTIN), dim_size, ) upper = self._visit_slice_bound( @@ -458,7 +458,7 @@ def visit_Constant(self, node: past.Constant, **kwargs: Any) -> itir.Literal: f"Scalars of kind '{node.type.kind}' not supported currently." ) typename = node.type.kind.name.lower() - return itir.Literal(value=str(node.value), type=typename) + return im.literal(str(node.value), typename) raise NotImplementedError("Only scalar literals supported currently.") diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 56f931f451..5fa259ec64 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -20,6 +20,7 @@ from gt4py.eve.concepts import SourceLocation from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.eve.utils import noninstantiable +from gt4py.next.type_system import type_specifications as ts @noninstantiable @@ -68,12 +69,7 @@ class Expr(Node): ... class Literal(Expr): value: str - type: str - - @datamodels.validator("type") - def _type_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): - if value not in TYPEBUILTINS: - raise ValueError(f"'{value}' is not a valid builtin type.") + type: ts.ScalarType class NoneLiteral(Expr): diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 8e505be0ec..83ca3322d0 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -20,6 +20,12 @@ from gt4py.next.type_system import type_specifications as ts, type_translation +def ensure_type(type_: str | ts.TypeSpec | None): + if isinstance(type_, str): + return ts.ScalarType(kind=getattr(ts.ScalarKind, type_.upper())) + return type_ + + def sym(sym_or_name: Union[str, itir.Sym]) -> itir.Sym: """ Convert to Sym if necessary. @@ -292,7 +298,7 @@ def shift(offset, value=None): def literal(value: str, typename: str): - return itir.Literal(value=value, type=typename) + return itir.Literal(value=value, type=ensure_type(typename)) def literal_from_value(val: core_defs.Scalar) -> itir.Literal: @@ -321,7 +327,7 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: typename = type_spec.kind.name.lower() assert typename in itir.TYPEBUILTINS - return itir.Literal(value=str(val), type=typename) + return literal(str(val), typename) def neighbors(offset, it): diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index 9dd96b076e..d460e2b4bf 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -95,14 +95,14 @@ def SYM(self, value: lark_lexer.Token) -> ir.Sym: def SYM_REF(self, value: lark_lexer.Token) -> Union[ir.SymRef, ir.Literal]: if value.value in ("True", "False"): - return ir.Literal(value=value.value, type="bool") + return im.literal(value.value, "bool") return ir.SymRef(id=value.value) def INT_LITERAL(self, value: lark_lexer.Token) -> ir.Literal: return im.literal_from_value(int(value.value)) def FLOAT_LITERAL(self, value: lark_lexer.Token) -> ir.Literal: - return ir.Literal(value=value.value, type="float64") + return im.literal(value.value, "float64") def OFFSET_LITERAL(self, value: lark_lexer.Token) -> ir.OffsetLiteral: v: Union[int, str] = value.value[:-1] diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 2bc33e85e1..4b8182a781 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -26,6 +26,7 @@ from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_if_call, is_let from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda +from gt4py.next.type_system import type_info class UnknownLength: @@ -232,7 +233,7 @@ def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[ and isinstance(node.args[0], ir.Literal) ): # `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` - assert node.args[0].type in ir.INTEGER_BUILTINS + assert type_info.is_integer(node.args[0].type) make_tuple_call = node.args[1] idx = int(node.args[0].value) assert idx < len( diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index bf56186253..74ef37fa0c 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -201,7 +201,7 @@ def visit_FunCall( assert len(node.args[0].fun.args) == 1 args = node.args[0].args if len(args) == 0: - return ir.Literal(value="True", type="bool") + return im.literal_from_value(True) res = ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[args[0]]) for arg in args[1:]: diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 1aae474c4c..89fed49551 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -23,6 +23,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.global_tmps import FencilWithTemporaries from gt4py.next.type_inference import Type, TypeVar, freshen, reindex_vars, unify +from gt4py.next.type_system import type_info """Constraint-based inference for the iterator IR.""" @@ -643,7 +644,7 @@ def visit_SymRef(self, node: ir.SymRef, *, symtable, **kwargs) -> Type: return TypeVar.fresh() def visit_Literal(self, node: ir.Literal, **kwargs) -> Val: - return Val(kind=Value(), dtype=Primitive(name=node.type)) + return Val(kind=Value(), dtype=Primitive(name=node.type.kind.name.lower())) def visit_AxisLiteral(self, node: ir.AxisLiteral, **kwargs) -> Val: return Val(kind=Value(), dtype=AXIS_DTYPE, size=Scalar()) @@ -672,10 +673,7 @@ def _visit_tuple_get(self, node: ir.FunCall, **kwargs) -> Type: # Calls to `tuple_get` are handled as being part of the grammar, not as function calls. if len(node.args) != 2: raise TypeError("'tuple_get' requires exactly two arguments.") - if ( - not isinstance(node.args[0], ir.Literal) - or node.args[0].type != ir.INTEGER_INDEX_BUILTIN - ): + if not isinstance(node.args[0], ir.Literal) or not type_info.is_integer(node.args[0].type): raise TypeError( f"The first argument to 'tuple_get' must be a literal of type '{ir.INTEGER_INDEX_BUILTIN}'." ) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 4617e54eae..7ec980aad1 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -46,6 +46,7 @@ UnstructuredDomain, ) from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Node, Sym, SymRef +from gt4py.next.type_system import type_info def pytype_to_cpptype(t: str) -> Optional[str]: @@ -184,7 +185,7 @@ def _collect_offset_definitions( def _literal_as_integral_constant(node: itir.Literal) -> IntegralConstant: - assert node.type in itir.INTEGER_BUILTINS + assert type_info.is_integer(node.type) return IntegralConstant(value=int(node.value)) @@ -194,7 +195,7 @@ def _is_scan(node: itir.Node) -> TypeGuard[itir.FunCall]: def _bool_from_literal(node: itir.Node) -> bool: assert isinstance(node, itir.Literal) - assert node.type == "bool" and node.value in ("True", "False") + assert type_info.is_logical(node.type) and node.value in ("True", "False") return node.value == "True" @@ -293,7 +294,7 @@ def visit_Lambda( ) def visit_Literal(self, node: itir.Literal, **kwargs: Any) -> Literal: - return Literal(value=node.value, type=node.type) + return Literal(value=node.value, type=node.type.kind.name.lower()) def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs: Any) -> OffsetLiteral: return OffsetLiteral(value=node.value) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 39a5440eaf..566b4b5a20 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -26,7 +26,7 @@ type_inference as itir_typing, ) from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef -from gt4py.next.type_system import type_specifications as ts, type_translation as tt +from gt4py.next.type_system import type_info, type_specifications as ts, type_translation as tt from .itir_to_tasklet import ( Context, @@ -68,7 +68,7 @@ def _get_scan_args(stencil: Expr) -> tuple[bool, Literal]: """ stencil_fobj = cast(FunCall, stencil) is_forward = stencil_fobj.args[1] - assert isinstance(is_forward, Literal) and is_forward.type == "bool" + assert isinstance(is_forward, Literal) and type_info.is_logical(is_forward.type) init_carry = stencil_fobj.args[2] assert isinstance(init_carry, Literal) return is_forward.value == "True", init_carry diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index b235e6f26d..cc9523b10d 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -223,6 +223,25 @@ def is_floating_point(symbol_type: ts.TypeSpec) -> bool: return extract_dtype(symbol_type).kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] +def is_integer(symbol_type: ts.TypeSpec) -> bool: + """ + Check if ``symbol_type`` is an integral type. + + Examples: + --------- + >>> is_integer(ts.ScalarType(kind=ts.ScalarKind.INT32)) + True + >>> is_integer(ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) + False + >>> is_integer(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))) + False + """ + return isinstance(symbol_type, ts.ScalarType) and symbol_type.kind in [ + ts.ScalarKind.INT32, + ts.ScalarKind.INT64, + ] + + def is_integral(symbol_type: ts.TypeSpec) -> bool: """ Check if the dtype of ``symbol_type`` is an integral type. @@ -236,7 +255,7 @@ def is_integral(symbol_type: ts.TypeSpec) -> bool: >>> is_integral(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))) True """ - return extract_dtype(symbol_type).kind in [ts.ScalarKind.INT32, ts.ScalarKind.INT64] + return is_integer(extract_dtype(symbol_type)) def is_number(symbol_type: ts.TypeSpec) -> bool: diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index 49c5b11b20..3d296b6377 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -23,6 +23,7 @@ from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.iterator import ir as itir +from gt4py.next.type_system import type_specifications as ts from next_tests.past_common_fixtures import ( IDim, @@ -59,7 +60,7 @@ def test_copy_lowering(copy_program_def, itir_identity_fundef): fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), args=[ P(itir.AxisLiteral, value="IDim"), - P(itir.Literal, value="0", type="int32"), + P(itir.Literal, value="0", type=ts.ScalarType(kind=ts.ScalarKind.INT32)), P(itir.SymRef, id=eve.SymbolRef("__out_size_0")), ], ) @@ -118,8 +119,20 @@ def test_copy_restrict_lowering(copy_restrict_program_def, itir_identity_fundef) fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), args=[ P(itir.AxisLiteral, value="IDim"), - P(itir.Literal, value="1", type=itir.INTEGER_INDEX_BUILTIN), - P(itir.Literal, value="2", type=itir.INTEGER_INDEX_BUILTIN), + P( + itir.Literal, + value="1", + type=ts.ScalarType( + kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + ), + ), + P( + itir.Literal, + value="2", + type=ts.ScalarType( + kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + ), + ), ], ) ], diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index b3a9ba8001..d753b13fc0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -14,6 +14,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.pretty_parser import pparse +from gt4py.next.iterator.ir_utils import ir_makers as im def test_symref(): @@ -41,14 +42,14 @@ def test_arithmetic(): ir.FunCall( fun=ir.SymRef(id="plus"), args=[ - ir.Literal(value="1", type="int32"), - ir.Literal(value="2", type="int32"), + im.literal("1", "int32"), + im.literal("2", "int32"), ], ), - ir.Literal(value="3", type="int32"), + im.literal("3", "int32"), ], ), - ir.Literal(value="4", type="int32"), + im.literal("4", "int32"), ], ) actual = pparse(testee) @@ -108,7 +109,7 @@ def test_tuple_get(): testee = "x[42]" expected = ir.FunCall( fun=ir.SymRef(id="tuple_get"), - args=[ir.Literal(value="42", type=ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], + args=[im.literal("42", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], ) actual = pparse(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 844c905e8e..70f365a56a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -14,6 +14,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.pretty_printer import PrettyPrinter, pformat +from gt4py.next.iterator.ir_utils import ir_makers as im def test_hmerge(): @@ -111,14 +112,14 @@ def test_arithmetic(): ir.FunCall( fun=ir.SymRef(id="plus"), args=[ - ir.Literal(value="1", type="int64"), - ir.Literal(value="2", type="int64"), + im.literal("1", "int64"), + im.literal("2", "int64"), ], ), - ir.Literal(value="3", type="int64"), + im.literal("3", "int64"), ], ), - ir.Literal(value="4", type="int64"), + im.literal("4", "int64"), ], ) expected = "(1 + 2) × 3 / 4" @@ -132,11 +133,11 @@ def test_associativity(): args=[ ir.FunCall( fun=ir.SymRef(id="plus"), - args=[ir.Literal(value="1", type="int64"), ir.Literal(value="2", type="int64")], + args=[im.literal("1", "int64"), im.literal("2", "int64")], ), ir.FunCall( fun=ir.SymRef(id="plus"), - args=[ir.Literal(value="3", type="int64"), ir.Literal(value="4", type="int64")], + args=[im.literal("3", "int64"), im.literal("4", "int64")], ), ], ) @@ -197,7 +198,7 @@ def test_shift(): def test_tuple_get(): testee = ir.FunCall( fun=ir.SymRef(id="tuple_get"), - args=[ir.Literal(value="42", type=ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], + args=[im.literal("42", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], ) expected = "x[42]" actual = pformat(testee) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 7beda20d31..731c163343 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -80,7 +80,7 @@ def test_sym_ref(): def test_bool_literal(): - testee = ir.Literal(value="False", type="bool") + testee = im.literal_from_value(False) expected = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="bool"), size=ti.TypeVar(idx=0)) inferred = ti.infer(testee) assert inferred == expected @@ -88,7 +88,7 @@ def test_bool_literal(): def test_int_literal(): - testee = ir.Literal(value="3", type="int32") + testee = im.literal("3", "int32") expected = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="int32"), size=ti.TypeVar(idx=0)) inferred = ti.infer(testee) assert inferred == expected @@ -96,7 +96,7 @@ def test_int_literal(): def test_float_literal(): - testee = ir.Literal(value="3.0", type="float64") + testee = im.literal("3.0", "float64") expected = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="float64"), size=ti.TypeVar(idx=0)) inferred = ti.infer(testee) assert inferred == expected @@ -223,7 +223,7 @@ def test_and(): def test_cast(): testee = ir.FunCall( fun=ir.SymRef(id="cast_"), - args=[ir.Literal(value="1.", type="float64"), ir.SymRef(id="int64")], + args=[im.literal("1.", "float64"), ir.SymRef(id="int64")], ) expected = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="int64"), size=ti.TypeVar(idx=0)) inferred = ti.infer(testee) @@ -342,8 +342,8 @@ def test_make_tuple(): testee = ir.FunCall( fun=ir.SymRef(id="make_tuple"), args=[ - ir.Literal(value="True", type="bool"), - ir.Literal(value="42.0", type="float64"), + im.literal("True", "bool"), + im.literal("42.0", "float64"), ir.SymRef(id="x"), ], ) @@ -363,12 +363,12 @@ def test_tuple_get(): testee = ir.FunCall( fun=ir.SymRef(id="tuple_get"), args=[ - ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("1", ir.INTEGER_INDEX_BUILTIN), ir.FunCall( fun=ir.SymRef(id="make_tuple"), args=[ - ir.Literal(value="True", type="bool"), - ir.Literal(value="42.0", type="float64"), + im.literal("True", "bool"), + im.literal("42.0", "float64"), ], ), ], @@ -384,7 +384,7 @@ def test_tuple_get_in_lambda(): params=[ir.Sym(id="x")], expr=ir.FunCall( fun=ir.SymRef(id="tuple_get"), - args=[ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], + args=[im.literal("1", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], ), ) expected = ti.FunctionType( @@ -449,9 +449,7 @@ def test_reduce(): ], ), ) - testee = ir.FunCall( - fun=ir.SymRef(id="reduce"), args=[reduction_f, ir.Literal(value="0", type="int64")] - ) + testee = ir.FunCall(fun=ir.SymRef(id="reduce"), args=[reduction_f, im.literal("0", "int64")]) expected = ti.FunctionType( args=ti.ValListTuple( kind=ti.Value(), @@ -486,7 +484,7 @@ def test_scan(): ) testee = ir.FunCall( fun=ir.SymRef(id="scan"), - args=[scan_f, ir.Literal(value="True", type="bool"), ir.Literal(value="0", type="int64")], + args=[scan_f, im.literal("True", "bool"), im.literal("0", "int64")], ) expected = ti.FunctionType( args=ti.Tuple.from_elems( @@ -697,7 +695,7 @@ def test_dynamic_offset(): fun=ir.SymRef(id="named_range"), args=[ ir.AxisLiteral(value="IDim"), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="i"), ], ), @@ -705,7 +703,7 @@ def test_dynamic_offset(): fun=ir.SymRef(id="named_range"), args=[ ir.AxisLiteral(value="JDim"), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="j"), ], ), @@ -713,7 +711,7 @@ def test_dynamic_offset(): fun=ir.SymRef(id="named_range"), args=[ ir.AxisLiteral(value="KDim"), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="k"), ], ), @@ -839,8 +837,8 @@ def test_fencil_definition_same_closure_input(): domain=im.call("unstructured_domain")( im.call("named_range")( ir.AxisLiteral(value="Edge"), - ir.Literal(value="0", type="int32"), - ir.Literal(value="10", type="int32"), + im.literal("0", "int32"), + im.literal("10", "int32"), ) ), stencil=im.ref("f1"), @@ -851,8 +849,8 @@ def test_fencil_definition_same_closure_input(): domain=im.call("unstructured_domain")( im.call("named_range")( ir.AxisLiteral(value="Vertex"), - ir.Literal(value="0", type="int32"), - ir.Literal(value="10", type="int32"), + im.literal("0", "int32"), + im.literal("10", "int32"), ) ), stencil=im.ref("f2"), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_list_get.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_list_get.py index b6463ba0d5..87ed414393 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_list_get.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_list_get.py @@ -14,6 +14,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet +from gt4py.next.iterator.ir_utils import ir_makers as im def _list_get(index: ir.Expr, lst: ir.Expr) -> ir.FunCall: @@ -26,7 +27,7 @@ def _neighbors(offset: ir.Expr, it: ir.Expr) -> ir.FunCall: def test_list_get_neighbors(): testee = _list_get( - ir.Literal(value="42", type="int32"), + im.literal("42", "int32"), _neighbors(ir.OffsetLiteral(value="foo"), ir.SymRef(id="bar")), ) @@ -49,13 +50,11 @@ def test_list_get_neighbors(): def test_list_get_make_const_list(): testee = _list_get( - ir.Literal(value="42", type="int32"), - ir.FunCall( - fun=ir.SymRef(id="make_const_list"), args=[ir.Literal(value="3.14", type="float64")] - ), + im.literal("42", "int32"), + ir.FunCall(fun=ir.SymRef(id="make_const_list"), args=[im.literal("3.14", "float64")]), ) - expected = ir.Literal(value="3.14", type="float64") + expected = im.literal("3.14", "float64") actual = CollapseListGet().visit(testee) assert expected == actual diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 0521b0414b..8bad361a20 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -242,7 +242,7 @@ def test_update_cartesian_domains(): *( im.call("named_range")( ir.AxisLiteral(value=a), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), im.ref(s), ) for a, s in (("IDim", "i"), ("JDim", "j"), ("KDim", "k")) @@ -270,7 +270,7 @@ def test_update_cartesian_domains(): im.literal("0", ir.INTEGER_INDEX_BUILTIN), im.literal("1", ir.INTEGER_INDEX_BUILTIN), ), - im.plus(im.ref("i"), ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN)), + im.plus(im.ref("i"), im.literal("1", ir.INTEGER_INDEX_BUILTIN)), ], ) ] @@ -297,7 +297,7 @@ def test_update_cartesian_domains(): im.literal("0", ir.INTEGER_INDEX_BUILTIN), im.literal("1", ir.INTEGER_INDEX_BUILTIN), ), - im.plus(im.ref("i"), ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN)), + im.plus(im.ref("i"), im.literal("1", ir.INTEGER_INDEX_BUILTIN)), ], ) ] @@ -306,7 +306,7 @@ def test_update_cartesian_domains(): fun=im.ref("named_range"), args=[ ir.AxisLiteral(value=a), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), im.ref(s), ], ) @@ -325,10 +325,10 @@ def test_collect_tmps_info(): fun=im.ref("named_range"), args=[ ir.AxisLiteral(value="IDim"), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), ir.FunCall( fun=im.ref("plus"), - args=[im.ref("i"), ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN)], + args=[im.ref("i"), im.literal("1", ir.INTEGER_INDEX_BUILTIN)], ), ], ) @@ -338,7 +338,7 @@ def test_collect_tmps_info(): fun=im.ref("named_range"), args=[ ir.AxisLiteral(value=a), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), im.ref(s), ], ) @@ -382,7 +382,7 @@ def test_collect_tmps_info(): fun=im.ref("named_range"), args=[ ir.AxisLiteral(value=a), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), im.ref(s), ], ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_into_scan.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_into_scan.py index 50a3b3ecab..c8a61a3b2f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_into_scan.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_into_scan.py @@ -14,6 +14,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan +from gt4py.next.iterator.ir_utils import ir_makers as im # TODO(havogt): remove duplication with test_eta_reduction @@ -25,8 +26,8 @@ def _make_scan(*args: list[str], scanpass_body: ir.Expr) -> ir.Expr: params=[ir.Sym(id="state")] + [ir.Sym(id=f"{arg}") for arg in args], expr=scanpass_body, ), - ir.Literal(value="0.0", type="float64"), - ir.Literal(value="True", type="bool"), + im.literal("0.0", "float64"), + im.literal("True", "bool"), ], ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_scan_eta_reduction.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_scan_eta_reduction.py index 5a9d3a676b..53678d278e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_scan_eta_reduction.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_scan_eta_reduction.py @@ -14,6 +14,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.scan_eta_reduction import ScanEtaReduction +from gt4py.next.iterator.ir_utils import ir_makers as im def _make_scan(*args: list[str]): @@ -24,8 +25,8 @@ def _make_scan(*args: list[str]): params=[ir.Sym(id="state")] + [ir.Sym(id=f"{arg}") for arg in args], expr=ir.SymRef(id="foo"), ), - ir.Literal(value="0.0", type="float64"), - ir.Literal(value="True", type="bool"), + im.literal("0.0", "float64"), + im.literal("True", "bool"), ], ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_simple_inline_heuristic.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_simple_inline_heuristic.py index 685625e9e7..e236b7dd49 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_simple_inline_heuristic.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_simple_inline_heuristic.py @@ -16,6 +16,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.simple_inline_heuristic import is_eligible_for_inlining +from gt4py.next.iterator.ir_utils import ir_makers as im @pytest.fixture @@ -33,8 +34,8 @@ def scan(): ], ), ), - ir.Literal(value="True", type="bool"), - ir.Literal(value="0.0", type="float64"), + im.literal("True", "bool"), + im.literal("0.0", "float64"), ], ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py index ba4a91e6b5..054e7fac12 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py @@ -19,6 +19,7 @@ from gt4py.eve.utils import UIDs from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce, _get_partial_offset_tags +from gt4py.next.iterator.ir_utils import ir_makers as im from next_tests.unit_tests.conftest import DummyConnectivity @@ -34,7 +35,7 @@ def basic_reduction(): return ir.FunCall( fun=ir.FunCall( fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), ir.Literal(value="0.0", type="float64")], + args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], ), args=[ ir.FunCall( @@ -51,7 +52,7 @@ def reduction_with_shift_on_second_arg(): return ir.FunCall( fun=ir.FunCall( fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), ir.Literal(value="0.0", type="float64")], + args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], ), args=[ ir.SymRef(id="x"), @@ -69,7 +70,7 @@ def reduction_with_incompatible_shifts(): return ir.FunCall( fun=ir.FunCall( fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), ir.Literal(value="0.0", type="float64")], + args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], ), args=[ ir.FunCall( @@ -90,7 +91,7 @@ def reduction_with_irrelevant_full_shift(): return ir.FunCall( fun=ir.FunCall( fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), ir.Literal(value="0.0", type="float64")], + args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], ), args=[ ir.FunCall( diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index be7a9ff81e..d3651d3084 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -19,6 +19,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.otf import languages, stages from gt4py.next.program_processors.codegens.gtfn import gtfn_module +from gt4py.next.iterator.ir_utils import ir_makers as im @pytest.fixture @@ -30,8 +31,8 @@ def fencil_example(): fun=itir.SymRef(id="named_range"), args=[ itir.AxisLiteral(value="X"), - itir.Literal(value="0", type=itir.INTEGER_INDEX_BUILTIN), - itir.Literal(value="10", type=itir.INTEGER_INDEX_BUILTIN), + im.literal("0", itir.INTEGER_INDEX_BUILTIN), + im.literal("10", itir.INTEGER_INDEX_BUILTIN), ], ) ], @@ -43,7 +44,7 @@ def fencil_example(): itir.FunctionDefinition( id="stencil", params=[itir.Sym(id="buf"), itir.Sym(id="sc")], - expr=itir.Literal(value="1", type="float64"), + expr=im.literal("1", "float64"), ) ], closures=[ From 095ae13b91385a0145c4ed22a5e58dffd7eba3be Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 15 Apr 2024 07:48:05 +0200 Subject: [PATCH 02/52] Small cleanup --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 83ca3322d0..9251da3c82 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -20,12 +20,6 @@ from gt4py.next.type_system import type_specifications as ts, type_translation -def ensure_type(type_: str | ts.TypeSpec | None): - if isinstance(type_, str): - return ts.ScalarType(kind=getattr(ts.ScalarKind, type_.upper())) - return type_ - - def sym(sym_or_name: Union[str, itir.Sym]) -> itir.Sym: """ Convert to Sym if necessary. @@ -100,6 +94,13 @@ def ensure_offset(str_or_offset: Union[str, int, itir.OffsetLiteral]) -> itir.Of return str_or_offset +def ensure_type(type_: str | ts.TypeSpec | None): + if isinstance(type_, str): + return ts.ScalarType(kind=getattr(ts.ScalarKind, type_.upper())) + assert isinstance(type_, ts.TypeSpec) + return type_ + + class lambda_: """ Create a lambda from params and an expression. From cace41a1ca079ebd95c0e5bce4bb1001c8e6e3c9 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 15 Apr 2024 07:59:18 +0200 Subject: [PATCH 03/52] Small fix --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 9251da3c82..fc941052b3 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -97,7 +97,7 @@ def ensure_offset(str_or_offset: Union[str, int, itir.OffsetLiteral]) -> itir.Of def ensure_type(type_: str | ts.TypeSpec | None): if isinstance(type_, str): return ts.ScalarType(kind=getattr(ts.ScalarKind, type_.upper())) - assert isinstance(type_, ts.TypeSpec) + assert isinstance(type_, ts.TypeSpec) or type_ is None return type_ From b22f92999f2e01808cf02d66f1d07a492315adbc Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 15 Apr 2024 10:09:34 +0200 Subject: [PATCH 04/52] Fix failing doctests --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index fc941052b3..9924721a9d 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -64,7 +64,7 @@ def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> iti SymRef(id=SymbolRef('a')) >>> ensure_expr(3) - Literal(value='3', type='int32') + Literal(value='3', type=ScalarType(kind=, shape=None)) >>> ensure_expr(itir.OffsetLiteral(value="i")) OffsetLiteral(value='i') @@ -125,7 +125,7 @@ class call: Examples -------- >>> call("plus")(1, 1) - FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type='int32'), Literal(value='1', type='int32')]) + FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type=ScalarType(kind=, shape=None)), Literal(value='1', type=ScalarType(kind=, shape=None))]) """ def __init__(self, expr): @@ -307,13 +307,13 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: Make a literal node from a value. >>> literal_from_value(1.0) - Literal(value='1.0', type='float64') + Literal(value='1.0', type=ScalarType(kind=, shape=None)) >>> literal_from_value(1) - Literal(value='1', type='int32') + Literal(value='1', type=ScalarType(kind=, shape=None)) >>> literal_from_value(2147483648) - Literal(value='2147483648', type='int64') + Literal(value='2147483648', type=ScalarType(kind=, shape=None)) >>> literal_from_value(True) - Literal(value='True', type='bool') + Literal(value='True', type=ScalarType(kind=, shape=None)) """ if not isinstance(val, core_defs.Scalar): # type: ignore[arg-type] # mypy bug #11673 raise ValueError(f"Value must be a scalar, got '{type(val).__name__}'.") From 288ea3c7e9e5bc77996c81312909d252aa7838e2 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 17 Apr 2024 13:13:13 +0200 Subject: [PATCH 05/52] Address review comments --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 4 ++-- src/gt4py/next/type_system/type_info.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 9924721a9d..7fe05594ad 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -94,7 +94,7 @@ def ensure_offset(str_or_offset: Union[str, int, itir.OffsetLiteral]) -> itir.Of return str_or_offset -def ensure_type(type_: str | ts.TypeSpec | None): +def ensure_type(type_: str | ts.TypeSpec | None) -> ts.TypeSpec | None: if isinstance(type_, str): return ts.ScalarType(kind=getattr(ts.ScalarKind, type_.upper())) assert isinstance(type_, ts.TypeSpec) or type_ is None @@ -298,7 +298,7 @@ def shift(offset, value=None): return call(call("shift")(*args)) -def literal(value: str, typename: str): +def literal(value: str, typename: str) -> itir.Literal: return itir.Literal(value=value, type=ensure_type(typename)) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index cc9523b10d..a05b9afde8 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -236,10 +236,10 @@ def is_integer(symbol_type: ts.TypeSpec) -> bool: >>> is_integer(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))) False """ - return isinstance(symbol_type, ts.ScalarType) and symbol_type.kind in [ + return isinstance(symbol_type, ts.ScalarType) and symbol_type.kind in { ts.ScalarKind.INT32, ts.ScalarKind.INT64, - ] + } def is_integral(symbol_type: ts.TypeSpec) -> bool: From c695935fd76c58fec6a9635497371d33a3aa298a Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 17 Apr 2024 21:41:54 +0200 Subject: [PATCH 06/52] New type inference first draft --- src/gt4py/next/__init__.py | 3 +- src/gt4py/next/config.py | 1 + .../ffront/foast_passes/type_deduction.py | 2 +- src/gt4py/next/ffront/foast_to_itir.py | 6 - src/gt4py/next/ffront/lowering_utils.py | 4 +- src/gt4py/next/ffront/past_to_itir.py | 23 +- src/gt4py/next/ffront/type_info.py | 10 +- src/gt4py/next/iterator/ir.py | 31 +- .../ir_utils/common_pattern_matcher.py | 10 + src/gt4py/next/iterator/ir_utils/ir_makers.py | 10 +- src/gt4py/next/iterator/tracing.py | 3 +- .../iterator/transforms/collapse_tuple.py | 51 +- .../next/iterator/transforms/global_tmps.py | 46 +- .../next/iterator/transforms/pass_manager.py | 7 +- src/gt4py/next/iterator/type_inference.py | 1123 --------------- .../next/iterator/type_system/__init__.py | 13 + .../next/iterator/type_system/inference.py | 412 ++++++ src/gt4py/next/iterator/type_system/rules.py | 250 ++++ .../type_system/type_specifications.py | 62 + .../formatters/type_check.py | 32 - .../runners/dace_iterator/itir_to_sdfg.py | 22 +- .../runners/dace_iterator/itir_to_tasklet.py | 56 +- src/gt4py/next/type_inference.py | 353 ----- src/gt4py/next/type_system/type_info.py | 74 +- .../next/type_system/type_specifications.py | 20 +- .../next/type_system/type_translation.py | 4 +- tests/next_tests/definitions.py | 1 - .../ffront_tests/ffront_test_utils.py | 1 + .../test_horizontal_indirection.py | 5 - .../iterator_tests/test_anton_toy.py | 11 + tests/next_tests/unit_tests/conftest.py | 1 - .../iterator_tests/test_type_inference.py | 1219 +++-------------- .../transforms_tests/test_collapse_tuple.py | 10 +- .../transforms_tests/test_global_tmps.py | 99 +- .../unit_tests/test_type_inference.py | 84 -- 35 files changed, 1219 insertions(+), 2840 deletions(-) delete mode 100644 src/gt4py/next/iterator/type_inference.py create mode 100644 src/gt4py/next/iterator/type_system/__init__.py create mode 100644 src/gt4py/next/iterator/type_system/inference.py create mode 100644 src/gt4py/next/iterator/type_system/rules.py create mode 100644 src/gt4py/next/iterator/type_system/type_specifications.py delete mode 100644 src/gt4py/next/program_processors/formatters/type_check.py delete mode 100644 src/gt4py/next/type_inference.py delete mode 100644 tests/next_tests/unit_tests/test_type_inference.py diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index e79e2f5517..f8882f8c5f 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -23,7 +23,7 @@ module in question is a submodule, defines `__all__` and exports many public API objects. """ -from . import common, ffront, iterator, program_processors, type_inference +from . import common, ffront, iterator, program_processors from .common import ( Dimension, DimensionKind, @@ -62,7 +62,6 @@ "ffront", "iterator", "program_processors", - "type_inference", # from common "Dimension", "DimensionKind", diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index 682d5254e5..048f7ff773 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -76,6 +76,7 @@ def env_flag_to_bool(name: str, default: bool) -> bool: pathlib.Path(os.environ.get(f"{_PREFIX}_BUILD_CACHE_DIR", tempfile.gettempdir())) / "gt4py_cache" ) +print(BUILD_CACHE_DIR) #: Whether generated code projects should be kept around between runs. diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 471840ff1b..f74a50b3f6 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -831,10 +831,10 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> foast.Call: ) return_type = type_info.apply_to_primitive_constituents( - value.type, lambda primitive_type: with_altered_scalar_kind( primitive_type, getattr(ts.ScalarKind, new_type.id.upper()) ), + value.type, ) assert isinstance(return_type, (ts.TupleType, ts.ScalarType, ts.FieldType)) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 80c0f1fea3..53ccbf56a1 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -228,12 +228,6 @@ def visit_Assign( ) def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: - # TODO(tehrengruber): extend to more types - if isinstance(node.type, ts.FieldType): - kind = "Iterator" - dtype = node.type.dtype.kind.name.lower() - is_list = type_info.is_local_field(node.type) - return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list)) return im.sym(node.id) def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index 72182b7d31..cde34f315a 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -47,7 +47,7 @@ def fun(primitive_type: ts.TypeSpec, path: tuple[int, ...]) -> itir.Expr: return im.let(param, expr)( type_info.apply_to_primitive_constituents( - arg_type, fun, with_path_arg=True, tuple_constructor=im.make_tuple + fun, arg_type, with_path_arg=True, tuple_constructor=im.make_tuple ) ) @@ -96,7 +96,7 @@ def fun(_: Any, path: tuple[int, ...]) -> itir.FunCall: lift_args.append(arg_expr) stencil_expr = type_info.apply_to_primitive_constituents( - arg_type, fun, with_path_arg=True, tuple_constructor=im.make_tuple + fun, arg_type, with_path_arg=True, tuple_constructor=im.make_tuple ) return im.let(param, expr)(im.lift(im.lambda_(*lift_params)(stencil_expr))(*lift_args)) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index fb5c1a6882..fe1bffb42c 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -175,7 +175,12 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: ) assert all(field_dims == fields_dims[0] for field_dims in fields_dims) for dim_idx in range(len(fields_dims[0])): - size_params.append(itir.Sym(id=_size_arg_from_field(param.id, dim_idx))) + size_params.append( + itir.Sym( + id=_size_arg_from_field(param.id, dim_idx), + type=ts.ScalarType(kind=ts.ScalarKind.INT32), + ) + ) return size_params @@ -390,11 +395,11 @@ def _compute_field_slice(node: past.Subscript) -> list[past.Slice]: raise AssertionError( "Unexpected 'out' argument, must be tuple of slices or slice expression." ) - node_dims_ls = cast(ts.FieldType, node.type).dims - assert isinstance(node_dims_ls, list) - if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims_ls): + node_dims = cast(ts.FieldType, node.type).dims + assert isinstance(node_dims, list) + if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims): raise ValueError( - f"Too many indices for field '{out_field_name}': field is {len(node_dims_ls)}" + f"Too many indices for field '{out_field_name}': field is {len(node_dims)}" f"-dimensional, but {len(out_field_slice_)} were indexed." ) return out_field_slice_ @@ -466,13 +471,7 @@ def visit_Name(self, node: past.Name, **kwargs: Any) -> itir.SymRef: return itir.SymRef(id=node.id) def visit_Symbol(self, node: past.Symbol, **kwargs: Any) -> itir.Sym: - # TODO(tehrengruber): extend to more types - if isinstance(node.type, ts.FieldType): - kind = "Iterator" - dtype = node.type.dtype.kind.name.lower() - is_list = type_info.is_local_field(node.type) - return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list)) - return itir.Sym(id=node.id) + return itir.Sym(id=node.id, type=node.type) def visit_BinOp(self, node: past.BinOp, **kwargs: Any) -> itir.FunCall: return itir.FunCall( diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 2bd4f21993..8ceb405486 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -37,7 +37,7 @@ def promote_el(type_el: ts.TypeSpec) -> ts.TypeSpec: return ts.FieldType(dims=[], dtype=type_el) return type_el - return type_info.apply_to_primitive_constituents(type_, promote_el) + return type_info.apply_to_primitive_constituents(promote_el, type_) def promote_zero_dims( @@ -69,11 +69,11 @@ def _as_field(arg_el: ts.TypeSpec, path: tuple[int, ...]) -> ts.TypeSpec: raise ValueError(f"'{arg_el}' is not compatible with '{param_el}'.") return arg_el - return type_info.apply_to_primitive_constituents(arg, _as_field, with_path_arg=True) + return type_info.apply_to_primitive_constituents(_as_field, arg, with_path_arg=True) new_args = [*args] for i, (param, arg) in enumerate( - zip(function_type.pos_only_args + list(function_type.pos_or_kw_args.values()), args) + zip(list(function_type.pos_only_args) + list(function_type.pos_or_kw_args.values()), args) ): new_args[i] = promote_arg(param, arg) new_kwargs = {**kwargs} @@ -192,7 +192,7 @@ def _as_field(dtype: ts.TypeSpec, path: tuple[int, ...]) -> ts.FieldType: # TODO: we want some generic field type here, but our type system does not support it yet. return ts.FieldType(dims=[common.Dimension("...")], dtype=dtype) - res = type_info.apply_to_primitive_constituents(param, _as_field, with_path_arg=True) + res = type_info.apply_to_primitive_constituents(_as_field, param, with_path_arg=True) assert isinstance(res, (ts.FieldType, ts.TupleType)) return res @@ -309,5 +309,5 @@ def return_type_scanop( [callable_type.axis], ) return type_info.apply_to_primitive_constituents( - carry_dtype, lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg)) + lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg)), carry_dtype ) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 5fa259ec64..75c9c8a866 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -12,7 +12,6 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import typing from typing import ClassVar, List, Optional, Union import gt4py.eve as eve @@ -27,6 +26,8 @@ class Node(eve.Node): location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False) + type: Optional[ts.TypeSpec] = eve.field(default=None, repr=False, compare=False) + def __str__(self) -> str: from gt4py.next.iterator.pretty_printer import pformat @@ -43,24 +44,6 @@ def __hash__(self) -> int: class Sym(Node): # helper id: Coerced[SymbolName] - # TODO(tehrengruber): Revisit. Using strings is a workaround to avoid coupling with the - # type inference. - kind: typing.Literal["Iterator", "Value", None] = None - dtype: Optional[tuple[str, bool]] = ( - None # format: name of primitive type, boolean indicating if it is a list - ) - - @datamodels.validator("kind") - def _kind_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str): - if value and value not in ["Iterator", "Value"]: - raise ValueError(f"Invalid kind '{value}', must be one of 'Iterator', 'Value'.") - - @datamodels.validator("dtype") - def _dtype_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str): - if value and value[0] not in TYPEBUILTINS: - raise ValueError( - f"Invalid dtype '{value}', must be one of '{', '.join(TYPEBUILTINS)}'." - ) @noninstantiable @@ -172,18 +155,14 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib FLOATING_POINT_BUILTINS = {"float32", "float64"} TYPEBUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"} -GRAMMAR_BUILTINS = { +BUILTINS = { + "tuple_get", + "cast_", "cartesian_domain", "unstructured_domain", "make_tuple", - "tuple_get", "shift", "neighbors", - "cast_", -} - -BUILTINS = { - *GRAMMAR_BUILTINS, "named_range", "list_get", "map_", diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index a4b074a4b6..f7b80ced4f 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -35,3 +35,13 @@ def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: def is_if_call(node: itir.Expr) -> TypeGuard[itir.FunCall]: """Match expression of the form `if_(cond, true_branch, false_branch)`.""" return isinstance(node, itir.FunCall) and node.fun == im.ref("if_") + + +def is_call_to(node: itir.Node, fun: str | list[str]) -> TypeGuard[itir.FunCall]: + if isinstance(fun, (list, tuple, set)): + return any((is_call_to(node, f) for f in fun)) + # TODO: fix in all places that we don't do node.fun == im.ref(...) because this breaks + # when the lhs has a type + return ( + isinstance(node, itir.FunCall) and isinstance(node.fun, itir.SymRef) and node.fun.id == fun + ) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 7fe05594ad..4fa93546a6 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -20,7 +20,7 @@ from gt4py.next.type_system import type_specifications as ts, type_translation -def sym(sym_or_name: Union[str, itir.Sym]) -> itir.Sym: +def sym(sym_or_name: Union[str, itir.Sym], type_=None) -> itir.Sym: """ Convert to Sym if necessary. @@ -33,11 +33,12 @@ def sym(sym_or_name: Union[str, itir.Sym]) -> itir.Sym: Sym(id=SymbolName('b'), kind=None, dtype=None) """ if isinstance(sym_or_name, itir.Sym): + assert not type_ return sym_or_name - return itir.Sym(id=sym_or_name) + return itir.Sym(id=sym_or_name, type=ensure_type(type_)) -def ref(ref_or_name: Union[str, itir.SymRef]) -> itir.SymRef: +def ref(ref_or_name: Union[str, itir.SymRef], type_=None) -> itir.SymRef: """ Convert to SymRef if necessary. @@ -50,8 +51,9 @@ def ref(ref_or_name: Union[str, itir.SymRef]) -> itir.SymRef: SymRef(id=SymbolRef('b')) """ if isinstance(ref_or_name, itir.SymRef): + assert not type_ return ref_or_name - return itir.SymRef(id=ref_or_name) + return itir.SymRef(id=ref_or_name, type=ensure_type(type_)) def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> itir.Expr: diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index d6dbb47ee9..41159b30e7 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -277,6 +277,7 @@ def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: "Only 'POSITIONAL_OR_KEYWORD' or 'VAR_POSITIONAL' parameters are supported." ) + arg_type = None kind, dtype = None, None if use_arg_types: # TODO(tehrengruber): Fields of tuples are not supported yet. Just ignore them for now. @@ -290,7 +291,7 @@ def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: type_info.is_local_field(arg_type), # is list ) - params.append(Sym(id=param_name, kind=kind, dtype=dtype)) + params.append(Sym(id=param_name, type=arg_type, kind=kind, dtype=dtype)) return params diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 4b8182a781..5783b36546 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -21,40 +21,11 @@ from gt4py import eve from gt4py.eve import utils as eve_utils -from gt4py.next import type_inference -from gt4py.next.iterator import ir, type_inference as it_type_inference +from gt4py.next.iterator import ir, type_system as itir_type_inference from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_if_call, is_let from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda -from gt4py.next.type_system import type_info - - -class UnknownLength: - pass - - -def _get_tuple_size(elem: ir.Node, use_global_information: bool) -> int | type[UnknownLength]: - if use_global_information: - type_ = elem.annex.type - # global inference should always give a length, fail otherwise - assert isinstance(type_, it_type_inference.Val) and isinstance( - type_.dtype, it_type_inference.Tuple - ) - else: - # use local type inference if no global information is available - assert isinstance(elem, ir.Node) - type_ = it_type_inference.infer(elem) - - if not ( - isinstance(type_, it_type_inference.Val) - and isinstance(type_.dtype, it_type_inference.Tuple) - ): - return UnknownLength - - if not type_.dtype.has_known_length: - return UnknownLength - - return len(type_.dtype) +from gt4py.next.type_system import type_info, type_specifications as ts def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): @@ -117,7 +88,6 @@ def all(self) -> CollapseTuple.Flag: return functools.reduce(operator.or_, self.__members__.values()) ignore_tuple_size: bool - use_global_type_inference: bool flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] PRESERVED_ANNEX_ATTRS = ("type",) @@ -128,16 +98,14 @@ def all(self) -> CollapseTuple.Flag: init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_tuple_el") ) - _node_types: Optional[dict[int, type_inference.Type]] = None - @classmethod def apply( cls, node: ir.Node, *, ignore_tuple_size: bool = False, - use_global_type_inference: bool = False, remove_letified_make_tuple_elements: bool = True, + offset_provider=None, # manually passing flags is mostly for allowing separate testing of the modes flags=None, ) -> ir.Node: @@ -150,18 +118,18 @@ def apply( Keyword arguments: ignore_tuple_size: Apply the transformation even if length of the inner tuple is greater than the length of the outer tuple. - use_global_type_inference: Run global type inference to determine tuple sizes. remove_letified_make_tuple_elements: Run `InlineLambdas` as a post-processing step to remove left-overs from `LETIFY_MAKE_TUPLE_ELEMENTS` transformation. `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` -> {1, 2}` """ flags = flags or cls.flags - if use_global_type_inference: - it_type_inference.infer_all(node, save_to_annex=True) + offset_provider = offset_provider or {} + + if not ignore_tuple_size: + node = itir_type_inference.infer(node, offset_provider=offset_provider) new_node = cls( ignore_tuple_size=ignore_tuple_size, - use_global_type_inference=use_global_type_inference, flags=flags, ).visit(node) @@ -219,9 +187,8 @@ def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ # tuple argument differs, just continue with the rest of the tree return None - if self.ignore_tuple_size or _get_tuple_size( - first_expr, self.use_global_type_inference - ) == len(node.args): + assert self.ignore_tuple_size or isinstance(first_expr.type, ts.TupleType) + if self.ignore_tuple_size or len(first_expr.type.types) == len(node.args): # type: ignore[union-attr] # ensured by assert above return first_expr return None diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index d099272b2b..1b19ca4bd9 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -23,7 +23,7 @@ from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator from gt4py.next import common -from gt4py.next.iterator import ir, type_inference +from gt4py.next.iterator import ir, type_system as itir_type_inference from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.pretty_printer import PrettyPrinter @@ -33,6 +33,8 @@ from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas from gt4py.next.iterator.transforms.prune_closure_inputs import PruneClosureInputs from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs +from gt4py.next.iterator.type_system import type_specifications as it_ts +from gt4py.next.type_system import type_specifications as ts """Iterator IR extension for global temporaries. @@ -59,7 +61,7 @@ class Temporary(ir.Node): id: Coerced[eve.SymbolName] domain: Optional[ir.Expr] = None - dtype: Optional[Any] = None + dtype: Optional[ts.ScalarType | ts.TupleType] = None class FencilWithTemporaries(ir.Node, SymbolTableTrait): @@ -166,7 +168,7 @@ def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: return False # do not extract when the result is a list (i.e. a lift expression used in a `reduce` call) # as we can not create temporaries for these stencils - if isinstance(expr.annex.type.dtype, type_inference.List): + if isinstance(expr.type, it_ts.ListType): return False if self.heuristics and not self.heuristics(expr): return False @@ -255,9 +257,9 @@ def always_extract_heuristics(_): uid_gen_tmps = UIDGenerator(prefix="_tmp") - type_inference.infer_all(node, offset_provider=offset_provider, save_to_annex=True) + node = itir_type_inference.infer(node, offset_provider=offset_provider) - tmps: list[ir.Sym] = [] + tmps: list[tuple[str, ts.DataType]] = [] closures: list[ir.StencilClosure] = [] for closure in reversed(node.closures): @@ -309,7 +311,9 @@ def always_extract_heuristics(_): stencil: ir.Node = lift_expr.fun.args[0] # usually an ir.Lambda or scan # allocate a new temporary - tmps.append(tmp_sym) + assert isinstance(stencil.type, ts.FunctionType) + assert isinstance(stencil.type.returns, ts.DataType) + tmps.append((tmp_sym.id, stencil.type.returns)) # create a new closure that executes the stencil of the applied lift and # writes the result to the newly created temporary @@ -361,13 +365,13 @@ def always_extract_heuristics(_): id=node.id, function_definitions=node.function_definitions, params=node.params - + [ir.Sym(id=tmp.id) for tmp in tmps] - + [ir.Sym(id=AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant + + [im.sym(name, type_) for name, type_ in tmps] + + [im.sym(AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant closures=list(reversed(closures)), location=node.location, ), params=node.params, - tmps=[Temporary(id=tmp.id) for tmp in tmps], + tmps=[Temporary(id=name, dtype=type_) for name, type_ in tmps], ) @@ -614,32 +618,10 @@ def collect_tmps_info(node: FencilWithTemporaries, *, offset_provider) -> Fencil assert output_field.id not in domains or domains[output_field.id] == closure.domain domains[output_field.id] = closure.domain - def convert_type(dtype): - if isinstance(dtype, type_inference.Primitive): - return dtype.name - elif isinstance(dtype, type_inference.Tuple): - return tuple(convert_type(el) for el in dtype) - elif isinstance(dtype, type_inference.List): - raise NotImplementedError("Temporaries with dtype list not supported.") - raise AssertionError() - - all_types = type_inference.infer_all(node.fencil, offset_provider=offset_provider) - fencil_type = all_types[id(node.fencil)] - assert isinstance(fencil_type, type_inference.FencilDefinitionType) - assert isinstance(fencil_type.params, type_inference.Tuple) - types = dict[str, ir.Expr]() - for param in node.fencil.params: - if param.id in tmps: - dtype = all_types[id(param)] - assert isinstance(dtype, type_inference.Val) - types[param.id] = convert_type(dtype.dtype) - return FencilWithTemporaries( fencil=node.fencil, params=node.params, - tmps=[ - Temporary(id=tmp.id, domain=domains[tmp.id], dtype=types[tmp.id]) for tmp in node.tmps - ], + tmps=[Temporary(id=tmp.id, domain=domains[tmp.id], dtype=tmp.dtype) for tmp in node.tmps], ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 5852ba9ae5..925b7d7587 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -96,6 +96,7 @@ def apply_common_transforms( assert isinstance(lift_mode, LiftMode) ir = MergeLet().visit(ir) ir = InlineFundefs().visit(ir) + ir = PruneUnreferencedFundefs().visit(ir) ir = PropagateDeref.apply(ir) ir = NormalizeShifts().visit(ir) @@ -119,8 +120,7 @@ def apply_common_transforms( # is constant-folded the surrounding tuple_get calls can be removed. inlined = CollapseTuple.apply( inlined, - # to limit number of times global type inference is executed, only in the last iterations. - use_global_type_inference=inlined == ir, + offset_provider=offset_provider, # TODO(tehrengruber): disabled since it increases compile-time too much right now flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, ) @@ -166,7 +166,8 @@ def apply_common_transforms( if unconditionally_collapse_tuples: ir = CollapseTuple.apply( ir, - ignore_tuple_size=unconditionally_collapse_tuples, + ignore_tuple_size=True, + offset_provider=offset_provider, # TODO(tehrengruber): disabled since it increases compile-time too much right now flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, ) diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py deleted file mode 100644 index 89fed49551..0000000000 --- a/src/gt4py/next/iterator/type_inference.py +++ /dev/null @@ -1,1123 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import dataclasses -import typing -from collections import abc -from typing import Optional - -import gt4py.eve as eve -import gt4py.next as gtx -from gt4py.next.common import Connectivity -from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.global_tmps import FencilWithTemporaries -from gt4py.next.type_inference import Type, TypeVar, freshen, reindex_vars, unify -from gt4py.next.type_system import type_info - - -"""Constraint-based inference for the iterator IR.""" - -T = typing.TypeVar("T", bound="Type") - -# list of nodes that have a type -TYPED_IR_NODES: typing.Final = ( - ir.Expr, - ir.FunctionDefinition, - ir.StencilClosure, - ir.FencilDefinition, - ir.Sym, -) - - -class UnsatisfiableConstraintsError(Exception): - unsatisfiable_constraints: list[tuple[Type, Type]] - - def __init__(self, unsatisfiable_constraints: list[tuple[Type, Type]]): - self.unsatisfiable_constraints = unsatisfiable_constraints - msg = "Type inference failed: Can not satisfy constraints:" - for lhs, rhs in unsatisfiable_constraints: - msg += f"\n {lhs} ≡ {rhs}" - super().__init__(msg) - - -class EmptyTuple(Type): - def __iter__(self) -> abc.Iterator[Type]: - return - yield - - def __len__(self) -> int: - return 0 - - -class Tuple(Type): - """Tuple type with arbitrary number of elements.""" - - front: Type - others: Type - - @classmethod - def from_elems(cls: typing.Type[T], *elems: Type) -> typing.Union[T, EmptyTuple]: - tup: typing.Union[T, EmptyTuple] = EmptyTuple() - for e in reversed(elems): - tup = cls(front=e, others=tup) - return tup - - def __iter__(self) -> abc.Iterator[Type]: - yield self.front - if not isinstance(self.others, (Tuple, EmptyTuple)): - raise ValueError(f"Can not iterate over partially defined tuple '{self}'.") - yield from self.others - - @property - def has_known_length(self): - return isinstance(self.others, EmptyTuple) or ( - isinstance(self.others, Tuple) and self.others.has_known_length - ) - - def __len__(self) -> int: - return sum(1 for _ in self) - - -class FunctionType(Type): - """Function type. - - Note: the type inference algorithm always infers a tuple-like type for - `args`, even for single-argument functions. - """ - - args: Type = eve.field(default_factory=TypeVar.fresh) - ret: Type = eve.field(default_factory=TypeVar.fresh) - - -class Location(Type): - """Location type.""" - - name: str - - -ANYWHERE = Location(name="ANYWHERE") - - -class Val(Type): - """The main type for representing values and iterators. - - Each `Val` consists of the following three things: - - A `kind` which is either `Value()`, `Iterator()`, or a variable - - A `dtype` which is either a `Primitive` or a variable - - A `size` which is either `Scalar()`, `Column()`, or a variable - """ - - kind: Type = eve.field(default_factory=TypeVar.fresh) - dtype: Type = eve.field(default_factory=TypeVar.fresh) - size: Type = eve.field(default_factory=TypeVar.fresh) - current_loc: Type = ANYWHERE - defined_loc: Type = ANYWHERE - - -class ValTuple(Type): - """A tuple of `Val` where all items have the same `kind` and `size`, but different dtypes.""" - - kind: Type = eve.field(default_factory=TypeVar.fresh) - dtypes: Type = eve.field(default_factory=TypeVar.fresh) - size: Type = eve.field(default_factory=TypeVar.fresh) - current_loc: Type = eve.field(default_factory=TypeVar.fresh) - defined_locs: Type = eve.field(default_factory=TypeVar.fresh) - - def __eq__(self, other: typing.Any) -> bool: - if ( - isinstance(self.dtypes, Tuple) - and isinstance(self.defined_locs, Tuple) - and isinstance(other, Tuple) - ): - dtypes: Type = self.dtypes - defined_locs: Type = self.defined_locs - elems: Type = other - while ( - isinstance(dtypes, Tuple) - and isinstance(defined_locs, Tuple) - and isinstance(elems, Tuple) - and Val( - kind=self.kind, - dtype=dtypes.front, - size=self.size, - current_loc=self.current_loc, - defined_loc=defined_locs.front, - ) - == elems.front - ): - dtypes = dtypes.others - defined_locs = defined_locs.others - elems = elems.others - return dtypes == defined_locs == elems == EmptyTuple() - - return ( - isinstance(other, ValTuple) - and self.kind == other.kind - and self.dtypes == other.dtypes - and self.size == other.size - and self.current_loc == other.current_loc - and self.defined_locs == other.defined_locs - ) - - def handle_constraint( - self, other: Type, add_constraint: abc.Callable[[Type, Type], None] - ) -> bool: - if isinstance(other, Tuple): - dtypes = [TypeVar.fresh() for _ in other] - defined_locs = [TypeVar.fresh() for _ in other] - expanded = [ - Val( - kind=self.kind, - dtype=dtype, - size=self.size, - current_loc=self.current_loc, - defined_loc=defined_loc, - ) - for dtype, defined_loc in zip(dtypes, defined_locs) - ] - add_constraint(self.dtypes, Tuple.from_elems(*dtypes)) - add_constraint(self.defined_locs, Tuple.from_elems(*defined_locs)) - add_constraint(Tuple.from_elems(*expanded), other) - return True - if isinstance(other, EmptyTuple): - add_constraint(self.dtypes, EmptyTuple()) - add_constraint(self.defined_locs, EmptyTuple()) - return True - return False - - -class ValListTuple(Type): - """ - A tuple of `Val` that contains `List`s. - - All items have: - - the same `kind` and `size`; - - `dtype` is `List` with different `list_dtypes`, but same `max_length`, and `has_skip_values`. - """ - - kind: Type = eve.field(default_factory=TypeVar.fresh) - list_dtypes: Type = eve.field(default_factory=TypeVar.fresh) - max_length: Type = eve.field(default_factory=TypeVar.fresh) - has_skip_values: Type = eve.field(default_factory=TypeVar.fresh) - size: Type = eve.field(default_factory=TypeVar.fresh) - - def __eq__(self, other: typing.Any) -> bool: - if isinstance(self.list_dtypes, Tuple) and isinstance(other, Tuple): - list_dtypes: Type = self.list_dtypes - elems: Type = other - while ( - isinstance(list_dtypes, Tuple) - and isinstance(elems, Tuple) - and Val( - kind=self.kind, - dtype=List( - dtype=list_dtypes.front, - max_length=self.max_length, - has_skip_values=self.has_skip_values, - ), - size=self.size, - ) - == elems.front - ): - list_dtypes = list_dtypes.others - elems = elems.others - return list_dtypes == elems == EmptyTuple() - - return ( - isinstance(other, ValListTuple) - and self.kind == other.kind - and self.list_dtypes == other.list_dtypes - and self.max_length == other.max_length - and self.has_skip_values == other.has_skip_values - and self.size == other.size - ) - - def handle_constraint( - self, other: Type, add_constraint: abc.Callable[[Type, Type], None] - ) -> bool: - if isinstance(other, Tuple): - list_dtypes = [TypeVar.fresh() for _ in other] - expanded = [ - Val( - kind=self.kind, - dtype=List( - dtype=dtype, - max_length=self.max_length, - has_skip_values=self.has_skip_values, - ), - size=self.size, - ) - for dtype in list_dtypes - ] - add_constraint(self.list_dtypes, Tuple.from_elems(*list_dtypes)) - add_constraint(Tuple.from_elems(*expanded), other) - return True - if isinstance(other, EmptyTuple): - add_constraint(self.list_dtypes, EmptyTuple()) - return True - return False - - -class Column(Type): - """Marker for column-sized values/iterators.""" - - ... - - -class Scalar(Type): - """Marker for scalar-sized values/iterators.""" - - ... - - -class Primitive(Type): - """Primitive type used in values/iterators.""" - - name: str - - def handle_constraint( - self, other: Type, add_constraint: abc.Callable[[Type, Type], None] - ) -> bool: - if not isinstance(other, Primitive): - return False - - if self.name != other.name: - raise TypeError( - f"Can not satisfy constraint on primitive types: '{self.name}' ≡ '{other.name}'." - ) - return True - - -class UnionPrimitive(Type): - """Union of primitive types.""" - - names: tuple[str, ...] - - def handle_constraint( - self, other: Type, add_constraint: abc.Callable[[Type, Type], None] - ) -> bool: - if isinstance(other, UnionPrimitive): - raise AssertionError("'UnionPrimitive' may only appear on one side of a constraint.") - if not isinstance(other, Primitive): - return False - - return other.name in self.names - - -class Value(Type): - """Marker for values.""" - - ... - - -class Iterator(Type): - """Marker for iterators.""" - - ... - - -class Length(Type): - length: int - - -class BoolType(Type): - value: bool - - -class List(Type): - dtype: Type = eve.field(default_factory=TypeVar.fresh) - max_length: Type = eve.field(default_factory=TypeVar.fresh) - has_skip_values: Type = eve.field(default_factory=TypeVar.fresh) - - -class Closure(Type): - """Stencil closure type.""" - - output: Type - inputs: Type - - -class FunctionDefinitionType(Type): - """Function definition type.""" - - name: str - fun: FunctionType - - -class FencilDefinitionType(Type): - """Fencil definition type.""" - - name: str - fundefs: Type - params: Type - - -class LetPolymorphic(Type): - """ - Wrapper for let-polymorphic types. - - Used for fencil-level function definitions. - """ - - dtype: Type - - -def _default_constraints(): - return { - (FLOAT_DTYPE, UnionPrimitive(names=("float32", "float64"))), - (INT_DTYPE, UnionPrimitive(names=("int32", "int64"))), - } - - -BOOL_DTYPE = Primitive(name="bool") -INT_DTYPE = TypeVar.fresh() -FLOAT_DTYPE = TypeVar.fresh() -AXIS_DTYPE = Primitive(name="axis") -NAMED_RANGE_DTYPE = Primitive(name="named_range") -DOMAIN_DTYPE = Primitive(name="domain") -OFFSET_TAG_DTYPE = Primitive(name="offset_tag") - -# Some helpers to define the builtins' types -T0 = TypeVar.fresh() -T1 = TypeVar.fresh() -T2 = TypeVar.fresh() -T3 = TypeVar.fresh() -T4 = TypeVar.fresh() -T5 = TypeVar.fresh() -Val_T0_T1 = Val(kind=Value(), dtype=T0, size=T1) -Val_T0_Scalar = Val(kind=Value(), dtype=T0, size=Scalar()) -Val_BOOL_T1 = Val(kind=Value(), dtype=BOOL_DTYPE, size=T1) - -BUILTIN_CATEGORY_MAPPING = ( - ( - ir.UNARY_MATH_FP_BUILTINS, - FunctionType( - args=Tuple.from_elems(Val(kind=Value(), dtype=FLOAT_DTYPE, size=T0)), - ret=Val(kind=Value(), dtype=FLOAT_DTYPE, size=T0), - ), - ), - (ir.UNARY_MATH_NUMBER_BUILTINS, FunctionType(args=Tuple.from_elems(Val_T0_T1), ret=Val_T0_T1)), - ( - {"power"}, - FunctionType( - args=Tuple.from_elems(Val_T0_T1, Val(kind=Value(), dtype=T2, size=T1)), ret=Val_T0_T1 - ), - ), - ( - ir.BINARY_MATH_NUMBER_BUILTINS, - FunctionType(args=Tuple.from_elems(Val_T0_T1, Val_T0_T1), ret=Val_T0_T1), - ), - ( - ir.UNARY_MATH_FP_PREDICATE_BUILTINS, - FunctionType( - args=Tuple.from_elems(Val(kind=Value(), dtype=FLOAT_DTYPE, size=T0)), - ret=Val(kind=Value(), dtype=BOOL_DTYPE, size=T0), - ), - ), - ( - ir.BINARY_MATH_COMPARISON_BUILTINS, - FunctionType(args=Tuple.from_elems(Val_T0_T1, Val_T0_T1), ret=Val_BOOL_T1), - ), - ( - ir.BINARY_LOGICAL_BUILTINS, - FunctionType(args=Tuple.from_elems(Val_BOOL_T1, Val_BOOL_T1), ret=Val_BOOL_T1), - ), - (ir.UNARY_LOGICAL_BUILTINS, FunctionType(args=Tuple.from_elems(Val_BOOL_T1), ret=Val_BOOL_T1)), -) - -BUILTIN_TYPES: dict[str, Type] = { - **{builtin: type_ for category, type_ in BUILTIN_CATEGORY_MAPPING for builtin in category}, - "deref": FunctionType( - args=Tuple.from_elems( - Val(kind=Iterator(), dtype=T0, size=T1, current_loc=T2, defined_loc=T2) - ), - ret=Val_T0_T1, - ), - "can_deref": FunctionType( - args=Tuple.from_elems( - Val(kind=Iterator(), dtype=T0, size=T1, current_loc=T2, defined_loc=T3) - ), - ret=Val_BOOL_T1, - ), - "if_": FunctionType(args=Tuple.from_elems(Val_BOOL_T1, T2, T2), ret=T2), - "lift": FunctionType( - args=Tuple.from_elems( - FunctionType( - args=ValTuple(kind=Iterator(), dtypes=T2, size=T1, current_loc=T3, defined_locs=T4), - ret=Val_T0_T1, - ) - ), - ret=FunctionType( - args=ValTuple(kind=Iterator(), dtypes=T2, size=T1, current_loc=T5, defined_locs=T4), - ret=Val(kind=Iterator(), dtype=T0, size=T1, current_loc=T5, defined_loc=T3), - ), - ), - "map_": FunctionType( - args=Tuple.from_elems( - FunctionType(args=ValTuple(kind=Value(), dtypes=T2, size=T1), ret=Val_T0_T1) - ), - ret=FunctionType( - args=ValListTuple(kind=Value(), list_dtypes=T2, size=T1), - ret=Val(kind=Value(), dtype=List(dtype=T0, max_length=T4, has_skip_values=T5), size=T1), - ), - ), - "reduce": FunctionType( - args=Tuple.from_elems( - FunctionType( - args=Tuple(front=Val_T0_T1, others=ValTuple(kind=Value(), dtypes=T2, size=T1)), - ret=Val_T0_T1, - ), - Val_T0_T1, - ), - ret=FunctionType( - args=ValListTuple( - kind=Value(), list_dtypes=T2, max_length=T4, has_skip_values=T5, size=T1 - ), - ret=Val_T0_T1, - ), - ), - "make_const_list": FunctionType( - args=Tuple.from_elems(Val_T0_T1), - ret=Val(kind=Value(), dtype=List(dtype=T0, max_length=T2, has_skip_values=T3), size=T1), - ), - "list_get": FunctionType( - args=Tuple.from_elems( - Val(kind=Value(), dtype=INT_DTYPE, size=Scalar()), - Val(kind=Value(), dtype=List(dtype=T0, max_length=T2, has_skip_values=T3), size=T1), - ), - ret=Val_T0_T1, - ), - "scan": FunctionType( - args=Tuple.from_elems( - FunctionType( - args=Tuple( - front=Val_T0_Scalar, - others=ValTuple( - kind=Iterator(), dtypes=T2, size=Scalar(), current_loc=T3, defined_locs=T4 - ), - ), - ret=Val_T0_Scalar, - ), - Val(kind=Value(), dtype=BOOL_DTYPE, size=Scalar()), - Val_T0_Scalar, - ), - ret=FunctionType( - args=ValTuple( - kind=Iterator(), dtypes=T2, size=Column(), current_loc=T3, defined_locs=T4 - ), - ret=Val(kind=Value(), dtype=T0, size=Column()), - ), - ), - "named_range": FunctionType( - args=Tuple.from_elems( - Val(kind=Value(), dtype=AXIS_DTYPE, size=Scalar()), - Val(kind=Value(), dtype=INT_DTYPE, size=Scalar()), - Val(kind=Value(), dtype=INT_DTYPE, size=Scalar()), - ), - ret=Val(kind=Value(), dtype=NAMED_RANGE_DTYPE, size=Scalar()), - ), -} - - -del T0, T1, T2, T3, T4, T5, Val_T0_T1, Val_T0_Scalar, Val_BOOL_T1 - - -def _infer_shift_location_types(shift_args, offset_provider, constraints): - current_loc_in = TypeVar.fresh() - if offset_provider: - current_loc_out = current_loc_in - for arg in shift_args: - if not isinstance(arg, ir.OffsetLiteral): - # probably some dynamically computed offset, thus we assume it's a number not an axis and just ignore it (see comment below) - continue - offset = arg.value - if isinstance(offset, int): - continue # ignore 'application' of (partial) shifts - else: - assert isinstance(offset, str) - axis = offset_provider[offset] - if isinstance(axis, gtx.Dimension): - continue # Cartesian shifts don't change the location type - elif isinstance(axis, Connectivity): - assert ( - axis.origin_axis.kind - == axis.neighbor_axis.kind - == gtx.DimensionKind.HORIZONTAL - ) - constraints.add((current_loc_out, Location(name=axis.origin_axis.value))) - current_loc_out = Location(name=axis.neighbor_axis.value) - else: - raise NotImplementedError() - elif not shift_args: - current_loc_out = current_loc_in - else: - current_loc_out = TypeVar.fresh() - return current_loc_in, current_loc_out - - -@dataclasses.dataclass -class _TypeInferrer(eve.traits.VisitorWithSymbolTableTrait, eve.NodeTranslator): - """ - Visit the full iterator IR tree, convert nodes to respective types and generate constraints. - - Attributes: - collected_types: Mapping from the (Python) id of a node to its type. - constraints: Set of constraints, where a constraint is a pair of types that need to agree. - See `unify` for more information. - """ - - offset_provider: Optional[dict[str, Connectivity | gtx.Dimension]] - collected_types: dict[int, Type] = dataclasses.field(default_factory=dict) - constraints: set[tuple[Type, Type]] = dataclasses.field(default_factory=_default_constraints) - - def visit(self, node, **kwargs) -> typing.Any: - result = super().visit(node, **kwargs) - if isinstance(node, TYPED_IR_NODES): - assert isinstance(result, Type) - if not ( - id(node) not in self.collected_types or self.collected_types[id(node)] == result - ): - # using the same node in multiple places is fine as long as the type is the same - # for all occurences - self.constraints.add((result, self.collected_types[id(node)])) - self.collected_types[id(node)] = result - - return result - - def visit_Sym(self, node: ir.Sym, **kwargs) -> Type: - result = TypeVar.fresh() - if node.kind: - kind = {"Iterator": Iterator(), "Value": Value()}[node.kind] - self.constraints.add( - (Val(kind=kind, current_loc=TypeVar.fresh(), defined_loc=TypeVar.fresh()), result) - ) - if node.dtype: - assert node.dtype is not None - dtype: Primitive | List = Primitive(name=node.dtype[0]) - if node.dtype[1]: - dtype = List(dtype=dtype) - self.constraints.add( - (Val(dtype=dtype, current_loc=TypeVar.fresh(), defined_loc=TypeVar.fresh()), result) - ) - return result - - def visit_SymRef(self, node: ir.SymRef, *, symtable, **kwargs) -> Type: - if node.id in ir.BUILTINS: - if node.id in BUILTIN_TYPES: - return freshen(BUILTIN_TYPES[node.id]) - elif node.id in ir.GRAMMAR_BUILTINS: - raise TypeError( - f"Builtin '{node.id}' is only allowed as applied/called function by the type " - "inference." - ) - elif node.id in ir.TYPEBUILTINS: - # TODO(tehrengruber): Implement propagating types of values referring to types, e.g. - # >>> my_int = int64 - # ... cast_(expr, my_int) - # One way to support this is by introducing a "type of type" similar to pythons - # `typing.Type`. - raise NotImplementedError( - f"Type builtin '{node.id}' is only supported as literal argument by the " - "type inference." - ) - else: - raise NotImplementedError(f"Missing type definition for builtin '{node.id}'.") - elif node.id in symtable: - sym_decl = symtable[node.id] - assert isinstance(sym_decl, TYPED_IR_NODES) - res = self.collected_types[id(sym_decl)] - if isinstance(res, LetPolymorphic): - return freshen(res.dtype) - return res - - return TypeVar.fresh() - - def visit_Literal(self, node: ir.Literal, **kwargs) -> Val: - return Val(kind=Value(), dtype=Primitive(name=node.type.kind.name.lower())) - - def visit_AxisLiteral(self, node: ir.AxisLiteral, **kwargs) -> Val: - return Val(kind=Value(), dtype=AXIS_DTYPE, size=Scalar()) - - def visit_OffsetLiteral(self, node: ir.OffsetLiteral, **kwargs) -> TypeVar: - return TypeVar.fresh() - - def visit_Lambda(self, node: ir.Lambda, **kwargs) -> FunctionType: - ptypes = {p.id: self.visit(p, **kwargs) for p in node.params} - ret = self.visit(node.expr, **kwargs) - return FunctionType(args=Tuple.from_elems(*(ptypes[p.id] for p in node.params)), ret=ret) - - def _visit_make_tuple(self, node: ir.FunCall, **kwargs) -> Type: - # Calls to `make_tuple` are handled as being part of the grammar, not as function calls. - argtypes = self.visit(node.args, **kwargs) - kind = ( - TypeVar.fresh() - ) # `kind == Iterator()` means zipping iterators into an iterator of tuples - size = TypeVar.fresh() - dtype = Tuple.from_elems(*(TypeVar.fresh() for _ in argtypes)) - for d, a in zip(dtype, argtypes): - self.constraints.add((Val(kind=kind, dtype=d, size=size), a)) - return Val(kind=kind, dtype=dtype, size=size) - - def _visit_tuple_get(self, node: ir.FunCall, **kwargs) -> Type: - # Calls to `tuple_get` are handled as being part of the grammar, not as function calls. - if len(node.args) != 2: - raise TypeError("'tuple_get' requires exactly two arguments.") - if not isinstance(node.args[0], ir.Literal) or not type_info.is_integer(node.args[0].type): - raise TypeError( - f"The first argument to 'tuple_get' must be a literal of type '{ir.INTEGER_INDEX_BUILTIN}'." - ) - self.visit(node.args[0], **kwargs) # visit index so that its type is collected - idx = int(node.args[0].value) - tup = self.visit(node.args[1], **kwargs) - kind = TypeVar.fresh() # `kind == Iterator()` means splitting an iterator of tuples - elem = TypeVar.fresh() - size = TypeVar.fresh() - - dtype = Tuple(front=elem, others=TypeVar.fresh()) - for _ in range(idx): - dtype = Tuple(front=TypeVar.fresh(), others=dtype) - - val = Val(kind=kind, dtype=dtype, size=size) - self.constraints.add((tup, val)) - return Val(kind=kind, dtype=elem, size=size) - - def _visit_neighbors(self, node: ir.FunCall, **kwargs) -> Type: - if len(node.args) != 2: - raise TypeError("'neighbors' requires exactly two arguments.") - if not (isinstance(node.args[0], ir.OffsetLiteral) and isinstance(node.args[0].value, str)): - raise TypeError("The first argument to 'neighbors' must be an 'OffsetLiteral' tag.") - - # Visit arguments such that their type is also inferred - self.visit(node.args, **kwargs) - - max_length: Type = TypeVar.fresh() - has_skip_values: Type = TypeVar.fresh() - if self.offset_provider: - connectivity = self.offset_provider[node.args[0].value] - assert isinstance(connectivity, Connectivity) - max_length = Length(length=connectivity.max_neighbors) - has_skip_values = BoolType(value=connectivity.has_skip_values) - current_loc_in, current_loc_out = _infer_shift_location_types( - [node.args[0]], self.offset_provider, self.constraints - ) - dtype_ = TypeVar.fresh() - size = TypeVar.fresh() - it = self.visit(node.args[1], **kwargs) - self.constraints.add( - ( - it, - Val( - kind=Iterator(), - dtype=dtype_, - size=size, - current_loc=current_loc_in, - defined_loc=current_loc_out, - ), - ) - ) - lst = List(dtype=dtype_, max_length=max_length, has_skip_values=has_skip_values) - return Val(kind=Value(), dtype=lst, size=size) - - def _visit_cast_(self, node: ir.FunCall, **kwargs) -> Type: - if len(node.args) != 2: - raise TypeError("'cast_' requires exactly two arguments.") - val_arg_type = self.visit(node.args[0], **kwargs) - type_arg = node.args[1] - if not isinstance(type_arg, ir.SymRef) or type_arg.id not in ir.TYPEBUILTINS: - raise TypeError("The second argument to 'cast_' must be a type literal.") - - size = TypeVar.fresh() - - self.constraints.add((val_arg_type, Val(kind=Value(), dtype=TypeVar.fresh(), size=size))) - - return Val(kind=Value(), dtype=Primitive(name=type_arg.id), size=size) - - def _visit_shift(self, node: ir.FunCall, **kwargs) -> Type: - # Calls to shift are handled as being part of the grammar, not - # as function calls, as the type depends on the offset provider. - - # Visit arguments such that their type is also inferred (particularly important for - # dynamic offsets) - self.visit(node.args) - - current_loc_in, current_loc_out = _infer_shift_location_types( - node.args, self.offset_provider, self.constraints - ) - defined_loc = TypeVar.fresh() - dtype_ = TypeVar.fresh() - size = TypeVar.fresh() - return FunctionType( - args=Tuple.from_elems( - Val( - kind=Iterator(), - dtype=dtype_, - size=size, - current_loc=current_loc_in, - defined_loc=defined_loc, - ) - ), - ret=Val( - kind=Iterator(), - dtype=dtype_, - size=size, - current_loc=current_loc_out, - defined_loc=defined_loc, - ), - ) - - def _visit_domain(self, node: ir.FunCall, **kwargs) -> Type: - for arg in node.args: - self.constraints.add( - ( - Val(kind=Value(), dtype=NAMED_RANGE_DTYPE, size=Scalar()), - self.visit(arg, **kwargs), - ) - ) - return Val(kind=Value(), dtype=DOMAIN_DTYPE, size=Scalar()) - - def _visit_cartesian_domain(self, node: ir.FunCall, **kwargs) -> Type: - return self._visit_domain(node, **kwargs) - - def _visit_unstructured_domain(self, node: ir.FunCall, **kwargs) -> Type: - return self._visit_domain(node, **kwargs) - - def visit_FunCall(self, node: ir.FunCall, **kwargs) -> Type: - if isinstance(node.fun, ir.SymRef) and node.fun.id in ir.GRAMMAR_BUILTINS: - # builtins that are treated as part of the grammar are handled in `_visit_` - return getattr(self, f"_visit_{node.fun.id}")(node, **kwargs) - elif isinstance(node.fun, ir.SymRef) and node.fun.id in ir.TYPEBUILTINS: - return Val(kind=Value(), dtype=Primitive(name=node.fun.id)) - - fun = self.visit(node.fun, **kwargs) - args = Tuple.from_elems(*self.visit(node.args, **kwargs)) - ret = TypeVar.fresh() - self.constraints.add((fun, FunctionType(args=args, ret=ret))) - return ret - - def visit_FunctionDefinition(self, node: ir.FunctionDefinition, **kwargs) -> LetPolymorphic: - fun = ir.Lambda(params=node.params, expr=node.expr) - - # Since functions defined in a function definition are let-polymorphic we don't want - # their parameters to inherit the constraints of the arguments in a call to them. A simple - # way to do this is to run the type inference on the function itself and reindex its type - # vars when referencing the function, i.e. in a `SymRef`. - collected_types = infer_all(fun, offset_provider=self.offset_provider, reindex=False) - fun_type = LetPolymorphic(dtype=collected_types.pop(id(fun))) - assert not set(self.collected_types.keys()) & set(collected_types.keys()) - self.collected_types = {**self.collected_types, **collected_types} - - return fun_type - - def visit_StencilClosure(self, node: ir.StencilClosure, **kwargs) -> Closure: - domain = self.visit(node.domain, **kwargs) - stencil = self.visit(node.stencil, **kwargs) - output = self.visit(node.output, **kwargs) - output_dtype = TypeVar.fresh() - output_loc = TypeVar.fresh() - self.constraints.add( - (domain, Val(kind=Value(), dtype=Primitive(name="domain"), size=Scalar())) - ) - self.constraints.add( - ( - output, - Val(kind=Iterator(), dtype=output_dtype, size=Column(), defined_loc=output_loc), - ) - ) - - inputs: list[Type] = self.visit(node.inputs, **kwargs) - stencil_params = [] - for input_ in inputs: - stencil_param = Val(current_loc=output_loc, defined_loc=TypeVar.fresh()) - self.constraints.add( - ( - input_, - Val( - kind=stencil_param.kind, - dtype=stencil_param.dtype, - size=stencil_param.size, - # closure input and stencil param differ in `current_loc` - current_loc=ANYWHERE, - # TODO(tehrengruber): Seems to break for scalars. Use `TypeVar.fresh()`? - defined_loc=stencil_param.defined_loc, - ), - ) - ) - stencil_params.append(stencil_param) - - self.constraints.add( - ( - stencil, - FunctionType( - args=Tuple.from_elems(*stencil_params), - ret=Val(kind=Value(), dtype=output_dtype, size=Column()), - ), - ) - ) - return Closure(output=output, inputs=Tuple.from_elems(*inputs)) - - def visit_FencilWithTemporaries(self, node: FencilWithTemporaries, **kwargs): - return self.visit(node.fencil, **kwargs) - - def visit_FencilDefinition(self, node: ir.FencilDefinition, **kwargs) -> FencilDefinitionType: - ftypes = [] - # Note: functions have to be ordered according to Lisp/Scheme `let*` - # statements; that is, functions can only reference other functions - # that are defined before - for fun_def in node.function_definitions: - fun_type: LetPolymorphic = self.visit(fun_def, **kwargs) - ftype = FunctionDefinitionType(name=fun_def.id, fun=fun_type.dtype) - ftypes.append(ftype) - - params = [self.visit(p, **kwargs) for p in node.params] - self.visit(node.closures, **kwargs) - return FencilDefinitionType( - name=str(node.id), fundefs=Tuple.from_elems(*ftypes), params=Tuple.from_elems(*params) - ) - - -def _save_types_to_annex(node: ir.Node, types: dict[int, Type]) -> None: - for child_node in node.pre_walk_values().if_isinstance(*TYPED_IR_NODES): - try: - child_node.annex.type = types[id(child_node)] - except KeyError as ex: - if not ( - isinstance(child_node, ir.SymRef) - and child_node.id in ir.GRAMMAR_BUILTINS | ir.TYPEBUILTINS - ): - raise AssertionError( - f"Expected a type to be inferred for node '{child_node}', but none was found." - ) from ex - - -def infer_all( - node: ir.Node, - *, - offset_provider: Optional[dict[str, Connectivity | gtx.Dimension]] = None, - reindex: bool = True, - save_to_annex=False, -) -> dict[int, Type]: - """ - Infer the types of the child expressions of a given iterator IR expression. - - The result is a dictionary mapping the (Python) id of child nodes to their type. - - The `save_to_annex` flag should only be used as a last resort when the return dictionary is - not enough. - """ - # Collect preliminary types of all nodes and constraints on them - inferrer = _TypeInferrer(offset_provider=offset_provider) - inferrer.visit(node) - - # Ensure dict order is pre-order of the tree - collected_types = dict(reversed(inferrer.collected_types.items())) - - # Compute the most general type that satisfies all constraints - unified_types, unsatisfiable_constraints = unify( - list(collected_types.values()), inferrer.constraints - ) - - if reindex: - unified_types, unsatisfiable_constraints = reindex_vars( - (unified_types, unsatisfiable_constraints) - ) - - result = { - id_: unified_type - for id_, unified_type in zip(collected_types.keys(), unified_types, strict=True) - } - - if save_to_annex: - _save_types_to_annex(node, result) - - if unsatisfiable_constraints: - raise UnsatisfiableConstraintsError(unsatisfiable_constraints) - - return result - - -def infer( - expr: ir.Node, - offset_provider: typing.Optional[dict[str, typing.Any]] = None, - save_to_annex: bool = False, -) -> Type: - """Infer the type of the given iterator IR expression.""" - inferred_types = infer_all(expr, offset_provider=offset_provider, save_to_annex=save_to_annex) - return inferred_types[id(expr)] - - -class PrettyPrinter(eve.NodeTranslator): - """Pretty-printer for type expressions.""" - - @staticmethod - def _subscript(i: int) -> str: - return "".join("₀₁₂₃₄₅₆₇₈₉"[int(d)] for d in str(i)) - - @staticmethod - def _superscript(i: int) -> str: - return "".join("⁰¹²³⁴⁵⁶⁷⁸⁹"[int(d)] for d in str(i)) - - def _fmt_size(self, size: Type) -> str: - if size == Column(): - return "ᶜ" - if size == Scalar(): - return "ˢ" - assert isinstance(size, TypeVar) - return self._superscript(size.idx) - - def _fmt_dtype( - self, - kind: Type, - dtype_str: str, - current_loc: typing.Optional[str] = None, - defined_loc: typing.Optional[str] = None, - ) -> str: - if kind == Value(): - return dtype_str - if kind == Iterator(): - if current_loc == defined_loc == "ANYWHERE" or current_loc is defined_loc is None: - locs = "" - else: - assert isinstance(current_loc, str) and isinstance(defined_loc, str) - locs = current_loc + ", " + defined_loc + ", " - return "It[" + locs + dtype_str + "]" - assert isinstance(kind, TypeVar) - return "ItOrVal" + self._subscript(kind.idx) + "[" + dtype_str + "]" - - def visit_EmptyTuple(self, node: EmptyTuple) -> str: - return "()" - - def visit_Tuple(self, node: Tuple) -> str: - s = "(" + self.visit(node.front) - while isinstance(node.others, Tuple): - node = node.others - s += ", " + self.visit(node.front) - s += ")" - if not isinstance(node.others, EmptyTuple): - s += ":" + self.visit(node.others) - return s - - def visit_Location(self, node: Location): - return node.name - - def visit_FunctionType(self, node: FunctionType) -> str: - return self.visit(node.args) + " → " + self.visit(node.ret) - - def visit_Val(self, node: Val) -> str: - return self._fmt_dtype( - node.kind, - self.visit(node.dtype) + self._fmt_size(node.size), - self.visit(node.current_loc), - self.visit(node.defined_loc), - ) - - def visit_Primitive(self, node: Primitive) -> str: - return node.name - - def visit_List(self, node: List) -> str: - return f"L[{self.visit(node.dtype)}, {self.visit(node.max_length)}, {self.visit(node.has_skip_values)}]" - - def visit_FunctionDefinitionType(self, node: FunctionDefinitionType) -> str: - return node.name + " :: " + self.visit(node.fun) - - def visit_Closure(self, node: Closure) -> str: - return self.visit(node.inputs) + " ⇒ " + self.visit(node.output) - - def visit_FencilDefinitionType(self, node: FencilDefinitionType) -> str: - assert isinstance(node.fundefs, (Tuple, EmptyTuple)) - assert isinstance(node.params, (Tuple, EmptyTuple)) - return ( - "{" - + "".join(self.visit(f) + ", " for f in node.fundefs) - + node.name - + "(" - + ", ".join(self.visit(p) for p in node.params) - + ")}" - ) - - def visit_ValTuple(self, node: ValTuple) -> str: - if isinstance(node.dtypes, TypeVar): - assert isinstance(node.defined_locs, TypeVar) - return ( - "(" - + self._fmt_dtype( - node.kind, - "T" + self._fmt_size(node.size), - self.visit(node.current_loc), - "…" + self._subscript(node.defined_locs.idx), - ) - + ", …)" - + self._subscript(node.dtypes.idx) - ) - assert isinstance(node.dtypes, (Tuple, EmptyTuple)) - if isinstance(node.defined_locs, (Tuple, EmptyTuple)): - defined_locs = node.defined_locs - else: - defined_locs = Tuple.from_elems(*(Location(name="_") for _ in node.dtypes)) - return ( - "(" - + ", ".join( - self.visit( - Val( - kind=node.kind, - dtype=dtype, - size=node.size, - current_loc=node.current_loc, - defined_loc=defined_loc, - ) - ) - for dtype, defined_loc in zip(node.dtypes, defined_locs) - ) - + ")" - ) - - def visit_ValListTuple(self, node: ValListTuple) -> str: - if isinstance(node.list_dtypes, TypeVar): - return f"(L[…{self._subscript(node.list_dtypes.idx)}, {self.visit(node.max_length)}, {self.visit(node.has_skip_values)}]{self._fmt_size(node.size)}, …)" - assert isinstance(node.list_dtypes, (Tuple, EmptyTuple)) - return ( - "(" - + ", ".join( - self.visit( - Val( - kind=Value(), - dtype=List( - dtype=dtype, - max_length=node.max_length, - has_skip_values=node.has_skip_values, - ), - size=node.size, - ) - ) - for dtype in node.list_dtypes - ) - + ")" - ) - - def visit_TypeVar(self, node: TypeVar) -> str: - return "T" + self._subscript(node.idx) - - def visit_Type(self, node: Type) -> str: - return ( - node.__class__.__name__ - + "(" - + ", ".join(f"{k}={v}" for k, v in node.iter_children_items()) - + ")" - ) - - -pformat = PrettyPrinter().visit - - -def pprint(x: Type) -> None: - print(pformat(x)) diff --git a/src/gt4py/next/iterator/type_system/__init__.py b/src/gt4py/next/iterator/type_system/__init__.py new file mode 100644 index 0000000000..6c43e2f12a --- /dev/null +++ b/src/gt4py/next/iterator/type_system/__init__.py @@ -0,0 +1,13 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py new file mode 100644 index 0000000000..818647637e --- /dev/null +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -0,0 +1,412 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import copy +import dataclasses +import functools +import inspect + +from gt4py import eve +from gt4py.eve import concepts +from gt4py.eve.extended_typing import Any, Callable, Optional, TypeVar, Union +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_call_to +from gt4py.next.iterator.type_system import rules, type_specifications as it_ts +from gt4py.next.type_system import type_info, type_specifications as ts + + +def _is_representable_as_int(s: int | str) -> bool: + try: + int(s) + return True + except ValueError: + return False + + +def _is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec): + """ + Predicate to determine if two types are compatible. + + This function gracefully handles iterators with unknown positions which are considered + compatible to any other positions of another iterator. Beside that this function + simply checks for equality of types. + + >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) + >>> IDim = common.Dimension(value="IDim") + >>> type_on_i_of_i_it = it_ts.IteratorType( + ... position_dims=[IDim], defined_dims=[IDim], element_type=bool_type + ... ) + >>> type_on_undefined_of_i_it = it_ts.IteratorType( + ... position_dims="unknown", defined_dims=[IDim], element_type=bool_type + ... ) + >>> _is_compatible_type(type_on_i_of_i_it, type_on_undefined_of_i_it) + True + + >>> JDim = common.Dimension(value="JDim") + >>> type_on_j_of_j_it = it_ts.IteratorType( + ... position_dims=[JDim], defined_dims=[JDim], element_type=bool_type + ... ) + >>> _is_compatible_type(type_on_i_of_i_it, type_on_j_of_j_it) + False + """ + is_compatible = True + + def is_compatible_element(el_type_a: ts.TypeSpec, el_type_b: ts.TypeSpec): + nonlocal is_compatible + if isinstance(el_type_a, it_ts.IteratorType) and isinstance(el_type_b, it_ts.IteratorType): + if not any(el_type.position_dims == "unknown" for el_type in [el_type_a, el_type_b]): + is_compatible &= el_type_a.position_dims == el_type_b.position_dims + is_compatible &= el_type_a.defined_dims == el_type_b.defined_dims + is_compatible &= el_type_a.element_type == el_type_b.element_type + else: + is_compatible &= el_type_a == el_type_b + + type_info.apply_to_primitive_constituents(is_compatible_element, type_a, type_b) + + return is_compatible + + +# Problems: +# - how to get the kind of the dimension in here? X +# maybe directly attach the type to an axis literal? +# - lift X (also mention to Hannes) +# - is_compatible +# - late offset literal in (also mention to Hannes) +# tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +# - what happens when we get a lambda function whose params are already typed +# - write back params type in lambda +# - documentation +# describe why lambda can only have one type. Describe idea to solve e.g. +# let("f", lambda x: x)(f(1)+f(1.)) +# -> let("f_int", lambda x: x, "f_float", lambda x: x)(f_int(1)+f_float(1.)) +# - make types hashable +# - ~~either Eve with Coercion and no runtime checking,~~ dataclass hash with cached property +# - document how scans are handled (also mention to Hannes) +# - types are stored in the node, but will be incomplete after some passes +# - deferred type for testing +# - visit_FunctionDefinition + + +# Design decisions +# Only the parameters of fencils need to be typed. +# Lambda functions are not polymorphic. + + +def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, "DeferredFunctionType"]) -> None: + """ + Execute `callback` as soon as all `args` have a type. + """ + ready_args = [False] * len(args) + inferred_args = [None] * len(args) + + def mark_ready(i, type_): + ready_args[i] = True + inferred_args[i] = type_ + if all(ready_args): + callback(*inferred_args) + + for i, arg in enumerate(args): + if isinstance(arg, DeferredFunctionType): + arg.on_type_ready(functools.partial(mark_ready, i)) + else: + assert isinstance(arg, ts.TypeSpec) + mark_ready(i, arg) + + +@dataclasses.dataclass +class DeferredFunctionType: + """ + This class wraps a raw type inference rule to handle typing of functions. + + As functions are represented by type inference rules + """ + + #: type rule that given a set of types or type rules returns the return type or a type rule + type_rule: rules.TypeInferenceRule + #: offset provider used by some type rules + offset_provider: Any + #: node that has this type + node: Optional[itir.Node] = None + #: list of references to this function + aliases: list[itir.SymRef] = dataclasses.field(default_factory=list) + #: list of callbacks executed as soon as the type is ready + callbacks: list[Any] = dataclasses.field(default_factory=list) + #: the inferred type when ready and None until then + inferred_type: Optional[ts.FunctionType] = None + #: whether to store the type in the node or not + store_inferred_type_in_node: bool = False + + def infer_type( + self, return_type: ts.DataType | ts.DeferredType, *args: ts.DataType | ts.DeferredType + ) -> ts.FunctionType: + return ts.FunctionType( + pos_only_args=list(args), pos_or_kw_args={}, kw_only_args={}, returns=return_type + ) + + def _infer_type_listener(self, return_type: ts.TypeSpec, *args: ts.TypeSpec) -> None: + self.inferred_type = self.infer_type(return_type, *args) # type: ignore[arg-type] # ensured by assert above + + # if the type has been fully inferred, notify all `DeferredFunctionType`s that depend on it. + for cb in self.callbacks: + cb(self.inferred_type) + + if self.store_inferred_type_in_node: + assert self.node + self.node.type = self.inferred_type + for alias in self.aliases: + alias.type = self.inferred_type + + def on_type_ready(self, cb: Callable[[ts.TypeSpec], None]) -> None: + if self.inferred_type: + # type has already been inferred, just call the callback + cb(self.inferred_type) + else: + self.callbacks.append(cb) + + def __call__(self, *args) -> ts.TypeSpec | rules.TypeInferenceRule: + if "offset_provider" in inspect.signature(self.type_rule).parameters: + return_type = self.type_rule(*args, offset_provider=self.offset_provider) + else: + return_type = self.type_rule(*args) + + # return type is a typing rule by itself + if callable(return_type): + return_type = DeferredFunctionType( + node=None, # node will be set by caller + type_rule=return_type, + offset_provider=self.offset_provider, + store_inferred_type_in_node=True, + ) + + # delay storing the type until the return type and all arguments are inferred + on_inferred(self._infer_type_listener, return_type, *args) + + return return_type + + +T = TypeVar("T", bound=itir.Node) + + +@dataclasses.dataclass +class ITIRTypeInference(eve.NodeTranslator): + """ + TODO + """ + + offset_provider: Any + + @functools.cached_property + def dimensions(self) -> dict[str, common.Dimension]: + dimensions: dict[str, common.Dimension] = {} + for offset_name, provider in self.offset_provider.items(): + dimensions[offset_name] = common.Dimension( + value=offset_name, kind=common.DimensionKind.LOCAL + ) + if isinstance(provider, common.Dimension): + dimensions[provider.value] = provider + elif isinstance(provider, common.Connectivity): + dimensions[provider.origin_axis.value] = provider.origin_axis + dimensions[provider.neighbor_axis.value] = provider.neighbor_axis + return dimensions + + @classmethod + def apply(cls, node: T, *, offset_provider, inplace: bool = False) -> T: + instance = cls(offset_provider=offset_provider) + if not inplace: + node = copy.deepcopy(node) + instance.visit( + node, + ctx={ + name: DeferredFunctionType( + type_rule=rules.type_inference_rules[name], + # builtin functions are polymorphic + store_inferred_type_in_node=False, + offset_provider=offset_provider, + ) + for name in rules.type_inference_rules.keys() + }, + ) + return node + + def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + result = super().visit(node, **kwargs) + if isinstance(node, itir.Node): + if isinstance(result, ts.TypeSpec): + # TODO: verify types match + node.type = result + elif isinstance(result, DeferredFunctionType): + pass + elif callable(result): + # TODO: only do for type rules not every callable + return DeferredFunctionType( + node=node, + type_rule=result, + store_inferred_type_in_node=True, + offset_provider=self.offset_provider, + ) + else: + raise AssertionError( + f"Expected a 'TypeSpec' or 'DeferredFunctionType', but got " + f"`{type(result).__name__}`" + ) + return result + + def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx): + params: dict[str, ts.DataType] = {} + for param in node.params: + assert isinstance(param.type, ts.DataType) + params[param.id] = param.type + + function_definitions: dict[str, rules.TypeInferenceRule] = {} + for fun_def in node.function_definitions: + function_definitions[fun_def.id] = self.visit(fun_def, ctx=ctx | function_definitions) + + closures = self.visit(node.closures, ctx=ctx | params | function_definitions) + return it_ts.FencilType(params=list(params.values()), closures=closures) + + def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx): + domain: it_ts.DomainType = self.visit(node.domain, ctx=ctx) + inputs: list[ts.FieldType] = self.visit(node.inputs, ctx=ctx) + output: ts.FieldType = self.visit(node.output, ctx=ctx) + + assert isinstance(domain, it_ts.DomainType) + for output_el in type_info.primitive_constituents(output): + assert isinstance(output_el, ts.FieldType) + + stencil_args = [] + for input_ in inputs: + defined_dims: list[common.Dimension] | None = None + + def extract_dtype_and_defined_dims(el_type: ts.TypeSpec): + nonlocal defined_dims + assert isinstance(el_type, (ts.FieldType, ts.ScalarType)) + el_type = type_info.promote(el_type, always_field=True) + if not defined_dims: + defined_dims = el_type.dims # type: ignore[union-attr] # ensured by always_field + else: + # tuple inputs must all have the same defined dimensions as we + # create an iterator of tuples from them + assert defined_dims == el_type.dims # type: ignore[union-attr] # ensured by always_field + return el_type.dtype # type: ignore[union-attr] # ensured by always_field + + element_type = type_info.apply_to_primitive_constituents( + extract_dtype_and_defined_dims, input_ + ) + + assert defined_dims is not None + + stencil_args.append( + it_ts.IteratorType( + position_dims=domain.dims, defined_dims=defined_dims, element_type=element_type + ) + ) + + stencil_type_rule = self.visit(node.stencil, ctx=ctx) + stencil_returns = stencil_type_rule(*stencil_args) + + return it_ts.StencilClosureType( + domain=domain, + stencil=ts.FunctionType( + pos_only_args=stencil_args, + pos_or_kw_args={}, + kw_only_args={}, + returns=stencil_returns, + ), + output=output, + inputs=inputs, + ) + + def visit_Node(self, node: itir.Node, **kwargs): + raise NotImplementedError( + f"No type deduction rule for nodes of type " f"'{type(node).__name__}'." + ) + + def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs): + assert node.value in self.dimensions + return ts.DimensionType(dim=self.dimensions[node.value]) + + def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs): + # TODO: this happens in tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py + if _is_representable_as_int(node.value): + return it_ts.OffsetLiteralType(value=int(node.value)) + else: + assert isinstance(node.value, str) and node.value in self.dimensions + return it_ts.OffsetLiteralType(value=self.dimensions[node.value]) + + def visit_Literal(self, node: itir.Literal, **kwargs): + assert isinstance(node.type, ts.ScalarType) + return node.type + + def visit_SymRef(self, node: itir.SymRef, *, ctx: dict[str, ts.TypeSpec]): + # for testing it is useful to be able to use types without a declaration, but just storing + # the type in the node itself. + if node.type: + assert node.id not in ctx or _is_compatible_type(ctx[node.id], node.type) + return node.type + # TODO: only allow in testing + if node.id not in ctx: + return ts.DeferredType(constraint=None) + result = ctx[node.id] + if isinstance(result, DeferredFunctionType): + result.aliases.append(node) + return result + + def visit_Lambda( + self, node: itir.Lambda | itir.FunctionDefinition, *, ctx: dict[str, ts.TypeSpec] + ): + def fun(*args): + return self.visit( + node.expr, ctx=ctx | {p.id: a for p, a in zip(node.params, args, strict=True)} + ) + + return DeferredFunctionType( + node=node, + type_rule=fun, + store_inferred_type_in_node=True, + offset_provider=self.offset_provider, + ) + + visit_FunctionDefinition = visit_Lambda + + def visit_FunCall(self, node: itir.FunCall, *, ctx: dict[str, ts.TypeSpec]): + if is_call_to(node, "cast_"): + value, type_constructor = node.args + assert ( + isinstance(type_constructor, itir.SymRef) + and type_constructor.id in itir.TYPEBUILTINS + ) + return ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper())) + + if is_call_to(node, "tuple_get"): + index_literal, tuple_ = node.args + self.visit(tuple_, ctx=ctx) # ensure tuple is typed + assert isinstance(index_literal, itir.Literal) + index = int(index_literal.value) + assert isinstance(tuple_.type, ts.TupleType) + return tuple_.type.types[index] + + fun = self.visit(node.fun, ctx=ctx) + args = self.visit(node.args, ctx=ctx) + + result = fun(*args) + + if isinstance(result, DeferredFunctionType): + assert not result.node + result.node = node + + return result + + +infer = ITIRTypeInference.apply diff --git a/src/gt4py/next/iterator/type_system/rules.py b/src/gt4py/next/iterator/type_system/rules.py new file mode 100644 index 0000000000..61650629b5 --- /dev/null +++ b/src/gt4py/next/iterator/type_system/rules.py @@ -0,0 +1,250 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +from gt4py.eve import extended_typing as xtyping +from gt4py.eve.extended_typing import Iterable, Optional, Union +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.type_system import type_specifications as it_ts +from gt4py.next.type_system import type_info, type_specifications as ts + + +TypeSpecOrTypeInferenceRule = Union[ts.TypeSpec, "TypeInferenceRule"] + +TypeInferenceRule = xtyping.Callable[..., TypeSpecOrTypeInferenceRule] + +#: dictionary from function name to its type inference rule +type_inference_rules: dict[str, TypeInferenceRule] = {} + + +def _is_derefable_iterator_type(it_type: it_ts.IteratorType) -> bool: + if it_type.position_dims == "unknown": + return True + it_position_dim_names = [dim.value for dim in it_type.position_dims] # TODO + return all(dim.value in it_position_dim_names for dim in it_type.defined_dims) + + +def _register_type_inference_rule( + rule: Optional[TypeInferenceRule] = None, *, fun_names: Optional[Iterable[str]] = None +): + def wrapper(rule): + nonlocal fun_names + if not fun_names: + fun_names = [rule.__name__] + else: + # store names in function object for better debuggability + rule.fun_names = fun_names + for fun_ in fun_names: + type_inference_rules[fun_] = rule + + if rule: + return wrapper(rule) + else: + return wrapper + + +@_register_type_inference_rule( + fun_names=itir.UNARY_MATH_NUMBER_BUILTINS | itir.UNARY_MATH_FP_BUILTINS +) +def _(val: ts.ScalarType) -> ts.ScalarType: + return val + + +@_register_type_inference_rule +def power(base: ts.ScalarType, exponent: ts.ScalarType) -> ts.ScalarType: + return base + + +@_register_type_inference_rule(fun_names=itir.BINARY_MATH_NUMBER_BUILTINS) +def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType: + assert lhs == rhs + return lhs + + +@_register_type_inference_rule( + fun_names=itir.UNARY_MATH_FP_PREDICATE_BUILTINS | itir.UNARY_LOGICAL_BUILTINS +) +def _(arg: ts.ScalarType) -> ts.ScalarType: + return ts.ScalarType(kind=ts.ScalarKind.BOOL) + + +@_register_type_inference_rule( + fun_names=itir.BINARY_MATH_COMPARISON_BUILTINS | itir.BINARY_LOGICAL_BUILTINS +) +def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType | ts.TupleType: + return ts.ScalarType(kind=ts.ScalarKind.BOOL) + + +@_register_type_inference_rule +def deref(it: it_ts.IteratorType) -> ts.DataType: + assert isinstance(it, it_ts.IteratorType) + assert _is_derefable_iterator_type(it) + return it.element_type + + +@_register_type_inference_rule +def can_deref(it: it_ts.IteratorType) -> ts.ScalarType: + assert isinstance(it, it_ts.IteratorType) + assert _is_derefable_iterator_type(it) + return ts.ScalarType(kind=ts.ScalarKind.BOOL) + + +@_register_type_inference_rule +def if_(cond: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType: + assert isinstance(cond, ts.ScalarType) and cond.kind == ts.ScalarKind.BOOL + assert true_branch == false_branch + return true_branch + + +@_register_type_inference_rule +def make_const_list(scalar: ts.ScalarType) -> it_ts.ListType: + assert isinstance(scalar, ts.ScalarType) + return it_ts.ListType(element_type=scalar) + + +@_register_type_inference_rule +def list_get(index: ts.ScalarType, list_: it_ts.ListType) -> ts.DataType: + assert isinstance(index, ts.ScalarType) and type_info.is_integral(index) + assert isinstance(list_, it_ts.ListType) + return list_.element_type + + +@_register_type_inference_rule +def named_range( + dim: ts.DimensionType, start: ts.ScalarType, stop: ts.ScalarType +) -> it_ts.NamedRangeType: + assert isinstance(dim, ts.DimensionType) + return it_ts.NamedRangeType(dim=dim.dim) + + +@_register_type_inference_rule(fun_names=["cartesian_domain", "unstructured_domain"]) +def _(*args: it_ts.NamedRangeType) -> it_ts.DomainType: + assert all(isinstance(arg, it_ts.NamedRangeType) for arg in args) + return it_ts.DomainType(dims=[arg.dim for arg in args]) + + +@_register_type_inference_rule +def make_tuple(*args: ts.DataType) -> ts.TupleType: + return ts.TupleType(types=list(args)) + + +@_register_type_inference_rule +def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType: + assert ( + isinstance(offset_literal, it_ts.OffsetLiteralType) + and isinstance(offset_literal.value, common.Dimension) + and offset_literal.value.kind == common.DimensionKind.LOCAL + ) + assert isinstance(it, it_ts.IteratorType) + return it_ts.ListType(element_type=it.element_type) + + +@_register_type_inference_rule +def lift(stencil: TypeInferenceRule) -> TypeInferenceRule: + def apply_lift(*its: it_ts.IteratorType) -> it_ts.IteratorType: + stencil_args = [] + for it in its: + assert isinstance(it, it_ts.IteratorType) + stencil_args.append( + it_ts.IteratorType( + # the positions are only known when we deref + position_dims="unknown", + defined_dims=it.defined_dims, + element_type=it.element_type, + ) + ) + stencil_return_type = stencil(*stencil_args) + assert isinstance(stencil_return_type, ts.DataType) + + position_dims = its[0].position_dims if its else [] + # we would need to look inside the stencil to find out where the resulting iterator + # is defined, e.g. using trace shift, instead just use an empty list which means + # everywhere + defined_dims: list[common.Dimension] = [] + return it_ts.IteratorType( + position_dims=position_dims, defined_dims=defined_dims, element_type=stencil_return_type + ) + + return apply_lift + + +@_register_type_inference_rule +def scan( + scan_pass: TypeInferenceRule, direction: ts.ScalarType, init: ts.ScalarType +) -> TypeInferenceRule: + assert isinstance(direction, ts.ScalarType) and direction.kind == ts.ScalarKind.BOOL + + def apply_scan(*its: it_ts.IteratorType) -> ts.DataType: + result = scan_pass(init, *its) + assert isinstance(result, ts.DataType) + return result + + return apply_scan + + +@_register_type_inference_rule +def map_(op: TypeInferenceRule) -> TypeInferenceRule: + def applied_map(*args: it_ts.ListType) -> it_ts.ListType: + assert len(args) > 0 + assert all(isinstance(arg, it_ts.ListType) for arg in args) + arg_el_types = [arg.element_type for arg in args] + el_type = op(*arg_el_types) + assert isinstance(el_type, ts.DataType) + return it_ts.ListType(element_type=el_type) + + return applied_map + + +@_register_type_inference_rule +def reduce(op: TypeInferenceRule, init: ts.TypeSpec) -> TypeInferenceRule: + def applied_reduce(*args: it_ts.ListType): + assert all(isinstance(arg, it_ts.ListType) for arg in args) + return op(init, *(arg.element_type for arg in args)) + + return applied_reduce + + +@_register_type_inference_rule +def shift(*offset_literals, offset_provider) -> TypeInferenceRule: + def apply_shift(it: it_ts.IteratorType): + assert isinstance(it, it_ts.IteratorType) + if it.position_dims == "unknown": # nothing to do here + return it + new_position_dims = [*it.position_dims] + assert len(offset_literals) % 2 == 0 + for offset_axis, _ in zip(offset_literals[:-1:2], offset_literals[1::2], strict=True): + assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance( + offset_axis.value, common.Dimension + ) + provider = offset_provider[offset_axis.value.value] # TODO: naming + if isinstance(provider, common.Dimension): + pass + elif isinstance(provider, common.Connectivity): + found = False + for i, dim in enumerate(new_position_dims): + if dim.value == provider.origin_axis.value: + assert not found + new_position_dims[i] = provider.neighbor_axis + found = True + assert found + else: + raise NotImplementedError() + return it_ts.IteratorType( + position_dims=new_position_dims, + defined_dims=it.defined_dims, + element_type=it.element_type, + ) + + return apply_shift diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py new file mode 100644 index 0000000000..811ec74605 --- /dev/null +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -0,0 +1,62 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import dataclasses +import typing + +from gt4py._core.definitions import IntegralScalar +from gt4py.next import common +from gt4py.next.type_system import type_specifications as ts + + +@dataclasses.dataclass(frozen=True) +class NamedRangeType(ts.TypeSpec): + dim: common.Dimension + + +@dataclasses.dataclass(frozen=True) +class DomainType(ts.DataType): + dims: list[common.Dimension] + + +# TODO: how about ts.OffsetType? +@dataclasses.dataclass(frozen=True) +class OffsetLiteralType(ts.TypeSpec): + value: IntegralScalar | common.Dimension + + +@dataclasses.dataclass(frozen=True) +class ListType(ts.DataType): + element_type: ts.DataType + + +@dataclasses.dataclass(frozen=True) +class IteratorType(ts.DataType, ts.CallableType): # todo: rename to iterator + position_dims: list[common.Dimension] | typing.Literal["unknown"] + defined_dims: list[common.Dimension] + element_type: ts.DataType + + +@dataclasses.dataclass(frozen=True) +class StencilClosureType(ts.TypeSpec): + domain: DomainType + stencil: ts.FunctionType + output: ts.FieldType | ts.TupleType # todo: validate tuple of fields + inputs: list[ts.FieldType] + + +@dataclasses.dataclass(frozen=True) +class FencilType(ts.TypeSpec): + params: list[ts.DataType] + closures: list[StencilClosureType] diff --git a/src/gt4py/next/program_processors/formatters/type_check.py b/src/gt4py/next/program_processors/formatters/type_check.py deleted file mode 100644 index 03aeef1264..0000000000 --- a/src/gt4py/next/program_processors/formatters/type_check.py +++ /dev/null @@ -1,32 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from typing import Any - -from gt4py.next.iterator import ir as itir, type_inference -from gt4py.next.iterator.transforms import apply_common_transforms, global_tmps -from gt4py.next.program_processors.processor_interface import program_formatter - - -@program_formatter -def check_type_inference(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: - type_inference.pprint(type_inference.infer(program, offset_provider=kwargs["offset_provider"])) - transformed = apply_common_transforms( - program, lift_mode=kwargs.get("lift_mode"), offset_provider=kwargs["offset_provider"] - ) - if isinstance(transformed, global_tmps.FencilWithTemporaries): - transformed = transformed.fencil - return type_inference.pformat( - type_inference.infer(transformed, offset_provider=kwargs["offset_provider"]) - ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 566b4b5a20..9ed8288218 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -18,14 +18,11 @@ from dace.sdfg.state import LoopRegion import gt4py.eve as eve -from gt4py.next import Dimension, DimensionKind, type_inference as next_typing +from gt4py.next import Dimension, DimensionKind from gt4py.next.common import Connectivity -from gt4py.next.iterator import ( - ir as itir, - transforms as itir_transforms, - type_inference as itir_typing, -) +from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef +from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.type_system import type_info, type_specifications as ts, type_translation as tt from .itir_to_tasklet import ( @@ -156,7 +153,6 @@ class ItirToSDFG(eve.NodeVisitor): storage_types: dict[str, ts.TypeSpec] column_axis: Optional[Dimension] offset_provider: dict[str, Any] - node_types: dict[int, next_typing.Type] unique_id: int use_field_canonical_representation: bool @@ -202,6 +198,7 @@ def add_storage_for_temporaries( # Here we collect these values in a symbol map. tmp_ids = set(tmp.id for tmp in self.tmps) for sym in node_params: + breakpoint() if sym.id not in tmp_ids and sym.kind != "Iterator": name_ = str(sym.id) type_ = self.storage_types[name_] @@ -214,6 +211,7 @@ def add_storage_for_temporaries( # We visit the domain of the temporary field, passing the set of available symbols. assert isinstance(tmp.domain, itir.FunCall) + breakpoint() self.node_types.update(itir_typing.infer_all(tmp.domain)) domain_ctx = Context(program_sdfg, defs_state, symbol_map) tmp_domain = self._visit_domain(tmp.domain, domain_ctx) @@ -275,12 +273,12 @@ def get_output_nodes( self, closure: itir.StencilClosure, sdfg: dace.SDFG, state: dace.SDFGState ) -> dict[str, dace.nodes.AccessNode]: # Visit output node, which could be a `make_tuple` expression, to collect the required access nodes - output_symbols_pass = GatherOutputSymbolsPass(sdfg, state, self.node_types) + output_symbols_pass = GatherOutputSymbolsPass(sdfg, state) output_symbols_pass.visit(closure.output) # Visit output node again to generate the corresponding tasklet context = Context(sdfg, state, output_symbols_pass.symbol_refs) translator = PythonTaskletCodegen( - self.offset_provider, context, self.node_types, self.use_field_canonical_representation + self.offset_provider, context, self.use_field_canonical_representation ) output_nodes = flatten_list(translator.visit(closure.output)) return {node.value.data: node.value for node in output_nodes} @@ -289,7 +287,8 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg = dace.SDFG(name=node.id) program_sdfg.debuginfo = dace_debuginfo(node) entry_state = program_sdfg.add_state("program_entry", is_start_block=True) - self.node_types = itir_typing.infer_all(node) + + node = itir_type_inference.infer(node, offset_provider=self.offset_provider) # Filter neighbor tables from offset providers. neighbor_tables = get_used_connectivities(node, self.offset_provider) @@ -675,7 +674,6 @@ def _visit_scan_stencil_closure( lambda_domain, input_arrays, connectivity_arrays, - self.node_types, self.use_field_canonical_representation, ) @@ -760,7 +758,6 @@ def _visit_parallel_stencil_closure( index_domain, input_arrays, connectivity_arrays, - self.node_types, self.use_field_canonical_representation, ) @@ -785,7 +782,6 @@ def _visit_domain( translator = PythonTaskletCodegen( self.offset_provider, context, - self.node_types, self.use_field_canonical_representation, ) lb = translator.visit(lower_bound)[0] diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 778846b218..6da819d20d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -22,11 +22,11 @@ import gt4py.eve.codegen from gt4py import eve -from gt4py.next import Dimension, type_inference as next_typing +from gt4py.next import Dimension from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value, Connectivity -from gt4py.next.iterator import ir as itir, type_inference as itir_typing +from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir import FunCall, Lambda -from gt4py.next.iterator.type_inference import Val +from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_specifications as ts from .utility import ( @@ -54,10 +54,17 @@ } -def itir_type_as_dace_type(type_: next_typing.Type): - if isinstance(type_, itir_typing.Primitive): - return _TYPE_MAPPING[type_.name] - raise NotImplementedError() +def itir_type_as_dace_type(type_: ts.TypeSpec): + # TODO: cleanup + dtype: ts.TypeSpec + if isinstance(type_, ts.FieldType): + dtype = type_.dtype + elif isinstance(type_, it_ts.IteratorType): + dtype = type_.element_type + else: + dtype = type_ + assert isinstance(dtype, ts.ScalarType) + return _TYPE_MAPPING[dtype.kind.name.lower()] def get_reduce_identity_value(op_name_: str, type_: Any): @@ -567,7 +574,6 @@ def build_if_state(arg, state): node_taskgen = PythonTaskletCodegen( transformer.offset_provider, node_context, - transformer.node_types, transformer.use_field_canonical_representation, ) return node_taskgen.visit(arg) @@ -707,9 +713,7 @@ def builtin_cast( target_type = node_args[1] assert isinstance(target_type, itir.SymRef) expr = _MATH_BUILTINS_MAPPING[target_type.id].format(*internals) - node_type = transformer.node_types[id(node)] - assert isinstance(node_type, itir_typing.Val) - type_ = itir_type_as_dace_type(node_type.dtype) + type_ = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference return transformer.add_expr_tasklet( list(zip(args, internals)), expr, type_, "cast", dace_debuginfo=di ) @@ -858,7 +862,6 @@ def visit_Lambda(self, node: itir.Lambda, args: Optional[Sequence[TaskletExpr]] class GatherOutputSymbolsPass(eve.NodeVisitor): _sdfg: dace.SDFG _state: dace.SDFGState - _node_types: dict[int, next_typing.Type] _symbol_map: dict[str, TaskletExpr] @property @@ -866,39 +869,34 @@ def symbol_refs(self): """Dictionary of symbols referenced from the output expression.""" return self._symbol_map - def __init__(self, sdfg, state, node_types): + def __init__(self, sdfg, state): self._sdfg = sdfg self._state = state - self._node_types = node_types self._symbol_map = {} def visit_SymRef(self, node: itir.SymRef): param = str(node.id) if param not in _GENERAL_BUILTIN_MAPPING and param not in self._symbol_map: - node_type = self._node_types[id(node)] - assert isinstance(node_type, Val) access_node = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) self._symbol_map[param] = ValueExpr( - access_node, dtype=itir_type_as_dace_type(node_type.dtype) + access_node, + dtype=itir_type_as_dace_type(node.type), # type: ignore[arg-type] # ensure by type inference ) class PythonTaskletCodegen(gt4py.eve.codegen.TemplatedGenerator): offset_provider: dict[str, Any] context: Context - node_types: dict[int, next_typing.Type] use_field_canonical_representation: bool def __init__( self, offset_provider: dict[str, Any], context: Context, - node_types: dict[int, next_typing.Type], use_field_canonical_representation: bool, ): self.offset_provider = offset_provider self.context = context - self.node_types = node_types self.use_field_canonical_representation = use_field_canonical_representation def get_sorted_field_dimensions(self, dims: Sequence[str]): @@ -976,7 +974,6 @@ def visit_Lambda( lambda_taskgen = PythonTaskletCodegen( self.offset_provider, lambda_context, - self.node_types, self.use_field_canonical_representation, ) @@ -1019,9 +1016,7 @@ def visit_SymRef(self, node: itir.SymRef) -> list[ValueExpr | SymbolExpr] | Iter return value def visit_Literal(self, node: itir.Literal) -> list[SymbolExpr]: - node_type = self.node_types[id(node)] - assert isinstance(node_type, Val) - return [SymbolExpr(node.value, itir_type_as_dace_type(node_type.dtype))] + return [SymbolExpr(node.value, itir_type_as_dace_type(node.type))] def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: node.fun.location = node.location @@ -1266,9 +1261,7 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: def _visit_reduce(self, node: itir.FunCall): di = dace_debuginfo(node, self.context.body.debuginfo) - node_type = self.node_types[id(node)] - assert isinstance(node_type, itir_typing.Val) - reduce_dtype = itir_type_as_dace_type(node_type.dtype) + reduce_dtype = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference if len(node.args) == 1: assert ( @@ -1449,9 +1442,7 @@ def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args ] expr = fmt.format(*internals) - node_type = self.node_types[id(node)] - assert isinstance(node_type, itir_typing.Val) - type_ = itir_type_as_dace_type(node_type.dtype) + type_ = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference return self.add_expr_tasklet( expr_args, expr, @@ -1517,7 +1508,6 @@ def closure_to_tasklet_sdfg( domain: dict[str, str], inputs: Sequence[tuple[str, ts.TypeSpec]], connectivities: Sequence[tuple[dace.ndarray, str]], - node_types: dict[int, next_typing.Type], use_field_canonical_representation: bool, ) -> tuple[Context, Sequence[ValueExpr]]: body = dace.SDFG("tasklet_toplevel") @@ -1555,9 +1545,7 @@ def closure_to_tasklet_sdfg( body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) context = Context(body, state, symbol_map) - translator = PythonTaskletCodegen( - offset_provider, context, node_types, use_field_canonical_representation - ) + translator = PythonTaskletCodegen(offset_provider, context, use_field_canonical_representation) args = [itir.SymRef(id=name) for name, _ in inputs] if is_scan(node.stencil): diff --git a/src/gt4py/next/type_inference.py b/src/gt4py/next/type_inference.py deleted file mode 100644 index fe2add820b..0000000000 --- a/src/gt4py/next/type_inference.py +++ /dev/null @@ -1,353 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from __future__ import annotations - -import typing -from collections import abc - -import gt4py.eve as eve -from gt4py.eve import datamodels -from gt4py.eve.utils import noninstantiable - - -"""Customizable constraint-based inference. - -Based on the classical constraint-based two-pass type consisting of the following passes: - 1. Constraint collection - 2. Type unification -""" - - -V = typing.TypeVar("V", bound="TypeVar") -T = typing.TypeVar("T", bound="Type") - - -@noninstantiable -class Type(eve.Node, unsafe_hash=True): # type: ignore[call-arg] - """Base class for all types. - - The initial type constraint collection pass treats all instances of Type as hashable frozen - nodes, that is, no in-place modification is used. - - In the type unification phase however, in-place modifications are used for efficient - renaming/node replacements and special care is taken to handle hash values that change due to - those modifications. - """ - - def handle_constraint( - self, other: Type, add_constraint: abc.Callable[[Type, Type], None] - ) -> bool: - """Implement special type-specific constraint handling for `self` ≡ `other`. - - New constraints can be added using the provided callback (`add_constraint`). Should return - `True` if the provided constraint `self` ≡ `other` was handled, `False` otherwise. If the - handler detects an unsatisfiable constraint, raise a `TypeError`. - """ - return False - - -class TypeVar(Type): - """Type variable.""" - - idx: int - - _counter: typing.ClassVar[int] = 0 - - @staticmethod - def fresh_index() -> int: - TypeVar._counter += 1 - return TypeVar._counter - - @classmethod - def fresh(cls: type[V], **kwargs: typing.Any) -> V: - """Create a type variable with a previously unused index.""" - return cls(idx=cls.fresh_index(), **kwargs) - - -class _TypeVarReindexer(eve.NodeTranslator): - """Reindex type variables in a type tree.""" - - def __init__(self, indexer: abc.Callable[[dict[int, int]], int]): - super().__init__() - self.indexer = indexer - - def visit_TypeVar(self, node: V, *, index_map: dict[int, int]) -> V: - node = self.generic_visit(node, index_map=index_map) - new_index = index_map.setdefault(node.idx, self.indexer(index_map)) - new_values = { - typing.cast(str, k): (new_index if k == "idx" else v) - for k, v in node.iter_children_items() - } - return node.__class__(**new_values) - - -@typing.overload -def freshen(dtypes: list[T]) -> list[T]: ... - - -@typing.overload -def freshen(dtypes: T) -> T: ... - - -def freshen(dtypes: list[T] | T) -> list[T] | T: - """Re-instantiate `dtype` with fresh type variables.""" - if not isinstance(dtypes, list): - assert isinstance(dtypes, Type) - return freshen([dtypes])[0] - - def indexer(index_map: dict[int, int]) -> int: - return TypeVar.fresh_index() - - index_map = dict[int, int]() - return [_TypeVarReindexer(indexer).visit(dtype, index_map=index_map) for dtype in dtypes] - - -def reindex_vars(dtypes: typing.Any) -> typing.Any: - """Reindex all type variables, to have nice indices starting at zero.""" - - def indexer(index_map: dict[int, int]) -> int: - return len(index_map) - - index_map = dict[int, int]() - return _TypeVarReindexer(indexer).visit(dtypes, index_map=index_map) - - -class _FreeVariables(eve.NodeVisitor): - """Collect type variables within a type expression.""" - - def visit_TypeVar(self, node: TypeVar, *, free_variables: set[TypeVar]) -> None: - self.generic_visit(node, free_variables=free_variables) - free_variables.add(node) - - -def _free_variables(x: Type) -> set[TypeVar]: - """Collect type variables within a type expression.""" - fv = set[TypeVar]() - _FreeVariables().visit(x, free_variables=fv) - return fv - - -class _Dedup(eve.NodeTranslator): - """Deduplicate type nodes that have the same value but a different `id`.""" - - def visit(self, node: Type | typing.Sequence[Type], *, memo: dict[Type, Type]) -> typing.Any: # type: ignore[override] - if isinstance(node, Type): - return memo.setdefault(node, node) - return self.generic_visit(node, memo=memo) - - -def _assert_constituent_types(value: typing.Any, allowed_types: tuple[type, ...]) -> None: - if isinstance(value, tuple): - for el in value: - _assert_constituent_types(el, allowed_types) - else: - assert isinstance(value, allowed_types) - - -class _Renamer: - """Efficiently rename (that is, replace) nodes in a type expression. - - Works by collecting all parent nodes of all nodes in a tree. If a node should be replaced by - another, all referencing parent nodes can be found efficiently and modified in place. - - Note that all types have to be registered before they can be used in a `rename` call. - - Besides basic renaming, this also resolves `ValTuple` to full `Tuple` if possible after - renaming. - """ - - def __init__(self) -> None: - self._parents = dict[Type, list[tuple[Type, str]]]() - - def register(self, dtype: Type) -> None: - """Register a type for possible future renaming. - - Collects the parent nodes of all nodes in the type tree. - """ - - def collect_parents(node: Type) -> None: - for field, child in node.iter_children_items(): - if isinstance(child, Type): - self._parents.setdefault(child, []).append((node, typing.cast(str, field))) - collect_parents(child) - else: - _assert_constituent_types(child, (int, str)) - - collect_parents(dtype) - - def _update_node(self, node: Type, field: str, replacement: Type) -> None: - """Replace a field of a node by some other value. - - Basically performs `setattr(node, field, replacement)`. Further, updates the mapping of node - parents and handles the possibly changing hash value of the updated node. - """ - # Pop the node out of the parents dict as its hash could change after modification - popped = self._parents.pop(node, None) - - # Update the node's field - setattr(node, field, replacement) - - # Register `node` to be the new parent of `replacement` - self._parents.setdefault(replacement, []).append((node, field)) - - # Put back possible previous entries to the parents dict after possible hash change - if popped: - self._parents[node] = popped - - def rename(self, node: Type, replacement: Type) -> None: - """Rename/replace all occurrences of `node` to/by `replacement`.""" - try: - # Find parent nodes - nodes = self._parents.pop(node) - except KeyError: - return - - for node, field in nodes: - # Default case: just update a field value of the node - self._update_node(node, field, replacement) - - -class _Box(Type): - """Simple value holder, used for wrapping root nodes of a type tree. - - This makes sure that all root nodes have a parent node which can be updated by the `_Renamer`. - """ - - value: Type - - -class _Unifier: - """A classical type unifier (Robinson, 1971). - - Computes the most general type satisfying all given constraints. Uses a `_Renamer` for efficient - type variable renaming. - """ - - def __init__(self, dtypes: list[Type], constraints: set[tuple[Type, Type]]) -> None: - # Wrap the original `dtype` and all `constraints` to make sure they have a parent node and - # thus the root nodes are correctly handled by the renamer - self._dtypes = [_Box(value=dtype) for dtype in dtypes] - self._constraints = [(_Box(value=s), _Box(value=t)) for s, t in constraints] - - # Create a renamer and register `dtype` and all `constraints` types - self._renamer = _Renamer() - for dtype in self._dtypes: - self._renamer.register(dtype) - for s, t in self._constraints: - self._renamer.register(s) - self._renamer.register(t) - - def unify(self) -> tuple[list[Type] | Type, list[tuple[Type, Type]]]: - """Run the unification.""" - unsatisfiable_constraints = [] - while self._constraints: - constraint = self._constraints.pop() - try: - handled = self._handle_constraint(constraint) - if not handled: - # Try with swapped LHS and RHS - handled = self._handle_constraint(constraint[::-1]) - except TypeError: - # custom constraint handler raised an error as constraint is not satisfiable - # (contrary to just not handled) - handled = False - - if not handled: - unsatisfiable_constraints.append((constraint[0].value, constraint[1].value)) - - unboxed_dtypes = [dtype.value for dtype in self._dtypes] - - return unboxed_dtypes, unsatisfiable_constraints - - def _rename(self, x: Type, y: Type) -> None: - """Type renaming/replacement.""" - self._renamer.register(x) - self._renamer.register(y) - self._renamer.rename(x, y) - - def _add_constraint(self, x: Type, y: Type) -> None: - """Register a new constraint.""" - x = _Box(value=x) - y = _Box(value=y) - self._renamer.register(x) - self._renamer.register(y) - self._constraints.append((x, y)) - - def _handle_constraint(self, constraint: tuple[_Box, _Box]) -> bool: - """Handle a single constraint.""" - s, t = (c.value for c in constraint) - if s == t: - # Constraint is satisfied if LHS equals RHS - return True - - if type(s) is TypeVar: - assert s not in _free_variables(t) - # Just replace LHS by RHS if LHS is a type variable - self._rename(s, t) - return True - - if s.handle_constraint(t, self._add_constraint): - # Use a custom constraint handler if available - return True - - if type(s) is type(t): - assert s not in _free_variables(t) and t not in _free_variables(s) - assert datamodels.fields(s).keys() == datamodels.fields(t).keys() - for k in datamodels.fields(s).keys(): - sv = getattr(s, k) - tv = getattr(t, k) - if isinstance(sv, Type): - assert isinstance(tv, Type) - self._add_constraint(sv, tv) - else: - assert sv == tv - return True - - # Constraint handling failed - return False - - -@typing.overload -def unify( - dtypes: list[Type], constraints: set[tuple[Type, Type]] -) -> tuple[list[Type], list[tuple[Type, Type]]]: ... - - -@typing.overload -def unify( - dtypes: Type, constraints: set[tuple[Type, Type]] -) -> tuple[Type, list[tuple[Type, Type]]]: ... - - -def unify( - dtypes: list[Type] | Type, constraints: set[tuple[Type, Type]] -) -> tuple[list[Type] | Type, list[tuple[Type, Type]]]: - """ - Unify all given constraints. - - Returns the unified types and a list of unsatisfiable constraints. - """ - if isinstance(dtypes, Type): - result_types, unsatisfiable_constraints = unify([dtypes], constraints) - return result_types[0], unsatisfiable_constraints - - # Deduplicate type nodes, this can speed up later things a bit - memo = dict[Type, Type]() - dtypes = [_Dedup().visit(dtype, memo=memo) for dtype in dtypes] - constraints = {_Dedup().visit(c, memo=memo) for c in constraints} - del memo - - unifier = _Unifier(dtypes, constraints) - return unifier.unify() diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index a05b9afde8..646912bf3b 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -90,18 +90,21 @@ def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: @typing.overload def primitive_constituents( - symbol_type: ts.TypeSpec, with_path_arg: typing.Literal[False] = False + symbol_type: ts.TypeSpec, + with_path_arg: typing.Literal[False] = False, ) -> XIterable[ts.TypeSpec]: ... @typing.overload def primitive_constituents( - symbol_type: ts.TypeSpec, with_path_arg: typing.Literal[True] + symbol_type: ts.TypeSpec, + with_path_arg: typing.Literal[True], ) -> XIterable[tuple[ts.TypeSpec, tuple[int, ...]]]: ... def primitive_constituents( - symbol_type: ts.TypeSpec, with_path_arg: bool = False + symbol_type: ts.TypeSpec, + with_path_arg: bool = False, ) -> XIterable[ts.TypeSpec] | XIterable[tuple[ts.TypeSpec, tuple[int, ...]]]: """ Return the primitive types contained in a composite type. @@ -145,12 +148,11 @@ def __call__(self, *args: Any) -> _R: ... # TODO(havogt): the complicated typing is a hint that this function needs refactoring def apply_to_primitive_constituents( - symbol_type: ts.TypeSpec, - fun: (Callable[[ts.TypeSpec], _T] | Callable[[ts.TypeSpec, tuple[int, ...]], _T]), - _path: tuple[int, ...] = (), - *, + fun: Callable[..., _T], + *symbol_types: ts.TypeSpec, with_path_arg: bool = False, tuple_constructor: TupleConstructorType[_R] = lambda *elements: ts.TupleType(types=[*elements]), # type: ignore[assignment] # probably related to https://github.com/python/mypy/issues/10854 + _path: tuple[int, ...] = (), ) -> _T | _R: """ Apply function to all primitive constituents of a type. @@ -159,28 +161,42 @@ def apply_to_primitive_constituents( >>> tuple_type = ts.TupleType(types=[int_type, int_type]) >>> print( ... apply_to_primitive_constituents( - ... tuple_type, lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type) + ... lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type), + ... tuple_type, ... ) ... ) tuple[Field[[], int64], Field[[], int64]] + + >>> apply_to_primitive_constituents( + ... lambda primitive_type, path: (path, primitive_type), + ... tuple_type, + ... with_path_arg=True, + ... tuple_constructor=lambda *elements: dict(elements), + ... ) + {(0,): ScalarType(kind=, shape=None), (1,): ScalarType(kind=, shape=None)} """ - if isinstance(symbol_type, ts.TupleType): + + # TODO: check structure matches + if isinstance(symbol_types[0], ts.TupleType): + assert all(isinstance(symbol_type, ts.TupleType) for symbol_type in symbol_types) return tuple_constructor( *[ apply_to_primitive_constituents( - el, fun, + *el_types, _path=(*_path, i), with_path_arg=with_path_arg, tuple_constructor=tuple_constructor, ) - for i, el in enumerate(symbol_type.types) + for i, el_types in enumerate( + zip(*(symbol_type.types for symbol_type in symbol_types)) # type: ignore[attr-defined] # ensured by assert above + ) ] ) if with_path_arg: - return fun(symbol_type, _path) # type: ignore[call-arg] # mypy not aware of `with_path_arg` + return fun(*symbol_types, path=_path) else: - return fun(symbol_type) # type: ignore[call-arg] # mypy not aware of `with_path_arg` + return fun(*symbol_types) def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType: @@ -453,7 +469,9 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: return False -def promote(*types: ts.FieldType | ts.ScalarType) -> ts.FieldType | ts.ScalarType: +def promote( + *types: ts.FieldType | ts.ScalarType, always_field: bool = False +) -> ts.FieldType | ts.ScalarType: """ Promote a set of field or scalar types to a common type. @@ -476,7 +494,7 @@ def promote(*types: ts.FieldType | ts.ScalarType) -> ts.FieldType | ts.ScalarTyp ... ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. """ - if all(isinstance(type_, ts.ScalarType) for type_ in types): + if not always_field and all(isinstance(type_, ts.ScalarType) for type_ in types): if not all(type_ == types[0] for type_ in types): raise ValueError("Could not promote scalars of different dtype (not implemented).") if not all(type_.shape is None for type_ in types): # type: ignore[union-attr] @@ -504,17 +522,28 @@ def return_type( @return_type.register def return_type_func( - func_type: ts.FunctionType, *, with_args: list[ts.TypeSpec], with_kwargs: dict[str, ts.TypeSpec] + func_type: ts.FunctionType, + *, + with_args: list[ts.TypeSpec], + with_kwargs: dict[str, ts.TypeSpec], ) -> ts.TypeSpec: return func_type.returns @return_type.register def return_type_field( - field_type: ts.FieldType, *, with_args: list[ts.TypeSpec], with_kwargs: dict[str, ts.TypeSpec] + field_type: ts.FieldType, + *, + with_args: list[ts.TypeSpec], + with_kwargs: dict[str, ts.TypeSpec], ) -> ts.FieldType: try: - accepts_args(field_type, with_args=with_args, with_kwargs=with_kwargs, raise_exception=True) + accepts_args( + field_type, + with_args=with_args, + with_kwargs=with_kwargs, + raise_exception=True, + ) except ValueError as ex: raise ValueError("Could not deduce return type of invalid remap operation.") from ex @@ -625,7 +654,8 @@ def structural_function_signature_incompatibilities( missing_positional_args = [] for i, arg_type in zip( - range(len(func_type.pos_only_args), num_pos_params), func_type.pos_or_kw_args.keys() + range(len(func_type.pos_only_args), num_pos_params), + func_type.pos_or_kw_args.keys(), ): if args[i] is UNDEFINED_ARG: missing_positional_args.append(f"'{arg_type}'") @@ -675,7 +705,7 @@ def function_signature_incompatibilities_func( num_pos_params = len(func_type.pos_only_args) + len(func_type.pos_or_kw_args) assert len(args) >= num_pos_params for i, (a_arg, b_arg) in enumerate( - zip(func_type.pos_only_args + list(func_type.pos_or_kw_args.values()), args) + zip(list(func_type.pos_only_args) + list(func_type.pos_or_kw_args.values()), args) ): if ( b_arg is not UNDEFINED_ARG @@ -697,7 +727,9 @@ def function_signature_incompatibilities_func( @function_signature_incompatibilities.register def function_signature_incompatibilities_field( - field_type: ts.FieldType, args: list[ts.TypeSpec], kwargs: dict[str, ts.TypeSpec] + field_type: ts.FieldType, + args: list[ts.TypeSpec], + kwargs: dict[str, ts.TypeSpec], ) -> Iterator[str]: if len(args) != 1: yield f"Function takes 1 argument, but {len(args)} were given." diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 9487d2f12b..d80495cb68 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -11,16 +11,22 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - from dataclasses import dataclass -from typing import Iterator, Optional +from typing import Iterator, Optional, Sequence, Union from gt4py.eve.type_definitions import IntEnum +from gt4py.eve.utils import content_hash from gt4py.next import common as func_common class TypeSpec: - pass + def __hash__(self) -> int: + return hash(content_hash(self)) + + def __init_subclass__(cls): + cls.__hash__ = TypeSpec.__hash__ + + # TODO: use __init_subclass__ @dataclass(frozen=True) @@ -115,10 +121,10 @@ def __str__(self) -> str: @dataclass(frozen=True) class FunctionType(TypeSpec, CallableType): - pos_only_args: list[DataType | DeferredType] - pos_or_kw_args: dict[str, DataType | DeferredType] - kw_only_args: dict[str, DataType | DeferredType] - returns: DataType | DeferredType | VoidType + pos_only_args: Sequence[TypeSpec] + pos_or_kw_args: dict[str, TypeSpec] + kw_only_args: dict[str, TypeSpec] + returns: Union[TypeSpec] def __str__(self) -> str: arg_strs = [str(arg) for arg in self.pos_only_args] diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 396d4c06b6..85f91abc03 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -162,8 +162,8 @@ def from_type_hint( # TODO(tehrengruber): print better error when no return type annotation is given return ts.FunctionType( - pos_only_args=new_args, # type: ignore[arg-type] # checked in assert - pos_or_kw_args=kwargs, # type: ignore[arg-type] # checked in assert + pos_only_args=new_args, + pos_or_kw_args=kwargs, kw_only_args={}, # TODO returns=returns, ) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 86eac69712..2d6db84644 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -92,7 +92,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ITIR_PRETTY_PRINTER = ( "gt4py.next.program_processors.formatters.pretty_print.format_itir_and_check" ) - ITIR_TYPE_CHECKER = "gt4py.next.program_processors.formatters.type_check.check_type_inference" LISP_FORMATTER = "gt4py.next.program_processors.formatters.lisp.format_lisp" diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 388849bf09..943da56427 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -236,6 +236,7 @@ def simple_mesh() -> MeshDescriptor: C2E.value: gtx.NeighborTableOffsetProvider( c2e_arr, Cell, Edge, 4, has_skip_values=False ), + # "KDim": KDim }, ) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py index b24cec7d02..67927437bc 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py @@ -56,11 +56,6 @@ def test_simple_indirection(program_processor): pytest.xfail("Applied shifts in if_ statements are not supported in TraceShift pass.") - if program_processor in [type_check.check_type_inference, gtfn_format_sourcecode]: - pytest.xfail( - "We only support applied shifts in type_inference." - ) # TODO fix test or generalize itir? - shape = [8] inp = gtx.as_field([IDim], np.arange(0, shape[0] + 2), origin={IDim: 1}) rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index bcea9e0901..0cf6e61b27 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -23,6 +23,7 @@ from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +# cross-reference why new type inference does not support this @fundef def ldif(d): return lambda inp: deref(shift(d, -1)(inp)) - deref(inp) @@ -47,6 +48,16 @@ def lap(inp): return dif2(i)(inp) + dif2(j)(inp) +@fundef +def lap2(inp): + return -4.0 * deref(inp) + ( + deref(shift(i, 1)(inp)) + + deref(shift(i, -1)(inp)) + + deref(shift(j, 1)(inp)) + + deref(shift(j, -1)(inp)) + ) + + IDim = gtx.Dimension("IDim") JDim = gtx.Dimension("JDim") KDim = gtx.Dimension("KDim") diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index c9406884e6..b890c66287 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -69,7 +69,6 @@ def lift_mode(request): # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation (next_tests.definitions.ProgramFormatterId.LISP_FORMATTER, False), (next_tests.definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), - (next_tests.definitions.ProgramFormatterId.ITIR_TYPE_CHECKER, False), (next_tests.definitions.ProgramFormatterId.GTFN_CPP_FORMATTER, False), ] + OPTIONAL_PROCESSORS, diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 731c163343..f73828e4f6 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -12,1068 +12,303 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import numpy as np +import pytest -import gt4py.next as gtx -from gt4py.next.iterator import ir, type_inference as ti +from gt4py import eve +from gt4py.next import common +from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_call_to +from gt4py.next.iterator.type_system import ( + inference as itir_type_inference, + type_specifications as it_ts, +) +from gt4py.next.type_system import type_specifications as ts + +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import simple_mesh + +from next_tests.integration_tests.cases import ( + C2E, + E2V, + V2E, + E2VDim, + IDim, + Ioff, + JDim, + KDim, + Koff, + V2EDim, + Vertex, + Edge, + mesh_descriptor, + exec_alloc_descriptor, + unstructured_case, +) +bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +float64_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +float64_list_type = it_ts.ListType(element_type=float64_type) +int_list_type = it_ts.ListType(element_type=int_type) -def test_unsatisfiable_constraints(): - a = ir.Sym(id="a", dtype=("float32", False)) - b = ir.Sym(id="b", dtype=("int32", False)) - - testee = im.lambda_(a, b)(im.plus("a", "b")) - - # The type inference uses a set to store the constraints. Since the TypeVar indices use a - # global counter the constraint resolution order depends on previous runs of the inference. - # To avoid false positives we just ignore which way the constraints have been resolved. - # (The previous description has never been verified.) - expected_error = [ - ( - "Type inference failed: Can not satisfy constraints:\n" - " Primitive(name='int32') ≡ Primitive(name='float32')" - ), - ( - "Type inference failed: Can not satisfy constraints:\n" - " Primitive(name='float32') ≡ Primitive(name='int32')" - ), - ] +bool_i_field = ts.FieldType(dims=[IDim], dtype=bool_type) +bool_vertex_k_field = ts.FieldType(dims=[Vertex, KDim], dtype=bool_type) +bool_edge_k_field = ts.FieldType(dims=[Edge, KDim], dtype=bool_type) - try: - inferred = ti.infer(testee) - except ti.UnsatisfiableConstraintsError as e: - assert str(e) in expected_error +it_on_v_of_e_type = it_ts.IteratorType( + position_dims=[Vertex, KDim], defined_dims=[Edge, KDim], element_type=int_type +) +it_on_e_of_e_type = it_ts.IteratorType( + position_dims=[Edge, KDim], defined_dims=[Edge, KDim], element_type=int_type +) -def test_unsatisfiable_constraints(): - a = ir.Sym(id="a", dtype=("float32", False)) - b = ir.Sym(id="b", dtype=("int32", False)) +it_ijk_type = it_ts.IteratorType( + position_dims=[IDim, JDim, KDim], defined_dims=[IDim, JDim, KDim], element_type=int_type +) - testee = im.lambda_(a, b)(im.plus("a", "b")) - # TODO(tehrengruber): For whatever reason the ordering in the error message is not - # deterministic. Ignoring for now, as we want to refactor the type inference anyway. - expected_error = [ +def expression_test_cases(): + return ( + # itir expr, type + (im.call("abs")(1), int_type), + (im.call("power")(2.0, 2), float64_type), + (im.plus(1, 2), int_type), + (im.eq(1, 2), bool_type), + (im.deref(im.ref("it", it_on_e_of_e_type)), it_on_e_of_e_type.element_type), + (im.call("can_deref")(im.ref("it", it_on_e_of_e_type)), bool_type), + (im.if_(True, 1, 2), int_type), + (im.call("make_const_list")(True), it_ts.ListType(element_type=bool_type)), + (im.call("list_get")(0, im.ref("l", it_ts.ListType(element_type=bool_type))), bool_type), ( - "Type inference failed: Can not satisfy constraints:\n" - " Primitive(name='int32') ≡ Primitive(name='float32')" + im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), + it_ts.NamedRangeType(dim=Vertex), ), ( - "Type inference failed: Can not satisfy constraints:\n" - " Primitive(name='float32') ≡ Primitive(name='int32')" - ), - ] - - try: - inferred = ti.infer(testee) - except ti.UnsatisfiableConstraintsError as e: - assert str(e) in expected_error - - -def test_sym_ref(): - testee = ir.SymRef(id="x") - expected = ti.TypeVar(idx=0) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "T₀" - - -def test_bool_literal(): - testee = im.literal_from_value(False) - expected = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="bool"), size=ti.TypeVar(idx=0)) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "bool⁰" - - -def test_int_literal(): - testee = im.literal("3", "int32") - expected = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="int32"), size=ti.TypeVar(idx=0)) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "int32⁰" - - -def test_float_literal(): - testee = im.literal("3.0", "float64") - expected = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="float64"), size=ti.TypeVar(idx=0)) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "float64⁰" - - -def test_deref(): - testee = ir.SymRef(id="deref") - expected = ti.FunctionType( - args=ti.Tuple.from_elems( - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.TypeVar(idx=2), - defined_loc=ti.TypeVar(idx=2), - ) - ), - ret=ti.Val(kind=ti.Value(), dtype=ti.TypeVar(idx=0), size=ti.TypeVar(idx=1)), - ) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "(It[T₂, T₂, T₀¹]) → T₀¹" - - -def test_deref_call(): - testee = ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="x")]) - expected = ti.Val(kind=ti.Value(), dtype=ti.TypeVar(idx=0), size=ti.TypeVar(idx=1)) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "T₀¹" - - -def test_lambda(): - testee = ir.Lambda(params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) - expected = ti.FunctionType(args=ti.Tuple.from_elems(ti.TypeVar(idx=0)), ret=ti.TypeVar(idx=0)) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "(T₀) → T₀" - - -def test_typed_lambda(): - testee = ir.Lambda( - params=[ir.Sym(id="x", kind="Iterator", dtype=("float64", False))], expr=ir.SymRef(id="x") - ) - expected_val = ti.Val( - kind=ti.Iterator(), - dtype=ti.Primitive(name="float64"), - size=ti.TypeVar(idx=0), - current_loc=ti.TypeVar(idx=1), - defined_loc=ti.TypeVar(idx=2), - ) - expected = ti.FunctionType(args=ti.Tuple.from_elems(expected_val), ret=expected_val) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "(It[T₁, T₂, float64⁰]) → It[T₁, T₂, float64⁰]" - - -def test_plus(): - testee = ir.SymRef(id="plus") - t = ti.Val(kind=ti.Value(), dtype=ti.TypeVar(idx=0), size=ti.TypeVar(idx=1)) - expected = ti.FunctionType(args=ti.Tuple.from_elems(t, t), ret=t) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "(T₀¹, T₀¹) → T₀¹" - - -def test_power(): - testee = im.call("power")(im.literal_from_value(1.0), im.literal_from_value(2)) - expected = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="float64"), size=ti.TypeVar(idx=0)) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "float64⁰" - - -def test_eq(): - testee = ir.SymRef(id="eq") - t = ti.Val(kind=ti.Value(), dtype=ti.TypeVar(idx=0), size=ti.TypeVar(idx=1)) - expected = ti.FunctionType( - args=ti.Tuple.from_elems(t, t), - ret=ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="bool"), size=ti.TypeVar(idx=1)), - ) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "(T₀¹, T₀¹) → bool¹" - - -def test_if(): - testee = ir.SymRef(id="if_") - c = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="bool"), size=ti.TypeVar(idx=0)) - t = ti.TypeVar(idx=1) - expected = ti.FunctionType(args=ti.Tuple.from_elems(c, t, t), ret=t) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "(bool⁰, T₁, T₁) → T₁" - - -def test_if_call(): - testee = im.if_("cond", im.literal("1", "int32"), im.literal("1", "int32")) - expected = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="int32"), size=ti.TypeVar(idx=0)) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "int32⁰" - - -def test_not(): - testee = ir.SymRef(id="not_") - t = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="bool"), size=ti.TypeVar(idx=0)) - expected = ti.FunctionType(args=ti.Tuple.from_elems(t), ret=t) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "(bool⁰) → bool⁰" - - -def test_and(): - testee = ir.SymRef(id="and_") - t = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="bool"), size=ti.TypeVar(idx=0)) - expected = ti.FunctionType(args=ti.Tuple.from_elems(t, t), ret=t) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "(bool⁰, bool⁰) → bool⁰" - - -def test_cast(): - testee = ir.FunCall( - fun=ir.SymRef(id="cast_"), - args=[im.literal("1.", "float64"), ir.SymRef(id="int64")], - ) - expected = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="int64"), size=ti.TypeVar(idx=0)) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "int64⁰" - - -def test_lift(): - testee = ir.SymRef(id="lift") - expected = ti.FunctionType( - args=ti.Tuple.from_elems( - ti.FunctionType( - args=ti.ValTuple( - kind=ti.Iterator(), - dtypes=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.TypeVar(idx=2), - defined_locs=ti.TypeVar(idx=3), - ), - ret=ti.Val(kind=ti.Value(), dtype=ti.TypeVar(idx=4), size=ti.TypeVar(idx=1)), - ) - ), - ret=ti.FunctionType( - args=ti.ValTuple( - kind=ti.Iterator(), - dtypes=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.TypeVar(idx=5), - defined_locs=ti.TypeVar(idx=3), + im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) ), - ret=ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=4), - size=ti.TypeVar(idx=1), - current_loc=ti.TypeVar(idx=5), - defined_loc=ti.TypeVar(idx=2), + it_ts.DomainType(dims=[IDim]), + ), + ( + im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1) ), + it_ts.DomainType(dims=[Vertex]), ), - ) - inferred = ti.infer(testee) - assert inferred == expected - assert ( - ti.pformat(inferred) - == "((It[T₂, …₃, T¹], …)₀ → T₄¹) → (It[T₅, …₃, T¹], …)₀ → It[T₅, T₂, T₄¹]" - ) - - -def test_lift_lambda_without_args(): - testee = ir.FunCall( - fun=ir.SymRef(id="lift"), args=[ir.Lambda(params=[], expr=ir.SymRef(id="x"))] - ) - expected = ti.FunctionType( - args=ti.ValTuple( - kind=ti.Iterator(), - dtypes=ti.EmptyTuple(), - size=ti.TypeVar(idx=0), - current_loc=ti.TypeVar(idx=1), - defined_locs=ti.EmptyTuple(), + ( + im.make_tuple(im.ref("a", int_type), im.ref("b", bool_type)), + ts.TupleType(types=[int_type, bool_type]), ), - ret=ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=2), - size=ti.TypeVar(idx=0), - current_loc=ti.TypeVar(idx=1), - defined_loc=ti.TypeVar(idx=3), + (im.tuple_get(0, im.make_tuple(im.ref("a", int_type), im.ref("b", bool_type))), int_type), + (im.tuple_get(1, im.make_tuple(im.ref("a", int_type), im.ref("b", bool_type))), bool_type), + ( + im.neighbors("E2V", im.ref("a", it_on_e_of_e_type)), + it_ts.ListType(element_type=it_on_e_of_e_type.element_type), ), - ) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "() → It[T₁, T₃, T₂⁰]" - - -def test_lift_application(): - testee = ir.FunCall(fun=ir.SymRef(id="lift"), args=[ir.SymRef(id="deref")]) - expected = ti.FunctionType( - args=ti.Tuple.from_elems( - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.TypeVar(idx=2), - defined_loc=ti.TypeVar(idx=3), - ) + (im.call("cast_")(1, "int32"), int_type), + # lift + # scan + ( + im.map_(im.ref("plus"))(im.ref("a", int_list_type), im.ref("b", int_list_type)), + int_list_type, ), - ret=ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.TypeVar(idx=2), - defined_loc=ti.TypeVar(idx=3), + (im.call(im.call("reduce")("plus", 0))(im.ref("l", int_list_type)), int_type), + ( + im.call( + im.call("reduce")( + im.lambda_("acc", "a", "b")( + im.make_tuple( + im.plus(im.tuple_get(0, "acc"), "a"), + im.plus(im.tuple_get(1, "acc"), "b"), + ) + ), + im.make_tuple(0, 0.0), + ) + )(im.ref("la", int_list_type), im.ref("lb", float64_list_type)), + ts.TupleType(types=[int_type, float64_type]), ), + (im.shift("V2E", 1)(im.ref("it", it_on_v_of_e_type)), it_on_e_of_e_type), + (im.shift("Ioff", 1)(im.ref("it", it_ijk_type)), it_ijk_type), ) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "(It[T₂, T₃, T₀¹]) → It[T₂, T₃, T₀¹]" - -def test_lifted_call(): - testee = ir.FunCall( - fun=ir.FunCall(fun=ir.SymRef(id="lift"), args=[ir.SymRef(id="deref")]), - args=[ir.SymRef(id="x")], - ) - expected = ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.TypeVar(idx=2), - defined_loc=ti.TypeVar(idx=3), - ) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "It[T₂, T₃, T₀¹]" +@pytest.mark.parametrize("test_case", expression_test_cases()) +def test_expression_type(test_case): + mesh = simple_mesh() + offset_provider = {**mesh.offset_provider, "Ioff": IDim, "Joff": JDim, "Koff": KDim} -def test_make_tuple(): - testee = ir.FunCall( - fun=ir.SymRef(id="make_tuple"), - args=[ - im.literal("True", "bool"), - im.literal("42.0", "float64"), - ir.SymRef(id="x"), - ], - ) - expected = ti.Val( - kind=ti.Value(), - dtype=ti.Tuple.from_elems( - ti.Primitive(name="bool"), ti.Primitive(name="float64"), ti.TypeVar(idx=0) - ), - size=ti.TypeVar(idx=1), - ) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "(bool, float64, T₀)¹" + testee, expected_type = test_case + result = itir_type_inference.infer(testee, offset_provider=offset_provider) + assert result.type == expected_type -def test_tuple_get(): - testee = ir.FunCall( - fun=ir.SymRef(id="tuple_get"), - args=[ - im.literal("1", ir.INTEGER_INDEX_BUILTIN), - ir.FunCall( - fun=ir.SymRef(id="make_tuple"), - args=[ - im.literal("True", "bool"), - im.literal("42.0", "float64"), - ], - ), - ], - ) - expected = ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="float64"), size=ti.TypeVar(idx=0)) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "float64⁰" +def test_adhoc_polymorphism(): + func = im.lambda_("a")(im.lambda_("b")(im.make_tuple("a", "b"))) + testee = im.call(im.call(func)(im.ref("a_", bool_type)))(im.ref("b_", int_type)) + result = itir_type_inference.infer(testee, offset_provider={}) -def test_tuple_get_in_lambda(): - testee = ir.Lambda( - params=[ir.Sym(id="x")], - expr=ir.FunCall( - fun=ir.SymRef(id="tuple_get"), - args=[im.literal("1", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], - ), - ) - expected = ti.FunctionType( - args=ti.Tuple.from_elems( - ti.Val( - kind=ti.TypeVar(idx=0), - dtype=ti.Tuple( - front=ti.TypeVar(idx=1), - others=ti.Tuple(front=ti.TypeVar(idx=2), others=ti.TypeVar(idx=3)), - ), - size=ti.TypeVar(idx=4), - ) - ), - ret=ti.Val(kind=ti.TypeVar(idx=0), dtype=ti.TypeVar(idx=2), size=ti.TypeVar(idx=4)), - ) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "(ItOrVal₀[(T₁, T₂):T₃⁴]) → ItOrVal₀[T₂⁴]" + assert result.type == ts.TupleType(types=[bool_type, int_type]) -def test_neighbors(): - testee = ir.FunCall( - fun=ir.SymRef(id="neighbors"), args=[ir.OffsetLiteral(value="V2E"), ir.SymRef(id="it")] - ) - expected = ti.Val( - kind=ti.Value(), - dtype=ti.List( - dtype=ti.TypeVar(idx=0), max_length=ti.TypeVar(idx=1), has_skip_values=ti.TypeVar(idx=2) - ), - size=ti.TypeVar(idx=3), - ) - inferred = ti.infer(testee) - assert expected == inferred - assert ti.pformat(inferred) == "L[T₀, T₁, T₂]³" +# TODO: test failure when something is not typed +# TODO: test lift with no args +# TODO: lambda function that is not called +# TODO: partially applied function in a let +# TODO: function calling itself +# TODO: lambda function called with different argument types +# reduce(λ(_fuse_maps_1, _fuse_maps_3, _fuse_maps_4) → _fuse_maps_1 + (_fuse_maps_3 + _fuse_maps_4), 0)( +# neighbors(V2Eₒ, in_edges), neighbors(V2Eₒ, in_edges) +# ) -def test_reduce(): - reduction_f = ir.Lambda( - params=[ir.Sym(id="acc"), ir.Sym(id="x"), ir.Sym(id="y")], - expr=ir.FunCall( - fun=ir.SymRef(id="plus"), - args=[ - ir.SymRef(id="acc"), - ir.FunCall( - fun=ir.SymRef(id="cast_"), # cast to the type of `init` - args=[ - ir.FunCall( - fun=ir.SymRef(id="multiplies"), - args=[ - ir.SymRef(id="x"), - ir.FunCall( - fun=ir.SymRef( - id="cast_" - ), # force `x` to be of type `float64` -> `y` is unconstrained - args=[ir.SymRef(id="y"), ir.SymRef(id="float64")], - ), - ], - ), - ir.SymRef(id="int64"), - ], - ), - ], - ), - ) - testee = ir.FunCall(fun=ir.SymRef(id="reduce"), args=[reduction_f, im.literal("0", "int64")]) - expected = ti.FunctionType( - args=ti.ValListTuple( - kind=ti.Value(), - list_dtypes=ti.Tuple.from_elems(ti.Primitive(name="float64"), ti.TypeVar(idx=0)), - max_length=ti.TypeVar(idx=1), - has_skip_values=ti.TypeVar(idx=2), - size=ti.TypeVar(idx=3), - ), - ret=ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="int64"), size=ti.TypeVar(idx=3)), - ) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "(L[float64, T₁, T₂]³, L[T₀, T₁, T₂]³) → int64³" +def test_aliased_function(): + testee = im.let("f", im.lambda_("x")("x"))(im.call("f")(1)) + result = itir_type_inference.infer(testee, offset_provider={}) -def test_scan(): - scan_f = ir.Lambda( - params=[ir.Sym(id="acc"), ir.Sym(id="x"), ir.Sym(id="y")], - expr=ir.FunCall( - fun=ir.SymRef(id="plus"), - args=[ - ir.SymRef(id="acc"), - ir.FunCall( - fun=ir.SymRef(id="multiplies"), - args=[ - ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="x")]), - ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="y")]), - ], - ), - ], - ), - ) - testee = ir.FunCall( - fun=ir.SymRef(id="scan"), - args=[scan_f, im.literal("True", "bool"), im.literal("0", "int64")], + assert result.args[0].type == ts.FunctionType( + pos_only_args=[int_type], pos_or_kw_args={}, kw_only_args={}, returns=int_type ) - expected = ti.FunctionType( - args=ti.Tuple.from_elems( - ti.Val( - kind=ti.Iterator(), - dtype=ti.Primitive(name="int64"), - size=ti.Column(), - current_loc=ti.TypeVar(idx=0), - defined_loc=ti.TypeVar(idx=0), - ), - ti.Val( - kind=ti.Iterator(), - dtype=ti.Primitive(name="int64"), - size=ti.Column(), - current_loc=ti.TypeVar(idx=0), - defined_loc=ti.TypeVar(idx=0), - ), - ), - ret=ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="int64"), size=ti.Column()), - ) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "(It[T₀, T₀, int64ᶜ], It[T₀, T₀, int64ᶜ]) → int64ᶜ" - - -def test_shift(): - testee = ir.FunCall( - fun=ir.SymRef(id="shift"), args=[ir.OffsetLiteral(value="i"), ir.OffsetLiteral(value=1)] - ) - expected = ti.FunctionType( - args=ti.Tuple.from_elems( - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.TypeVar(idx=2), - defined_loc=ti.TypeVar(idx=3), - ) - ), - ret=ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.TypeVar(idx=4), - defined_loc=ti.TypeVar(idx=3), - ), - ) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "(It[T₂, T₃, T₀¹]) → It[T₄, T₃, T₀¹]" - - -def test_shift_with_cartesian_offset_provider(): - testee = ir.FunCall( - fun=ir.SymRef(id="shift"), args=[ir.OffsetLiteral(value="i"), ir.OffsetLiteral(value=1)] - ) - expected = ti.FunctionType( - args=ti.Tuple.from_elems( - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.TypeVar(idx=2), - defined_loc=ti.TypeVar(idx=3), - ) - ), - ret=ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.TypeVar(idx=2), - defined_loc=ti.TypeVar(idx=3), - ), - ) - offset_provider = {"i": gtx.Dimension("IDim")} - inferred = ti.infer(testee, offset_provider=offset_provider) - assert inferred == expected - assert ti.pformat(inferred) == "(It[T₂, T₃, T₀¹]) → It[T₂, T₃, T₀¹]" - - -def test_partial_shift_with_cartesian_offset_provider(): - testee = ir.FunCall(fun=ir.SymRef(id="shift"), args=[ir.OffsetLiteral(value="i")]) - expected = ti.FunctionType( - args=ti.Tuple.from_elems( - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.TypeVar(idx=2), - defined_loc=ti.TypeVar(idx=3), - ) - ), - ret=ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.TypeVar(idx=2), - defined_loc=ti.TypeVar(idx=3), - ), - ) - offset_provider = {"i": gtx.Dimension("IDim")} - inferred = ti.infer(testee, offset_provider=offset_provider) - assert inferred == expected - assert ti.pformat(inferred) == "(It[T₂, T₃, T₀¹]) → It[T₂, T₃, T₀¹]" + assert result.type == int_type -def test_shift_with_unstructured_offset_provider(): - testee = ir.FunCall( - fun=ir.SymRef(id="shift"), args=[ir.OffsetLiteral(value="V2E"), ir.OffsetLiteral(value=0)] - ) - expected = ti.FunctionType( - args=ti.Tuple.from_elems( - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.Location(name="Vertex"), - defined_loc=ti.TypeVar(idx=2), - ) - ), - ret=ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.Location(name="Edge"), - defined_loc=ti.TypeVar(idx=2), - ), - ) - offset_provider = { - "V2E": gtx.NeighborTableOffsetProvider( - np.empty((0, 1), dtype=np.int64), gtx.Dimension("Vertex"), gtx.Dimension("Edge"), 1 - ) - } - inferred = ti.infer(testee, offset_provider=offset_provider) - assert inferred == expected - assert ti.pformat(inferred) == "(It[Vertex, T₂, T₀¹]) → It[Edge, T₂, T₀¹]" +def test_late_offset_axis(): + mesh = simple_mesh() + func = im.lambda_("dim")(im.shift(im.ref("dim"), 1)(im.ref("it", it_on_v_of_e_type))) + testee = im.call(func)(im.ensure_offset("V2E")) -def test_partial_shift_with_unstructured_offset_provider(): - testee = ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[ - ir.OffsetLiteral(value="V2E"), - ir.OffsetLiteral(value=0), - ir.OffsetLiteral(value="E2C"), - ], - ) - expected = ti.FunctionType( - args=ti.Tuple.from_elems( - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.Location(name="Vertex"), - defined_loc=ti.TypeVar(idx=2), - ) - ), - ret=ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.TypeVar(idx=1), - current_loc=ti.Location(name="Cell"), - defined_loc=ti.TypeVar(idx=2), - ), - ) - offset_provider = { - "V2E": gtx.NeighborTableOffsetProvider( - np.empty((0, 1), dtype=np.int64), gtx.Dimension("Vertex"), gtx.Dimension("Edge"), 1 - ), - "E2C": gtx.NeighborTableOffsetProvider( - np.empty((0, 1), dtype=np.int64), gtx.Dimension("Edge"), gtx.Dimension("Cell"), 1 - ), - } - inferred = ti.infer(testee, offset_provider=offset_provider) - assert inferred == expected - assert ti.pformat(inferred) == "(It[Vertex, T₂, T₀¹]) → It[Cell, T₂, T₀¹]" + result = itir_type_inference.infer(testee, offset_provider=mesh.offset_provider) + assert result.type == it_on_e_of_e_type -def test_function_definition(): - testee = ir.FunctionDefinition(id="f", params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) - expected = ti.LetPolymorphic( - dtype=ti.FunctionType(args=ti.Tuple.from_elems(ti.TypeVar(idx=0)), ret=ti.TypeVar(idx=0)) +def test_cartesian_fencil_definition(): + cartesian_domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) ) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred.dtype) == "(T₀) → T₀" - -def test_dynamic_offset(): - """Test that the type of a dynamic offset is correctly inferred.""" - offset_it = ir.SymRef(id="offset_it") - testee = ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[ - ir.OffsetLiteral(value="V2E"), - ir.FunCall(fun=ir.SymRef(id="deref"), args=[offset_it]), + testee = itir.FencilDefinition( + id="f", + function_definitions=[], + params=[im.sym("inp", bool_i_field), im.sym("out", bool_i_field)], + closures=[ + itir.StencilClosure( + domain=cartesian_domain, + stencil=im.ref("deref"), + output=im.ref("out"), + inputs=[im.ref("inp")], + ), ], ) - inferred_all: dict[int, ti.Type] = ti.infer_all(testee) - offset_it_type = inferred_all[id(offset_it)] - assert isinstance(offset_it_type, ti.Val) and offset_it_type.kind == ti.Iterator() + result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) -CARTESIAN_DOMAIN = ir.FunCall( - fun=ir.SymRef(id="cartesian_domain"), - args=[ - ir.FunCall( - fun=ir.SymRef(id="named_range"), - args=[ - ir.AxisLiteral(value="IDim"), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - ir.SymRef(id="i"), - ], - ), - ir.FunCall( - fun=ir.SymRef(id="named_range"), - args=[ - ir.AxisLiteral(value="JDim"), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - ir.SymRef(id="j"), + closure_type = it_ts.StencilClosureType( + domain=it_ts.DomainType(dims=[IDim]), + stencil=ts.FunctionType( + pos_only_args=[ + it_ts.IteratorType( + position_dims=[IDim], defined_dims=[IDim], element_type=bool_type + ) ], + pos_or_kw_args={}, + kw_only_args={}, + returns=bool_type, ), - ir.FunCall( - fun=ir.SymRef(id="named_range"), - args=[ - ir.AxisLiteral(value="KDim"), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - ir.SymRef(id="k"), - ], - ), - ], -) + output=bool_i_field, + inputs=[bool_i_field], + ) + fencil_type = it_ts.FencilType(params=[bool_i_field, bool_i_field], closures=[closure_type]) + assert result.type == fencil_type + assert result.closures[0].type == closure_type -def test_stencil_closure(): - testee = ir.StencilClosure( - domain=CARTESIAN_DOMAIN, - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="inp")], - ) - expected = ti.Closure( - output=ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.Column(), - current_loc=ti.ANYWHERE, - defined_loc=ti.TypeVar(idx=1), - ), - inputs=ti.Tuple.from_elems( - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.Column(), - current_loc=ti.ANYWHERE, - defined_loc=ti.TypeVar(idx=1), - ) - ), +def test_unstructured_fencil_definition(): + mesh = simple_mesh() + unstructured_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), + im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), ) - inferred = ti.infer(testee) - assert inferred == expected - assert ti.pformat(inferred) == "(It[ANYWHERE, T₁, T₀ᶜ]) ⇒ It[ANYWHERE, T₁, T₀ᶜ]" - -def test_fencil_definition(): - testee = ir.FencilDefinition( + testee = itir.FencilDefinition( id="f", function_definitions=[], - params=[ - ir.Sym(id="i"), - ir.Sym(id="j"), - ir.Sym(id="k"), - ir.Sym(id="a"), - ir.Sym(id="b"), - ir.Sym(id="c"), - ir.Sym(id="d"), - ], + params=[im.sym("inp", bool_edge_k_field), im.sym("out", bool_vertex_k_field)], closures=[ - ir.StencilClosure( - domain=CARTESIAN_DOMAIN, - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="b"), - inputs=[ir.SymRef(id="a")], - ), - ir.StencilClosure( - domain=CARTESIAN_DOMAIN, - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="d"), - inputs=[ir.SymRef(id="c")], + itir.StencilClosure( + domain=unstructured_domain, + stencil=im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), + output=im.ref("out"), + inputs=[im.ref("inp")], ), ], ) - expected = ti.FencilDefinitionType( - name="f", - fundefs=ti.EmptyTuple(), - params=ti.Tuple.from_elems( - ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="int32"), size=ti.Scalar()), - ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="int32"), size=ti.Scalar()), - ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="int32"), size=ti.Scalar()), - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.Column(), - current_loc=ti.ANYWHERE, - defined_loc=ti.TypeVar(idx=1), - ), - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=0), - size=ti.Column(), - current_loc=ti.ANYWHERE, - defined_loc=ti.TypeVar(idx=1), - ), - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=2), - size=ti.Column(), - current_loc=ti.ANYWHERE, - defined_loc=ti.TypeVar(idx=3), - ), - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=2), - size=ti.Column(), - current_loc=ti.ANYWHERE, - defined_loc=ti.TypeVar(idx=3), - ), - ), - ) - inferred = ti.infer(testee) - assert inferred == expected - assert ( - ti.pformat(inferred) - == "{f(int32ˢ, int32ˢ, int32ˢ, It[ANYWHERE, T₁, T₀ᶜ], It[ANYWHERE, T₁, T₀ᶜ], It[ANYWHERE, T₃, T₂ᶜ], It[ANYWHERE, T₃, T₂ᶜ])}" - ) + result = itir_type_inference.infer(testee, offset_provider=mesh.offset_provider) -def test_fencil_definition_same_closure_input(): - f1 = ir.FunctionDefinition( - id="f1", params=[im.sym("vertex_it")], expr=im.deref(im.shift("E2V")("vertex_it")) + closure_type = it_ts.StencilClosureType( + domain=it_ts.DomainType(dims=[Vertex, KDim]), + stencil=ts.FunctionType( + pos_only_args=[ + it_ts.IteratorType( + position_dims=[Vertex, KDim], defined_dims=[Edge, KDim], element_type=bool_type + ) + ], + pos_or_kw_args={}, + kw_only_args={}, + returns=bool_type, + ), + output=bool_vertex_k_field, + inputs=[bool_edge_k_field], ) - f2 = ir.FunctionDefinition(id="f2", params=[im.sym("vertex_it")], expr=im.deref("vertex_it")) - - testee = ir.FencilDefinition( - id="fencil", - function_definitions=[f1, f2], - params=[im.sym("vertex_it"), im.sym("output_edge_it"), im.sym("output_vertex_it")], - closures=[ - ir.StencilClosure( - domain=im.call("unstructured_domain")( - im.call("named_range")( - ir.AxisLiteral(value="Edge"), - im.literal("0", "int32"), - im.literal("10", "int32"), - ) - ), - stencil=im.ref("f1"), - output=im.ref("output_edge_it"), - inputs=[im.ref("vertex_it")], - ), - ir.StencilClosure( - domain=im.call("unstructured_domain")( - im.call("named_range")( - ir.AxisLiteral(value="Vertex"), - im.literal("0", "int32"), - im.literal("10", "int32"), - ) - ), - stencil=im.ref("f2"), - output=im.ref("output_vertex_it"), - inputs=[im.ref("vertex_it")], - ), - ], + fencil_type = it_ts.FencilType( + params=[bool_edge_k_field, bool_vertex_k_field], closures=[closure_type] ) + assert result.type == fencil_type + assert result.closures[0].type == closure_type - offset_provider = { - "E2V": gtx.NeighborTableOffsetProvider( - np.empty((0, 2), dtype=np.int64), - gtx.Dimension("Edge"), - gtx.Dimension("Vertex"), - 2, - False, - ) - } - inferred_all: dict[int, ti.Type] = ti.infer_all(testee, offset_provider=offset_provider) - - # validate locations of fencil params - fencil_param_types = [inferred_all[id(testee.params[i])] for i in range(3)] - assert fencil_param_types[0].defined_loc == ti.Location(name="Vertex") - assert fencil_param_types[1].defined_loc == ti.Location(name="Edge") - assert fencil_param_types[2].defined_loc == ti.Location(name="Vertex") - - # validate locations of stencil params - f1_param_type: ti.Val = inferred_all[id(f1.params[0])] - assert f1_param_type.current_loc == ti.Location(name="Edge") - assert f1_param_type.defined_loc == ti.Location(name="Vertex") - # f2 is polymorphic and there is no shift inside so we only get a TypeVar here - f2_param_type: ti.Val = inferred_all[id(f2.params[0])] - assert isinstance(f2_param_type.current_loc, ti.TypeVar) - assert isinstance(f2_param_type.defined_loc, ti.TypeVar) +def test_function_definition(): + cartesian_domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ) -def test_fencil_definition_with_function_definitions(): - fundefs = [ - ir.FunctionDefinition(id="f", params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")), - ir.FunctionDefinition( - id="g", - params=[ir.Sym(id="x")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="x")]), - ), - ] - testee = ir.FencilDefinition( - id="foo", - function_definitions=fundefs, - params=[ - ir.Sym(id="i"), - ir.Sym(id="j"), - ir.Sym(id="k"), - ir.Sym(id="a"), - ir.Sym(id="b"), - ir.Sym(id="c"), - ir.Sym(id="d"), - ir.Sym(id="x"), - ir.Sym(id="y"), + testee = itir.FencilDefinition( + id="f", + function_definitions=[ + itir.FunctionDefinition(id="foo", params=[im.sym("it")], expr=im.deref("it")), + itir.FunctionDefinition(id="bar", params=[im.sym("it")], expr=im.call("foo")("it")), ], + params=[im.sym("inp", bool_i_field), im.sym("out", bool_i_field)], closures=[ - ir.StencilClosure( - domain=CARTESIAN_DOMAIN, - stencil=ir.SymRef(id="g"), - output=ir.SymRef(id="b"), - inputs=[ir.SymRef(id="a")], - ), - ir.StencilClosure( - domain=CARTESIAN_DOMAIN, - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="d"), - inputs=[ir.SymRef(id="c")], - ), - ir.StencilClosure( - domain=CARTESIAN_DOMAIN, - stencil=ir.Lambda( - params=[ir.Sym(id="y")], - expr=ir.FunCall( - fun=ir.SymRef(id="g"), - args=[ir.FunCall(fun=ir.SymRef(id="f"), args=[ir.SymRef(id="y")])], - ), - ), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], + itir.StencilClosure( + domain=cartesian_domain, + stencil=im.ref("bar"), + output=im.ref("out"), + inputs=[im.ref("inp")], ), ], ) - expected = ti.FencilDefinitionType( - name="foo", - fundefs=ti.Tuple.from_elems( - ti.FunctionDefinitionType( - name="f", - fun=ti.FunctionType( - args=ti.Tuple.from_elems(ti.TypeVar(idx=0)), ret=ti.TypeVar(idx=0) - ), - ), - ti.FunctionDefinitionType( - name="g", - fun=ti.FunctionType( - args=ti.Tuple.from_elems( - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=1), - size=ti.TypeVar(idx=2), - current_loc=ti.TypeVar(idx=3), - defined_loc=ti.TypeVar(idx=3), - ) - ), - ret=ti.Val(kind=ti.Value(), dtype=ti.TypeVar(idx=1), size=ti.TypeVar(idx=2)), - ), - ), - ), - params=ti.Tuple.from_elems( - ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="int32"), size=ti.Scalar()), - ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="int32"), size=ti.Scalar()), - ti.Val(kind=ti.Value(), dtype=ti.Primitive(name="int32"), size=ti.Scalar()), - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=4), - size=ti.Column(), - current_loc=ti.ANYWHERE, - defined_loc=ti.TypeVar(idx=5), - ), - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=4), - size=ti.Column(), - current_loc=ti.ANYWHERE, - defined_loc=ti.TypeVar(idx=5), - ), - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=6), - size=ti.Column(), - current_loc=ti.ANYWHERE, - defined_loc=ti.TypeVar(idx=7), - ), - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=6), - size=ti.Column(), - current_loc=ti.ANYWHERE, - defined_loc=ti.TypeVar(idx=7), - ), - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=8), - size=ti.Column(), - current_loc=ti.ANYWHERE, - defined_loc=ti.TypeVar(idx=9), - ), - ti.Val( - kind=ti.Iterator(), - dtype=ti.TypeVar(idx=8), - size=ti.Column(), - current_loc=ti.ANYWHERE, - defined_loc=ti.TypeVar(idx=9), - ), - ), - ) - inferred = ti.infer(testee) - assert inferred == expected - assert ( - ti.pformat(inferred) - == "{f :: (T₀) → T₀, g :: (It[T₃, T₃, T₁²]) → T₁², foo(int32ˢ, int32ˢ, int32ˢ, It[ANYWHERE, T₅, T₄ᶜ], It[ANYWHERE, T₅, T₄ᶜ], It[ANYWHERE, T₇, T₆ᶜ], It[ANYWHERE, T₇, T₆ᶜ], It[ANYWHERE, T₉, T₈ᶜ], It[ANYWHERE, T₉, T₈ᶜ])}" - ) - -def test_save_types_to_annex(): - testee = im.lambda_("a")(im.plus("a", im.literal("1", "float32"))) - ti.infer(testee, save_to_annex=True) - param_type = testee.params[0].annex.type - assert isinstance(param_type, ti.Val) and param_type.dtype.name == "float32" + result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) - -def test_pformat(): - vs = [ti.TypeVar(idx=i) for i in range(5)] - assert ti.pformat(vs[0]) == "T₀" - assert ti.pformat(ti.Tuple.from_elems(*vs[:2])) == "(T₀, T₁)" - assert ( - ti.pformat(ti.Tuple(front=vs[0], others=ti.Tuple(front=vs[1], others=vs[2]))) - == "(T₀, T₁):T₂" - ) - assert ti.pformat(ti.FunctionType(args=vs[0], ret=vs[1])) == "T₀ → T₁" - assert ti.pformat(ti.Val(kind=vs[0], dtype=vs[1], size=vs[2])) == "ItOrVal₀[T₁²]" - assert ti.pformat(ti.Val(kind=ti.Value(), dtype=vs[0], size=vs[1])) == "T₀¹" - assert ti.pformat(ti.Val(kind=ti.Iterator(), dtype=vs[0], size=vs[1])) == "It[T₀¹]" - assert ( - ti.pformat( - ti.Val( - kind=ti.Iterator(), dtype=vs[0], size=vs[1], current_loc=vs[2], defined_loc=vs[3] - ) - ) - == "It[T₂, T₃, T₀¹]" - ) - assert ti.pformat(ti.Val(kind=ti.Value(), dtype=vs[0], size=ti.Scalar())) == "T₀ˢ" - assert ti.pformat(ti.Val(kind=ti.Value(), dtype=vs[0], size=ti.Column())) == "T₀ᶜ" - assert ti.pformat(ti.ValTuple(kind=vs[0], dtypes=vs[1], size=vs[2])) == "(ItOrVal₀[T²], …)₁" - assert ( - ti.pformat( - ti.ValListTuple( - list_dtypes=ti.Tuple.from_elems(vs[0], vs[1]), - max_length=vs[2], - has_skip_values=vs[3], - size=vs[4], - ) - ) - == "(L[T₀, T₂, T₃]⁴, L[T₁, T₂, T₃]⁴)" - ) - assert ( - ti.pformat( - ti.ValListTuple(list_dtypes=vs[0], max_length=vs[1], has_skip_values=vs[2], size=vs[3]) - ) - == "(L[…₀, T₁, T₂]³, …)" - ) - assert ti.pformat(ti.Primitive(name="foo")) == "foo" - assert ti.pformat(ti.Closure(output=vs[0], inputs=vs[1])) == "T₁ ⇒ T₀" - assert ( - ti.pformat(ti.FunctionDefinitionType(name="f", fun=ti.FunctionType(args=vs[0], ret=vs[1]))) - == "f :: T₀ → T₁" - ) - assert ( - ti.pformat( - ti.FencilDefinitionType(name="f", fundefs=ti.EmptyTuple(), params=ti.EmptyTuple()) - ) - == "{f()}" + closure_type = it_ts.StencilClosureType( + domain=it_ts.DomainType(dims=[IDim]), + stencil=ts.FunctionType( + pos_only_args=[ + it_ts.IteratorType( + position_dims=[IDim], defined_dims=[IDim], element_type=bool_type + ) + ], + pos_or_kw_args={}, + kw_only_args={}, + returns=bool_type, + ), + output=bool_i_field, + inputs=[bool_i_field], ) + fencil_type = it_ts.FencilType(params=[bool_i_field, bool_i_field], closures=[closure_type]) + assert result.type == fencil_type + assert result.closures[0].type == closure_type diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 330f66bee5..eac9b8bd10 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -152,7 +152,9 @@ def test_inline_trivial_make_tuple(): def test_propagate_to_if_on_tuples(): testee = im.tuple_get(0, im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4))) expected = im.if_( - "cond", im.tuple_get(0, im.make_tuple(1, 2)), im.tuple_get(0, im.make_tuple(3, 4)) + im.ref("cond", "bool"), + im.tuple_get(0, im.make_tuple(1, 2)), + im.tuple_get(0, im.make_tuple(3, 4)), ) actual = CollapseTuple.apply( testee, @@ -190,9 +192,9 @@ def test_propagate_nested_lift(): def test_if_on_tuples_with_let(): - testee = im.let("val", im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4)))( - im.tuple_get(0, "val") - ) + testee = im.let( + "val", im.if_(im.ref("cond", "bool"), im.make_tuple(1, 2), im.make_tuple(3, 4)) + )(im.tuple_get(0, "val")) expected = im.if_("cond", 1, 3) actual = CollapseTuple.apply(testee, remove_letified_make_tuple_elements=False) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 8bad361a20..e65a7f7972 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -15,6 +15,7 @@ import gt4py.next as gtx from gt4py.eve.utils import UIDs +from gt4py.next import common from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.global_tmps import ( @@ -26,6 +27,16 @@ split_closures, update_domains, ) +from gt4py.next.type_system import type_specifications as ts + + +IDim = common.Dimension(value="IDim") +JDim = common.Dimension(value="JDim") +KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) +index_type = ts.ScalarType(kind=getattr(ts.ScalarKind, ir.INTEGER_INDEX_BUILTIN.upper())) +float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) +index_field_type_factory = lambda dim: ts.FieldType(dims=[dim], dtype=index_type) def test_split_closures(): @@ -33,7 +44,11 @@ def test_split_closures(): testee = ir.FencilDefinition( id="f", function_definitions=[], - params=[im.sym("d"), im.sym("inp"), im.sym("out")], + params=[ + im.sym("d", i_field_type), + im.sym("inp", i_field_type), + im.sym("out", i_field_type), + ], closures=[ ir.StencilClosure( domain=im.call("cartesian_domain")(), @@ -58,12 +73,12 @@ def test_split_closures(): id="f", function_definitions=[], params=[ - im.sym("d"), - im.sym("inp"), - im.sym("out"), - im.sym("_tmp_1"), - im.sym("_tmp_2"), - im.sym("_gtmp_auto_domain"), + im.sym("d", i_field_type), + im.sym("inp", i_field_type), + im.sym("out", i_field_type), + im.sym("_tmp_1", i_field_type), + im.sym("_tmp_2", i_field_type), + im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), ], closures=[ ir.StencilClosure( @@ -87,7 +102,10 @@ def test_split_closures(): ], ) actual = split_closures(testee, offset_provider={}) - assert actual.tmps == [Temporary(id="_tmp_1"), Temporary(id="_tmp_2")] + assert actual.tmps == [ + Temporary(id="_tmp_1", dtype=float_type), + Temporary(id="_tmp_2", dtype=float_type), + ] assert actual.fencil == expected @@ -96,7 +114,11 @@ def test_split_closures_simple_heuristics(): testee = ir.FencilDefinition( id="f", function_definitions=[], - params=[im.sym("d"), im.sym("inp"), im.sym("out")], + params=[ + im.sym("d", i_field_type), + im.sym("inp", i_field_type), + im.sym("out", i_field_type), + ], closures=[ ir.StencilClosure( domain=im.call("cartesian_domain")(), @@ -115,11 +137,11 @@ def test_split_closures_simple_heuristics(): id="f", function_definitions=[], params=[ - im.sym("d"), - im.sym("inp"), - im.sym("out"), - im.sym("_tmp_1"), - im.sym("_gtmp_auto_domain"), + im.sym("d", i_field_type), + im.sym("inp", i_field_type), + im.sym("out", i_field_type), + im.sym("_tmp_1", i_field_type), + im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), ], closures=[ ir.StencilClosure( @@ -139,9 +161,11 @@ def test_split_closures_simple_heuristics(): ], ) actual = split_closures( - testee, extraction_heuristics=SimpleTemporaryExtractionHeuristics, offset_provider={} + testee, + extraction_heuristics=SimpleTemporaryExtractionHeuristics, + offset_provider={"I": IDim}, ) - assert actual.tmps == [Temporary(id="_tmp_1")] + assert actual.tmps == [Temporary(id="_tmp_1", dtype=float_type)] assert actual.fencil == expected @@ -151,7 +175,7 @@ def test_split_closures_lifted_scan(): testee = ir.FencilDefinition( id="f", function_definitions=[], - params=[im.sym("inp"), im.sym("out")], + params=[im.sym("inp", i_field_type), im.sym("out", i_field_type)], closures=[ ir.StencilClosure( domain=im.call("cartesian_domain")(), @@ -181,7 +205,12 @@ def test_split_closures_lifted_scan(): expected = ir.FencilDefinition( id="f", function_definitions=[], - params=[im.sym("inp"), im.sym("out"), im.sym("_tmp_1"), im.sym("_gtmp_auto_domain")], + params=[ + im.sym("inp", i_field_type), + im.sym("out", i_field_type), + im.sym("_tmp_1", i_field_type), + im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), + ], closures=[ ir.StencilClosure( domain=AUTO_DOMAIN, @@ -211,7 +240,7 @@ def test_split_closures_lifted_scan(): ) actual = split_closures(testee, offset_provider={}) - assert actual.tmps == [Temporary(id="_tmp_1")] + assert actual.tmps == [Temporary(id="_tmp_1", dtype=float_type)] assert actual.fencil == expected @@ -221,8 +250,14 @@ def test_update_cartesian_domains(): id="f", function_definitions=[], params=[ - im.sym(name) - for name in ("i", "j", "k", "inp", "out", "_gtmp_0", "_gtmp_1", "_gtmp_auto_domain") + im.sym("i", index_type), + im.sym("j", index_type), + im.sym("k", index_type), + im.sym("inp", i_field_type), + im.sym("out", i_field_type), + im.sym("_gtmp_0", i_field_type), + im.sym("_gtmp_1", i_field_type), + im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), ], closures=[ ir.StencilClosure( @@ -350,13 +385,13 @@ def test_collect_tmps_info(): id="f", function_definitions=[], params=[ - ir.Sym(id="i"), - ir.Sym(id="j"), - ir.Sym(id="k"), - ir.Sym(id="inp", dtype=("float64", False)), - ir.Sym(id="out", dtype=("float64", False)), - ir.Sym(id="_gtmp_0"), - ir.Sym(id="_gtmp_1"), + im.sym("i", index_type), + im.sym("j", index_type), + im.sym("k", index_type), + im.sym("inp", i_field_type), + im.sym("out", i_field_type), + im.sym("_gtmp_0", i_field_type), + im.sym("_gtmp_1", i_field_type), ], closures=[ ir.StencilClosure( @@ -413,15 +448,15 @@ def test_collect_tmps_info(): ], ), params=[ir.Sym(id="i"), ir.Sym(id="j"), ir.Sym(id="k"), ir.Sym(id="inp"), ir.Sym(id="out")], - tmps=[Temporary(id="_gtmp_0"), Temporary(id="_gtmp_1")], + tmps=[Temporary(id="_gtmp_0", dtype=float_type), Temporary(id="_gtmp_1", dtype=float_type)], ) expected = FencilWithTemporaries( fencil=testee.fencil, params=testee.params, tmps=[ - Temporary(id="_gtmp_0", domain=tmp_domain, dtype="float64"), - Temporary(id="_gtmp_1", domain=tmp_domain, dtype="float64"), + Temporary(id="_gtmp_0", domain=tmp_domain, dtype=float_type), + Temporary(id="_gtmp_1", domain=tmp_domain, dtype=float_type), ], ) - actual = collect_tmps_info(testee, offset_provider={}) + actual = collect_tmps_info(testee, offset_provider={"I": IDim, "J": JDim, "K": KDim}) assert actual == expected diff --git a/tests/next_tests/unit_tests/test_type_inference.py b/tests/next_tests/unit_tests/test_type_inference.py deleted file mode 100644 index 3db67320f1..0000000000 --- a/tests/next_tests/unit_tests/test_type_inference.py +++ /dev/null @@ -1,84 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from gt4py.next import type_inference as ti - - -def test_renamer(): - class Foo(ti.Type): - bar: ti.Type - baz: ti.Type - - class Bar(ti.Type): ... - - r = ti._Renamer() - actual = [ - ( - ti._Box(value=Foo(bar=ti.TypeVar(idx=0), baz=ti.TypeVar(idx=1))), - ti._Box(value=ti.TypeVar(idx=0)), - ) - ] - src = ti.TypeVar(idx=0) - dst = ti.TypeVar(idx=1) - for s, t in actual: - r.register(s) - r.register(t) - r.register(src) - r.register(dst) - r.rename(src, dst) - expected = [ - ( - ti._Box(value=Foo(bar=ti.TypeVar(idx=1), baz=ti.TypeVar(idx=1))), - ti._Box(value=ti.TypeVar(idx=1)), - ) - ] - assert actual == expected - - -def test_custom_type_inference(): - class Fun(ti.Type): - arg: ti.Type - ret: ti.Type - - class Basic(ti.Type): - name: str - - class SpecialFun(ti.Type): - arg_and_ret: ti.Type - - def __eq__(self, other): - if isinstance(other, Fun): - return self.arg_and_ret == other.arg == other.ret - return isinstance(other, SpecialFun) and self.arg_and_ret == other.arg_and_ret - - def handle_constraint(self, other, add_constraint): - if isinstance(other, Fun): - add_constraint(self.arg_and_ret, other.arg) - add_constraint(self.arg_and_ret, other.ret) - return True - return False - - v = [ti.TypeVar(idx=i) for i in range(5)] - constraints = { - (v[0], SpecialFun(arg_and_ret=v[2])), - (Fun(arg=v[0], ret=v[3]), v[4]), - (Basic(name="int"), v[1]), - (v[1], v[2]), - } - dtype = v[4] - - expected = Fun(arg=Fun(arg=Basic(name="int"), ret=Basic(name="int")), ret=ti.TypeVar(idx=0)) - - actual = ti.reindex_vars(ti.unify(dtype, constraints)[0]) - assert actual == expected From beabac21957fcad1e534e85215f7b905470b444b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 18 Apr 2024 07:59:55 +0200 Subject: [PATCH 07/52] Fix some tests --- src/gt4py/next/iterator/ir.py | 3 +- src/gt4py/next/iterator/tracing.py | 17 ++---- .../iterator/transforms/collapse_tuple.py | 2 +- .../next/iterator/transforms/global_tmps.py | 10 ++-- .../next/iterator/type_system/inference.py | 54 +++++++++++++------ .../runners/dace_iterator/itir_to_sdfg.py | 5 +- .../transforms_tests/test_global_tmps.py | 5 +- 7 files changed, 57 insertions(+), 39 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index a2ecd30bd1..b3392c03a6 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -12,8 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import typing -from typing import Any, ClassVar, List, Optional, Union +from typing import ClassVar, List, Optional, Union import gt4py.eve as eve from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 41159b30e7..2f8a18143e 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -35,7 +35,7 @@ SymRef, ) from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.type_system import type_info, type_specifications, type_translation +from gt4py.next.type_system import type_translation TRACING = "tracing" @@ -280,18 +280,9 @@ def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: arg_type = None kind, dtype = None, None if use_arg_types: - # TODO(tehrengruber): Fields of tuples are not supported yet. Just ignore them for now. - if not _contains_tuple_dtype_field(arg): - arg_type = type_translation.from_value(arg) - # TODO(tehrengruber): Support more types. - if isinstance(arg_type, type_specifications.FieldType): - kind = "Iterator" - dtype = ( - arg_type.dtype.kind.name.lower(), # actual dtype - type_info.is_local_field(arg_type), # is list - ) - - params.append(Sym(id=param_name, type=arg_type, kind=kind, dtype=dtype)) + arg_type = type_translation.from_value(arg) + + params.append(Sym(id=param_name, type=arg_type)) return params diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 2da12f83dc..b42de122a1 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -22,10 +22,10 @@ from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir -from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_if_call, is_let from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda +from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.type_system import type_info, type_specifications as ts diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index b776a92a0f..083845a531 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -23,7 +23,6 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir -from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.pretty_printer import PrettyPrinter @@ -33,7 +32,10 @@ from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas from gt4py.next.iterator.transforms.prune_closure_inputs import PruneClosureInputs from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs -from gt4py.next.iterator.type_system import type_specifications as it_ts +from gt4py.next.iterator.type_system import ( + inference as itir_type_inference, + type_specifications as it_ts, +) from gt4py.next.type_system import type_specifications as ts @@ -599,7 +601,9 @@ def collect_tmps_info(node: FencilWithTemporaries, *, offset_provider) -> Fencil return FencilWithTemporaries( fencil=node.fencil, params=node.params, - tmps=[ir.Temporary(id=tmp.id, domain=domains[tmp.id], dtype=tmp.dtype) for tmp in node.tmps], + tmps=[ + ir.Temporary(id=tmp.id, domain=domains[tmp.id], dtype=tmp.dtype) for tmp in node.tmps + ], ) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 818647637e..8c10346fd5 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -199,6 +199,32 @@ def __call__(self, *args) -> ts.TypeSpec | rules.TypeInferenceRule: T = TypeVar("T", bound=itir.Node) +def _get_dimensions_from_offset_provider(offset_provider) -> dict[str, common.Dimension]: + dimensions: dict[str, common.Dimension] = {} + for offset_name, provider in offset_provider.items(): + dimensions[offset_name] = common.Dimension( + value=offset_name, kind=common.DimensionKind.LOCAL + ) + if isinstance(provider, common.Dimension): + dimensions[provider.value] = provider + elif isinstance(provider, common.Connectivity): + dimensions[provider.origin_axis.value] = provider.origin_axis + dimensions[provider.neighbor_axis.value] = provider.neighbor_axis + return dimensions + + +def _get_dimensions_from_types(types) -> dict[str, common.Dimension]: + dimensions: list[common.Dimension] = [] + for type_ in types: + + def _get_dimensions(el_type): + if isinstance(el_type, ts.FieldType): + dimensions.extend(el_type.dims) + + type_info.apply_to_primitive_constituents(_get_dimensions, type_) + return {dim.value: dim for dim in dimensions} + + @dataclasses.dataclass class ITIRTypeInference(eve.NodeTranslator): """ @@ -207,23 +233,19 @@ class ITIRTypeInference(eve.NodeTranslator): offset_provider: Any - @functools.cached_property - def dimensions(self) -> dict[str, common.Dimension]: - dimensions: dict[str, common.Dimension] = {} - for offset_name, provider in self.offset_provider.items(): - dimensions[offset_name] = common.Dimension( - value=offset_name, kind=common.DimensionKind.LOCAL - ) - if isinstance(provider, common.Dimension): - dimensions[provider.value] = provider - elif isinstance(provider, common.Connectivity): - dimensions[provider.origin_axis.value] = provider.origin_axis - dimensions[provider.neighbor_axis.value] = provider.neighbor_axis - return dimensions + dimensions: dict[str, common.Dimension] @classmethod def apply(cls, node: T, *, offset_provider, inplace: bool = False) -> T: - instance = cls(offset_provider=offset_provider) + instance = cls( + offset_provider=offset_provider, + dimensions=( + _get_dimensions_from_offset_provider(offset_provider) + | _get_dimensions_from_types( + node.pre_walk_values().if_isinstance(itir.Node).getattr("type").if_is_not(None) + ) + ), + ) if not inplace: node = copy.deepcopy(node) instance.visit( @@ -334,7 +356,9 @@ def visit_Node(self, node: itir.Node, **kwargs): ) def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs): - assert node.value in self.dimensions + assert ( + node.value in self.dimensions + ), f"Dimension {node.value} not present in offset provider." return ts.DimensionType(dim=self.dimensions[node.value]) def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs): diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 7496b1e573..1b7e435982 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -20,10 +20,7 @@ import gt4py.eve as eve from gt4py.next import Dimension, DimensionKind from gt4py.next.common import Connectivity -from gt4py.next.iterator import ( - ir as itir, - transforms as itir_transforms, -) +from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.type_system import type_info, type_specifications as ts, type_translation as tt diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index a832023693..e71435c2fb 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -447,7 +447,10 @@ def test_collect_tmps_info(): ], ), params=[ir.Sym(id="i"), ir.Sym(id="j"), ir.Sym(id="k"), ir.Sym(id="inp"), ir.Sym(id="out")], - tmps=[ir.Temporary(id="_gtmp_0", dtype=float_type), ir.Temporary(id="_gtmp_1", dtype=float_type)], + tmps=[ + ir.Temporary(id="_gtmp_0", dtype=float_type), + ir.Temporary(id="_gtmp_1", dtype=float_type), + ], ) expected = FencilWithTemporaries( fencil=testee.fencil, From ed4bf27d1cda64f1e51fa472b945e0c91bc396f1 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 18 Apr 2024 08:09:31 +0200 Subject: [PATCH 08/52] Cleanup --- src/gt4py/next/iterator/ir.py | 2 ++ .../ir_utils/common_pattern_matcher.py | 2 -- .../next/iterator/type_system/inference.py | 35 ++++++++++--------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index b3392c03a6..feef7f2508 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -31,6 +31,8 @@ class Node(eve.Node): location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False) + # be very careful here: this will break many places when compare is not False as we oftentimes + # do comparisons like `node.fun == im.ref(...)`. type: Optional[ts.TypeSpec] = eve.field(default=None, repr=False, compare=False) def __str__(self) -> str: diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index f7b80ced4f..5aab3866a4 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -40,8 +40,6 @@ def is_if_call(node: itir.Expr) -> TypeGuard[itir.FunCall]: def is_call_to(node: itir.Node, fun: str | list[str]) -> TypeGuard[itir.FunCall]: if isinstance(fun, (list, tuple, set)): return any((is_call_to(node, f) for f in fun)) - # TODO: fix in all places that we don't do node.fun == im.ref(...) because this breaks - # when the lhs has a type return ( isinstance(node, itir.FunCall) and isinstance(node.fun, itir.SymRef) and node.fun.id == fun ) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 8c10346fd5..1216327f12 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -215,13 +215,14 @@ def _get_dimensions_from_offset_provider(offset_provider) -> dict[str, common.Di def _get_dimensions_from_types(types) -> dict[str, common.Dimension]: dimensions: list[common.Dimension] = [] - for type_ in types: - def _get_dimensions(el_type): - if isinstance(el_type, ts.FieldType): - dimensions.extend(el_type.dims) + def _get_dimensions(el_type): + if isinstance(el_type, ts.FieldType): + dimensions.extend(el_type.dims) + for type_ in types: type_info.apply_to_primitive_constituents(_get_dimensions, type_) + return {dim.value: dim for dim in dimensions} @@ -285,7 +286,7 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: ) return result - def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx): + def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx) -> it_ts.FencilType: params: dict[str, ts.DataType] = {} for param in node.params: assert isinstance(param.type, ts.DataType) @@ -298,7 +299,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx): closures = self.visit(node.closures, ctx=ctx | params | function_definitions) return it_ts.FencilType(params=list(params.values()), closures=closures) - def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx): + def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.StencilClosureType: domain: it_ts.DomainType = self.visit(node.domain, ctx=ctx) inputs: list[ts.FieldType] = self.visit(node.inputs, ctx=ctx) output: ts.FieldType = self.visit(node.output, ctx=ctx) @@ -350,18 +351,13 @@ def extract_dtype_and_defined_dims(el_type: ts.TypeSpec): inputs=inputs, ) - def visit_Node(self, node: itir.Node, **kwargs): - raise NotImplementedError( - f"No type deduction rule for nodes of type " f"'{type(node).__name__}'." - ) - - def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs): + def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionType: assert ( node.value in self.dimensions ), f"Dimension {node.value} not present in offset provider." return ts.DimensionType(dim=self.dimensions[node.value]) - def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs): + def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs) -> it_ts.OffsetLiteralType: # TODO: this happens in tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py if _is_representable_as_int(node.value): return it_ts.OffsetLiteralType(value=int(node.value)) @@ -369,11 +365,11 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs): assert isinstance(node.value, str) and node.value in self.dimensions return it_ts.OffsetLiteralType(value=self.dimensions[node.value]) - def visit_Literal(self, node: itir.Literal, **kwargs): + def visit_Literal(self, node: itir.Literal, **kwargs) -> ts.ScalarType: assert isinstance(node.type, ts.ScalarType) return node.type - def visit_SymRef(self, node: itir.SymRef, *, ctx: dict[str, ts.TypeSpec]): + def visit_SymRef(self, node: itir.SymRef, *, ctx: dict[str, ts.TypeSpec]) -> ts.TypeSpec | rules.TypeInferenceRule: # for testing it is useful to be able to use types without a declaration, but just storing # the type in the node itself. if node.type: @@ -389,7 +385,7 @@ def visit_SymRef(self, node: itir.SymRef, *, ctx: dict[str, ts.TypeSpec]): def visit_Lambda( self, node: itir.Lambda | itir.FunctionDefinition, *, ctx: dict[str, ts.TypeSpec] - ): + ) -> DeferredFunctionType: def fun(*args): return self.visit( node.expr, ctx=ctx | {p.id: a for p, a in zip(node.params, args, strict=True)} @@ -404,7 +400,7 @@ def fun(*args): visit_FunctionDefinition = visit_Lambda - def visit_FunCall(self, node: itir.FunCall, *, ctx: dict[str, ts.TypeSpec]): + def visit_FunCall(self, node: itir.FunCall, *, ctx: dict[str, ts.TypeSpec]) -> ts.TypeSpec | rules.TypeInferenceRule: if is_call_to(node, "cast_"): value, type_constructor = node.args assert ( @@ -432,5 +428,10 @@ def visit_FunCall(self, node: itir.FunCall, *, ctx: dict[str, ts.TypeSpec]): return result + def visit_Node(self, node: itir.Node, **kwargs): + raise NotImplementedError( + f"No type deduction rule for nodes of type " f"'{type(node).__name__}'." + ) + infer = ITIRTypeInference.apply From 1c94bf0de5fca5c4cfba06d7398fe65fe4549fc1 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 18 Apr 2024 10:30:45 +0200 Subject: [PATCH 09/52] Use is_call_to instead of equality comparison with itir.Ref. --- .../ir_utils/common_pattern_matcher.py | 11 +++++--- .../iterator/transforms/collapse_tuple.py | 24 +++++++++--------- .../next/iterator/transforms/global_tmps.py | 22 ++++++++-------- .../iterator/transforms/propagate_deref.py | 25 +++++-------------- .../transforms_tests/test_propagate_deref.py | 10 +++++++- 5 files changed, 46 insertions(+), 46 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index a4b074a4b6..5ba73d4047 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -14,7 +14,6 @@ from typing import TypeGuard from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: @@ -32,6 +31,10 @@ def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) -def is_if_call(node: itir.Expr) -> TypeGuard[itir.FunCall]: - """Match expression of the form `if_(cond, true_branch, false_branch)`.""" - return isinstance(node, itir.FunCall) and node.fun == im.ref("if_") +def is_call_to(node: itir.Node, fun: str | list[str]) -> TypeGuard[itir.FunCall]: + """Match expression that are calls to a given function.""" + if isinstance(fun, (list, tuple, set)): + return any((is_call_to(node, f) for f in fun)) + return ( + isinstance(node, itir.FunCall) and isinstance(node.fun, itir.SymRef) and node.fun.id == fun + ) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 4b8182a781..cbe9ebe8d0 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -23,8 +23,12 @@ from gt4py.eve import utils as eve_utils from gt4py.next import type_inference from gt4py.next.iterator import ir, type_inference as it_type_inference -from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_if_call, is_let +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + ir_makers as im, + misc as ir_misc, +) +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_call_to, is_let from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda from gt4py.next.type_system import type_info @@ -66,7 +70,7 @@ def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): def _is_trivial_make_tuple_call(node: ir.Expr): """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" - if not (isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple")): + if not is_call_to(node, "make_tuple"): return False if not all( isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg) @@ -256,7 +260,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: )( *let_expr.args # type: ignore[attr-defined] # ensured by is_let ) - elif isinstance(node.args[1], ir.FunCall) and node.args[1].fun == im.ref("if_"): + elif cpm.is_call_to(node.args[1], "if_"): idx = node.args[0] cond, true_branch, false_branch = node.args[1].args return im.if_( @@ -273,11 +277,7 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir. bound_vars: dict[str, ir.Expr] = {} new_args: list[ir.Expr] = [] for arg in node.args: - if ( - isinstance(node, ir.FunCall) - and node.fun == im.ref("make_tuple") - and not _is_trivial_make_tuple_call(node) - ): + if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node): el_name = self._letify_make_tuple_uids.sequential_id() new_args.append(im.ref(el_name)) bound_vars[el_name] = arg @@ -298,7 +298,7 @@ def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.N return None def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.Node]: - if not node.fun == im.ref("if_"): + if not cpm.is_call_to(node, "if_"): # TODO(tehrengruber): This significantly increases the size of the tree. Revisit. # TODO(tehrengruber): Only inline if type of branch value is a tuple. # Examples: @@ -306,7 +306,7 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.N # `let (b, if cond then {1, 2} else {3, 4})) b[0]` # -> `if cond then let(b, {1, 2})(b[0]) else let(b, {3, 4})(b[0])` for i, arg in enumerate(node.args): - if is_if_call(arg): + if is_call_to(arg, "if_"): cond, true_branch, false_branch = arg.args new_true_branch = self.fp_transform(_with_altered_arg(node, i, true_branch)) new_false_branch = self.fp_transform(_with_altered_arg(node, i, false_branch)) @@ -340,6 +340,6 @@ def transform_inline_trivial_let(self, node: ir.FunCall) -> Optional[ir.Node]: if is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let - if node.fun.expr == im.ref(arg_sym.id): # type: ignore[attr-defined] # ensured by is_let + if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let return arg return None diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 42d68318a0..3a39d6314c 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -23,8 +23,7 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir, type_inference -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.pretty_printer import PrettyPrinter from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.cse import extract_subexpression @@ -139,7 +138,7 @@ class TemporaryExtractionPredicate: def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: """Determine if `expr` is an applied lift that should be extracted as a temporary.""" - if not is_applied_lift(expr): + if not cpm.is_applied_lift(expr): return False # do not extract when the result is a list (i.e. a lift expression used in a `reduce` call) # as we can not create temporaries for these stencils @@ -185,7 +184,7 @@ def _closure_parameter_argument_mapping(closure: ir.StencilClosure): to `arg`. In case the stencil is a scan, a mapping from closure inputs to scan pass (i.e. first arg is ignored) is returned. """ - is_scan = isinstance(closure.stencil, ir.FunCall) and closure.stencil.fun == im.ref("scan") + is_scan = cpm.is_call_to(closure.stencil, "scan") if is_scan: stencil = closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan @@ -242,13 +241,16 @@ def always_extract_heuristics(_): while closure_stack: current_closure: ir.StencilClosure = closure_stack.pop() - if current_closure.stencil == im.ref("deref"): + if ( + isinstance(current_closure.stencil, ir.SymRef) + and current_closure.stencil.id == "deref" + ): closures.append(current_closure) continue - is_scan: bool = isinstance( - current_closure.stencil, ir.FunCall - ) and current_closure.stencil.fun == im.ref("scan") + is_scan: bool = isinstance(current_closure.stencil, ir.FunCall) and cpm.is_call_to( + current_closure.stencil, "scan" + ) current_closure_stencil = ( current_closure.stencil if not is_scan else current_closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan ) @@ -571,7 +573,7 @@ def update_domains( def _tuple_constituents(node: ir.Expr) -> Iterable[ir.Expr]: - if isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple"): + if cpm.is_call_to(node, "make_tuple"): for arg in node.args: yield from _tuple_constituents(arg) else: @@ -625,7 +627,7 @@ def validate_no_dynamic_offsets(node: ir.Node): """Vaidate we have no dynamic offsets, e.g. `shift(Ioff, deref(...))(...)`""" for call_node in node.walk_values().if_isinstance(ir.FunCall): assert isinstance(call_node, ir.FunCall) - if call_node.fun == im.ref("shift"): + if cpm.is_call_to(call_node, "shift"): if any(not isinstance(arg, ir.OffsetLiteral) for arg in call_node.args): raise NotImplementedError("Dynamic offsets not supported in temporary pass.") diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index 21551fab6a..9f3e9d48fc 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -13,9 +13,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.eve.pattern_matching import ObjectPattern as P from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im # TODO(tehrengruber): This pass can be generalized to all builtins, e.g. @@ -44,23 +43,11 @@ def apply(cls, node: ir.Node): return cls().visit(node) def visit_FunCall(self, node: ir.FunCall): - if P(ir.FunCall, fun=ir.SymRef(id="deref"), args=[P(ir.FunCall, fun=P(ir.Lambda))]).match( - node - ): - builtin = node.fun - lambda_fun: ir.Lambda = node.args[0].fun # type: ignore[attr-defined] # invariant ensured by pattern match above - lambda_args: list[ir.Expr] = node.args[0].args # type: ignore[attr-defined] # invariant ensured by pattern match above - node = ir.FunCall( - fun=ir.Lambda( - params=lambda_fun.params, expr=ir.FunCall(fun=builtin, args=[lambda_fun.expr]) - ), - args=lambda_args, - ) - elif ( - node.fun == im.ref("deref") - and isinstance(node.args[0], ir.FunCall) - and node.args[0].fun == im.ref("if_") - ): + if cpm.is_call_to(node, "deref") and cpm.is_let(node.args[0]): + fun: ir.Lambda = node.args[0].fun # type: ignore[assignment] # ensured by is_let + args: list[ir.Expr] = node.args[0].args + node = im.let(*zip(fun.params, args))(im.deref(fun.expr)) # type: ignore[arg-type] # mypy not smart enough + elif cpm.is_call_to(node, "deref") and cpm.is_call_to(node.args[0], "if_"): cond, true_branch, false_branch = node.args[0].args return im.if_(cond, im.deref(true_branch), im.deref(false_branch)) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py index e2e29cd4db..899c108a98 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py @@ -16,9 +16,17 @@ from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref -def test_deref_propagation(): +def test_deref_let_propagation(): testee = im.deref(im.call(im.lambda_("inner_it")(im.lift("stencil")("inner_it")))("outer_it")) expected = im.call(im.lambda_("inner_it")(im.deref(im.lift("stencil")("inner_it"))))("outer_it") actual = PropagateDeref.apply(testee) assert actual == expected + + +def test_deref_if_propagation(): + testee = im.deref(im.if_("cond", "true_branch", "false_branch")) + expected = im.if_("cond", im.deref("true_branch"), im.deref("false_branch")) + + actual = PropagateDeref.apply(testee) + assert actual == expected From 3799b7a22504d6737633af9d24314f0fec609e51 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 18 Apr 2024 10:47:59 +0200 Subject: [PATCH 10/52] Cleanup --- .../iterator/ir_utils/common_pattern_matcher.py | 11 ++++++++++- .../next/iterator/transforms/collapse_tuple.py | 15 +++++++-------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 5ba73d4047..b3cba3c6b0 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -32,7 +32,16 @@ def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: def is_call_to(node: itir.Node, fun: str | list[str]) -> TypeGuard[itir.FunCall]: - """Match expression that are calls to a given function.""" + """ + Match expression that are calls to a given function. + + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> node = im.call("make_tuple")(1, 2) + >>> is_call_to(node, "make_tuple") + True + >>> is_call_to(node, "plus") + False + """ if isinstance(fun, (list, tuple, set)): return any((is_call_to(node, f) for f in fun)) return ( diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index cbe9ebe8d0..4e4443696f 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -28,7 +28,6 @@ ir_makers as im, misc as ir_misc, ) -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_call_to, is_let from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda from gt4py.next.type_system import type_info @@ -70,7 +69,7 @@ def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): def _is_trivial_make_tuple_call(node: ir.Expr): """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" - if not is_call_to(node, "make_tuple"): + if not cpm.is_call_to(node, "make_tuple"): return False if not all( isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg) @@ -251,7 +250,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: # TODO(tehrengruber): extend to general symbols as long as the tail call in the let # does not capture # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` - if is_let(node.args[1]): + if cpm.is_let(node.args[1]): idx, let_expr = node.args return im.call( im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let @@ -289,7 +288,7 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir. return None def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: - if is_let(node): + if cpm.is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] @@ -306,7 +305,7 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.N # `let (b, if cond then {1, 2} else {3, 4})) b[0]` # -> `if cond then let(b, {1, 2})(b[0]) else let(b, {3, 4})(b[0])` for i, arg in enumerate(node.args): - if is_call_to(arg, "if_"): + if cpm.is_call_to(arg, "if_"): cond, true_branch, false_branch = arg.args new_true_branch = self.fp_transform(_with_altered_arg(node, i, true_branch)) new_false_branch = self.fp_transform(_with_altered_arg(node, i, false_branch)) @@ -314,14 +313,14 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.N return None def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: - if is_let(node): + if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} inner_vars = {} original_inner_expr = node.fun.expr # type: ignore[attr-defined] # ensured by is_let for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let assert arg_sym not in inner_vars # TODO(tehrengruber): fix collisions - if is_let(arg): + if cpm.is_let(arg): for sym, val in zip(arg.fun.params, arg.args): # type: ignore[attr-defined] # ensured by is_let assert sym not in outer_vars # TODO(tehrengruber): fix collisions outer_vars[sym] = val @@ -337,7 +336,7 @@ def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: return None def transform_inline_trivial_let(self, node: ir.FunCall) -> Optional[ir.Node]: - if is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let + if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let From f75f0e4f6b893a276508dd13cef9e0b79673b210 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 18 Apr 2024 10:49:03 +0200 Subject: [PATCH 11/52] Cleanup --- src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index b3cba3c6b0..29db997684 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -33,7 +33,7 @@ def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: def is_call_to(node: itir.Node, fun: str | list[str]) -> TypeGuard[itir.FunCall]: """ - Match expression that are calls to a given function. + Match call expression to a given function. >>> from gt4py.next.iterator.ir_utils import ir_makers as im >>> node = im.call("make_tuple")(1, 2) From 779ba3d90bf1ba664c163d9de739793f5c8881bc Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 18 Apr 2024 11:04:20 +0200 Subject: [PATCH 12/52] Address reviewer comments --- .../iterator/ir_utils/common_pattern_matcher.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 29db997684..bd6dedc0f6 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -11,6 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +from collections.abc import Iterable from typing import TypeGuard from gt4py.next.iterator import ir as itir @@ -31,19 +32,21 @@ def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) -def is_call_to(node: itir.Node, fun: str | list[str]) -> TypeGuard[itir.FunCall]: +def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]: """ Match call expression to a given function. >>> from gt4py.next.iterator.ir_utils import ir_makers as im - >>> node = im.call("make_tuple")(1, 2) - >>> is_call_to(node, "make_tuple") - True + >>> node = im.call("plus")(1, 2) >>> is_call_to(node, "plus") + True + >>> is_call_to(node, "minus") False + >>> is_call_to(node, ("plus", "minus")) + True """ - if isinstance(fun, (list, tuple, set)): + if isinstance(fun, (list, tuple, set, Iterable)) and not isinstance(fun, str): return any((is_call_to(node, f) for f in fun)) return ( isinstance(node, itir.FunCall) and isinstance(node.fun, itir.SymRef) and node.fun.id == fun - ) + ) \ No newline at end of file From c7e1e9356ae25d1e8628b86d56dea77dfce688fa Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 18 Apr 2024 11:06:38 +0200 Subject: [PATCH 13/52] Cleanup --- src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index bd6dedc0f6..4933307c53 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -49,4 +49,4 @@ def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunC return any((is_call_to(node, f) for f in fun)) return ( isinstance(node, itir.FunCall) and isinstance(node.fun, itir.SymRef) and node.fun.id == fun - ) \ No newline at end of file + ) From 172982f304e90664ce135214423f414ed998efa2 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 18 Apr 2024 11:22:19 +0200 Subject: [PATCH 14/52] Merge origin_tehrengruber/fix_im_ref_comp --- src/gt4py/next/ffront/past_to_itir.py | 2 +- src/gt4py/next/iterator/ir.py | 1 + .../ir_utils/common_pattern_matcher.py | 26 ++++++++++----- .../iterator/transforms/collapse_tuple.py | 33 +++++++++---------- .../next/iterator/transforms/global_tmps.py | 22 +++++++------ .../iterator/transforms/propagate_deref.py | 25 ++++---------- .../transforms_tests/test_propagate_deref.py | 10 +++++- 7 files changed, 62 insertions(+), 57 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index fe1bffb42c..41ec7125f8 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -178,7 +178,7 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: size_params.append( itir.Sym( id=_size_arg_from_field(param.id, dim_idx), - type=ts.ScalarType(kind=ts.ScalarKind.INT32), + type=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN)), ) ) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index feef7f2508..f1932c3a70 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -33,6 +33,7 @@ class Node(eve.Node): # be very careful here: this will break many places when compare is not False as we oftentimes # do comparisons like `node.fun == im.ref(...)`. + # TODO(tehrengruber): type: Optional[ts.TypeSpec] = eve.field(default=None, repr=False, compare=False) def __str__(self) -> str: diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 5aab3866a4..0d4185c864 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -11,10 +11,11 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +import collections.abc +from collections.abc import Iterable from typing import TypeGuard from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: @@ -32,14 +33,21 @@ def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) -def is_if_call(node: itir.Expr) -> TypeGuard[itir.FunCall]: - """Match expression of the form `if_(cond, true_branch, false_branch)`.""" - return isinstance(node, itir.FunCall) and node.fun == im.ref("if_") - - -def is_call_to(node: itir.Node, fun: str | list[str]) -> TypeGuard[itir.FunCall]: - if isinstance(fun, (list, tuple, set)): +def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]: + """ + Match call expression to a given function. + + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> node = im.call("plus")(1, 2) + >>> is_call_to(node, "plus") + True + >>> is_call_to(node, "minus") + False + >>> is_call_to(node, ("plus", "minus")) + True + """ + if isinstance(fun, (list, tuple, set, collections.abc.Iterable)) and not isinstance(fun, str): return any((is_call_to(node, f) for f in fun)) return ( isinstance(node, itir.FunCall) and isinstance(node.fun, itir.SymRef) and node.fun.id == fun - ) + ) \ No newline at end of file diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index b42de122a1..4d3054f752 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -22,8 +22,11 @@ from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_if_call, is_let +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + ir_makers as im, + misc as ir_misc, +) from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.type_system import type_info, type_specifications as ts @@ -38,7 +41,7 @@ def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): def _is_trivial_make_tuple_call(node: ir.Expr): """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" - if not (isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple")): + if not cpm.is_call_to(node, "make_tuple"): return False if not all( isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg) @@ -215,7 +218,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: # TODO(tehrengruber): extend to general symbols as long as the tail call in the let # does not capture # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` - if is_let(node.args[1]): + if cpm.is_let(node.args[1]): idx, let_expr = node.args return im.call( im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let @@ -224,7 +227,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: )( *let_expr.args # type: ignore[attr-defined] # ensured by is_let ) - elif isinstance(node.args[1], ir.FunCall) and node.args[1].fun == im.ref("if_"): + elif cpm.is_call_to(node.args[1], "if_"): idx = node.args[0] cond, true_branch, false_branch = node.args[1].args return im.if_( @@ -241,11 +244,7 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir. bound_vars: dict[str, ir.Expr] = {} new_args: list[ir.Expr] = [] for arg in node.args: - if ( - isinstance(node, ir.FunCall) - and node.fun == im.ref("make_tuple") - and not _is_trivial_make_tuple_call(node) - ): + if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node): el_name = self._letify_make_tuple_uids.sequential_id() new_args.append(im.ref(el_name)) bound_vars[el_name] = arg @@ -257,7 +256,7 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir. return None def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: - if is_let(node): + if cpm.is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] @@ -266,7 +265,7 @@ def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.N return None def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.Node]: - if not node.fun == im.ref("if_"): + if not cpm.is_call_to(node, "if_"): # TODO(tehrengruber): This significantly increases the size of the tree. Revisit. # TODO(tehrengruber): Only inline if type of branch value is a tuple. # Examples: @@ -274,7 +273,7 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.N # `let (b, if cond then {1, 2} else {3, 4})) b[0]` # -> `if cond then let(b, {1, 2})(b[0]) else let(b, {3, 4})(b[0])` for i, arg in enumerate(node.args): - if is_if_call(arg): + if cpm.is_call_to(arg, "if_"): cond, true_branch, false_branch = arg.args new_true_branch = self.fp_transform(_with_altered_arg(node, i, true_branch)) new_false_branch = self.fp_transform(_with_altered_arg(node, i, false_branch)) @@ -282,14 +281,14 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.N return None def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: - if is_let(node): + if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} inner_vars = {} original_inner_expr = node.fun.expr # type: ignore[attr-defined] # ensured by is_let for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let assert arg_sym not in inner_vars # TODO(tehrengruber): fix collisions - if is_let(arg): + if cpm.is_let(arg): for sym, val in zip(arg.fun.params, arg.args): # type: ignore[attr-defined] # ensured by is_let assert sym not in outer_vars # TODO(tehrengruber): fix collisions outer_vars[sym] = val @@ -305,9 +304,9 @@ def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: return None def transform_inline_trivial_let(self, node: ir.FunCall) -> Optional[ir.Node]: - if is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let + if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let - if node.fun.expr == im.ref(arg_sym.id): # type: ignore[attr-defined] # ensured by is_let + if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let return arg return None diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 083845a531..9b1b04ef29 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -23,8 +23,7 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.pretty_printer import PrettyPrinter from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.cse import extract_subexpression @@ -144,7 +143,7 @@ class TemporaryExtractionPredicate: def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: """Determine if `expr` is an applied lift that should be extracted as a temporary.""" - if not is_applied_lift(expr): + if not cpm.is_applied_lift(expr): return False # do not extract when the result is a list (i.e. a lift expression used in a `reduce` call) # as we can not create temporaries for these stencils @@ -190,7 +189,7 @@ def _closure_parameter_argument_mapping(closure: ir.StencilClosure): to `arg`. In case the stencil is a scan, a mapping from closure inputs to scan pass (i.e. first arg is ignored) is returned. """ - is_scan = isinstance(closure.stencil, ir.FunCall) and closure.stencil.fun == im.ref("scan") + is_scan = cpm.is_call_to(closure.stencil, "scan") if is_scan: stencil = closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan @@ -247,13 +246,16 @@ def always_extract_heuristics(_): while closure_stack: current_closure: ir.StencilClosure = closure_stack.pop() - if current_closure.stencil == im.ref("deref"): + if ( + isinstance(current_closure.stencil, ir.SymRef) + and current_closure.stencil.id == "deref" + ): closures.append(current_closure) continue - is_scan: bool = isinstance( - current_closure.stencil, ir.FunCall - ) and current_closure.stencil.fun == im.ref("scan") + is_scan: bool = isinstance(current_closure.stencil, ir.FunCall) and cpm.is_call_to( + current_closure.stencil, "scan" + ) current_closure_stencil = ( current_closure.stencil if not is_scan else current_closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan ) @@ -578,7 +580,7 @@ def update_domains( def _tuple_constituents(node: ir.Expr) -> Iterable[ir.Expr]: - if isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple"): + if cpm.is_call_to(node, "make_tuple"): for arg in node.args: yield from _tuple_constituents(arg) else: @@ -611,7 +613,7 @@ def validate_no_dynamic_offsets(node: ir.Node): """Vaidate we have no dynamic offsets, e.g. `shift(Ioff, deref(...))(...)`""" for call_node in node.walk_values().if_isinstance(ir.FunCall): assert isinstance(call_node, ir.FunCall) - if call_node.fun == im.ref("shift"): + if cpm.is_call_to(call_node, "shift"): if any(not isinstance(arg, ir.OffsetLiteral) for arg in call_node.args): raise NotImplementedError("Dynamic offsets not supported in temporary pass.") diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index 21551fab6a..9f3e9d48fc 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -13,9 +13,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.eve.pattern_matching import ObjectPattern as P from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im # TODO(tehrengruber): This pass can be generalized to all builtins, e.g. @@ -44,23 +43,11 @@ def apply(cls, node: ir.Node): return cls().visit(node) def visit_FunCall(self, node: ir.FunCall): - if P(ir.FunCall, fun=ir.SymRef(id="deref"), args=[P(ir.FunCall, fun=P(ir.Lambda))]).match( - node - ): - builtin = node.fun - lambda_fun: ir.Lambda = node.args[0].fun # type: ignore[attr-defined] # invariant ensured by pattern match above - lambda_args: list[ir.Expr] = node.args[0].args # type: ignore[attr-defined] # invariant ensured by pattern match above - node = ir.FunCall( - fun=ir.Lambda( - params=lambda_fun.params, expr=ir.FunCall(fun=builtin, args=[lambda_fun.expr]) - ), - args=lambda_args, - ) - elif ( - node.fun == im.ref("deref") - and isinstance(node.args[0], ir.FunCall) - and node.args[0].fun == im.ref("if_") - ): + if cpm.is_call_to(node, "deref") and cpm.is_let(node.args[0]): + fun: ir.Lambda = node.args[0].fun # type: ignore[assignment] # ensured by is_let + args: list[ir.Expr] = node.args[0].args + node = im.let(*zip(fun.params, args))(im.deref(fun.expr)) # type: ignore[arg-type] # mypy not smart enough + elif cpm.is_call_to(node, "deref") and cpm.is_call_to(node.args[0], "if_"): cond, true_branch, false_branch = node.args[0].args return im.if_(cond, im.deref(true_branch), im.deref(false_branch)) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py index e2e29cd4db..899c108a98 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py @@ -16,9 +16,17 @@ from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref -def test_deref_propagation(): +def test_deref_let_propagation(): testee = im.deref(im.call(im.lambda_("inner_it")(im.lift("stencil")("inner_it")))("outer_it")) expected = im.call(im.lambda_("inner_it")(im.deref(im.lift("stencil")("inner_it"))))("outer_it") actual = PropagateDeref.apply(testee) assert actual == expected + + +def test_deref_if_propagation(): + testee = im.deref(im.if_("cond", "true_branch", "false_branch")) + expected = im.if_("cond", im.deref("true_branch"), im.deref("false_branch")) + + actual = PropagateDeref.apply(testee) + assert actual == expected From 79bee51f9edf5a6d185b9d52750768f7d1a05844 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 18 Apr 2024 11:27:52 +0200 Subject: [PATCH 15/52] Cleanup --- src/gt4py/next/iterator/transforms/global_tmps.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 3a39d6314c..a3260d5a37 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -248,9 +248,7 @@ def always_extract_heuristics(_): closures.append(current_closure) continue - is_scan: bool = isinstance(current_closure.stencil, ir.FunCall) and cpm.is_call_to( - current_closure.stencil, "scan" - ) + is_scan: bool = cpm.is_call_to(current_closure.stencil, "scan") current_closure_stencil = ( current_closure.stencil if not is_scan else current_closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan ) From dd5e60154c0838d95a6102fe182f08bf0d3b46cf Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 18 Apr 2024 11:29:16 +0200 Subject: [PATCH 16/52] Cleanup --- src/gt4py/next/ffront/past_to_itir.py | 2 +- src/gt4py/next/iterator/ir.py | 4 +--- src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py | 5 ++--- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 41ec7125f8..a5d549e7ff 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -178,7 +178,7 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: size_params.append( itir.Sym( id=_size_arg_from_field(param.id, dim_idx), - type=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN)), + type=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())), ) ) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index f1932c3a70..79e7ac0a81 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -31,9 +31,7 @@ class Node(eve.Node): location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False) - # be very careful here: this will break many places when compare is not False as we oftentimes - # do comparisons like `node.fun == im.ref(...)`. - # TODO(tehrengruber): + # TODO(tehrengruber): include in comparison if value is not None type: Optional[ts.TypeSpec] = eve.field(default=None, repr=False, compare=False) def __str__(self) -> str: diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 0d4185c864..4933307c53 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -11,7 +11,6 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -import collections.abc from collections.abc import Iterable from typing import TypeGuard @@ -46,8 +45,8 @@ def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunC >>> is_call_to(node, ("plus", "minus")) True """ - if isinstance(fun, (list, tuple, set, collections.abc.Iterable)) and not isinstance(fun, str): + if isinstance(fun, (list, tuple, set, Iterable)) and not isinstance(fun, str): return any((is_call_to(node, f) for f in fun)) return ( isinstance(node, itir.FunCall) and isinstance(node.fun, itir.SymRef) and node.fun.id == fun - ) \ No newline at end of file + ) From f41e0327616a0f1754df1201dbc39731c5abbf38 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 18 Apr 2024 15:30:35 +0200 Subject: [PATCH 17/52] Merge origin_tehrengruber/fix_im_ref_comp --- src/gt4py/next/ffront/past_to_itir.py | 4 +- src/gt4py/next/iterator/tracing.py | 1 - .../iterator/transforms/collapse_tuple.py | 8 +- .../next/iterator/type_system/inference.py | 163 ++++++++++-------- .../next/type_system/type_specifications.py | 4 +- .../iterator_tests/test_type_inference.py | 13 +- .../transforms_tests/test_collapse_tuple.py | 41 +++-- 7 files changed, 138 insertions(+), 96 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index a5d549e7ff..09ed645ed3 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -178,7 +178,9 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: size_params.append( itir.Sym( id=_size_arg_from_field(param.id, dim_idx), - type=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())), + type=ts.ScalarType( + kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + ), ) ) diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 2f8a18143e..9436acccd4 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -278,7 +278,6 @@ def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: ) arg_type = None - kind, dtype = None, None if use_arg_types: arg_type = type_translation.from_value(arg) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 4d3054f752..f3342a591c 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -112,6 +112,8 @@ def apply( offset_provider=None, # manually passing flags is mostly for allowing separate testing of the modes flags=None, + # allow sym references without a symbol declaration, mostly for testing + allow_undeclared_symbols: bool = False, ) -> ir.Node: """ Simplifies `make_tuple`, `tuple_get` calls. @@ -130,7 +132,11 @@ def apply( offset_provider = offset_provider or {} if not ignore_tuple_size: - node = itir_type_inference.infer(node, offset_provider=offset_provider) + node = itir_type_inference.infer( + node, + offset_provider=offset_provider, + allow_undeclared_symbols=allow_undeclared_symbols, + ) new_node = cls( ignore_tuple_size=ignore_tuple_size, diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 1216327f12..551fcc14cc 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -78,33 +78,24 @@ def is_compatible_element(el_type_a: ts.TypeSpec, el_type_b: ts.TypeSpec): return is_compatible +# TODO(tehrengruber): remove after documentation is written # Problems: -# - how to get the kind of the dimension in here? X -# maybe directly attach the type to an axis literal? -# - lift X (also mention to Hannes) -# - is_compatible -# - late offset literal in (also mention to Hannes) -# tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py # - what happens when we get a lambda function whose params are already typed # - write back params type in lambda # - documentation # describe why lambda can only have one type. Describe idea to solve e.g. # let("f", lambda x: x)(f(1)+f(1.)) # -> let("f_int", lambda x: x, "f_float", lambda x: x)(f_int(1)+f_float(1.)) -# - make types hashable -# - ~~either Eve with Coercion and no runtime checking,~~ dataclass hash with cached property # - document how scans are handled (also mention to Hannes) # - types are stored in the node, but will be incomplete after some passes -# - deferred type for testing -# - visit_FunctionDefinition - - # Design decisions # Only the parameters of fencils need to be typed. # Lambda functions are not polymorphic. -def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, "DeferredFunctionType"]) -> None: +def on_inferred( + callback: Callable, *args: Union[ts.TypeSpec, "ObservableTypeInferenceRule"] +) -> None: """ Execute `callback` as soon as all `args` have a type. """ @@ -118,7 +109,7 @@ def mark_ready(i, type_): callback(*inferred_args) for i, arg in enumerate(args): - if isinstance(arg, DeferredFunctionType): + if isinstance(arg, ObservableTypeInferenceRule): arg.on_type_ready(functools.partial(mark_ready, i)) else: assert isinstance(arg, ts.TypeSpec) @@ -126,23 +117,23 @@ def mark_ready(i, type_): @dataclasses.dataclass -class DeferredFunctionType: +class ObservableTypeInferenceRule: """ This class wraps a raw type inference rule to handle typing of functions. - As functions are represented by type inference rules + TODO: As functions in ITIR are represented by type inference rules, i.e. regular callables, """ #: type rule that given a set of types or type rules returns the return type or a type rule type_rule: rules.TypeInferenceRule #: offset provider used by some type rules - offset_provider: Any + offset_provider: common.OffsetProvider #: node that has this type node: Optional[itir.Node] = None #: list of references to this function aliases: list[itir.SymRef] = dataclasses.field(default_factory=list) #: list of callbacks executed as soon as the type is ready - callbacks: list[Any] = dataclasses.field(default_factory=list) + callbacks: list[Callable[[ts.TypeSpec], None]] = dataclasses.field(default_factory=list) #: the inferred type when ready and None until then inferred_type: Optional[ts.FunctionType] = None #: whether to store the type in the node or not @@ -158,7 +149,7 @@ def infer_type( def _infer_type_listener(self, return_type: ts.TypeSpec, *args: ts.TypeSpec) -> None: self.inferred_type = self.infer_type(return_type, *args) # type: ignore[arg-type] # ensured by assert above - # if the type has been fully inferred, notify all `DeferredFunctionType`s that depend on it. + # if the type has been fully inferred, notify all `ObservableTypeInferenceRule`s that depend on it. for cb in self.callbacks: cb(self.inferred_type) @@ -183,7 +174,7 @@ def __call__(self, *args) -> ts.TypeSpec | rules.TypeInferenceRule: # return type is a typing rule by itself if callable(return_type): - return_type = DeferredFunctionType( + return_type = ObservableTypeInferenceRule( node=None, # node will be set by caller type_rule=return_type, offset_provider=self.offset_provider, @@ -196,9 +187,6 @@ def __call__(self, *args) -> ts.TypeSpec | rules.TypeInferenceRule: return return_type -T = TypeVar("T", bound=itir.Node) - - def _get_dimensions_from_offset_provider(offset_provider) -> dict[str, common.Dimension]: dimensions: dict[str, common.Dimension] = {} for offset_name, provider in offset_provider.items(): @@ -226,18 +214,68 @@ def _get_dimensions(el_type): return {dim.value: dim for dim in dimensions} +T = TypeVar("T", bound=itir.Node) + + +def _convert_closure_input_to_iterator( + domain: it_ts.DomainType, input_: ts.TypeSpec +) -> it_ts.IteratorType: + input_dims: list[common.Dimension] | None = None + + def extract_dtype_and_dims(el_type: ts.TypeSpec): + nonlocal input_dims + assert isinstance(el_type, (ts.FieldType, ts.ScalarType)) + el_type = type_info.promote(el_type, always_field=True) + if not input_dims: + input_dims = el_type.dims # type: ignore[union-attr] # ensured by always_field + else: + # tuple inputs must all have the same defined dimensions as we + # create an iterator of tuples from them + assert input_dims == el_type.dims # type: ignore[union-attr] # ensured by always_field + return el_type.dtype # type: ignore[union-attr] # ensured by always_field + + element_type = type_info.apply_to_primitive_constituents(extract_dtype_and_dims, input_) + + assert input_dims is not None + + # handle neighbor / sparse input fields + defined_dims = [] + is_nb_field = False + for dim in input_dims: + if dim.kind == common.DimensionKind.LOCAL: + assert not is_nb_field + is_nb_field = True + else: + defined_dims.append(dim) + if is_nb_field: + element_type = it_ts.ListType(element_type=element_type) + + return it_ts.IteratorType( + position_dims=domain.dims, defined_dims=defined_dims, element_type=element_type + ) + + @dataclasses.dataclass class ITIRTypeInference(eve.NodeTranslator): """ TODO """ - offset_provider: Any - + offset_provider: common.OffsetProvider + #: Mapping from a dimension name to the actual dimension instance dimensions: dict[str, common.Dimension] + #: Allow sym refs to symbols that have not been declared. Mostly used in testing. + allow_undeclared_symbols: bool @classmethod - def apply(cls, node: T, *, offset_provider, inplace: bool = False) -> T: + def apply( + cls, + node: T, + *, + offset_provider: common.OffsetProvider, + inplace: bool = False, + allow_undeclared_symbols: bool = False, + ) -> T: instance = cls( offset_provider=offset_provider, dimensions=( @@ -246,13 +284,14 @@ def apply(cls, node: T, *, offset_provider, inplace: bool = False) -> T: node.pre_walk_values().if_isinstance(itir.Node).getattr("type").if_is_not(None) ) ), + allow_undeclared_symbols=allow_undeclared_symbols, ) if not inplace: node = copy.deepcopy(node) instance.visit( node, ctx={ - name: DeferredFunctionType( + name: ObservableTypeInferenceRule( type_rule=rules.type_inference_rules[name], # builtin functions are polymorphic store_inferred_type_in_node=False, @@ -267,13 +306,13 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: result = super().visit(node, **kwargs) if isinstance(node, itir.Node): if isinstance(result, ts.TypeSpec): - # TODO: verify types match + if node.type: + assert _is_compatible_type(node.type, result) node.type = result - elif isinstance(result, DeferredFunctionType): + elif isinstance(result, ObservableTypeInferenceRule): pass elif callable(result): - # TODO: only do for type rules not every callable - return DeferredFunctionType( + return ObservableTypeInferenceRule( node=node, type_rule=result, store_inferred_type_in_node=True, @@ -281,7 +320,7 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: ) else: raise AssertionError( - f"Expected a 'TypeSpec' or 'DeferredFunctionType', but got " + f"Expected a 'TypeSpec' or 'ObservableTypeInferenceRule', but got " f"`{type(result).__name__}`" ) return result @@ -308,35 +347,8 @@ def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.Stenc for output_el in type_info.primitive_constituents(output): assert isinstance(output_el, ts.FieldType) - stencil_args = [] - for input_ in inputs: - defined_dims: list[common.Dimension] | None = None - - def extract_dtype_and_defined_dims(el_type: ts.TypeSpec): - nonlocal defined_dims - assert isinstance(el_type, (ts.FieldType, ts.ScalarType)) - el_type = type_info.promote(el_type, always_field=True) - if not defined_dims: - defined_dims = el_type.dims # type: ignore[union-attr] # ensured by always_field - else: - # tuple inputs must all have the same defined dimensions as we - # create an iterator of tuples from them - assert defined_dims == el_type.dims # type: ignore[union-attr] # ensured by always_field - return el_type.dtype # type: ignore[union-attr] # ensured by always_field - - element_type = type_info.apply_to_primitive_constituents( - extract_dtype_and_defined_dims, input_ - ) - - assert defined_dims is not None - - stencil_args.append( - it_ts.IteratorType( - position_dims=domain.dims, defined_dims=defined_dims, element_type=element_type - ) - ) - stencil_type_rule = self.visit(node.stencil, ctx=ctx) + stencil_args = [_convert_closure_input_to_iterator(domain, input_) for input_ in inputs] stencil_returns = stencil_type_rule(*stencil_args) return it_ts.StencilClosureType( @@ -357,8 +369,9 @@ def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionTyp ), f"Dimension {node.value} not present in offset provider." return ts.DimensionType(dim=self.dimensions[node.value]) + # TODO: revisit what we want to do with OffsetLiterals as we already have an Offset type in + # the frontend. def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs) -> it_ts.OffsetLiteralType: - # TODO: this happens in tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py if _is_representable_as_int(node.value): return it_ts.OffsetLiteralType(value=int(node.value)) else: @@ -369,29 +382,29 @@ def visit_Literal(self, node: itir.Literal, **kwargs) -> ts.ScalarType: assert isinstance(node.type, ts.ScalarType) return node.type - def visit_SymRef(self, node: itir.SymRef, *, ctx: dict[str, ts.TypeSpec]) -> ts.TypeSpec | rules.TypeInferenceRule: - # for testing it is useful to be able to use types without a declaration, but just storing - # the type in the node itself. - if node.type: - assert node.id not in ctx or _is_compatible_type(ctx[node.id], node.type) - return node.type - # TODO: only allow in testing - if node.id not in ctx: + def visit_SymRef( + self, node: itir.SymRef, *, ctx: dict[str, ts.TypeSpec] + ) -> ts.TypeSpec | rules.TypeInferenceRule: + # for testing, it is useful to be able to use types without a declaration + if self.allow_undeclared_symbols and node.id not in ctx: + # type has been stored in the node itself + if node.type: + return node.type return ts.DeferredType(constraint=None) result = ctx[node.id] - if isinstance(result, DeferredFunctionType): + if isinstance(result, ObservableTypeInferenceRule): result.aliases.append(node) return result def visit_Lambda( self, node: itir.Lambda | itir.FunctionDefinition, *, ctx: dict[str, ts.TypeSpec] - ) -> DeferredFunctionType: + ) -> ObservableTypeInferenceRule: def fun(*args): return self.visit( node.expr, ctx=ctx | {p.id: a for p, a in zip(node.params, args, strict=True)} ) - return DeferredFunctionType( + return ObservableTypeInferenceRule( node=node, type_rule=fun, store_inferred_type_in_node=True, @@ -400,7 +413,9 @@ def fun(*args): visit_FunctionDefinition = visit_Lambda - def visit_FunCall(self, node: itir.FunCall, *, ctx: dict[str, ts.TypeSpec]) -> ts.TypeSpec | rules.TypeInferenceRule: + def visit_FunCall( + self, node: itir.FunCall, *, ctx: dict[str, ts.TypeSpec] + ) -> ts.TypeSpec | rules.TypeInferenceRule: if is_call_to(node, "cast_"): value, type_constructor = node.args assert ( @@ -422,7 +437,7 @@ def visit_FunCall(self, node: itir.FunCall, *, ctx: dict[str, ts.TypeSpec]) -> t result = fun(*args) - if isinstance(result, DeferredFunctionType): + if isinstance(result, ObservableTypeInferenceRule): assert not result.node result.node = node diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index d80495cb68..7eafcf237e 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -23,11 +23,9 @@ class TypeSpec: def __hash__(self) -> int: return hash(content_hash(self)) - def __init_subclass__(cls): + def __init_subclass__(cls) -> None: cls.__hash__ = TypeSpec.__hash__ - # TODO: use __init_subclass__ - @dataclass(frozen=True) class DataType(TypeSpec): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index f73828e4f6..3c78d53fe6 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -14,11 +14,8 @@ import pytest -from gt4py import eve -from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_call_to from gt4py.next.iterator.type_system import ( inference as itir_type_inference, type_specifications as it_ts, @@ -139,7 +136,9 @@ def test_expression_type(test_case): offset_provider = {**mesh.offset_provider, "Ioff": IDim, "Joff": JDim, "Koff": KDim} testee, expected_type = test_case - result = itir_type_inference.infer(testee, offset_provider=offset_provider) + result = itir_type_inference.infer( + testee, offset_provider=offset_provider, allow_undeclared_symbols=True + ) assert result.type == expected_type @@ -147,7 +146,7 @@ def test_adhoc_polymorphism(): func = im.lambda_("a")(im.lambda_("b")(im.make_tuple("a", "b"))) testee = im.call(im.call(func)(im.ref("a_", bool_type)))(im.ref("b_", int_type)) - result = itir_type_inference.infer(testee, offset_provider={}) + result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) assert result.type == ts.TupleType(types=[bool_type, int_type]) @@ -180,7 +179,9 @@ def test_late_offset_axis(): func = im.lambda_("dim")(im.shift(im.ref("dim"), 1)(im.ref("it", it_on_v_of_e_type))) testee = im.call(func)(im.ensure_offset("V2E")) - result = itir_type_inference.infer(testee, offset_provider=mesh.offset_provider) + result = itir_type_inference.infer( + testee, offset_provider=mesh.offset_provider, allow_undeclared_symbols=True + ) assert result.type == it_on_e_of_e_type diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index eac9b8bd10..6fd876b630 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -24,6 +24,7 @@ def test_simple_make_tuple_tuple_get(): testee, remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + allow_undeclared_symbols=True, ) expected = tuple_of_size_2 @@ -40,6 +41,7 @@ def test_nested_make_tuple_tuple_get(): testee, remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + allow_undeclared_symbols=True, ) assert actual == tup_of_size2_from_lambda @@ -54,6 +56,7 @@ def test_different_tuples_make_tuple_tuple_get(): testee, remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + allow_undeclared_symbols=True, ) assert actual == testee # did nothing @@ -66,6 +69,7 @@ def test_incompatible_order_make_tuple_tuple_get(): testee, remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + allow_undeclared_symbols=True, ) assert actual == testee # did nothing @@ -76,6 +80,7 @@ def test_incompatible_size_make_tuple_tuple_get(): testee, remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + allow_undeclared_symbols=True, ) assert actual == testee # did nothing @@ -83,7 +88,10 @@ def test_incompatible_size_make_tuple_tuple_get(): def test_merged_with_smaller_outer_size_make_tuple_tuple_get(): testee = im.make_tuple(im.tuple_get(0, im.make_tuple("first", "second"))) actual = CollapseTuple.apply( - testee, ignore_tuple_size=True, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET + testee, + ignore_tuple_size=True, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + allow_undeclared_symbols=True, ) assert actual == im.make_tuple("first", "second") @@ -95,6 +103,7 @@ def test_simple_tuple_get_make_tuple(): testee, remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE, + allow_undeclared_symbols=True, ) assert expected == actual @@ -106,14 +115,16 @@ def test_propagate_tuple_get(): testee, remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET, + allow_undeclared_symbols=True, ) assert expected == actual def test_letify_make_tuple_elements(): - opaque_call = im.call("opaque")() - testee = im.make_tuple(opaque_call, opaque_call) - expected = im.let(("_tuple_el_1", opaque_call), ("_tuple_el_2", opaque_call))( + # anything that is not trivial, i.e. a SymRef, works here + el1, el2 = im.let("foo", "foo")("foo"), im.let("bar", "bar")("bar") + testee = im.make_tuple(el1, el2) + expected = im.let(("_tuple_el_1", el1), ("_tuple_el_2", el2))( im.make_tuple("_tuple_el_1", "_tuple_el_2") ) @@ -121,6 +132,7 @@ def test_letify_make_tuple_elements(): testee, remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + allow_undeclared_symbols=True, ) assert actual == expected @@ -133,6 +145,7 @@ def test_letify_make_tuple_with_trivial_elements(): testee, remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + allow_undeclared_symbols=True, ) assert actual == expected @@ -145,12 +158,15 @@ def test_inline_trivial_make_tuple(): testee, remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE, + allow_undeclared_symbols=True, ) assert actual == expected def test_propagate_to_if_on_tuples(): - testee = im.tuple_get(0, im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4))) + testee = im.tuple_get( + 0, im.if_(im.ref("cond", "bool"), im.make_tuple(1, 2), im.make_tuple(3, 4)) + ) expected = im.if_( im.ref("cond", "bool"), im.tuple_get(0, im.make_tuple(1, 2)), @@ -160,22 +176,24 @@ def test_propagate_to_if_on_tuples(): testee, remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, ) assert actual == expected def test_propagate_to_if_on_tuples_with_let(): - testee = im.let("val", im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4)))( - im.tuple_get(0, "val") - ) + testee = im.let( + "val", im.if_(im.ref("cond", "bool"), im.make_tuple(1, 2), im.make_tuple(3, 4)) + )(im.tuple_get(0, "val")) expected = im.if_( - "cond", im.tuple_get(0, im.make_tuple(1, 2)), im.tuple_get(0, im.make_tuple(3, 4)) + im.ref("cond"), im.tuple_get(0, im.make_tuple(1, 2)), im.tuple_get(0, im.make_tuple(3, 4)) ) actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=True, flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + allow_undeclared_symbols=True, ) assert actual == expected @@ -187,6 +205,7 @@ def test_propagate_nested_lift(): testee, remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET, + allow_undeclared_symbols=True, ) assert actual == expected @@ -196,5 +215,7 @@ def test_if_on_tuples_with_let(): "val", im.if_(im.ref("cond", "bool"), im.make_tuple(1, 2), im.make_tuple(3, 4)) )(im.tuple_get(0, "val")) expected = im.if_("cond", 1, 3) - actual = CollapseTuple.apply(testee, remove_letified_make_tuple_elements=False) + actual = CollapseTuple.apply( + testee, remove_letified_make_tuple_elements=False, allow_undeclared_symbols=True + ) assert actual == expected From 129f6b362fb329a6212992c61c96bc258ce69dc4 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 18 Apr 2024 16:01:33 +0200 Subject: [PATCH 18/52] Add test for neighbor / sparse input field --- .../type_system/type_specifications.py | 12 +++- .../iterator_tests/test_type_inference.py | 58 ++++++++++++++----- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index 811ec74605..982683a5eb 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -30,7 +30,6 @@ class DomainType(ts.DataType): dims: list[common.Dimension] -# TODO: how about ts.OffsetType? @dataclasses.dataclass(frozen=True) class OffsetLiteralType(ts.TypeSpec): value: IntegralScalar | common.Dimension @@ -42,7 +41,7 @@ class ListType(ts.DataType): @dataclasses.dataclass(frozen=True) -class IteratorType(ts.DataType, ts.CallableType): # todo: rename to iterator +class IteratorType(ts.DataType, ts.CallableType): position_dims: list[common.Dimension] | typing.Literal["unknown"] defined_dims: list[common.Dimension] element_type: ts.DataType @@ -52,9 +51,16 @@ class IteratorType(ts.DataType, ts.CallableType): # todo: rename to iterator class StencilClosureType(ts.TypeSpec): domain: DomainType stencil: ts.FunctionType - output: ts.FieldType | ts.TupleType # todo: validate tuple of fields + output: ts.FieldType | ts.TupleType inputs: list[ts.FieldType] + def __post_init__(self): + # local import to avoid importing type_info from a type_specification module + from gt4py.next.type_system import type_info + + for el_type in type_info.primitive_constituents(self.output): + assert isinstance(el_type, ts.FieldType), "All constituent types must be field types." + @dataclasses.dataclass(frozen=True) class FencilType(ts.TypeSpec): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 3c78d53fe6..ed7c999a0f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -48,9 +48,10 @@ float64_list_type = it_ts.ListType(element_type=float64_type) int_list_type = it_ts.ListType(element_type=int_type) -bool_i_field = ts.FieldType(dims=[IDim], dtype=bool_type) -bool_vertex_k_field = ts.FieldType(dims=[Vertex, KDim], dtype=bool_type) -bool_edge_k_field = ts.FieldType(dims=[Edge, KDim], dtype=bool_type) +float_i_field = ts.FieldType(dims=[IDim], dtype=float64_type) +float_vertex_k_field = ts.FieldType(dims=[Vertex, KDim], dtype=float64_type) +float_edge_k_field = ts.FieldType(dims=[Edge, KDim], dtype=float64_type) +float_vertex_v2e_field = ts.FieldType(dims=[Vertex, V2EDim], dtype=float64_type) it_on_v_of_e_type = it_ts.IteratorType( position_dims=[Vertex, KDim], defined_dims=[Edge, KDim], element_type=int_type @@ -193,7 +194,7 @@ def test_cartesian_fencil_definition(): testee = itir.FencilDefinition( id="f", function_definitions=[], - params=[im.sym("inp", bool_i_field), im.sym("out", bool_i_field)], + params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], closures=[ itir.StencilClosure( domain=cartesian_domain, @@ -218,10 +219,10 @@ def test_cartesian_fencil_definition(): kw_only_args={}, returns=bool_type, ), - output=bool_i_field, - inputs=[bool_i_field], + output=float_i_field, + inputs=[float_i_field], ) - fencil_type = it_ts.FencilType(params=[bool_i_field, bool_i_field], closures=[closure_type]) + fencil_type = it_ts.FencilType(params=[float_i_field, float_i_field], closures=[closure_type]) assert result.type == fencil_type assert result.closures[0].type == closure_type @@ -236,7 +237,7 @@ def test_unstructured_fencil_definition(): testee = itir.FencilDefinition( id="f", function_definitions=[], - params=[im.sym("inp", bool_edge_k_field), im.sym("out", bool_vertex_k_field)], + params=[im.sym("inp", float_edge_k_field), im.sym("out", float_vertex_k_field)], closures=[ itir.StencilClosure( domain=unstructured_domain, @@ -261,11 +262,11 @@ def test_unstructured_fencil_definition(): kw_only_args={}, returns=bool_type, ), - output=bool_vertex_k_field, - inputs=[bool_edge_k_field], + output=float_vertex_k_field, + inputs=[float_edge_k_field], ) fencil_type = it_ts.FencilType( - params=[bool_edge_k_field, bool_vertex_k_field], closures=[closure_type] + params=[float_edge_k_field, float_vertex_k_field], closures=[closure_type] ) assert result.type == fencil_type assert result.closures[0].type == closure_type @@ -282,7 +283,7 @@ def test_function_definition(): itir.FunctionDefinition(id="foo", params=[im.sym("it")], expr=im.deref("it")), itir.FunctionDefinition(id="bar", params=[im.sym("it")], expr=im.call("foo")("it")), ], - params=[im.sym("inp", bool_i_field), im.sym("out", bool_i_field)], + params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], closures=[ itir.StencilClosure( domain=cartesian_domain, @@ -307,9 +308,36 @@ def test_function_definition(): kw_only_args={}, returns=bool_type, ), - output=bool_i_field, - inputs=[bool_i_field], + output=float_i_field, + inputs=[float_i_field], ) - fencil_type = it_ts.FencilType(params=[bool_i_field, bool_i_field], closures=[closure_type]) + fencil_type = it_ts.FencilType(params=[float_i_field, float_i_field], closures=[closure_type]) assert result.type == fencil_type assert result.closures[0].type == closure_type + + +def test_fencil_with_nb_field_input(): + mesh = simple_mesh() + unstructured_domain = im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), + im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), + ) + + testee = itir.FencilDefinition( + id="f", + function_definitions=[], + params=[im.sym("inp", float_vertex_v2e_field), im.sym("out", float_vertex_k_field)], + closures=[ + itir.StencilClosure( + domain=unstructured_domain, + stencil=im.lambda_("it")(im.call(im.call("reduce")("plus", 0.0))(im.deref("it"))), + output=im.ref("out"), + inputs=[im.ref("inp")], + ), + ], + ) + + result = itir_type_inference.infer(testee, offset_provider=mesh.offset_provider) + + assert result.closures[0].stencil.expr.args[0].type == float64_list_type + assert result.closures[0].stencil.type.returns == float64_type From a99ca9a9db498c4798a82cd6423981020f12ad12 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 19 Apr 2024 17:37:34 +0200 Subject: [PATCH 19/52] Multiple fixes --- .../next/iterator/transforms/global_tmps.py | 46 ++++++---- .../iterator/transforms/symbol_ref_utils.py | 78 +++++++++++++++-- .../next/iterator/transforms/trace_shifts.py | 6 +- .../next/iterator/type_system/inference.py | 87 +++++++++++++++---- src/gt4py/next/iterator/type_system/rules.py | 5 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 21 +++-- .../next/program_processors/runners/dace.py | 1 + .../runners/dace_iterator/itir_to_sdfg.py | 51 ++++------- .../runners/dace_iterator/utility.py | 8 -- .../next/type_system/type_specifications.py | 4 +- .../test_horizontal_indirection.py | 4 +- .../iterator_tests/test_type_inference.py | 3 + 12 files changed, 207 insertions(+), 107 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 6193b02a87..1e929c4b19 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -115,18 +115,20 @@ def canonicalize_applied_lift(closure_params: list[str], node: ir.FunCall) -> ir >>> print(canonicalize_applied_lift(["inp"], expr)) (↑(λ(inp) → (λ(a) → ·a)((↑deref)(inp))))(inp) """ - assert ( - isinstance(node, ir.FunCall) - and isinstance(node.fun, ir.FunCall) - and node.fun.fun == ir.SymRef(id="lift") - ) - stencil = node.fun.args[0] + assert cpm.is_applied_lift(node) + stencil = node.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied lift it_args = node.args if any(not isinstance(it_arg, ir.SymRef) for it_arg in it_args): - used_closure_params = collect_symbol_refs(node) - assert not (set(used_closure_params) - set(closure_params)) - return im.lift(im.lambda_(*used_closure_params)(im.call(stencil)(*it_args)))( - *used_closure_params + closure_param_refs = collect_symbol_refs(node, as_ref=True) + assert not ({str(ref.id) for ref in closure_param_refs} - set(closure_params)) + new_node = im.lift( + im.lambda_(*[im.sym(param.id) for param in closure_param_refs])( + im.call(stencil)(*it_args) + ) + )(*closure_param_refs) + # add types again + return itir_type_inference.infer( + new_node, inplace=True, allow_undeclared_symbols=True, offset_provider={} ) return node @@ -147,7 +149,8 @@ def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: return False # do not extract when the result is a list (i.e. a lift expression used in a `reduce` call) # as we can not create temporaries for these stencils - if isinstance(expr.type, it_ts.ListType): + assert isinstance(expr.type, it_ts.IteratorType) + if isinstance(expr.type.element_type, it_ts.ListType): return False if self.heuristics and not self.heuristics(expr): return False @@ -175,6 +178,7 @@ def closure_shifts(self): return trace_shifts.TraceShifts.apply(self.closure, inputs_only=False) def __call__(self, expr: ir.Expr) -> bool: + return True shifts = self.closure_shifts[id(expr)] if len(shifts) > 1: return True @@ -280,15 +284,17 @@ def always_extract_heuristics(_): ) # make sure the arguments to the applied lift are only symbols - # (otherwise we would need to canonicalize using `canonicalize_applied_lift` - # this doesn't seem to be necessary right now as we extract the lifts - # in post-order of the tree) + if not all(isinstance(arg, ir.SymRef) for arg in lift_expr.args): + lift_expr = canonicalize_applied_lift( + [str(param.id) for param in current_closure_stencil.params], lift_expr + ) assert all(isinstance(arg, ir.SymRef) for arg in lift_expr.args) # create a mapping from the closures parameters to the closure arguments closure_param_arg_mapping = _closure_parameter_argument_mapping(current_closure) - stencil: ir.Node = lift_expr.fun.args[0] # usually an ir.Lambda or scan + # usually an ir.Lambda or scan + stencil: ir.Node = lift_expr.fun.args[0] # type: ignore[attr-defined] # ensured by canonicalize_applied_lift # allocate a new temporary assert isinstance(stencil.type, ts.FunctionType) @@ -344,9 +350,7 @@ def always_extract_heuristics(_): fencil=ir.FencilDefinition( id=node.id, function_definitions=node.function_definitions, - params=node.params - + [im.sym(name, type_) for name, type_ in tmps] - + [im.sym(AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant + params=node.params + [im.sym(name) for name, _ in tmps] + [im.sym(AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant closures=list(reversed(closures)), location=node.location, ), @@ -598,13 +602,17 @@ def collect_tmps_info(node: FencilWithTemporaries, *, offset_provider) -> Fencil assert output_field.id not in domains or domains[output_field.id] == closure.domain domains[output_field.id] = closure.domain - return FencilWithTemporaries( + new_node = FencilWithTemporaries( fencil=node.fencil, params=node.params, tmps=[ ir.Temporary(id=tmp.id, domain=domains[tmp.id], dtype=tmp.dtype) for tmp in node.tmps ], ) + # TODO(tehrengruber): type inference is only really needed to infer the types of the temporaries + # and write them to the params of the inner fencil. This should be cleaned up after we + # refactored the IR. + return itir_type_inference.infer(new_node, offset_provider=offset_provider) def validate_no_dynamic_offsets(node: ir.Node): diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 05d137e8c4..28b29940d3 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses +import typing from collections import defaultdict from typing import Iterable, Optional, Sequence @@ -22,8 +23,9 @@ @dataclasses.dataclass class CountSymbolRefs(eve.PreserveLocationVisitor, eve.NodeVisitor): - ref_counts: dict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int)) + ref_counts: dict[itir.SymRef, int] = dataclasses.field(default_factory=dict) + @typing.overload @classmethod def apply( cls, @@ -31,7 +33,29 @@ def apply( symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, - ) -> dict[str, int]: + as_ref: typing.Literal[False] = False, + ) -> dict[str, int]: ... + + @typing.overload + @classmethod + def apply( + cls, + node: itir.Node | Sequence[itir.Node], + symbol_names: Optional[Iterable[str]] = None, + *, + ignore_builtins: bool = True, + as_ref: typing.Literal[True], + ) -> dict[itir.SymRef, int]: ... + + @classmethod + def apply( + cls, + node: itir.Node | Sequence[itir.Node], + symbol_names: Optional[Iterable[str]] = None, + *, + ignore_builtins: bool = True, + as_ref: bool = False, + ) -> dict[str, int] | dict[itir.SymRef, int]: """ Count references to given or all symbols in scope. @@ -45,6 +69,11 @@ def apply( >>> CountSymbolRefs.apply(expr, symbol_names=["x", "z"]) {'x': 2, 'z': 1} + + In some cases, e.g. when the type of the reference is required, the references instead + of strings can be retrieved. + >>> CountSymbolRefs.apply(expr, as_ref=True) + {SymRef(id=SymbolRef('x')): 2, SymRef(id=SymbolRef('y')): 2, SymRef(id=SymbolRef('z')): 1} """ if ignore_builtins: inactive_refs = {str(n.id) for n in itir.FencilDefinition._NODE_SYMBOLS_} @@ -55,12 +84,22 @@ def apply( obj.visit(node, inactive_refs=inactive_refs) if symbol_names: - return {k: obj.ref_counts.get(k, 0) for k in symbol_names} - return dict(obj.ref_counts) + ref_counts = {k: v for k, v in obj.ref_counts.items() if k.id in symbol_names} + else: + ref_counts = obj.ref_counts + + result: dict[str, int] | dict[itir.SymRef, int] + if as_ref: + result = ref_counts + else: + result = {str(k.id): v for k, v in ref_counts.items()} + + return defaultdict(int, result) def visit_SymRef(self, node: itir.SymRef, *, inactive_refs: set[str]): if node.id not in inactive_refs: - self.ref_counts[str(node.id)] += 1 + self.ref_counts.setdefault(node, 0) + self.ref_counts[node] += 1 def visit_Lambda(self, node: itir.Lambda, *, inactive_refs: set[str]): inactive_refs = inactive_refs | {param.id for param in node.params} @@ -68,16 +107,41 @@ def visit_Lambda(self, node: itir.Lambda, *, inactive_refs: set[str]): self.generic_visit(node, inactive_refs=inactive_refs) +@typing.overload +def collect_symbol_refs( + node: itir.Node | Sequence[itir.Node], + symbol_names: Optional[Iterable[str]] = None, + *, + ignore_builtins: bool = True, + as_ref: typing.Literal[False] = False, +) -> list[str]: ... + + +@typing.overload +def collect_symbol_refs( + node: itir.Node | Sequence[itir.Node], + symbol_names: Optional[Iterable[str]] = None, + *, + ignore_builtins: bool = True, + as_ref: typing.Literal[True], +) -> list[itir.SymRef]: ... + + def collect_symbol_refs( node: itir.Node | Sequence[itir.Node], symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, -) -> list[str]: + as_ref: bool = False, +): + assert as_ref in [True, False] return [ symbol_name for symbol_name, count in CountSymbolRefs.apply( - node, symbol_names, ignore_builtins=ignore_builtins + node, + symbol_names, + ignore_builtins=ignore_builtins, + as_ref=typing.cast(typing.Literal[True, False], as_ref), ).items() if count > 0 ] diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 925dbb8f43..17a55be4a2 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -51,10 +51,10 @@ def copy_recorded_shifts(from_: ir.Node, to: ir.Node) -> None: class Sentinel(enum.Enum): - VALUE = object() - TYPE = object() + VALUE = enum.auto() + TYPE = enum.auto() - ALL_NEIGHBORS = object() + ALL_NEIGHBORS = enum.auto() @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 551fcc14cc..4093581d06 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -11,7 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - +import collections.abc import copy import dataclasses import functools @@ -23,6 +23,7 @@ from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_call_to +from gt4py.next.iterator.transforms import global_tmps from gt4py.next.iterator.type_system import rules, type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts @@ -39,9 +40,11 @@ def _is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec): """ Predicate to determine if two types are compatible. - This function gracefully handles iterators with unknown positions which are considered - compatible to any other positions of another iterator. Beside that this function - simply checks for equality of types. + This function gracefully handles: + - iterators with unknown positions which are considered compatible to any other positions + of another iterator. + - iterators which are defined everywhere, i.e. empty defined dimensions + Beside that this function simply checks for equality of types. >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) >>> IDim = common.Dimension(value="IDim") @@ -68,7 +71,8 @@ def is_compatible_element(el_type_a: ts.TypeSpec, el_type_b: ts.TypeSpec): if isinstance(el_type_a, it_ts.IteratorType) and isinstance(el_type_b, it_ts.IteratorType): if not any(el_type.position_dims == "unknown" for el_type in [el_type_a, el_type_b]): is_compatible &= el_type_a.position_dims == el_type_b.position_dims - is_compatible &= el_type_a.defined_dims == el_type_b.defined_dims + if el_type_a.defined_dims and el_type_b.defined_dims: + is_compatible &= el_type_a.defined_dims == el_type_b.defined_dims is_compatible &= el_type_a.element_type == el_type_b.element_type else: is_compatible &= el_type_a == el_type_b @@ -84,8 +88,10 @@ def is_compatible_element(el_type_a: ts.TypeSpec, el_type_b: ts.TypeSpec): # - write back params type in lambda # - documentation # describe why lambda can only have one type. Describe idea to solve e.g. -# let("f", lambda x: x)(f(1)+f(1.)) -# -> let("f_int", lambda x: x, "f_float", lambda x: x)(f_int(1)+f_float(1.)) +# `let("f", lambda x: x)(f(1)+f(1.)) +# -> let("f_int", lambda x: x, "f_float", lambda x: x)(f_int(1)+f_float(1.))` +# describe where this is needed, e.g.: +# `if_(cond, fun_tail(it_on_vertex), fun_tail(it_on_vertex_k))` # - document how scans are handled (also mention to Hannes) # - types are stored in the node, but will be incomplete after some passes # Design decisions @@ -202,16 +208,20 @@ def _get_dimensions_from_offset_provider(offset_provider) -> dict[str, common.Di def _get_dimensions_from_types(types) -> dict[str, common.Dimension]: - dimensions: list[common.Dimension] = [] - - def _get_dimensions(el_type): - if isinstance(el_type, ts.FieldType): - dimensions.extend(el_type.dims) - - for type_ in types: - type_info.apply_to_primitive_constituents(_get_dimensions, type_) - - return {dim.value: dim for dim in dimensions} + def _get_dimensions(obj: Any): + if isinstance(obj, common.Dimension): + yield obj + elif isinstance(obj, ts.TypeSpec): + for field in dataclasses.fields(obj.__class__): + yield from _get_dimensions(getattr(obj, field.name)) + elif isinstance(obj, collections.abc.Mapping): + for el in obj.values(): + yield from _get_dimensions(el) + elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, str): + for el in obj: + yield from _get_dimensions(el) + + return {dim.value: dim for dim in _get_dimensions(types)} T = TypeVar("T", bound=itir.Node) @@ -255,6 +265,14 @@ def extract_dtype_and_dims(el_type: ts.TypeSpec): ) +def _type_inference_rule_from_function_type(fun_type: ts.FunctionType): + def type_rule(*args, **kwargs): + assert type_info.accepts_args(fun_type, with_args=args, with_kwargs=kwargs) + return fun_type.returns + + return type_rule + + @dataclasses.dataclass class ITIRTypeInference(eve.NodeTranslator): """ @@ -281,7 +299,11 @@ def apply( dimensions=( _get_dimensions_from_offset_provider(offset_provider) | _get_dimensions_from_types( - node.pre_walk_values().if_isinstance(itir.Node).getattr("type").if_is_not(None) + node.pre_walk_values() + .if_isinstance(itir.Node) + .getattr("type") + .if_is_not(None) + .to_list() ) ), allow_undeclared_symbols=allow_undeclared_symbols, @@ -338,6 +360,32 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx) -> it_ts.F closures = self.visit(node.closures, ctx=ctx | params | function_definitions) return it_ts.FencilType(params=list(params.values()), closures=closures) + def visit_FencilWithTemporaries(self, node: global_tmps.FencilWithTemporaries, *, ctx): + # TODO(tehrengruber): This implementation is not very appealing. Since we are about to + # refactor the PR anyway this is fine for now. + params: dict[str, ts.DataType] = {} + for param in node.params: + assert isinstance(param.type, ts.DataType) + params[param.id] = param.type + # infer types of temporary declarations + tmps: dict[str, ts.FieldType] = {} + for tmp_node in node.tmps: + tmps[tmp_node.id] = self.visit(tmp_node, ctx=ctx | params) + # and store them in the inner fencil + for fencil_param in node.fencil.params: + if fencil_param.id in tmps: + fencil_param.type = tmps[fencil_param.id] + self.visit(node.fencil, ctx=ctx) + return node.fencil.type + + def visit_Temporary(self, node: itir.Temporary, *, ctx): + domain = self.visit(node.domain, ctx=ctx) + assert isinstance(domain, it_ts.DomainType) + assert node.dtype + return type_info.apply_to_primitive_constituents( + lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), node.dtype + ) + def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.StencilClosureType: domain: it_ts.DomainType = self.visit(node.domain, ctx=ctx) inputs: list[ts.FieldType] = self.visit(node.inputs, ctx=ctx) @@ -389,8 +437,11 @@ def visit_SymRef( if self.allow_undeclared_symbols and node.id not in ctx: # type has been stored in the node itself if node.type: + if isinstance(node.type, ts.FunctionType): + return _type_inference_rule_from_function_type(node.type) return node.type return ts.DeferredType(constraint=None) + assert node.id in ctx result = ctx[node.id] if isinstance(result, ObservableTypeInferenceRule): result.aliases.append(node) diff --git a/src/gt4py/next/iterator/type_system/rules.py b/src/gt4py/next/iterator/type_system/rules.py index 61650629b5..e225acc4b2 100644 --- a/src/gt4py/next/iterator/type_system/rules.py +++ b/src/gt4py/next/iterator/type_system/rules.py @@ -104,7 +104,10 @@ def can_deref(it: it_ts.IteratorType) -> ts.ScalarType: @_register_type_inference_rule def if_(cond: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType: assert isinstance(cond, ts.ScalarType) and cond.kind == ts.ScalarKind.BOOL - assert true_branch == false_branch + # TODO(tehrengruber): Enable this or a similar check. In case the true- and false-branch are + # iterators defined on different positions this fails. For the GTFN backend we also don't + # want this, but for roundtrip it is totally fine. + # assert true_branch == false_branch # noqa: ERA001 return true_branch diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index e5ba965e3b..dfcb70cb4f 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -45,10 +45,12 @@ UnstructuredDomain, ) from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Node, Sym, SymRef -from gt4py.next.type_system import type_info +from gt4py.next.type_system import type_info, type_specifications as ts -def pytype_to_cpptype(t: str) -> Optional[str]: +def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]: + if isinstance(t, ts.ScalarType): + t = t.kind.name.lower() try: return { "float32": "float", @@ -550,19 +552,16 @@ def visit_Program(self, node: itir.Program, **kwargs: Any) -> Program: def visit_Temporary( self, node: itir.Temporary, *, params: list, **kwargs: Any ) -> TemporaryAllocation: - def dtype_to_cpp(x: int | tuple | str) -> str: - if isinstance(x, int): - return f"std::remove_const_t<::gridtools::sid::element_type>" - if isinstance(x, tuple): - return "::gridtools::tuple<" + ", ".join(dtype_to_cpp(i) for i in x) + ">" - assert isinstance(x, str) + def dtype_to_cpp(x: ts.DataType) -> str: + if isinstance(x, ts.TupleType): + assert all(isinstance(i, ts.ScalarType) for i in x.types) + return "::gridtools::tuple<" + ", ".join(dtype_to_cpp(i) for i in x.types) + ">" + assert isinstance(x, ts.ScalarType) res = pytype_to_cpptype(x) assert isinstance(res, str) return res - assert isinstance( - node.dtype, (int, tuple, str) - ) # TODO(havogt): this looks weird, consider refactoring + assert node.dtype return TemporaryAllocation( id=node.id, dtype=dtype_to_cpp(node.dtype), domain=self.visit(node.domain, **kwargs) ) diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py index 266d7b3530..23e95f7446 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace.py @@ -86,6 +86,7 @@ class Params: run_dace_cpu_with_temporaries = DaCeBackendFactory( cached=True, auto_optimize=True, use_temporaries=True ) +run_dace_cpu = run_dace_cpu_with_temporaries run_dace_gpu = DaCeBackendFactory(gpu=True, cached=True, auto_optimize=True) run_dace_gpu_with_temporaries = DaCeBackendFactory( diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 1b7e435982..00248fbe2d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -38,7 +38,6 @@ from .utility import ( add_mapped_nested_sdfg, as_dace_type, - as_scalar_type, connectivity_identifier, dace_debuginfo, filter_connectivities, @@ -196,14 +195,10 @@ def add_storage_for_temporaries( symbol_map: dict[str, TaskletExpr] = {} # The shape of temporary arrays might be defined based on scalar values passed as program arguments. # Here we collect these values in a symbol map. - tmp_ids = set(tmp.id for tmp in self.tmps) for sym in node_params: - breakpoint() - if sym.id not in tmp_ids and sym.kind != "Iterator": + if isinstance(sym.type, ts.ScalarType): name_ = str(sym.id) - type_ = self.storage_types[name_] - assert isinstance(type_, ts.ScalarType) - symbol_map[name_] = SymbolExpr(name_, as_dace_type(type_)) + symbol_map[name_] = SymbolExpr(name_, as_dace_type(sym.type)) tmp_symbols: dict[str, str] = {} for tmp in self.tmps: @@ -211,47 +206,33 @@ def add_storage_for_temporaries( # We visit the domain of the temporary field, passing the set of available symbols. assert isinstance(tmp.domain, itir.FunCall) - breakpoint() - self.node_types.update(itir_typing.infer_all(tmp.domain)) domain_ctx = Context(program_sdfg, defs_state, symbol_map) tmp_domain = self._visit_domain(tmp.domain, domain_ctx) - # We build the FieldType for this temporary array. - dims: list[Dimension] = [] - for dim, _ in tmp_domain: - dims.append( - Dimension( - value=dim, - kind=( - DimensionKind.VERTICAL - if self.column_axis is not None and self.column_axis.value == dim - else DimensionKind.HORIZONTAL - ), - ) - ) - assert isinstance(tmp.dtype, str) - type_ = ts.FieldType(dims=dims, dtype=as_scalar_type(tmp.dtype)) - self.storage_types[tmp_name] = type_ + if isinstance(tmp.type, ts.TupleType): + raise NotImplementedError("Temporaries of tuples are not supported.") + assert isinstance(tmp.type, ts.FieldType) and isinstance(tmp.dtype, ts.ScalarType) + + # We store the FieldType for this temporary array. + self.storage_types[tmp_name] = tmp.type # N.B.: skip generation of symbolic strides and just let dace assign default strides, for now. # Another option, in the future, is to use symbolic strides and apply auto-tuning or some heuristics # to assign optimal stride values. - tmp_shape, _ = new_array_symbols(tmp_name, len(dims)) + tmp_shape, _ = new_array_symbols(tmp_name, len(tmp.type.dims)) _, tmp_array = program_sdfg.add_array( - tmp_name, tmp_shape, as_dace_type(type_.dtype), transient=True + tmp_name, tmp_shape, as_dace_type(tmp.dtype), transient=True ) # Loop through all dimensions to visit the symbolic expressions for array shape and offset. # These expressions are later mapped to interstate symbols. for (_, (begin, end)), shape_sym in zip(tmp_domain, tmp_array.shape): - """ - The temporary field has a dimension range defined by `begin` and `end` values. - Therefore, the actual size is given by the difference `end.value - begin.value`. - Instead of allocating the actual size, we allocate space to enable indexing from 0 - because we want to avoid using dace array offsets (which will be deprecated soon). - The result should still be valid, but the stencil will be using only a subset - of the array. - """ + # The temporary field has a dimension range defined by `begin` and `end` values. + # Therefore, the actual size is given by the difference `end.value - begin.value`. + # Instead of allocating the actual size, we allocate space to enable indexing from 0 + # because we want to avoid using dace array offsets (which will be deprecated soon). + # The result should still be valid, but the stencil will be using only a subset + # of the array. if not (isinstance(begin, SymbolExpr) and begin.value == "0"): warnings.warn( f"Domain start offset for temporary {tmp_name} is ignored.", stacklevel=2 diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 5f852b2838..a5276f7da4 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -52,14 +52,6 @@ def as_dace_type(type_: ts.ScalarType) -> dace.dtypes.typeclass: raise ValueError(f"Scalar type '{type_}' not supported.") -def as_scalar_type(typestr: str) -> ts.ScalarType: - try: - kind = getattr(ts.ScalarKind, typestr.upper()) - except AttributeError as ex: - raise ValueError(f"Data type {typestr} not supported.") from ex - return ts.ScalarType(kind) - - def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Connectivity]: return { offset: table diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 7eafcf237e..3dc2a13a60 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -19,15 +19,15 @@ from gt4py.next import common as func_common +@dataclass(frozen=True) class TypeSpec: def __hash__(self) -> int: return hash(content_hash(self)) def __init_subclass__(cls) -> None: - cls.__hash__ = TypeSpec.__hash__ + cls.__hash__ = TypeSpec.__hash__ # type: ignore[method-assign] -@dataclass(frozen=True) class DataType(TypeSpec): """ A base type for all types that represent data storage. diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py index 67927437bc..8c80a86b65 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py @@ -30,11 +30,9 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import fundef, offset -from gt4py.next.program_processors.formatters import type_check -from gt4py.next.program_processors.formatters.gtfn import format_cpp as gtfn_format_sourcecode from next_tests.integration_tests.cases import IDim -from next_tests.unit_tests.conftest import program_processor, run_processor +from next_tests.unit_tests.conftest import run_processor I = offset("I") diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index ed7c999a0f..3921c698fe 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -341,3 +341,6 @@ def test_fencil_with_nb_field_input(): assert result.closures[0].stencil.expr.args[0].type == float64_list_type assert result.closures[0].stencil.type.returns == float64_type + + +# TODO(tehrengruber): add tests for itir.Program From 896582ec8a52f8dd3f1b3a56675f1c6ae131f60d Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 19 Apr 2024 18:06:42 +0200 Subject: [PATCH 20/52] Multiple fixes --- src/gt4py/next/iterator/pretty_parser.py | 9 +++++- src/gt4py/next/iterator/pretty_printer.py | 2 +- .../next/iterator/type_system/inference.py | 8 ++++-- .../runners/dace_iterator/itir_to_tasklet.py | 4 ++- src/gt4py/next/type_system/type_info.py | 2 -- .../ffront_tests/ffront_test_utils.py | 1 - .../iterator_tests/test_anton_toy.py | 28 +++++++++++-------- .../iterator_tests/test_pretty_parser.py | 2 +- .../iterator_tests/test_pretty_printer.py | 5 ++-- 9 files changed, 38 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index 3b7a2522a1..05e618d8c1 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -18,6 +18,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.type_system import type_specifications as ts GRAMMAR = """ @@ -31,6 +32,7 @@ SYM: CNAME SYM_REF: CNAME + TYPE_LITERAL: CNAME INT_LITERAL: SIGNED_INT FLOAT_LITERAL: SIGNED_FLOAT OFFSET_LITERAL: ( INT_LITERAL | CNAME ) "ₒ" @@ -84,7 +86,7 @@ named_range: AXIS_NAME ":" "[" prec0 "," prec0 ")" function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";" - declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" prec0 ")" ";" + declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";" stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";" set_at: prec0 "@" prec0 "←" prec1 ";" fencil_definition: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( stencil_closure )+ "}" @@ -111,6 +113,11 @@ def INT_LITERAL(self, value: lark_lexer.Token) -> ir.Literal: def FLOAT_LITERAL(self, value: lark_lexer.Token) -> ir.Literal: return im.literal(value.value, "float64") + def TYPE_LITERAL(self, value: lark_lexer.Token) -> ts.TypeSpec: + if hasattr(ts.ScalarKind, value.upper()): + return ts.ScalarType(kind=getattr(ts.ScalarKind, value.upper())) + raise NotImplementedError(f"Type {value} not supported.") + def OFFSET_LITERAL(self, value: lark_lexer.Token) -> ir.OffsetLiteral: v: Union[int, str] = value.value[:-1] try: diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 3f224d2ef4..57d8ee637e 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -279,7 +279,7 @@ def visit_Temporary(self, node: ir.Temporary, *, prec: int) -> list[str]: if node.domain is not None: args.append(self._hmerge(["domain="], self.visit(node.domain, prec=0))) if node.dtype is not None: - args.append(self._hmerge(["dtype="], [str(node.dtype)])) + args.append(self._hmerge(["dtype="], [str(node.dtype.kind.name.lower())])) hargs = self._hmerge(*self._hinterleave(args, ", ")) vargs = self._vmerge(*self._hinterleave(args, ",")) oargs = self._optimum(hargs, vargs) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 4093581d06..4a59fd83b0 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -97,7 +97,10 @@ def is_compatible_element(el_type_a: ts.TypeSpec, el_type_b: ts.TypeSpec): # Design decisions # Only the parameters of fencils need to be typed. # Lambda functions are not polymorphic. - +def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: + if node.type: + assert _is_compatible_type(node.type, type_), "Node already has a type which differs." + node.type = type_ def on_inferred( callback: Callable, *args: Union[ts.TypeSpec, "ObservableTypeInferenceRule"] @@ -161,9 +164,10 @@ def _infer_type_listener(self, return_type: ts.TypeSpec, *args: ts.TypeSpec) -> if self.store_inferred_type_in_node: assert self.node + _set_node_type(self.node, self.inferred_type) self.node.type = self.inferred_type for alias in self.aliases: - alias.type = self.inferred_type + _set_node_type(alias, self.inferred_type) def on_type_ready(self, cb: Callable[[ts.TypeSpec], None]) -> None: if self.inferred_type: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 6da819d20d..2e732ac863 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -55,7 +55,9 @@ def itir_type_as_dace_type(type_: ts.TypeSpec): - # TODO: cleanup + # TODO(tehrengruber): this function just converts the scalar type of whatever it is given, + # let it be a field, iterator, or directly a scalar. The caller should take care of the + # extraction. dtype: ts.TypeSpec if isinstance(type_, ts.FieldType): dtype = type_.dtype diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 646912bf3b..12f7d8595d 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -175,8 +175,6 @@ def apply_to_primitive_constituents( ... ) {(0,): ScalarType(kind=, shape=None), (1,): ScalarType(kind=, shape=None)} """ - - # TODO: check structure matches if isinstance(symbol_types[0], ts.TupleType): assert all(isinstance(symbol_type, ts.TupleType) for symbol_type in symbol_types) return tuple_constructor( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 943da56427..388849bf09 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -236,7 +236,6 @@ def simple_mesh() -> MeshDescriptor: C2E.value: gtx.NeighborTableOffsetProvider( c2e_arr, Cell, Edge, 4, has_skip_values=False ), - # "KDim": KDim }, ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 0cf6e61b27..291fda560a 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -49,7 +49,7 @@ def lap(inp): @fundef -def lap2(inp): +def lap_flat(inp): return -4.0 * deref(inp) + ( deref(shift(i, 1)(inp)) + deref(shift(i, -1)(inp)) @@ -63,16 +63,6 @@ def lap2(inp): KDim = gtx.Dimension("KDim") -@fendef(offset_provider={"i": IDim, "j": JDim}) -def fencil(x, y, z, out, inp): - closure( - cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z)), - lap, - out, - [inp], - ) - - def naive_lap(inp): shape = [inp.shape[0] - 2, inp.shape[1] - 2, inp.shape[2]] out = np.zeros(shape) @@ -90,7 +80,8 @@ def naive_lap(inp): @pytest.mark.uses_origin -def test_anton_toy(program_processor, lift_mode): +@pytest.mark.parametrize("stencil", [lap, lap_flat]) +def test_anton_toy(stencil, program_processor, lift_mode): program_processor, validate = program_processor if program_processor in [ @@ -103,6 +94,19 @@ def test_anton_toy(program_processor, lift_mode): if lift_mode != transforms.LiftMode.FORCE_INLINE: pytest.xfail("TODO: issue with temporaries that crashes the application") + if stencil is lap: + pytest.xfail("Type inference does not support calling lambdas with offset arguments of changing type.") + + @fendef(offset_provider={"i": IDim, "j": JDim}) + def fencil(x, y, z, out, inp): + closure( + cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y), + named_range(KDim, 0, z)), + stencil, + out, + [inp], + ) + shape = [5, 7, 9] rng = np.random.default_rng() inp = gtx.as_field( diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index 5ec80aefba..4cbf32c1f2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -257,7 +257,7 @@ def test_program(): ir.Temporary( id="tmp", domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - dtype=ir.SymRef(id="float64"), + dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64), ), ], body=[ diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 44308473b7..666250cace 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -15,6 +15,7 @@ from gt4py.next.iterator import ir from gt4py.next.iterator.pretty_printer import PrettyPrinter, pformat from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.type_system import type_specifications as ts def test_hmerge(): @@ -296,7 +297,7 @@ def test_function_definition(): def test_temporary(): - testee = ir.Temporary(id="t", domain=ir.SymRef(id="domain"), dtype="float64") + testee = ir.Temporary(id="t", domain=ir.SymRef(id="domain"), dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) expected = "t = temporary(domain=domain, dtype=float64);" actual = pformat(testee) assert actual == expected @@ -358,7 +359,7 @@ def test_program(): ir.Temporary( id="tmp", domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - dtype="float64", + dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64), ), ], body=[ From c8fb2d929346bd3f04744f2a904ec1f25178ff61 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 19 Apr 2024 18:08:19 +0200 Subject: [PATCH 21/52] Cleanup --- src/gt4py/next/iterator/pretty_printer.py | 2 ++ src/gt4py/next/iterator/type_system/inference.py | 1 + .../multi_feature_tests/iterator_tests/test_anton_toy.py | 9 ++++++--- .../unit_tests/iterator_tests/test_pretty_printer.py | 4 +++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 57d8ee637e..80599e983e 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -25,6 +25,7 @@ from gt4py.eve import NodeTranslator from gt4py.next.iterator import ir +from gt4py.next.type_system import type_specifications as ts # replacements for builtin binary operations @@ -279,6 +280,7 @@ def visit_Temporary(self, node: ir.Temporary, *, prec: int) -> list[str]: if node.domain is not None: args.append(self._hmerge(["domain="], self.visit(node.domain, prec=0))) if node.dtype is not None: + assert isinstance(node.dtype, ts.ScalarType) args.append(self._hmerge(["dtype="], [str(node.dtype.kind.name.lower())])) hargs = self._hmerge(*self._hinterleave(args, ", ")) vargs = self._vmerge(*self._hinterleave(args, ",")) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 4a59fd83b0..12e403e0ae 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -102,6 +102,7 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: assert _is_compatible_type(node.type, type_), "Node already has a type which differs." node.type = type_ + def on_inferred( callback: Callable, *args: Union[ts.TypeSpec, "ObservableTypeInferenceRule"] ) -> None: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 291fda560a..670536eafc 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -95,13 +95,16 @@ def test_anton_toy(stencil, program_processor, lift_mode): pytest.xfail("TODO: issue with temporaries that crashes the application") if stencil is lap: - pytest.xfail("Type inference does not support calling lambdas with offset arguments of changing type.") + pytest.xfail( + "Type inference does not support calling lambdas with offset arguments of changing type." + ) @fendef(offset_provider={"i": IDim, "j": JDim}) def fencil(x, y, z, out, inp): closure( - cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y), - named_range(KDim, 0, z)), + cartesian_domain( + named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z) + ), stencil, out, [inp], diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 666250cace..0f4ac4d2c7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -297,7 +297,9 @@ def test_function_definition(): def test_temporary(): - testee = ir.Temporary(id="t", domain=ir.SymRef(id="domain"), dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + testee = ir.Temporary( + id="t", domain=ir.SymRef(id="domain"), dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) expected = "t = temporary(domain=domain, dtype=float64);" actual = pformat(testee) assert actual == expected From cb0ebbbc1573ec353cbf33401500d475bd3ed4d3 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 19 Apr 2024 18:21:19 +0200 Subject: [PATCH 22/52] Cleanup --- .../next/iterator/type_system/inference.py | 31 ++++++++++++------- .../transforms_tests/test_global_tmps.py | 19 ++++++------ 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 12e403e0ae..4835eb3b8e 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -66,18 +66,25 @@ def _is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec): """ is_compatible = True - def is_compatible_element(el_type_a: ts.TypeSpec, el_type_b: ts.TypeSpec): - nonlocal is_compatible - if isinstance(el_type_a, it_ts.IteratorType) and isinstance(el_type_b, it_ts.IteratorType): - if not any(el_type.position_dims == "unknown" for el_type in [el_type_a, el_type_b]): - is_compatible &= el_type_a.position_dims == el_type_b.position_dims - if el_type_a.defined_dims and el_type_b.defined_dims: - is_compatible &= el_type_a.defined_dims == el_type_b.defined_dims - is_compatible &= el_type_a.element_type == el_type_b.element_type - else: - is_compatible &= el_type_a == el_type_b - - type_info.apply_to_primitive_constituents(is_compatible_element, type_a, type_b) + if isinstance(type_a, it_ts.IteratorType) and isinstance(type_b, it_ts.IteratorType): + if not any(el_type.position_dims == "unknown" for el_type in [type_a, type_b]): + is_compatible &= type_a.position_dims == type_b.position_dims + if type_a.defined_dims and type_b.defined_dims: + is_compatible &= type_a.defined_dims == type_b.defined_dims + is_compatible &= type_a.element_type == type_b.element_type + elif isinstance(type_a, ts.TupleType): + for el_type_a, el_type_b in zip(type_a.types, type_b.types, strict=True): + is_compatible &= _is_compatible_type(el_type_a, el_type_b) + elif isinstance(type_a, ts.FunctionType): + for arg_a, arg_b in zip(type_a.pos_only_args, type_b.pos_only_args, strict=True): + is_compatible &= _is_compatible_type(arg_a, arg_b) + for arg_a, arg_b in zip(type_a.pos_or_kw_args.values(), type_b.pos_or_kw_args.values(), strict=True): + is_compatible &= _is_compatible_type(arg_a, arg_b) + for arg_a, arg_b in zip(type_a.kw_only_args.values(), type_b.kw_only_args.values(), strict=True): + is_compatible &= _is_compatible_type(arg_a, arg_b) + is_compatible &= _is_compatible_type(type_a.returns, type_b.returns) + else: + is_compatible &= type_a == type_b return is_compatible diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index e71435c2fb..0feb4ea4c4 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -379,19 +379,18 @@ def test_collect_tmps_info(): for a, s in (("JDim", "j"), ("KDim", "k")) ], ) + + i = im.sym("i", index_type) + j = im.sym("j", index_type) + k = im.sym("k", index_type) + inp = im.sym("inp", i_field_type) + out = im.sym("out", i_field_type) + testee = FencilWithTemporaries( fencil=ir.FencilDefinition( id="f", function_definitions=[], - params=[ - im.sym("i", index_type), - im.sym("j", index_type), - im.sym("k", index_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - im.sym("_gtmp_0", i_field_type), - im.sym("_gtmp_1", i_field_type), - ], + params=[i, j, k, inp, out, im.sym("_gtmp_0", i_field_type), im.sym("_gtmp_1", i_field_type)], closures=[ ir.StencilClosure( domain=tmp_domain, @@ -446,7 +445,7 @@ def test_collect_tmps_info(): ), ], ), - params=[ir.Sym(id="i"), ir.Sym(id="j"), ir.Sym(id="k"), ir.Sym(id="inp"), ir.Sym(id="out")], + params=[i, j, k, inp, out], tmps=[ ir.Temporary(id="_gtmp_0", dtype=float_type), ir.Temporary(id="_gtmp_1", dtype=float_type), From 873c60e9d01133e43981bb19468c7bd278c3d6fd Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 19 Apr 2024 18:28:10 +0200 Subject: [PATCH 23/52] Fix dace --- .../next/program_processors/runners/dace_iterator/__init__.py | 3 +++ .../program_processors/runners/dace_iterator/itir_to_sdfg.py | 3 --- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 2b56dc0420..4b20b036c9 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -24,6 +24,7 @@ import gt4py.next.iterator.ir as itir from gt4py.next import common from gt4py.next.iterator import transforms as itir_transforms +from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.type_system import type_specifications as ts from .itir_to_sdfg import ItirToSDFG @@ -84,6 +85,8 @@ def preprocess_program( unroll_reduce=unroll_reduce, ) + node = itir_type_inference.infer(node, offset_provider=offset_provider) + if isinstance(node, itir_transforms.global_tmps.FencilWithTemporaries): fencil_definition = node.fencil tmps = node.tmps diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 00248fbe2d..35b9ce1694 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -22,7 +22,6 @@ from gt4py.next.common import Connectivity from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef -from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.type_system import type_info, type_specifications as ts, type_translation as tt from .itir_to_tasklet import ( @@ -269,8 +268,6 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg.debuginfo = dace_debuginfo(node) entry_state = program_sdfg.add_state("program_entry", is_start_block=True) - node = itir_type_inference.infer(node, offset_provider=self.offset_provider) - # Filter neighbor tables from offset providers. neighbor_tables = get_used_connectivities(node, self.offset_provider) From d42fa1245ebe34c5ffdfabbd58d0d0ffac1c1a94 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 19 Apr 2024 18:28:16 +0200 Subject: [PATCH 24/52] Formatting --- src/gt4py/next/iterator/type_system/inference.py | 8 ++++++-- .../transforms_tests/test_global_tmps.py | 10 +++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 4835eb3b8e..a4ebae563c 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -78,9 +78,13 @@ def _is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec): elif isinstance(type_a, ts.FunctionType): for arg_a, arg_b in zip(type_a.pos_only_args, type_b.pos_only_args, strict=True): is_compatible &= _is_compatible_type(arg_a, arg_b) - for arg_a, arg_b in zip(type_a.pos_or_kw_args.values(), type_b.pos_or_kw_args.values(), strict=True): + for arg_a, arg_b in zip( + type_a.pos_or_kw_args.values(), type_b.pos_or_kw_args.values(), strict=True + ): is_compatible &= _is_compatible_type(arg_a, arg_b) - for arg_a, arg_b in zip(type_a.kw_only_args.values(), type_b.kw_only_args.values(), strict=True): + for arg_a, arg_b in zip( + type_a.kw_only_args.values(), type_b.kw_only_args.values(), strict=True + ): is_compatible &= _is_compatible_type(arg_a, arg_b) is_compatible &= _is_compatible_type(type_a.returns, type_b.returns) else: diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 0feb4ea4c4..f1b15f4e18 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -390,7 +390,15 @@ def test_collect_tmps_info(): fencil=ir.FencilDefinition( id="f", function_definitions=[], - params=[i, j, k, inp, out, im.sym("_gtmp_0", i_field_type), im.sym("_gtmp_1", i_field_type)], + params=[ + i, + j, + k, + inp, + out, + im.sym("_gtmp_0", i_field_type), + im.sym("_gtmp_1", i_field_type), + ], closures=[ ir.StencilClosure( domain=tmp_domain, From 8c632de6aad209a6c2a716014bd9c4face6fac2d Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 19 Apr 2024 18:29:02 +0200 Subject: [PATCH 25/52] Formatting --- src/gt4py/next/iterator/type_system/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index a4ebae563c..1f191e52b3 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -72,10 +72,10 @@ def _is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec): if type_a.defined_dims and type_b.defined_dims: is_compatible &= type_a.defined_dims == type_b.defined_dims is_compatible &= type_a.element_type == type_b.element_type - elif isinstance(type_a, ts.TupleType): + elif isinstance(type_a, ts.TupleType) and isinstance(type_b, ts.TupleType): for el_type_a, el_type_b in zip(type_a.types, type_b.types, strict=True): is_compatible &= _is_compatible_type(el_type_a, el_type_b) - elif isinstance(type_a, ts.FunctionType): + elif isinstance(type_a, ts.FunctionType) and isinstance(type_b, ts.FunctionType): for arg_a, arg_b in zip(type_a.pos_only_args, type_b.pos_only_args, strict=True): is_compatible &= _is_compatible_type(arg_a, arg_b) for arg_a, arg_b in zip( From 1d043d168ffc1ca79d9ad035a48c825707f2abad Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 19 Apr 2024 18:31:56 +0200 Subject: [PATCH 26/52] Don't use temporaries in dace --- src/gt4py/next/program_processors/runners/dace.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py index 23e95f7446..266d7b3530 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace.py @@ -86,7 +86,6 @@ class Params: run_dace_cpu_with_temporaries = DaCeBackendFactory( cached=True, auto_optimize=True, use_temporaries=True ) -run_dace_cpu = run_dace_cpu_with_temporaries run_dace_gpu = DaCeBackendFactory(gpu=True, cached=True, auto_optimize=True) run_dace_gpu_with_temporaries = DaCeBackendFactory( From b792a5335a5c5e96e0b174df3ece82175786e87c Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 22 Apr 2024 11:24:50 +0200 Subject: [PATCH 27/52] Try fixing test failures by removing types --- src/gt4py/next/iterator/type_system/inference.py | 9 +++++++++ src/gt4py/next/iterator/type_system/rules.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 1f191e52b3..74bdbb70ec 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -288,6 +288,12 @@ def type_rule(*args, **kwargs): return type_rule +class RemoveTypes(eve.NodeTranslator): + def visit_Node(self, node: ts.TypeSpec): + node = self.generic_visit(node) + if not isinstance(node, (itir.Literal, itir.Sym)): + node.type = None + return node @dataclasses.dataclass class ITIRTypeInference(eve.NodeTranslator): @@ -310,6 +316,9 @@ def apply( inplace: bool = False, allow_undeclared_symbols: bool = False, ) -> T: + if not allow_undeclared_symbols: + node = RemoveTypes().visit(node) + instance = cls( offset_provider=offset_provider, dimensions=( diff --git a/src/gt4py/next/iterator/type_system/rules.py b/src/gt4py/next/iterator/type_system/rules.py index e225acc4b2..acc622d57a 100644 --- a/src/gt4py/next/iterator/type_system/rules.py +++ b/src/gt4py/next/iterator/type_system/rules.py @@ -221,7 +221,7 @@ def applied_reduce(*args: it_ts.ListType): @_register_type_inference_rule def shift(*offset_literals, offset_provider) -> TypeInferenceRule: - def apply_shift(it: it_ts.IteratorType): + def apply_shift(it: it_ts.IteratorType) -> it_ts.IteratorType: assert isinstance(it, it_ts.IteratorType) if it.position_dims == "unknown": # nothing to do here return it From f0680dfc021e09f2fd82856b7e0cae6b0c9099c8 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 22 Apr 2024 12:27:15 +0200 Subject: [PATCH 28/52] Fix pretty print --- .../program_processors/formatters/pretty_print.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/formatters/pretty_print.py b/src/gt4py/next/program_processors/formatters/pretty_print.py index 4f4a15f908..219c97ae92 100644 --- a/src/gt4py/next/program_processors/formatters/pretty_print.py +++ b/src/gt4py/next/program_processors/formatters/pretty_print.py @@ -21,16 +21,9 @@ import gt4py.next.program_processors.processor_interface as ppi -class _RemoveITIRSymTypes(eve.NodeTranslator): - def visit_Sym(self, node: itir.Sym) -> itir.Sym: - return itir.Sym(id=node.id, dtype=None, kind=None) - - @ppi.program_formatter def format_itir_and_check(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: - # remove types from ITIR as they are not supported for the roundtrip - root = _RemoveITIRSymTypes().visit(program) - pretty = pretty_printer.pformat(root) + pretty = pretty_printer.pformat(program) parsed = pretty_parser.pparse(pretty) - assert parsed == root + assert parsed == program return pretty From 3603adb73484b1922652950bfb1b2f357091797a Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 22 Apr 2024 12:32:42 +0200 Subject: [PATCH 29/52] Fix lowering tests --- .../next/iterator/type_system/inference.py | 4 +++- .../formatters/pretty_print.py | 1 - .../ffront_tests/test_foast_to_itir.py | 19 +++++++++---------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 74bdbb70ec..0fae84037c 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -288,13 +288,15 @@ def type_rule(*args, **kwargs): return type_rule + class RemoveTypes(eve.NodeTranslator): - def visit_Node(self, node: ts.TypeSpec): + def visit_Node(self, node: itir.Node): node = self.generic_visit(node) if not isinstance(node, (itir.Literal, itir.Sym)): node.type = None return node + @dataclasses.dataclass class ITIRTypeInference(eve.NodeTranslator): """ diff --git a/src/gt4py/next/program_processors/formatters/pretty_print.py b/src/gt4py/next/program_processors/formatters/pretty_print.py index 219c97ae92..39a5dc953c 100644 --- a/src/gt4py/next/program_processors/formatters/pretty_print.py +++ b/src/gt4py/next/program_processors/formatters/pretty_print.py @@ -14,7 +14,6 @@ from typing import Any -import gt4py.eve as eve import gt4py.next.iterator.ir as itir import gt4py.next.iterator.pretty_parser as pretty_parser import gt4py.next.iterator.pretty_printer as pretty_printer diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index 2bb2c844a9..10123d95aa 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -32,6 +32,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts, type_translation +from gt4py.next.iterator.type_system import type_specifications as it_ts IDim = gtx.Dimension("IDim") @@ -140,15 +141,13 @@ def copy_field(inp: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(copy_field) lowered = FieldOperatorLowering.apply(parsed) - reference = im.let( - itir.Sym(id=ssa.unique_name("tmp", 0), dtype=("float64", False), kind="Iterator"), "inp" - )( + reference = im.let(ssa.unique_name("tmp", 0), "inp")( im.let( - itir.Sym(id=ssa.unique_name("inp", 0), dtype=("float64", False), kind="Iterator"), + ssa.unique_name("inp", 0), ssa.unique_name("tmp", 0), )( im.let( - itir.Sym(id=ssa.unique_name("tmp2", 0), dtype=("float64", False), kind="Iterator"), + ssa.unique_name("tmp2", 0), ssa.unique_name("inp", 0), )(ssa.unique_name("tmp2", 0)) ) @@ -167,13 +166,13 @@ def unary(inp: gtx.Field[[TDim], float64]): lowered = FieldOperatorLowering.apply(parsed) reference = im.let( - itir.Sym(id=ssa.unique_name("tmp", 0), dtype=("float64", False), kind="Iterator"), + ssa.unique_name("tmp", 0), im.promote_to_lifted_stencil("plus")( im.promote_to_const_iterator(im.literal("0", "float64")), "inp" ), )( im.let( - itir.Sym(id=ssa.unique_name("tmp", 1), dtype=("float64", False), kind="Iterator"), + ssa.unique_name("tmp", 1), im.promote_to_lifted_stencil("minus")( im.promote_to_const_iterator(im.literal("0", "float64")), ssa.unique_name("tmp", 0) ), @@ -201,11 +200,11 @@ def unpacking( reference = im.let("__tuple_tmp_0", tuple_expr)( im.let( - itir.Sym(id=ssa.unique_name("tmp1", 0), dtype=("float64", False), kind="Iterator"), + ssa.unique_name("tmp1", 0), tuple_access_0, )( im.let( - itir.Sym(id=ssa.unique_name("tmp2", 0), dtype=("float64", False), kind="Iterator"), + ssa.unique_name("tmp2", 0), tuple_access_1, )(ssa.unique_name("tmp1", 0)) ) @@ -503,7 +502,7 @@ def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], fl ) reference = im.let( - itir.Sym(id=ssa.unique_name("e1_nbh", 0), dtype=("float64", True), kind="Iterator"), + ssa.unique_name("e1_nbh", 0), im.lifted_neighbors("V2E", "e1"), )( im.promote_to_lifted_stencil( From 4517d2f86c390e67142e63e84983e22c52ef38a9 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 22 Apr 2024 12:35:33 +0200 Subject: [PATCH 30/52] Fix missing fixture import in tests --- .../feature_tests/iterator_tests/test_horizontal_indirection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py index 8c80a86b65..ebb70e04ce 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py @@ -32,7 +32,7 @@ from gt4py.next.iterator.runtime import fundef, offset from next_tests.integration_tests.cases import IDim -from next_tests.unit_tests.conftest import run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor I = offset("I") From dbf1001d72320c8da57c11f1752cff6c089310df Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 22 Apr 2024 13:09:27 +0200 Subject: [PATCH 31/52] Fix failing tests --- src/gt4py/next/iterator/pretty_printer.py | 3 +-- .../unit_tests/iterator_tests/test_type_inference.py | 12 ++++++------ .../codegens_tests/gtfn_tests/test_gtfn_module.py | 11 +++++++---- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 80599e983e..851908b370 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -280,8 +280,7 @@ def visit_Temporary(self, node: ir.Temporary, *, prec: int) -> list[str]: if node.domain is not None: args.append(self._hmerge(["domain="], self.visit(node.domain, prec=0))) if node.dtype is not None: - assert isinstance(node.dtype, ts.ScalarType) - args.append(self._hmerge(["dtype="], [str(node.dtype.kind.name.lower())])) + args.append(self._hmerge(["dtype="], [str(node.dtype)])) hargs = self._hmerge(*self._hinterleave(args, ", ")) vargs = self._vmerge(*self._hinterleave(args, ",")) oargs = self._optimum(hargs, vargs) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 3921c698fe..86397eb3a4 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -212,12 +212,12 @@ def test_cartesian_fencil_definition(): stencil=ts.FunctionType( pos_only_args=[ it_ts.IteratorType( - position_dims=[IDim], defined_dims=[IDim], element_type=bool_type + position_dims=[IDim], defined_dims=[IDim], element_type=float64_type ) ], pos_or_kw_args={}, kw_only_args={}, - returns=bool_type, + returns=float64_type, ), output=float_i_field, inputs=[float_i_field], @@ -255,12 +255,12 @@ def test_unstructured_fencil_definition(): stencil=ts.FunctionType( pos_only_args=[ it_ts.IteratorType( - position_dims=[Vertex, KDim], defined_dims=[Edge, KDim], element_type=bool_type + position_dims=[Vertex, KDim], defined_dims=[Edge, KDim], element_type=float64_type ) ], pos_or_kw_args={}, kw_only_args={}, - returns=bool_type, + returns=float64_type, ), output=float_vertex_k_field, inputs=[float_edge_k_field], @@ -301,12 +301,12 @@ def test_function_definition(): stencil=ts.FunctionType( pos_only_args=[ it_ts.IteratorType( - position_dims=[IDim], defined_dims=[IDim], element_type=bool_type + position_dims=[IDim], defined_dims=[IDim], element_type=float64_type ) ], pos_or_kw_args={}, kw_only_args={}, - returns=bool_type, + returns=float64_type, ), output=float_i_field, inputs=[float_i_field], diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index d3651d3084..d239e0c0d8 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -20,17 +20,22 @@ from gt4py.next.otf import languages, stages from gt4py.next.program_processors.codegens.gtfn import gtfn_module from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.type_system import type_info, type_translation @pytest.fixture def fencil_example(): + IDim = gtx.Dimension("I") + params = [gtx.as_field([IDim], np.empty((1,), dtype=np.float32)), np.float32(3.14)] + param_types = [type_translation.from_value(param) for param in params] + domain = itir.FunCall( fun=itir.SymRef(id="cartesian_domain"), args=[ itir.FunCall( fun=itir.SymRef(id="named_range"), args=[ - itir.AxisLiteral(value="X"), + itir.AxisLiteral(value="I"), im.literal("0", itir.INTEGER_INDEX_BUILTIN), im.literal("10", itir.INTEGER_INDEX_BUILTIN), ], @@ -39,7 +44,7 @@ def fencil_example(): ) fencil = itir.FencilDefinition( id="example", - params=[itir.Sym(id="buf"), itir.Sym(id="sc")], + params=[im.sym(name, type_) for name, type_ in zip(("buf", "sc"), param_types)], function_definitions=[ itir.FunctionDefinition( id="stencil", @@ -56,8 +61,6 @@ def fencil_example(): ) ], ) - IDim = gtx.Dimension("I") - params = [gtx.as_field([IDim], np.empty((1,), dtype=np.float32)), np.float32(3.14)] return fencil, params From 6793d58b1d44b1a664376797806fe76273f8e4d8 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 22 Apr 2024 13:09:56 +0200 Subject: [PATCH 32/52] Fix format --- .../codegens_tests/gtfn_tests/test_gtfn_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index d239e0c0d8..93d908bef7 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -20,7 +20,7 @@ from gt4py.next.otf import languages, stages from gt4py.next.program_processors.codegens.gtfn import gtfn_module from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.type_system import type_info, type_translation +from gt4py.next.type_system import type_translation @pytest.fixture From cdf0a7f4567f1a230ce42fee16072c886bbb1d62 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 22 Apr 2024 13:20:19 +0200 Subject: [PATCH 33/52] Fix format --- src/gt4py/next/iterator/pretty_printer.py | 1 - .../unit_tests/iterator_tests/test_type_inference.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 851908b370..3f224d2ef4 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -25,7 +25,6 @@ from gt4py.eve import NodeTranslator from gt4py.next.iterator import ir -from gt4py.next.type_system import type_specifications as ts # replacements for builtin binary operations diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 86397eb3a4..2edaef7e37 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -255,7 +255,9 @@ def test_unstructured_fencil_definition(): stencil=ts.FunctionType( pos_only_args=[ it_ts.IteratorType( - position_dims=[Vertex, KDim], defined_dims=[Edge, KDim], element_type=float64_type + position_dims=[Vertex, KDim], + defined_dims=[Edge, KDim], + element_type=float64_type, ) ], pos_or_kw_args={}, From 17644dc204c0bb52860baf6a8ad7d3dc23d8f161 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 22 Apr 2024 13:40:49 +0200 Subject: [PATCH 34/52] Fix doctests --- src/gt4py/next/ffront/foast_to_itir.py | 2 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 6 +++--- src/gt4py/next/iterator/transforms/global_tmps.py | 6 ++++-- src/gt4py/next/iterator/transforms/symbol_ref_utils.py | 6 +++--- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 53ccbf56a1..433fb37a98 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -63,7 +63,7 @@ class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): >>> lowered.id SymbolName('fieldop') >>> lowered.params # doctest: +ELLIPSIS - [Sym(id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))] + [Sym(id=SymbolName('inp'))] """ uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 4fa93546a6..52c3da16b0 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -27,10 +27,10 @@ def sym(sym_or_name: Union[str, itir.Sym], type_=None) -> itir.Sym: Examples -------- >>> sym("a") - Sym(id=SymbolName('a'), kind=None, dtype=None) + Sym(id=SymbolName('a')) >>> sym(itir.Sym(id="b")) - Sym(id=SymbolName('b'), kind=None, dtype=None) + Sym(id=SymbolName('b')) """ if isinstance(sym_or_name, itir.Sym): assert not type_ @@ -110,7 +110,7 @@ class lambda_: Examples -------- >>> lambda_("a")(deref("a")) # doctest: +ELLIPSIS - Lambda(params=[Sym(id=SymbolName('a'), kind=None, dtype=None)], expr=FunCall(fun=SymRef(id=SymbolRef('deref')), args=[SymRef(id=SymbolRef('a'))])) + Lambda(params=[Sym(id=SymbolName('a'))], expr=FunCall(fun=SymRef(id=SymbolRef('deref')), args=[SymRef(id=SymbolRef('a'))])) """ def __init__(self, *args): diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 1e929c4b19..72c8cbdff1 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -109,7 +109,9 @@ def canonicalize_applied_lift(closure_params: list[str], node: ir.FunCall) -> ir Transform lift such that the arguments to the applied lift are only symbols. - >>> expr = im.lift(im.lambda_("a")(im.deref("a")))(im.lift("deref")("inp")) + >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) + >>> it_type = it_ts.IteratorType(position_dims=[], defined_dims=[], element_type=bool_type) + >>> expr = im.lift(im.lambda_("a")(im.deref("a")))(im.lift("deref")(im.ref("inp", it_type))) >>> print(expr) (↑(λ(a) → ·a))((↑deref)(inp)) >>> print(canonicalize_applied_lift(["inp"], expr)) @@ -126,7 +128,7 @@ def canonicalize_applied_lift(closure_params: list[str], node: ir.FunCall) -> ir im.call(stencil)(*it_args) ) )(*closure_param_refs) - # add types again + # ensure all types are inferred return itir_type_inference.infer( new_node, inplace=True, allow_undeclared_symbols=True, offset_provider={} ) diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 28b29940d3..1cde0e4739 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -63,17 +63,17 @@ def apply( >>> import gt4py.next.iterator.ir_utils.ir_makers as im >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> CountSymbolRefs.apply(expr) - {'x': 2, 'y': 2, 'z': 1} + defaultdict(, {'x': 2, 'y': 2, 'z': 1}) If only some symbols are of interests the search can be restricted: >>> CountSymbolRefs.apply(expr, symbol_names=["x", "z"]) - {'x': 2, 'z': 1} + defaultdict(, {'x': 2, 'z': 1}) In some cases, e.g. when the type of the reference is required, the references instead of strings can be retrieved. >>> CountSymbolRefs.apply(expr, as_ref=True) - {SymRef(id=SymbolRef('x')): 2, SymRef(id=SymbolRef('y')): 2, SymRef(id=SymbolRef('z')): 1} + defaultdict(, {SymRef(id=SymbolRef('x')): 2, SymRef(id=SymbolRef('y')): 2, SymRef(id=SymbolRef('z')): 1}) """ if ignore_builtins: inactive_refs = {str(n.id) for n in itir.FencilDefinition._NODE_SYMBOLS_} From 3b8303f87cacc1861b034b1856b02925bd6c1d0b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 22 Apr 2024 17:27:21 +0200 Subject: [PATCH 35/52] Fix cpp tests --- src/gt4py/next/iterator/tracing.py | 21 +++++--------- .../cpp_backend_tests/anton_lap.py | 29 +++++++------------ .../cpp_backend_tests/copy_stencil.py | 6 +++- .../cpp_backend_tests/fvm_nabla.py | 28 +++++++++++++++++- .../cpp_backend_tests/tridiagonal_solve.py | 6 +++- 5 files changed, 56 insertions(+), 34 deletions(-) diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 9436acccd4..816cb57d25 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -35,7 +35,7 @@ SymRef, ) from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.type_system import type_translation +from gt4py.next.type_system import type_specifications as ts, type_translation TRACING = "tracing" @@ -252,7 +252,7 @@ def _contains_tuple_dtype_field(arg): return isinstance(arg, common.Field) and any(dim is None for dim in arg.domain.dims) -def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: +def _make_fencil_params(fun, args) -> list[Sym]: params: list[Sym] = [] param_infos = list(inspect.signature(fun).parameters.values()) @@ -278,30 +278,25 @@ def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: ) arg_type = None - if use_arg_types: + if isinstance(arg, ts.TypeSpec): + arg_type = arg + else: arg_type = type_translation.from_value(arg) params.append(Sym(id=param_name, type=arg_type)) return params -def trace_fencil_definition( - fun: typing.Callable, args: typing.Iterable, *, use_arg_types=True -) -> FencilDefinition: +def trace_fencil_definition(fun: typing.Callable, args: typing.Iterable) -> FencilDefinition: """ Transform fencil given as a callable into `itir.FencilDefinition` using tracing. Arguments: fun: The fencil / callable to trace. - args: A list of arguments, e.g. fields, scalars or composites thereof. If `use_arg_types` - is `False` may also be dummy values. - - Keyword arguments: - use_arg_types: Deduce type of the arguments and add them to the fencil parameter nodes - (i.e. `itir.Sym`s). + args: A list of arguments, e.g. fields, scalars, composites thereof, or directly a type. """ with TracerContext() as _: - params = _make_fencil_params(fun, args, use_arg_types=use_arg_types) + params = _make_fencil_params(fun, args) trace_function_call(fun, args=(_s(param.id) for param in params)) return FencilDefinition( diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py index 5af4605988..1e5938506c 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py @@ -19,22 +19,7 @@ from gt4py.next.iterator.runtime import closure, fundef, offset from gt4py.next.iterator.tracing import trace_fencil_definition from gt4py.next.program_processors.runners.gtfn import run_gtfn - - -@fundef -def ldif(d): - return lambda inp: deref(shift(d, -1)(inp)) - deref(inp) - - -@fundef -def rdif(d): - return lambda inp: ldif(d)(shift(d, 1)(inp)) - - -@fundef -def dif2(d): - return lambda inp: ldif(d)(lift(rdif(d))(inp)) - +from gt4py.next.type_system import type_specifications as ts i = offset("i") j = offset("j") @@ -42,7 +27,12 @@ def dif2(d): @fundef def lap(inp): - return dif2(i)(inp) + dif2(j)(inp) + return -4.0 * deref(inp) + ( + deref(shift(i, 1)(inp)) + + deref(shift(i, -1)(inp)) + + deref(shift(j, 1)(inp)) + + deref(shift(j, -1)(inp)) + ) IDim = gtx.Dimension("IDim") @@ -68,7 +58,10 @@ def lap_fencil(i_size, j_size, k_size, i_off, j_off, k_off, out, inp): raise RuntimeError(f"Usage: {sys.argv[0]} ") output_file = sys.argv[1] - prog = trace_fencil_definition(lap_fencil, [None] * 8, use_arg_types=False) + ijk_field_type = ts.FieldType( + dims=[IDim, JDim, KDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) + prog = trace_fencil_definition(lap_fencil, [ijk_field_type] * 8) generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( prog, offset_provider={"i": IDim, "j": JDim}, column_axis=None ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py index 3e8b88ac66..a35e8d3330 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py @@ -19,6 +19,7 @@ from gt4py.next.iterator.runtime import closure, fundef from gt4py.next.iterator.tracing import trace_fencil_definition from gt4py.next.program_processors.runners.gtfn import run_gtfn +from gt4py.next.type_system import type_specifications as ts IDim = gtx.Dimension("IDim") @@ -47,7 +48,10 @@ def copy_fencil(isize, jsize, ksize, inp, out): raise RuntimeError(f"Usage: {sys.argv[0]} ") output_file = sys.argv[1] - prog = trace_fencil_definition(copy_fencil, [None] * 5, use_arg_types=False) + ijk_field_type = ts.FieldType( + dims=[IDim, JDim, KDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) + prog = trace_fencil_definition(copy_fencil, [ijk_field_type] * 5) generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( prog, offset_provider={}, column_axis=None ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py index fe8b54f95c..eb5f214525 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py @@ -20,6 +20,7 @@ from gt4py.next.iterator.runtime import closure, fundef, offset from gt4py.next.iterator.tracing import trace_fencil_definition from gt4py.next.program_processors.runners.gtfn import run_gtfn, run_gtfn_imperative +from gt4py.next.type_system import type_specifications as ts E2V = offset("E2V") @@ -57,6 +58,8 @@ def zavgS_fencil(edge_domain, out, pp, S_M): Vertex = gtx.Dimension("Vertex") +Edge = gtx.Dimension("Edge") +V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) K = gtx.Dimension("K", kind=gtx.DimensionKind.VERTICAL) @@ -92,8 +95,31 @@ def mapped_index(_, __) -> int: else: backend = run_gtfn + int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) + vertex_k_field_type = ts.FieldType( + dims=[Vertex, K], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) + vertex_field_type = ts.FieldType(dims=[Vertex, K], + dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + edge_k_field = ts.FieldType(dims=[Edge, K], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + edge_field = ts.FieldType(dims=[Edge], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + vertex_v2e_field = ts.FieldType( + dims=[Vertex, V2EDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) + # prog = trace(zavgS_fencil, [None] * 4) # TODO allow generating of 2 fencils - prog = trace_fencil_definition(nabla_fencil, [None] * 7, use_arg_types=False) + prog = trace_fencil_definition( + nabla_fencil, + [ + int_type, + int_type, + edge_k_field, + vertex_k_field_type, + ts.TupleType(types=[edge_field, edge_field]), + vertex_v2e_field, + vertex_field_type, + ], + ) offset_provider = { "V2E": DummyConnectivity(max_neighbors=6, has_skip_values=True), "E2V": DummyConnectivity(max_neighbors=2, has_skip_values=False), diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py index 9755774fd0..2f23901aeb 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py @@ -20,6 +20,7 @@ from gt4py.next.iterator.tracing import trace_fencil_definition from gt4py.next.iterator.transforms import LiftMode from gt4py.next.program_processors.runners.gtfn import run_gtfn +from gt4py.next.type_system import type_specifications as ts IDim = gtx.Dimension("IDim") @@ -65,7 +66,10 @@ def tridiagonal_solve_fencil(isize, jsize, ksize, a, b, c, d, x): raise RuntimeError(f"Usage: {sys.argv[0]} ") output_file = sys.argv[1] - prog = trace_fencil_definition(tridiagonal_solve_fencil, [None] * 8, use_arg_types=False) + ijk_field_type = ts.FieldType( + dims=[IDim, JDim, KDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) + prog = trace_fencil_definition(tridiagonal_solve_fencil, [ijk_field_type] * 8) offset_provider = {"I": gtx.Dimension("IDim"), "J": gtx.Dimension("JDim")} generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( prog, From 9535e72d4dae6302b1646d4239960c9a68652fe1 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 24 Apr 2024 13:14:43 +0200 Subject: [PATCH 36/52] Add documentation --- .../next/iterator/type_system/inference.py | 92 ++++++++++++++++--- 1 file changed, 80 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 0fae84037c..391f560a02 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -140,9 +140,57 @@ def mark_ready(i, type_): @dataclasses.dataclass class ObservableTypeInferenceRule: """ - This class wraps a raw type inference rule to handle typing of functions. + This class wraps a raw type inference rule to handle typing of nodes representing functions. + + The type inference algorithm represents functions as type inference rules, i.e. regular + callables that given a set of arguments compute / deduce the return type. The return type of + functions, let it be a builtin like ``itir.plus`` or a user defined lambda function, is only + defined when all its arguments are typed. + + Let's start with a small example to examplify this. The power function has a rather simple + type inference rule, where the output type is simply the type of the base. + + >>> def power(base: ts.ScalarType, exponent: ts.ScalarType) -> ts.ScalarType: + ... return base + >>> float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + >>> int_type = ts.ScalarType(kind=ts.ScalarKind.INT64) + >>> power(float_type, int_type) + ScalarType(kind=, shape=None) + + Now, consider a simple lambda function that, using the power builtin squares its argument. A + type inference rule for this function is simple to formulate, but merely gives us the return + type of the function. + + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> square_func = im.lambda_("base")(im.call("power")("base", 2)) + >>> square_func_type_rule = lambda base: power(base, int_type) + >>> square_func_type_rule(float_type) + ScalarType(kind=, shape=None) + + Note that without a corresponding call the function itself can not be fully typed and as such + the type inference algorithm has to defer typing until then. This task is handled transparently + (in the sense that an ``ObservableTypeInferenceRule`` is a type inference rule again) by this + class. Given a type inference rule and a node we obtain a new type inference rule that when + evaluated stores the type of the function in the node. + + >>> o_type_rule = ObservableTypeInferenceRule( + ... type_rule=square_func_type_rule, + ... offset_provider={}, + ... node=square_func, + ... store_inferred_type_in_node=True, + ... ) + >>> o_type_rule(float_type) + ScalarType(kind=, shape=None) + >>> square_func.type == ts.FunctionType( + ... pos_only_args=[float_type], pos_or_kw_args={}, kw_only_args={}, returns=float_type + ... ) + True - TODO: As functions in ITIR are represented by type inference rules, i.e. regular callables, + Note that this is a simple example where the type of the arguments and the return value is + available when the function is called. In order to support higher-order functions, where + arguments or return value are functions itself (i.e. passed as type rules) this class provides + additional functionality for multiple typing rules to notify each other about a type being + ready. """ #: type rule that given a set of types or type rules returns the return type or a type rule @@ -188,7 +236,9 @@ def on_type_ready(self, cb: Callable[[ts.TypeSpec], None]) -> None: else: self.callbacks.append(cb) - def __call__(self, *args) -> ts.TypeSpec | rules.TypeInferenceRule: + def __call__( + self, *args: Union[ts.TypeSpec, "ObservableTypeInferenceRule"] + ) -> Union[ts.TypeSpec, "ObservableTypeInferenceRule"]: if "offset_provider" in inspect.signature(self.type_rule).parameters: return_type = self.type_rule(*args, offset_provider=self.offset_provider) else: @@ -240,9 +290,6 @@ def _get_dimensions(obj: Any): return {dim.value: dim for dim in _get_dimensions(types)} -T = TypeVar("T", bound=itir.Node) - - def _convert_closure_input_to_iterator( domain: it_ts.DomainType, input_: ts.TypeSpec ) -> it_ts.IteratorType: @@ -283,7 +330,7 @@ def extract_dtype_and_dims(el_type: ts.TypeSpec): def _type_inference_rule_from_function_type(fun_type: ts.FunctionType): def type_rule(*args, **kwargs): - assert type_info.accepts_args(fun_type, with_args=args, with_kwargs=kwargs) + assert type_info.accepts_args(fun_type, with_args=list(args), with_kwargs=kwargs) return fun_type.returns return type_rule @@ -297,14 +344,19 @@ def visit_Node(self, node: itir.Node): return node +T = TypeVar("T", bound=itir.Node) + + @dataclasses.dataclass class ITIRTypeInference(eve.NodeTranslator): """ - TODO + ITIR type inference algorithm. + + See :py:method:ITIRTypeInference.apply for more details. """ offset_provider: common.OffsetProvider - #: Mapping from a dimension name to the actual dimension instance + #: Mapping from a dimension name to the actual dimension instance. dimensions: dict[str, common.Dimension] #: Allow sym refs to symbols that have not been declared. Mostly used in testing. allow_undeclared_symbols: bool @@ -318,6 +370,18 @@ def apply( inplace: bool = False, allow_undeclared_symbols: bool = False, ) -> T: + """ + Infer the type of ``node`` and its sub-nodes. + + Arguments: + node: The :py:class:`itir.Node` to infer the types of. + + Keyword Arguments: + offset_provider: Offset provider dictionary. + inplace: Write types directly to the given ``node`` instead of returning a copy. + allow_undeclared_symbols: Allow references to symbols that don't have a corresponding + declaration. This is useful for testing or inference on partially inferred sub-nodes. + """ if not allow_undeclared_symbols: node = RemoveTypes().visit(node) @@ -369,7 +433,7 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: ) else: raise AssertionError( - f"Expected a 'TypeSpec' or 'ObservableTypeInferenceRule', but got " + f"Expected a 'TypeSpec', `callable` or 'ObservableTypeInferenceRule', but got " f"`{type(result).__name__}`" ) return result @@ -387,7 +451,9 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx) -> it_ts.F closures = self.visit(node.closures, ctx=ctx | params | function_definitions) return it_ts.FencilType(params=list(params.values()), closures=closures) - def visit_FencilWithTemporaries(self, node: global_tmps.FencilWithTemporaries, *, ctx): + def visit_FencilWithTemporaries( + self, node: global_tmps.FencilWithTemporaries, *, ctx + ) -> it_ts.FencilType: # TODO(tehrengruber): This implementation is not very appealing. Since we are about to # refactor the PR anyway this is fine for now. params: dict[str, ts.DataType] = {} @@ -403,9 +469,10 @@ def visit_FencilWithTemporaries(self, node: global_tmps.FencilWithTemporaries, * if fencil_param.id in tmps: fencil_param.type = tmps[fencil_param.id] self.visit(node.fencil, ctx=ctx) + assert isinstance(node.fencil.type, it_ts.FencilType) return node.fencil.type - def visit_Temporary(self, node: itir.Temporary, *, ctx): + def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.TupleType: domain = self.visit(node.domain, ctx=ctx) assert isinstance(domain, it_ts.DomainType) assert node.dtype @@ -494,6 +561,7 @@ def fun(*args): def visit_FunCall( self, node: itir.FunCall, *, ctx: dict[str, ts.TypeSpec] ) -> ts.TypeSpec | rules.TypeInferenceRule: + # grammar builtins if is_call_to(node, "cast_"): value, type_constructor = node.args assert ( From d9a772c42575ee4da93ccfa2db82a9e3e973b7c7 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 25 Apr 2024 08:39:47 +0200 Subject: [PATCH 37/52] Remove debug code --- src/gt4py/next/iterator/transforms/global_tmps.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 72c8cbdff1..e89373077e 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -180,7 +180,6 @@ def closure_shifts(self): return trace_shifts.TraceShifts.apply(self.closure, inputs_only=False) def __call__(self, expr: ir.Expr) -> bool: - return True shifts = self.closure_shifts[id(expr)] if len(shifts) > 1: return True From 5db67fc43c7818e60e88e4fb9c5b88145035b6de Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 25 Apr 2024 09:58:13 +0200 Subject: [PATCH 38/52] Fix temporary codegen in roundtrip --- src/gt4py/next/config.py | 2 +- src/gt4py/next/program_processors/runners/roundtrip.py | 8 +++++--- .../iterator_tests/transforms_tests/test_global_tmps.py | 3 +++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index 682d5254e5..b5264ac5f1 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -61,7 +61,7 @@ def env_flag_to_bool(name: str, default: bool) -> bool: #: Master debug flag #: Changes defaults for all the other options to be as helpful for debugging as possible. #: Does not override values set in environment variables. -DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=False) +DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=True) #: Verbose flag for DSL compilation errors diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index d90e7e5c8f..0d798b6a6a 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -31,12 +31,14 @@ from gt4py.next.iterator.transforms import global_tmps as gtmps_transform from gt4py.next.otf import stages, workflow from gt4py.next.program_processors import modular_executor, processor_interface as ppi +from gt4py.next.type_system import type_specifications as ts -def _create_tmp(axes: str, origin: str, shape: str, dtype: Any) -> str: - if isinstance(dtype, tuple): - return f"({','.join(_create_tmp(axes, origin, shape, dt) for dt in dtype)},)" +def _create_tmp(axes: str, origin: str, shape: str, dtype: ts.ScalarType | ts.TupleType) -> str: + if isinstance(dtype, ts.TupleType): + return f"({','.join(_create_tmp(axes, origin, shape, dt) for dt in dtype.types)},)" else: + assert isinstance(dtype, ts.ScalarType) return ( f"gtx.as_field([{axes}], np.empty({shape}, dtype=np.dtype('{dtype}')), origin={origin})" ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index f1b15f4e18..51c196ad33 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -11,6 +11,9 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +# TODO(tehrengruber): add integration tests for temporaries starting from manually written +# itir. Currently we only test temporaries from frontend code which makes testing changes +# to anything related to temporaries tedious. import copy import gt4py.next as gtx From 2d28181546d6a7b45118b8c3d92b9daf049054f7 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 25 Apr 2024 10:00:03 +0200 Subject: [PATCH 39/52] Formatting --- src/gt4py/next/program_processors/runners/roundtrip.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 0d798b6a6a..a592130829 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -34,7 +34,7 @@ from gt4py.next.type_system import type_specifications as ts -def _create_tmp(axes: str, origin: str, shape: str, dtype: ts.ScalarType | ts.TupleType) -> str: +def _create_tmp(axes: str, origin: str, shape: str, dtype: ts.TypeSpec) -> str: if isinstance(dtype, ts.TupleType): return f"({','.join(_create_tmp(axes, origin, shape, dt) for dt in dtype.types)},)" else: @@ -102,6 +102,7 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: axes = ", ".join(label for label, _, _ in domain_ranges) origin = "{" + ", ".join(f"{label}: -{start}" for label, start, _ in domain_ranges) + "}" shape = "(" + ", ".join(f"{stop}-{start}" for _, start, stop in domain_ranges) + ")" + assert node.dtype return f"{node.id} = {_create_tmp(axes, origin, shape, node.dtype)}" From 09006e94fa567121e61896295c6d007157d38eaa Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 25 Apr 2024 10:16:56 +0200 Subject: [PATCH 40/52] Cleanup --- src/gt4py/next/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index b5264ac5f1..682d5254e5 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -61,7 +61,7 @@ def env_flag_to_bool(name: str, default: bool) -> bool: #: Master debug flag #: Changes defaults for all the other options to be as helpful for debugging as possible. #: Does not override values set in environment variables. -DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=True) +DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=False) #: Verbose flag for DSL compilation errors From 7ebb25731748530debd4352002d2d45281984d9c Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 25 Apr 2024 17:59:39 +0200 Subject: [PATCH 41/52] Add support for new itir.Program format --- .../next/iterator/type_system/inference.py | 65 +++++++------------ src/gt4py/next/iterator/type_system/rules.py | 52 +++++++++++++++ .../type_system/type_specifications.py | 8 ++- .../codegens/gtfn/itir_to_gtfn_ir.py | 2 + .../iterator_tests/test_type_inference.py | 54 ++++++++++++++- 5 files changed, 136 insertions(+), 45 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 391f560a02..be9afe087f 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -290,44 +290,6 @@ def _get_dimensions(obj: Any): return {dim.value: dim for dim in _get_dimensions(types)} -def _convert_closure_input_to_iterator( - domain: it_ts.DomainType, input_: ts.TypeSpec -) -> it_ts.IteratorType: - input_dims: list[common.Dimension] | None = None - - def extract_dtype_and_dims(el_type: ts.TypeSpec): - nonlocal input_dims - assert isinstance(el_type, (ts.FieldType, ts.ScalarType)) - el_type = type_info.promote(el_type, always_field=True) - if not input_dims: - input_dims = el_type.dims # type: ignore[union-attr] # ensured by always_field - else: - # tuple inputs must all have the same defined dimensions as we - # create an iterator of tuples from them - assert input_dims == el_type.dims # type: ignore[union-attr] # ensured by always_field - return el_type.dtype # type: ignore[union-attr] # ensured by always_field - - element_type = type_info.apply_to_primitive_constituents(extract_dtype_and_dims, input_) - - assert input_dims is not None - - # handle neighbor / sparse input fields - defined_dims = [] - is_nb_field = False - for dim in input_dims: - if dim.kind == common.DimensionKind.LOCAL: - assert not is_nb_field - is_nb_field = True - else: - defined_dims.append(dim) - if is_nb_field: - element_type = it_ts.ListType(element_type=element_type) - - return it_ts.IteratorType( - position_dims=domain.dims, defined_dims=defined_dims, element_type=element_type - ) - - def _type_inference_rule_from_function_type(fun_type: ts.FunctionType): def type_rule(*args, **kwargs): assert type_info.accepts_args(fun_type, with_args=list(args), with_kwargs=kwargs) @@ -422,7 +384,7 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: if node.type: assert _is_compatible_type(node.type, result) node.type = result - elif isinstance(result, ObservableTypeInferenceRule): + elif isinstance(result, ObservableTypeInferenceRule) or result is None: pass elif callable(result): return ObservableTypeInferenceRule( @@ -438,6 +400,7 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: ) return result + # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx) -> it_ts.FencilType: params: dict[str, ts.DataType] = {} for param in node.params: @@ -449,8 +412,9 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx) -> it_ts.F function_definitions[fun_def.id] = self.visit(fun_def, ctx=ctx | function_definitions) closures = self.visit(node.closures, ctx=ctx | params | function_definitions) - return it_ts.FencilType(params=list(params.values()), closures=closures) + return it_ts.FencilType(params=params, closures=closures) + # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere def visit_FencilWithTemporaries( self, node: global_tmps.FencilWithTemporaries, *, ctx ) -> it_ts.FencilType: @@ -472,6 +436,17 @@ def visit_FencilWithTemporaries( assert isinstance(node.fencil.type, it_ts.FencilType) return node.fencil.type + def visit_Program(self, node: itir.Program, *, ctx): + params: dict[str, ts.DataType] = {} + for param in node.params: + assert isinstance(param.type, ts.DataType) + params[param.id] = param.type + decls: dict[str, ts.FieldType] = {} + for decl_node in node.declarations: + decls[decl_node.id] = self.visit(decl_node, ctx=ctx | params) + self.visit(node.body, ctx=ctx | params | decls) + return it_ts.ProgramType(params=params) + def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.TupleType: domain = self.visit(node.domain, ctx=ctx) assert isinstance(domain, it_ts.DomainType) @@ -480,6 +455,12 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), node.dtype ) + def visit_SetAt(self, node: itir.SetAt, *, ctx): + self.visit(node.target, ctx=ctx) + self.visit(node.expr, ctx=ctx) + assert node.target.type == node.expr.type + + # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.StencilClosureType: domain: it_ts.DomainType = self.visit(node.domain, ctx=ctx) inputs: list[ts.FieldType] = self.visit(node.inputs, ctx=ctx) @@ -490,7 +471,9 @@ def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.Stenc assert isinstance(output_el, ts.FieldType) stencil_type_rule = self.visit(node.stencil, ctx=ctx) - stencil_args = [_convert_closure_input_to_iterator(domain, input_) for input_ in inputs] + stencil_args = [ + rules._convert_as_fieldop_input_to_iterator(domain, input_) for input_ in inputs + ] stencil_returns = stencil_type_rule(*stencil_args) return it_ts.StencilClosureType( diff --git a/src/gt4py/next/iterator/type_system/rules.py b/src/gt4py/next/iterator/type_system/rules.py index acc622d57a..2fcf6799f0 100644 --- a/src/gt4py/next/iterator/type_system/rules.py +++ b/src/gt4py/next/iterator/type_system/rules.py @@ -183,6 +183,58 @@ def apply_lift(*its: it_ts.IteratorType) -> it_ts.IteratorType: return apply_lift +def _convert_as_fieldop_input_to_iterator( + domain: it_ts.DomainType, input_: ts.TypeSpec +) -> it_ts.IteratorType: + input_dims: list[common.Dimension] | None = None + + def extract_dtype_and_dims(el_type: ts.TypeSpec): + nonlocal input_dims + assert isinstance(el_type, (ts.FieldType, ts.ScalarType)) + el_type = type_info.promote(el_type, always_field=True) + if not input_dims: + input_dims = el_type.dims # type: ignore[union-attr] # ensured by always_field + else: + # tuple inputs must all have the same defined dimensions as we + # create an iterator of tuples from them + assert input_dims == el_type.dims # type: ignore[union-attr] # ensured by always_field + return el_type.dtype # type: ignore[union-attr] # ensured by always_field + + element_type = type_info.apply_to_primitive_constituents(extract_dtype_and_dims, input_) + + assert input_dims is not None + + # handle neighbor / sparse input fields + defined_dims = [] + is_nb_field = False + for dim in input_dims: + if dim.kind == common.DimensionKind.LOCAL: + assert not is_nb_field + is_nb_field = True + else: + defined_dims.append(dim) + if is_nb_field: + element_type = it_ts.ListType(element_type=element_type) + + return it_ts.IteratorType( + position_dims=domain.dims, defined_dims=defined_dims, element_type=element_type + ) + + +@_register_type_inference_rule +def as_fieldop(stencil: TypeInferenceRule, domain: it_ts.DomainType) -> TypeInferenceRule: + def applied_as_fieldop(*fields) -> ts.FieldType: + stencil_return = stencil( + *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields) + ) + assert isinstance(stencil_return, ts.DataType) + return type_info.apply_to_primitive_constituents( + lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type), stencil_return + ) + + return applied_as_fieldop + + @_register_type_inference_rule def scan( scan_pass: TypeInferenceRule, direction: ts.ScalarType, init: ts.ScalarType diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index 982683a5eb..b44dd0007b 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -62,7 +62,13 @@ def __post_init__(self): assert isinstance(el_type, ts.FieldType), "All constituent types must be field types." +# TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere @dataclasses.dataclass(frozen=True) class FencilType(ts.TypeSpec): - params: list[ts.DataType] + params: dict[str, ts.DataType] closures: list[StencilClosureType] + + +@dataclasses.dataclass(frozen=True) +class ProgramType(ts.TypeSpec): + params: dict[str, ts.DataType] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index dfcb70cb4f..f2ee833c2c 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -20,6 +20,7 @@ from gt4py.eve.concepts import SymbolName from gt4py.next import common from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import ( Backend, BinaryExpr, @@ -249,6 +250,7 @@ def apply( if not isinstance(node, itir.Program): raise TypeError(f"Expected a 'Program', got '{type(node).__name__}'.") + node = itir_type_inference.infer(node, offset_provider=offset_provider) grid_type = _get_gridtype(node.body) return cls( offset_provider=offset_provider, column_axis=column_axis, grid_type=grid_type diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 2edaef7e37..3330f6e6f5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -94,23 +94,31 @@ def expression_test_cases(): ), it_ts.DomainType(dims=[Vertex]), ), + # make_tuple ( im.make_tuple(im.ref("a", int_type), im.ref("b", bool_type)), ts.TupleType(types=[int_type, bool_type]), ), + # tuple_get (im.tuple_get(0, im.make_tuple(im.ref("a", int_type), im.ref("b", bool_type))), int_type), (im.tuple_get(1, im.make_tuple(im.ref("a", int_type), im.ref("b", bool_type))), bool_type), + # neighbors ( im.neighbors("E2V", im.ref("a", it_on_e_of_e_type)), it_ts.ListType(element_type=it_on_e_of_e_type.element_type), ), + # cast (im.call("cast_")(1, "int32"), int_type), # lift + # ... # scan + # ... + # map ( im.map_(im.ref("plus"))(im.ref("a", int_list_type), im.ref("b", int_list_type)), int_list_type, ), + # reduce (im.call(im.call("reduce")("plus", 0))(im.ref("l", int_list_type)), int_type), ( im.call( @@ -126,8 +134,44 @@ def expression_test_cases(): )(im.ref("la", int_list_type), im.ref("lb", float64_list_type)), ts.TupleType(types=[int_type, float64_type]), ), + # shift (im.shift("V2E", 1)(im.ref("it", it_on_v_of_e_type)), it_on_e_of_e_type), (im.shift("Ioff", 1)(im.ref("it", it_ijk_type)), it_ijk_type), + # as_fieldop + ( + im.call( + im.call("as_fieldop")( + "deref", + im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ), + ) + )(im.ref("inp", float_i_field)), + float_i_field, + ), + ( + im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), + im.call("unstructured_domain")( + im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), + im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), + ), + ) + )(im.ref("inp", float_edge_k_field)), + float_vertex_k_field, + ), + ( + im.call( + im.call("as_fieldop")( + im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))), + im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ), + ) + )(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)), + ts.TupleType(types=[float_i_field, float_i_field]), + ), ) @@ -222,7 +266,9 @@ def test_cartesian_fencil_definition(): output=float_i_field, inputs=[float_i_field], ) - fencil_type = it_ts.FencilType(params=[float_i_field, float_i_field], closures=[closure_type]) + fencil_type = it_ts.FencilType( + params={"inp": float_i_field, "out": float_i_field}, closures=[closure_type] + ) assert result.type == fencil_type assert result.closures[0].type == closure_type @@ -268,7 +314,7 @@ def test_unstructured_fencil_definition(): inputs=[float_edge_k_field], ) fencil_type = it_ts.FencilType( - params=[float_edge_k_field, float_vertex_k_field], closures=[closure_type] + params={"inp": float_edge_k_field, "out": float_vertex_k_field}, closures=[closure_type] ) assert result.type == fencil_type assert result.closures[0].type == closure_type @@ -313,7 +359,9 @@ def test_function_definition(): output=float_i_field, inputs=[float_i_field], ) - fencil_type = it_ts.FencilType(params=[float_i_field, float_i_field], closures=[closure_type]) + fencil_type = it_ts.FencilType( + params={"inp": float_i_field, "out": float_i_field}, closures=[closure_type] + ) assert result.type == fencil_type assert result.closures[0].type == closure_type From b23454dde77a5f7167b87fb86a749a472dd2ede3 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 2 May 2024 14:45:16 +0200 Subject: [PATCH 42/52] Small fixes --- .../next/iterator/type_system/inference.py | 4 +- src/gt4py/next/iterator/type_system/rules.py | 42 +++++++++++-------- .../type_system/type_specifications.py | 3 +- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index be9afe087f..e43ba35b47 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -498,7 +498,9 @@ def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionTyp # the frontend. def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs) -> it_ts.OffsetLiteralType: if _is_representable_as_int(node.value): - return it_ts.OffsetLiteralType(value=int(node.value)) + return it_ts.OffsetLiteralType( + value=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())) + ) else: assert isinstance(node.value, str) and node.value in self.dimensions return it_ts.OffsetLiteralType(value=self.dimensions[node.value]) diff --git a/src/gt4py/next/iterator/type_system/rules.py b/src/gt4py/next/iterator/type_system/rules.py index 2fcf6799f0..bee3c757f7 100644 --- a/src/gt4py/next/iterator/type_system/rules.py +++ b/src/gt4py/next/iterator/type_system/rules.py @@ -97,7 +97,15 @@ def deref(it: it_ts.IteratorType) -> ts.DataType: @_register_type_inference_rule def can_deref(it: it_ts.IteratorType) -> ts.ScalarType: assert isinstance(it, it_ts.IteratorType) - assert _is_derefable_iterator_type(it) + # note: We don't check if the iterator is derefable here as the iterator only needs to + # to have a valid position. Consider a nested reduction, e.g. + # `reduce(plus, 0)(neighbors(V2Eₒ, (↑(λ(a) → reduce(plus, 0)(neighbors(E2Vₒ, a))))(it))` + # When written using a `can_deref` we only care if the edge neighbor of the vertex of `it` + # is valid, i.e. we want `can_deref(shift(V2Eₒ, i)(it))` to return true. But since `it` is an + # iterator backed by a vertex field, the iterator is not derefable in the sense that + # its position is a valid position of the backing field. + # TODO(tehrengruber): Consider renaming can_deref to something that better reflects its + # meaning. return ts.ScalarType(kind=ts.ScalarKind.BOOL) @@ -186,23 +194,23 @@ def apply_lift(*its: it_ts.IteratorType) -> it_ts.IteratorType: def _convert_as_fieldop_input_to_iterator( domain: it_ts.DomainType, input_: ts.TypeSpec ) -> it_ts.IteratorType: - input_dims: list[common.Dimension] | None = None - - def extract_dtype_and_dims(el_type: ts.TypeSpec): - nonlocal input_dims - assert isinstance(el_type, (ts.FieldType, ts.ScalarType)) - el_type = type_info.promote(el_type, always_field=True) - if not input_dims: - input_dims = el_type.dims # type: ignore[union-attr] # ensured by always_field - else: - # tuple inputs must all have the same defined dimensions as we - # create an iterator of tuples from them - assert input_dims == el_type.dims # type: ignore[union-attr] # ensured by always_field - return el_type.dtype # type: ignore[union-attr] # ensured by always_field - - element_type = type_info.apply_to_primitive_constituents(extract_dtype_and_dims, input_) + # get the dimensions of all non-zero-dimensional field inputs and check they agree + all_input_dims = ( + type_info.primitive_constituents(input_) + .if_isinstance(ts.FieldType) + .getattr("dims") + .filter(lambda dims: len(dims) > 0) + .to_list() + ) + input_dims: list[common.Dimension] + if all_input_dims: + assert all(cur_input_dims == all_input_dims[0] for cur_input_dims in all_input_dims) + input_dims = all_input_dims[0] + else: + input_dims = [] - assert input_dims is not None + element_type: ts.DataType + element_type = type_info.apply_to_primitive_constituents(type_info.extract_dtype, input_) # handle neighbor / sparse input fields defined_dims = [] diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index b44dd0007b..554adb0a71 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -15,7 +15,6 @@ import dataclasses import typing -from gt4py._core.definitions import IntegralScalar from gt4py.next import common from gt4py.next.type_system import type_specifications as ts @@ -32,7 +31,7 @@ class DomainType(ts.DataType): @dataclasses.dataclass(frozen=True) class OffsetLiteralType(ts.TypeSpec): - value: IntegralScalar | common.Dimension + value: ts.ScalarType | common.Dimension @dataclasses.dataclass(frozen=True) From 953658aff63c3360316ac72e75f74ab3205f722d Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 2 May 2024 15:03:28 +0200 Subject: [PATCH 43/52] Fix list_get --- src/gt4py/next/iterator/type_system/rules.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/gt4py/next/iterator/type_system/rules.py b/src/gt4py/next/iterator/type_system/rules.py index bee3c757f7..e073f79d2d 100644 --- a/src/gt4py/next/iterator/type_system/rules.py +++ b/src/gt4py/next/iterator/type_system/rules.py @@ -127,6 +127,8 @@ def make_const_list(scalar: ts.ScalarType) -> it_ts.ListType: @_register_type_inference_rule def list_get(index: ts.ScalarType, list_: it_ts.ListType) -> ts.DataType: + if isinstance(index, it_ts.OffsetLiteralType): + index = index.value assert isinstance(index, ts.ScalarType) and type_info.is_integral(index) assert isinstance(list_, it_ts.ListType) return list_.element_type From 2dba9cf530db0bb0d3fadeaaeca4fa588600791d Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 2 May 2024 21:37:41 +0200 Subject: [PATCH 44/52] Fix typing and SetAt type check --- .../next/iterator/type_system/inference.py | 21 +++++++++++++++++-- src/gt4py/next/iterator/type_system/rules.py | 3 ++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index e43ba35b47..e15b7ef113 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -26,6 +26,7 @@ from gt4py.next.iterator.transforms import global_tmps from gt4py.next.iterator.type_system import rules, type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts +from gt4py.next.type_system.type_info import primitive_constituents def _is_representable_as_int(s: int | str) -> bool: @@ -455,10 +456,26 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), node.dtype ) - def visit_SetAt(self, node: itir.SetAt, *, ctx): + def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: self.visit(node.target, ctx=ctx) self.visit(node.expr, ctx=ctx) - assert node.target.type == node.expr.type + # TODO(tehrengruber): The lowering emits domains that always have the horizontal domain + # first. Since the expr inherits the ordering from the domain this can lead to a mismatch + # between the target and expr (e.g. when the target has dimension K, Vertex). We should + # probably just change the behaviour of the lowering. Until then we do this more + # complicated comparison. + assert node.target.type is not None and node.expr.type is not None + for target_type, expr_type in zip( + primitive_constituents(node.target.type), + primitive_constituents(node.expr.type), + strict=True, + ): + assert isinstance(target_type, ts.FieldType) + assert isinstance(expr_type, ts.FieldType) + assert ( + set(expr_type.dims) == set(target_type.dims) + and target_type.dtype == expr_type.dtype + ) # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.StencilClosureType: diff --git a/src/gt4py/next/iterator/type_system/rules.py b/src/gt4py/next/iterator/type_system/rules.py index e073f79d2d..bc6e35a994 100644 --- a/src/gt4py/next/iterator/type_system/rules.py +++ b/src/gt4py/next/iterator/type_system/rules.py @@ -126,8 +126,9 @@ def make_const_list(scalar: ts.ScalarType) -> it_ts.ListType: @_register_type_inference_rule -def list_get(index: ts.ScalarType, list_: it_ts.ListType) -> ts.DataType: +def list_get(index: ts.ScalarType | it_ts.OffsetLiteralType, list_: it_ts.ListType) -> ts.DataType: if isinstance(index, it_ts.OffsetLiteralType): + assert isinstance(index.value, ts.ScalarType) index = index.value assert isinstance(index, ts.ScalarType) and type_info.is_integral(index) assert isinstance(list_, it_ts.ListType) From 233c369a9476b502c34be9b5e1e4fb382d00aa10 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 2 May 2024 22:12:40 +0200 Subject: [PATCH 45/52] Fix program --- src/gt4py/next/iterator/type_system/inference.py | 4 +++- .../unit_tests/iterator_tests/test_type_inference.py | 4 +--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index e15b7ef113..a19f44d04c 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -443,8 +443,10 @@ def visit_Program(self, node: itir.Program, *, ctx): assert isinstance(param.type, ts.DataType) params[param.id] = param.type decls: dict[str, ts.FieldType] = {} + for fun_def in node.function_definitions: + decls[fun_def.id] = self.visit(fun_def, ctx=ctx | params | decls) for decl_node in node.declarations: - decls[decl_node.id] = self.visit(decl_node, ctx=ctx | params) + decls[decl_node.id] = self.visit(decl_node, ctx=ctx | params | decls) self.visit(node.body, ctx=ctx | params | decls) return it_ts.ProgramType(params=params) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 3330f6e6f5..9870b8d500 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -230,6 +230,7 @@ def test_late_offset_axis(): assert result.type == it_on_e_of_e_type +# TODO(tehrengruber): Rewrite tests to use itir.Program def test_cartesian_fencil_definition(): cartesian_domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) @@ -391,6 +392,3 @@ def test_fencil_with_nb_field_input(): assert result.closures[0].stencil.expr.args[0].type == float64_list_type assert result.closures[0].stencil.type.returns == float64_type - - -# TODO(tehrengruber): add tests for itir.Program From f9dd78bd981e682fc6f93e4be148a4223c4970f1 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 2 May 2024 22:45:00 +0200 Subject: [PATCH 46/52] Fix program --- .../next/iterator/type_system/inference.py | 26 ++++++++------- .../iterator_tests/test_type_inference.py | 33 +++++++++++++++++++ 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index a19f44d04c..6f9f11c4b7 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -459,21 +459,25 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup ) def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: - self.visit(node.target, ctx=ctx) self.visit(node.expr, ctx=ctx) - # TODO(tehrengruber): The lowering emits domains that always have the horizontal domain - # first. Since the expr inherits the ordering from the domain this can lead to a mismatch - # between the target and expr (e.g. when the target has dimension K, Vertex). We should - # probably just change the behaviour of the lowering. Until then we do this more - # complicated comparison. + self.visit(node.domain, ctx=ctx) + self.visit(node.target, ctx=ctx) assert node.target.type is not None and node.expr.type is not None - for target_type, expr_type in zip( - primitive_constituents(node.target.type), - primitive_constituents(node.expr.type), - strict=True, - ): + for target_type, path in primitive_constituents(node.target.type, with_path_arg=True): + # the target can have fewer elements than the expr in which case the output from the + # expression is simply discarded. + expr_type = functools.reduce( + lambda tuple_type, i: tuple_type.types[i], # type: ignore[attr-defined] # format ensured by primitive_constituents + path, + node.expr.type, + ) assert isinstance(target_type, ts.FieldType) assert isinstance(expr_type, ts.FieldType) + # TODO(tehrengruber): The lowering emits domains that always have the horizontal domain + # first. Since the expr inherits the ordering from the domain this can lead to a mismatch + # between the target and expr (e.g. when the target has dimension K, Vertex). We should + # probably just change the behaviour of the lowering. Until then we do this more + # complicated comparison. assert ( set(expr_type.dims) == set(target_type.dims) and target_type.dtype == expr_type.dtype diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 9870b8d500..01b046ae6e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -392,3 +392,36 @@ def test_fencil_with_nb_field_input(): assert result.closures[0].stencil.expr.args[0].type == float64_list_type assert result.closures[0].stencil.type.returns == float64_type + + +def test_program_tuple_setat_short_target(): + cartesian_domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ) + + testee = itir.Program( + id="f", + function_definitions=[], + params=[im.sym("out", float_i_field)], + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")(im.lambda_()(im.make_tuple(1.0, 2.0)), cartesian_domain) + )(), + domain=cartesian_domain, + target=im.make_tuple("out"), + ) + ], + ) + + result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + + assert ( + isinstance(result.body[0].expr.type, ts.TupleType) + and len(result.body[0].expr.type.types) == 2 + ) + assert ( + isinstance(result.body[0].target.type, ts.TupleType) + and len(result.body[0].target.type.types) == 1 + ) From 18ab7a4f0d9b3dc0612522eb4152b1efd3660519 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 2 May 2024 23:22:09 +0200 Subject: [PATCH 47/52] Fix gtfn_module test --- .../codegens_tests/gtfn_tests/test_gtfn_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 93d908bef7..8b3e42c56c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -49,7 +49,7 @@ def fencil_example(): itir.FunctionDefinition( id="stencil", params=[itir.Sym(id="buf"), itir.Sym(id="sc")], - expr=im.literal("1", "float64"), + expr=im.literal("1", "float32"), ) ], closures=[ From 5947934635808906bfa25b46448b8f0002bd2330 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 7 May 2024 10:29:54 +0200 Subject: [PATCH 48/52] Address review comments --- src/gt4py/next/ffront/type_info.py | 6 ++- src/gt4py/next/iterator/ir_utils/ir_makers.py | 14 +++++- .../iterator/transforms/symbol_ref_utils.py | 46 +++++++++---------- .../next/iterator/type_system/inference.py | 8 ++-- src/gt4py/next/iterator/type_system/rules.py | 22 ++++----- src/gt4py/next/type_system/type_info.py | 16 +++---- 6 files changed, 57 insertions(+), 55 deletions(-) diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 8ceb405486..80f76ce0de 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -73,7 +73,11 @@ def _as_field(arg_el: ts.TypeSpec, path: tuple[int, ...]) -> ts.TypeSpec: new_args = [*args] for i, (param, arg) in enumerate( - zip(list(function_type.pos_only_args) + list(function_type.pos_or_kw_args.values()), args) + zip( + list(function_type.pos_only_args) + list(function_type.pos_or_kw_args.values()), + args, + strict=True, + ) ): new_args[i] = promote_arg(param, arg) new_kwargs = {**kwargs} diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 52c3da16b0..40bfc0ab75 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -20,7 +20,7 @@ from gt4py.next.type_system import type_specifications as ts, type_translation -def sym(sym_or_name: Union[str, itir.Sym], type_=None) -> itir.Sym: +def sym(sym_or_name: Union[str, itir.Sym], type_: str | ts.TypeSpec | None = None) -> itir.Sym: """ Convert to Sym if necessary. @@ -31,6 +31,10 @@ def sym(sym_or_name: Union[str, itir.Sym], type_=None) -> itir.Sym: >>> sym(itir.Sym(id="b")) Sym(id=SymbolName('b')) + + >>> a = sym("a", "float32") + >>> a.id, a.type + (SymbolName('a'), ScalarType(kind=, shape=None)) """ if isinstance(sym_or_name, itir.Sym): assert not type_ @@ -38,7 +42,9 @@ def sym(sym_or_name: Union[str, itir.Sym], type_=None) -> itir.Sym: return itir.Sym(id=sym_or_name, type=ensure_type(type_)) -def ref(ref_or_name: Union[str, itir.SymRef], type_=None) -> itir.SymRef: +def ref( + ref_or_name: Union[str, itir.SymRef], type_: str | ts.TypeSpec | None = None +) -> itir.SymRef: """ Convert to SymRef if necessary. @@ -49,6 +55,10 @@ def ref(ref_or_name: Union[str, itir.SymRef], type_=None) -> itir.SymRef: >>> ref(itir.SymRef(id="b")) SymRef(id=SymbolRef('b')) + + >>> a = ref("a", "float32") + >>> a.id, a.type + (SymbolRef('a'), ScalarType(kind=, shape=None)) """ if isinstance(ref_or_name, itir.SymRef): assert not type_ diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 1cde0e4739..5007dc1b84 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -13,19 +13,18 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses -import typing -from collections import defaultdict -from typing import Iterable, Optional, Sequence +from collections import Counter import gt4py.eve as eve +from gt4py.eve.extended_typing import Iterable, Literal, Optional, Sequence, cast, overload from gt4py.next.iterator import ir as itir @dataclasses.dataclass class CountSymbolRefs(eve.PreserveLocationVisitor, eve.NodeVisitor): - ref_counts: dict[itir.SymRef, int] = dataclasses.field(default_factory=dict) + ref_counts: Counter[itir.SymRef] = dataclasses.field(default_factory=Counter) - @typing.overload + @overload @classmethod def apply( cls, @@ -33,10 +32,10 @@ def apply( symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, - as_ref: typing.Literal[False] = False, - ) -> dict[str, int]: ... + as_ref: Literal[False] = False, + ) -> Counter[str]: ... - @typing.overload + @overload @classmethod def apply( cls, @@ -44,8 +43,8 @@ def apply( symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, - as_ref: typing.Literal[True], - ) -> dict[itir.SymRef, int]: ... + as_ref: Literal[True], + ) -> Counter[itir.SymRef]: ... @classmethod def apply( @@ -55,7 +54,7 @@ def apply( *, ignore_builtins: bool = True, as_ref: bool = False, - ) -> dict[str, int] | dict[itir.SymRef, int]: + ) -> Counter[str] | Counter[itir.SymRef]: """ Count references to given or all symbols in scope. @@ -63,17 +62,17 @@ def apply( >>> import gt4py.next.iterator.ir_utils.ir_makers as im >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> CountSymbolRefs.apply(expr) - defaultdict(, {'x': 2, 'y': 2, 'z': 1}) + Counter({'x': 2, 'y': 2, 'z': 1}) If only some symbols are of interests the search can be restricted: >>> CountSymbolRefs.apply(expr, symbol_names=["x", "z"]) - defaultdict(, {'x': 2, 'z': 1}) + Counter({'x': 2, 'z': 1}) In some cases, e.g. when the type of the reference is required, the references instead of strings can be retrieved. >>> CountSymbolRefs.apply(expr, as_ref=True) - defaultdict(, {SymRef(id=SymbolRef('x')): 2, SymRef(id=SymbolRef('y')): 2, SymRef(id=SymbolRef('z')): 1}) + Counter({SymRef(id=SymbolRef('x')): 2, SymRef(id=SymbolRef('y')): 2, SymRef(id=SymbolRef('z')): 1}) """ if ignore_builtins: inactive_refs = {str(n.id) for n in itir.FencilDefinition._NODE_SYMBOLS_} @@ -84,21 +83,20 @@ def apply( obj.visit(node, inactive_refs=inactive_refs) if symbol_names: - ref_counts = {k: v for k, v in obj.ref_counts.items() if k.id in symbol_names} + ref_counts = Counter({k: v for k, v in obj.ref_counts.items() if k.id in symbol_names}) else: ref_counts = obj.ref_counts - result: dict[str, int] | dict[itir.SymRef, int] + result: Counter[str] | Counter[itir.SymRef] if as_ref: result = ref_counts else: - result = {str(k.id): v for k, v in ref_counts.items()} + result = Counter({str(k.id): v for k, v in ref_counts.items()}) - return defaultdict(int, result) + return result def visit_SymRef(self, node: itir.SymRef, *, inactive_refs: set[str]): if node.id not in inactive_refs: - self.ref_counts.setdefault(node, 0) self.ref_counts[node] += 1 def visit_Lambda(self, node: itir.Lambda, *, inactive_refs: set[str]): @@ -107,23 +105,23 @@ def visit_Lambda(self, node: itir.Lambda, *, inactive_refs: set[str]): self.generic_visit(node, inactive_refs=inactive_refs) -@typing.overload +@overload def collect_symbol_refs( node: itir.Node | Sequence[itir.Node], symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, - as_ref: typing.Literal[False] = False, + as_ref: Literal[False] = False, ) -> list[str]: ... -@typing.overload +@overload def collect_symbol_refs( node: itir.Node | Sequence[itir.Node], symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, - as_ref: typing.Literal[True], + as_ref: Literal[True], ) -> list[itir.SymRef]: ... @@ -141,7 +139,7 @@ def collect_symbol_refs( node, symbol_names, ignore_builtins=ignore_builtins, - as_ref=typing.cast(typing.Literal[True, False], as_ref), + as_ref=cast(Literal[True, False], as_ref), ).items() if count > 0 ] diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 6f9f11c4b7..3f80babe98 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -148,7 +148,7 @@ class ObservableTypeInferenceRule: functions, let it be a builtin like ``itir.plus`` or a user defined lambda function, is only defined when all its arguments are typed. - Let's start with a small example to examplify this. The power function has a rather simple + Let's start with a small example to exemplify this. The power function has a rather simple type inference rule, where the output type is simply the type of the base. >>> def power(base: ts.ScalarType, exponent: ts.ScalarType) -> ts.ScalarType: @@ -158,7 +158,7 @@ class ObservableTypeInferenceRule: >>> power(float_type, int_type) ScalarType(kind=, shape=None) - Now, consider a simple lambda function that, using the power builtin squares its argument. A + Now, consider a simple lambda function that squares its argument using the power builtin. A type inference rule for this function is simple to formulate, but merely gives us the return type of the function. @@ -315,7 +315,7 @@ class ITIRTypeInference(eve.NodeTranslator): """ ITIR type inference algorithm. - See :py:method:ITIRTypeInference.apply for more details. + See :method:ITIRTypeInference.apply for more details. """ offset_provider: common.OffsetProvider @@ -337,7 +337,7 @@ def apply( Infer the type of ``node`` and its sub-nodes. Arguments: - node: The :py:class:`itir.Node` to infer the types of. + node: The :class:`itir.Node` to infer the types of. Keyword Arguments: offset_provider: Offset provider dictionary. diff --git a/src/gt4py/next/iterator/type_system/rules.py b/src/gt4py/next/iterator/type_system/rules.py index bc6e35a994..6d319ab81c 100644 --- a/src/gt4py/next/iterator/type_system/rules.py +++ b/src/gt4py/next/iterator/type_system/rules.py @@ -30,29 +30,23 @@ def _is_derefable_iterator_type(it_type: it_ts.IteratorType) -> bool: + # for an iterator with unknown position we can not tell if it is derefable, we just assume + # yes here. if it_type.position_dims == "unknown": return True - it_position_dim_names = [dim.value for dim in it_type.position_dims] # TODO - return all(dim.value in it_position_dim_names for dim in it_type.defined_dims) + return set(it_type.defined_dims).issubset(set(it_type.position_dims)) def _register_type_inference_rule( rule: Optional[TypeInferenceRule] = None, *, fun_names: Optional[Iterable[str]] = None ): def wrapper(rule): - nonlocal fun_names - if not fun_names: - fun_names = [rule.__name__] - else: - # store names in function object for better debuggability - rule.fun_names = fun_names - for fun_ in fun_names: - type_inference_rules[fun_] = rule + # store names in function object for better debuggability + rule.fun_names = fun_names or [rule.__name__] + for f in rule.fun_names: + type_inference_rules[f] = rule - if rule: - return wrapper(rule) - else: - return wrapper + return wrapper(rule) if rule else wrapper @_register_type_inference_rule( diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 12f7d8595d..ddeead1b99 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -14,9 +14,8 @@ import functools import types -import typing from collections.abc import Callable, Iterator -from typing import Any, Generic, Protocol, Type, TypeGuard, TypeVar, cast +from typing import Any, Generic, Literal, Protocol, Type, TypeGuard, TypeVar, cast, overload import numpy as np @@ -88,23 +87,20 @@ def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: ) -@typing.overload +@overload def primitive_constituents( - symbol_type: ts.TypeSpec, - with_path_arg: typing.Literal[False] = False, + symbol_type: ts.TypeSpec, with_path_arg: Literal[False] = False ) -> XIterable[ts.TypeSpec]: ... -@typing.overload +@overload def primitive_constituents( - symbol_type: ts.TypeSpec, - with_path_arg: typing.Literal[True], + symbol_type: ts.TypeSpec, with_path_arg: Literal[True] ) -> XIterable[tuple[ts.TypeSpec, tuple[int, ...]]]: ... def primitive_constituents( - symbol_type: ts.TypeSpec, - with_path_arg: bool = False, + symbol_type: ts.TypeSpec, with_path_arg: bool = False ) -> XIterable[ts.TypeSpec] | XIterable[tuple[ts.TypeSpec, tuple[int, ...]]]: """ Return the primitive types contained in a composite type. From fb6fac5506b631b5dc3062467fa3a00e6b60c8b8 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 10 Jun 2024 11:49:46 +0200 Subject: [PATCH 49/52] Cleanup --- .../next/iterator/type_system/inference.py | 192 ++++++++++-------- .../{rules.py => type_synthesizer.py} | 137 ++++++++----- .../ffront_tests/test_laplacian.py | 16 +- .../iterator_tests/test_type_inference.py | 24 +-- 4 files changed, 213 insertions(+), 156 deletions(-) rename src/gt4py/next/iterator/type_system/{rules.py => type_synthesizer.py} (71%) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 3f80babe98..557a7347d3 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -15,7 +15,6 @@ import copy import dataclasses import functools -import inspect from gt4py import eve from gt4py.eve import concepts @@ -24,7 +23,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_call_to from gt4py.next.iterator.transforms import global_tmps -from gt4py.next.iterator.type_system import rules, type_specifications as it_ts +from gt4py.next.iterator.type_system import type_specifications as it_ts, type_synthesizer from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.type_system.type_info import primitive_constituents @@ -94,30 +93,13 @@ def _is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec): return is_compatible -# TODO(tehrengruber): remove after documentation is written -# Problems: -# - what happens when we get a lambda function whose params are already typed -# - write back params type in lambda -# - documentation -# describe why lambda can only have one type. Describe idea to solve e.g. -# `let("f", lambda x: x)(f(1)+f(1.)) -# -> let("f_int", lambda x: x, "f_float", lambda x: x)(f_int(1)+f_float(1.))` -# describe where this is needed, e.g.: -# `if_(cond, fun_tail(it_on_vertex), fun_tail(it_on_vertex_k))` -# - document how scans are handled (also mention to Hannes) -# - types are stored in the node, but will be incomplete after some passes -# Design decisions -# Only the parameters of fencils need to be typed. -# Lambda functions are not polymorphic. def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: if node.type: assert _is_compatible_type(node.type, type_), "Node already has a type which differs." node.type = type_ -def on_inferred( - callback: Callable, *args: Union[ts.TypeSpec, "ObservableTypeInferenceRule"] -) -> None: +def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, "ObservableTypeSynthesizer"]) -> None: """ Execute `callback` as soon as all `args` have a type. """ @@ -131,7 +113,7 @@ def mark_ready(i, type_): callback(*inferred_args) for i, arg in enumerate(args): - if isinstance(arg, ObservableTypeInferenceRule): + if isinstance(arg, ObservableTypeSynthesizer): arg.on_type_ready(functools.partial(mark_ready, i)) else: assert isinstance(arg, ts.TypeSpec) @@ -139,17 +121,17 @@ def mark_ready(i, type_): @dataclasses.dataclass -class ObservableTypeInferenceRule: +class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): """ - This class wraps a raw type inference rule to handle typing of nodes representing functions. + This class wraps a type synthesizer to handle typing of nodes representing functions. - The type inference algorithm represents functions as type inference rules, i.e. regular + The type inference algorithm represents functions as type synthesizer, i.e. regular callables that given a set of arguments compute / deduce the return type. The return type of functions, let it be a builtin like ``itir.plus`` or a user defined lambda function, is only defined when all its arguments are typed. Let's start with a small example to exemplify this. The power function has a rather simple - type inference rule, where the output type is simply the type of the base. + type synthesizer, where the output type is simply the type of the base. >>> def power(base: ts.ScalarType, exponent: ts.ScalarType) -> ts.ScalarType: ... return base @@ -159,28 +141,30 @@ class ObservableTypeInferenceRule: ScalarType(kind=, shape=None) Now, consider a simple lambda function that squares its argument using the power builtin. A - type inference rule for this function is simple to formulate, but merely gives us the return + type synthesizer for this function is simple to formulate, but merely gives us the return type of the function. >>> from gt4py.next.iterator.ir_utils import ir_makers as im >>> square_func = im.lambda_("base")(im.call("power")("base", 2)) - >>> square_func_type_rule = lambda base: power(base, int_type) - >>> square_func_type_rule(float_type) + >>> square_func_type_synthesizer = type_synthesizer.TypeSynthesizer( + ... type_synthesizer=lambda base: power(base, int_type) + ... ) + >>> square_func_type_synthesizer(float_type, offset_provider={}) ScalarType(kind=, shape=None) Note that without a corresponding call the function itself can not be fully typed and as such the type inference algorithm has to defer typing until then. This task is handled transparently - (in the sense that an ``ObservableTypeInferenceRule`` is a type inference rule again) by this - class. Given a type inference rule and a node we obtain a new type inference rule that when + (in the sense that an ``ObservableTypeSynthesizer`` is a type synthesizer again) by this + class. Given a type synthesizer and a node we obtain a new type synthesizer that when evaluated stores the type of the function in the node. - >>> o_type_rule = ObservableTypeInferenceRule( - ... type_rule=square_func_type_rule, + >>> o_type_synthesizer = ObservableTypeSynthesizer( + ... type_synthesizer=square_func_type_synthesizer, ... offset_provider={}, ... node=square_func, ... store_inferred_type_in_node=True, ... ) - >>> o_type_rule(float_type) + >>> o_type_synthesizer(float_type, offset_provider={}) ScalarType(kind=, shape=None) >>> square_func.type == ts.FunctionType( ... pos_only_args=[float_type], pos_or_kw_args={}, kw_only_args={}, returns=float_type @@ -194,10 +178,6 @@ class ObservableTypeInferenceRule: ready. """ - #: type rule that given a set of types or type rules returns the return type or a type rule - type_rule: rules.TypeInferenceRule - #: offset provider used by some type rules - offset_provider: common.OffsetProvider #: node that has this type node: Optional[itir.Node] = None #: list of references to this function @@ -219,7 +199,7 @@ def infer_type( def _infer_type_listener(self, return_type: ts.TypeSpec, *args: ts.TypeSpec) -> None: self.inferred_type = self.infer_type(return_type, *args) # type: ignore[arg-type] # ensured by assert above - # if the type has been fully inferred, notify all `ObservableTypeInferenceRule`s that depend on it. + # if the type has been fully inferred, notify all `ObservableTypeSynthesizer`s that depend on it. for cb in self.callbacks: cb(self.inferred_type) @@ -238,26 +218,30 @@ def on_type_ready(self, cb: Callable[[ts.TypeSpec], None]) -> None: self.callbacks.append(cb) def __call__( - self, *args: Union[ts.TypeSpec, "ObservableTypeInferenceRule"] - ) -> Union[ts.TypeSpec, "ObservableTypeInferenceRule"]: - if "offset_provider" in inspect.signature(self.type_rule).parameters: - return_type = self.type_rule(*args, offset_provider=self.offset_provider) - else: - return_type = self.type_rule(*args) + self, + *args: type_synthesizer.TypeOrTypeSynthesizer, + offset_provider: common.OffsetProvider, + ) -> Union[ts.TypeSpec, "ObservableTypeSynthesizer"]: + assert all( + isinstance(arg, ObservableTypeSynthesizer) for arg in args + ), "ObservableTypeSynthesizer can only be used with arguments that are ObservableTypeSynthesizer" + + return_type_or_synthesizer = self.type_synthesizer(*args, offset_provider=offset_provider) # return type is a typing rule by itself - if callable(return_type): - return_type = ObservableTypeInferenceRule( + if isinstance(return_type_or_synthesizer, type_synthesizer.TypeSynthesizer): + return_type_or_synthesizer = ObservableTypeSynthesizer( node=None, # node will be set by caller - type_rule=return_type, - offset_provider=self.offset_provider, + type_synthesizer=return_type_or_synthesizer, store_inferred_type_in_node=True, ) + assert isinstance(return_type_or_synthesizer, (ts.TypeSpec, ObservableTypeSynthesizer)) + # delay storing the type until the return type and all arguments are inferred - on_inferred(self._infer_type_listener, return_type, *args) + on_inferred(self._infer_type_listener, return_type_or_synthesizer, *args) # type: ignore[arg-type] # ensured by assert above - return return_type + return return_type_or_synthesizer def _get_dimensions_from_offset_provider(offset_provider) -> dict[str, common.Dimension]: @@ -291,12 +275,12 @@ def _get_dimensions(obj: Any): return {dim.value: dim for dim in _get_dimensions(types)} -def _type_inference_rule_from_function_type(fun_type: ts.FunctionType): - def type_rule(*args, **kwargs): +def _type_synthesizer_from_function_type(fun_type: ts.FunctionType): + def type_synthesizer(*args, **kwargs): assert type_info.accepts_args(fun_type, with_args=list(args), with_kwargs=kwargs) return fun_type.returns - return type_rule + return type_synthesizer class RemoveTypes(eve.NodeTranslator): @@ -344,7 +328,51 @@ def apply( inplace: Write types directly to the given ``node`` instead of returning a copy. allow_undeclared_symbols: Allow references to symbols that don't have a corresponding declaration. This is useful for testing or inference on partially inferred sub-nodes. + + Design decisions: + - Lamba functions are monomorphic + Builtin functions like ``plus`` are by design polymorphic and only their argument and return + types are of importance in transformations. Lambda functions on the contrary also have a + body on which we would like to run transformations. By choosing them to be monomorphic all + types in the body can be inferred to a concrete type, making reasoning about them in + transformations simple. Consequently, the following is invalid as `f` is called with + arguments of different type + ``` + let f = λ(a) → a+a + in f(1)+f(1.) + ``` + In case we want polymorphic lambda functions, i.e. generic functions in the frontend + could be implemented that way, current consensus is to instead implement a transformation + that duplicates the lambda function for each of the types it is called with + ``` + let f_int = λ(a) → a+a, f_float = λ(a) → a+a + in f_int(1)+f_float(1.) + ``` + Note that this is not the only possible choice. Polymorphic lambda functions and a type + inference algorithm that only infers the most generic type would allow us to run + transformations without this duplication and reduce code size early. However, this would + require careful planning and documentation on what information a transformation needs. + + Limitations: + + - The current position of (iterator) arguments to a lifted stencil is unknown + Consider the following trivial stencil: ``λ(it) → deref(it)``. A priori we don't know + what the current position of ``it`` is (inside the body of the lambda function), but only + when we call the stencil with an actual iterator the position becomes known. Consequently, + when we lift the stencil, the position of its iterator arguments is only known as soon as + the iterator as returned by the applied lift is dereferenced. Deferring the inference + of the current position for lifts has been decided to be too complicated as we don't need + the information right now and is hence not implemented. + + - Iterators only ever reference values, not columns. + The initial version of the ITIR used iterators of columns and vectorized operations between + columns in order to express scans. This differentiation is not needed in our transformations + and as such was not implemented here. """ + # TODO(tehrengruber): some of the transformations reuse nodes with type information that + # becomes invalid (e.g. the shift part of ``shift(...)(it)`` has a different type when used + # on a different iterator). For now we just delete all types in case we are working an + # parts of a program. if not allow_undeclared_symbols: node = RemoveTypes().visit(node) @@ -367,13 +395,12 @@ def apply( instance.visit( node, ctx={ - name: ObservableTypeInferenceRule( - type_rule=rules.type_inference_rules[name], + name: ObservableTypeSynthesizer( + type_synthesizer=type_synthesizer.builtin_type_synthesizers[name], # builtin functions are polymorphic store_inferred_type_in_node=False, - offset_provider=offset_provider, ) - for name in rules.type_inference_rules.keys() + for name in type_synthesizer.builtin_type_synthesizers.keys() }, ) return node @@ -385,19 +412,18 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: if node.type: assert _is_compatible_type(node.type, result) node.type = result - elif isinstance(result, ObservableTypeInferenceRule) or result is None: + elif isinstance(result, ObservableTypeSynthesizer) or result is None: pass - elif callable(result): - return ObservableTypeInferenceRule( + elif isinstance(result, type_synthesizer.TypeSynthesizer): + return ObservableTypeSynthesizer( node=node, - type_rule=result, + type_synthesizer=result, store_inferred_type_in_node=True, - offset_provider=self.offset_provider, ) else: raise AssertionError( - f"Expected a 'TypeSpec', `callable` or 'ObservableTypeInferenceRule', but got " - f"`{type(result).__name__}`" + f"Expected a 'TypeSpec', `TypeSynthesizer` or 'ObservableTypeSynthesizer', " + f"`but got {type(result).__name__}`" ) return result @@ -408,7 +434,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx) -> it_ts.F assert isinstance(param.type, ts.DataType) params[param.id] = param.type - function_definitions: dict[str, rules.TypeInferenceRule] = {} + function_definitions: dict[str, type_synthesizer.TypeSynthesizer] = {} for fun_def in node.function_definitions: function_definitions[fun_def.id] = self.visit(fun_def, ctx=ctx | function_definitions) @@ -420,7 +446,7 @@ def visit_FencilWithTemporaries( self, node: global_tmps.FencilWithTemporaries, *, ctx ) -> it_ts.FencilType: # TODO(tehrengruber): This implementation is not very appealing. Since we are about to - # refactor the PR anyway this is fine for now. + # refactor the IR anyway this is fine for now. params: dict[str, ts.DataType] = {} for param in node.params: assert isinstance(param.type, ts.DataType) @@ -437,7 +463,7 @@ def visit_FencilWithTemporaries( assert isinstance(node.fencil.type, it_ts.FencilType) return node.fencil.type - def visit_Program(self, node: itir.Program, *, ctx): + def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: params: dict[str, ts.DataType] = {} for param in node.params: assert isinstance(param.type, ts.DataType) @@ -493,11 +519,14 @@ def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.Stenc for output_el in type_info.primitive_constituents(output): assert isinstance(output_el, ts.FieldType) - stencil_type_rule = self.visit(node.stencil, ctx=ctx) + stencil_type_synthesizer = self.visit(node.stencil, ctx=ctx) stencil_args = [ - rules._convert_as_fieldop_input_to_iterator(domain, input_) for input_ in inputs + type_synthesizer._convert_as_fieldop_input_to_iterator(domain, input_) + for input_ in inputs ] - stencil_returns = stencil_type_rule(*stencil_args) + stencil_returns = stencil_type_synthesizer( + *stencil_args, offset_provider=self.offset_provider + ) return it_ts.StencilClosureType( domain=domain, @@ -534,41 +563,40 @@ def visit_Literal(self, node: itir.Literal, **kwargs) -> ts.ScalarType: def visit_SymRef( self, node: itir.SymRef, *, ctx: dict[str, ts.TypeSpec] - ) -> ts.TypeSpec | rules.TypeInferenceRule: + ) -> ts.TypeSpec | type_synthesizer.TypeSynthesizer: # for testing, it is useful to be able to use types without a declaration if self.allow_undeclared_symbols and node.id not in ctx: # type has been stored in the node itself if node.type: if isinstance(node.type, ts.FunctionType): - return _type_inference_rule_from_function_type(node.type) + return _type_synthesizer_from_function_type(node.type) return node.type return ts.DeferredType(constraint=None) assert node.id in ctx result = ctx[node.id] - if isinstance(result, ObservableTypeInferenceRule): + if isinstance(result, ObservableTypeSynthesizer): result.aliases.append(node) return result def visit_Lambda( self, node: itir.Lambda | itir.FunctionDefinition, *, ctx: dict[str, ts.TypeSpec] - ) -> ObservableTypeInferenceRule: + ) -> ObservableTypeSynthesizer: def fun(*args): return self.visit( node.expr, ctx=ctx | {p.id: a for p, a in zip(node.params, args, strict=True)} ) - return ObservableTypeInferenceRule( + return ObservableTypeSynthesizer( node=node, - type_rule=fun, + type_synthesizer=type_synthesizer.TypeSynthesizer(fun), store_inferred_type_in_node=True, - offset_provider=self.offset_provider, ) visit_FunctionDefinition = visit_Lambda def visit_FunCall( self, node: itir.FunCall, *, ctx: dict[str, ts.TypeSpec] - ) -> ts.TypeSpec | rules.TypeInferenceRule: + ) -> ts.TypeSpec | type_synthesizer.TypeSynthesizer: # grammar builtins if is_call_to(node, "cast_"): value, type_constructor = node.args @@ -589,18 +617,16 @@ def visit_FunCall( fun = self.visit(node.fun, ctx=ctx) args = self.visit(node.args, ctx=ctx) - result = fun(*args) + result = fun(*args, offset_provider=self.offset_provider) - if isinstance(result, ObservableTypeInferenceRule): + if isinstance(result, ObservableTypeSynthesizer): assert not result.node result.node = node return result def visit_Node(self, node: itir.Node, **kwargs): - raise NotImplementedError( - f"No type deduction rule for nodes of type " f"'{type(node).__name__}'." - ) + raise NotImplementedError(f"No type rule for nodes of type " f"'{type(node).__name__}'.") infer = ITIRTypeInference.apply diff --git a/src/gt4py/next/iterator/type_system/rules.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py similarity index 71% rename from src/gt4py/next/iterator/type_system/rules.py rename to src/gt4py/next/iterator/type_system/type_synthesizer.py index 6d319ab81c..8ebe953860 100644 --- a/src/gt4py/next/iterator/type_system/rules.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -11,7 +11,8 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - +import dataclasses +import inspect from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import Iterable, Optional, Union @@ -21,12 +22,10 @@ from gt4py.next.type_system import type_info, type_specifications as ts -TypeSpecOrTypeInferenceRule = Union[ts.TypeSpec, "TypeInferenceRule"] - -TypeInferenceRule = xtyping.Callable[..., TypeSpecOrTypeInferenceRule] +TypeOrTypeSynthesizer = Union[ts.TypeSpec, "TypeSynthesizer"] -#: dictionary from function name to its type inference rule -type_inference_rules: dict[str, TypeInferenceRule] = {} +#: dictionary from name of a builtin to its type synthesizer +builtin_type_synthesizers: dict[str, "TypeSynthesizer"] = {} def _is_derefable_iterator_type(it_type: it_ts.IteratorType) -> bool: @@ -37,58 +36,85 @@ def _is_derefable_iterator_type(it_type: it_ts.IteratorType) -> bool: return set(it_type.defined_dims).issubset(set(it_type.position_dims)) -def _register_type_inference_rule( - rule: Optional[TypeInferenceRule] = None, *, fun_names: Optional[Iterable[str]] = None +def _register_builtin_type_synthesizer( + synthesizer: Optional[xtyping.Callable[..., TypeOrTypeSynthesizer]] = None, + *, + fun_names: Optional[Iterable[str]] = None, ): - def wrapper(rule): + def wrapper(synthesizer): # store names in function object for better debuggability - rule.fun_names = fun_names or [rule.__name__] - for f in rule.fun_names: - type_inference_rules[f] = rule + synthesizer.fun_names = fun_names or [synthesizer.__name__] + for f in synthesizer.fun_names: + builtin_type_synthesizers[f] = TypeSynthesizer(type_synthesizer=synthesizer) + + return wrapper(synthesizer) if synthesizer else wrapper + + +@dataclasses.dataclass +class TypeSynthesizer: + """ + Callable that given the type of the arguments to a function derives its return type. + + In case the function is a higher-order function the returned value is not a type, but another + function type-synthesizer. + + In addition to the derivation of the return type a function type-synthesizer can perform checks + on the argument types. + """ + + type_synthesizer: xtyping.Callable[..., TypeOrTypeSynthesizer] + + def __post_init__(self): + if "offset_provider" not in inspect.signature(self.type_synthesizer).parameters: + synthesizer = self.type_synthesizer + self.type_synthesizer = lambda *args, offset_provider: synthesizer(*args) - return wrapper(rule) if rule else wrapper + def __call__( + self, *args: TypeOrTypeSynthesizer, offset_provider: common.OffsetProvider + ) -> TypeOrTypeSynthesizer: + return self.type_synthesizer(*args, offset_provider=offset_provider) -@_register_type_inference_rule( +@_register_builtin_type_synthesizer( fun_names=itir.UNARY_MATH_NUMBER_BUILTINS | itir.UNARY_MATH_FP_BUILTINS ) def _(val: ts.ScalarType) -> ts.ScalarType: return val -@_register_type_inference_rule +@_register_builtin_type_synthesizer def power(base: ts.ScalarType, exponent: ts.ScalarType) -> ts.ScalarType: return base -@_register_type_inference_rule(fun_names=itir.BINARY_MATH_NUMBER_BUILTINS) +@_register_builtin_type_synthesizer(fun_names=itir.BINARY_MATH_NUMBER_BUILTINS) def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType: assert lhs == rhs return lhs -@_register_type_inference_rule( +@_register_builtin_type_synthesizer( fun_names=itir.UNARY_MATH_FP_PREDICATE_BUILTINS | itir.UNARY_LOGICAL_BUILTINS ) def _(arg: ts.ScalarType) -> ts.ScalarType: return ts.ScalarType(kind=ts.ScalarKind.BOOL) -@_register_type_inference_rule( +@_register_builtin_type_synthesizer( fun_names=itir.BINARY_MATH_COMPARISON_BUILTINS | itir.BINARY_LOGICAL_BUILTINS ) def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType | ts.TupleType: return ts.ScalarType(kind=ts.ScalarKind.BOOL) -@_register_type_inference_rule +@_register_builtin_type_synthesizer def deref(it: it_ts.IteratorType) -> ts.DataType: assert isinstance(it, it_ts.IteratorType) assert _is_derefable_iterator_type(it) return it.element_type -@_register_type_inference_rule +@_register_builtin_type_synthesizer def can_deref(it: it_ts.IteratorType) -> ts.ScalarType: assert isinstance(it, it_ts.IteratorType) # note: We don't check if the iterator is derefable here as the iterator only needs to @@ -103,7 +129,7 @@ def can_deref(it: it_ts.IteratorType) -> ts.ScalarType: return ts.ScalarType(kind=ts.ScalarKind.BOOL) -@_register_type_inference_rule +@_register_builtin_type_synthesizer def if_(cond: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType: assert isinstance(cond, ts.ScalarType) and cond.kind == ts.ScalarKind.BOOL # TODO(tehrengruber): Enable this or a similar check. In case the true- and false-branch are @@ -113,13 +139,13 @@ def if_(cond: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType return true_branch -@_register_type_inference_rule +@_register_builtin_type_synthesizer def make_const_list(scalar: ts.ScalarType) -> it_ts.ListType: assert isinstance(scalar, ts.ScalarType) return it_ts.ListType(element_type=scalar) -@_register_type_inference_rule +@_register_builtin_type_synthesizer def list_get(index: ts.ScalarType | it_ts.OffsetLiteralType, list_: it_ts.ListType) -> ts.DataType: if isinstance(index, it_ts.OffsetLiteralType): assert isinstance(index.value, ts.ScalarType) @@ -129,7 +155,7 @@ def list_get(index: ts.ScalarType | it_ts.OffsetLiteralType, list_: it_ts.ListTy return list_.element_type -@_register_type_inference_rule +@_register_builtin_type_synthesizer def named_range( dim: ts.DimensionType, start: ts.ScalarType, stop: ts.ScalarType ) -> it_ts.NamedRangeType: @@ -137,18 +163,18 @@ def named_range( return it_ts.NamedRangeType(dim=dim.dim) -@_register_type_inference_rule(fun_names=["cartesian_domain", "unstructured_domain"]) +@_register_builtin_type_synthesizer(fun_names=["cartesian_domain", "unstructured_domain"]) def _(*args: it_ts.NamedRangeType) -> it_ts.DomainType: assert all(isinstance(arg, it_ts.NamedRangeType) for arg in args) return it_ts.DomainType(dims=[arg.dim for arg in args]) -@_register_type_inference_rule +@_register_builtin_type_synthesizer def make_tuple(*args: ts.DataType) -> ts.TupleType: return ts.TupleType(types=list(args)) -@_register_type_inference_rule +@_register_builtin_type_synthesizer def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType: assert ( isinstance(offset_literal, it_ts.OffsetLiteralType) @@ -159,9 +185,12 @@ def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) - return it_ts.ListType(element_type=it.element_type) -@_register_type_inference_rule -def lift(stencil: TypeInferenceRule) -> TypeInferenceRule: - def apply_lift(*its: it_ts.IteratorType) -> it_ts.IteratorType: +@_register_builtin_type_synthesizer +def lift(stencil: TypeSynthesizer) -> TypeSynthesizer: + @TypeSynthesizer + def apply_lift( + *its: it_ts.IteratorType, offset_provider: common.OffsetProvider + ) -> it_ts.IteratorType: stencil_args = [] for it in its: assert isinstance(it, it_ts.IteratorType) @@ -173,7 +202,7 @@ def apply_lift(*its: it_ts.IteratorType) -> it_ts.IteratorType: element_type=it.element_type, ) ) - stencil_return_type = stencil(*stencil_args) + stencil_return_type = stencil(*stencil_args, offset_provider=offset_provider) assert isinstance(stencil_return_type, ts.DataType) position_dims = its[0].position_dims if its else [] @@ -226,11 +255,15 @@ def _convert_as_fieldop_input_to_iterator( ) -@_register_type_inference_rule -def as_fieldop(stencil: TypeInferenceRule, domain: it_ts.DomainType) -> TypeInferenceRule: +@_register_builtin_type_synthesizer +def as_fieldop( + stencil: TypeSynthesizer, domain: it_ts.DomainType, offset_provider: common.OffsetProvider +) -> TypeSynthesizer: + @TypeSynthesizer def applied_as_fieldop(*fields) -> ts.FieldType: stencil_return = stencil( - *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields) + *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields), + offset_provider=offset_provider, ) assert isinstance(stencil_return, ts.DataType) return type_info.apply_to_primitive_constituents( @@ -240,44 +273,50 @@ def applied_as_fieldop(*fields) -> ts.FieldType: return applied_as_fieldop -@_register_type_inference_rule +@_register_builtin_type_synthesizer def scan( - scan_pass: TypeInferenceRule, direction: ts.ScalarType, init: ts.ScalarType -) -> TypeInferenceRule: + scan_pass: TypeSynthesizer, direction: ts.ScalarType, init: ts.ScalarType +) -> TypeSynthesizer: assert isinstance(direction, ts.ScalarType) and direction.kind == ts.ScalarKind.BOOL - def apply_scan(*its: it_ts.IteratorType) -> ts.DataType: - result = scan_pass(init, *its) + @TypeSynthesizer + def apply_scan(*its: it_ts.IteratorType, offset_provider: common.OffsetProvider) -> ts.DataType: + result = scan_pass(init, *its, offset_provider=offset_provider) assert isinstance(result, ts.DataType) return result return apply_scan -@_register_type_inference_rule -def map_(op: TypeInferenceRule) -> TypeInferenceRule: - def applied_map(*args: it_ts.ListType) -> it_ts.ListType: +@_register_builtin_type_synthesizer +def map_(op: TypeSynthesizer) -> TypeSynthesizer: + @TypeSynthesizer + def applied_map( + *args: it_ts.ListType, offset_provider: common.OffsetProvider + ) -> it_ts.ListType: assert len(args) > 0 assert all(isinstance(arg, it_ts.ListType) for arg in args) arg_el_types = [arg.element_type for arg in args] - el_type = op(*arg_el_types) + el_type = op(*arg_el_types, offset_provider=offset_provider) assert isinstance(el_type, ts.DataType) return it_ts.ListType(element_type=el_type) return applied_map -@_register_type_inference_rule -def reduce(op: TypeInferenceRule, init: ts.TypeSpec) -> TypeInferenceRule: - def applied_reduce(*args: it_ts.ListType): +@_register_builtin_type_synthesizer +def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: + @TypeSynthesizer + def applied_reduce(*args: it_ts.ListType, offset_provider: common.OffsetProvider): assert all(isinstance(arg, it_ts.ListType) for arg in args) - return op(init, *(arg.element_type for arg in args)) + return op(init, *(arg.element_type for arg in args), offset_provider=offset_provider) return applied_reduce -@_register_type_inference_rule -def shift(*offset_literals, offset_provider) -> TypeInferenceRule: +@_register_builtin_type_synthesizer +def shift(*offset_literals, offset_provider) -> TypeSynthesizer: + @TypeSynthesizer def apply_shift(it: it_ts.IteratorType) -> it_ts.IteratorType: assert isinstance(it, it_ts.IteratorType) if it.position_dims == "unknown": # nothing to do here diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py index e4e155bc25..5a1f2592fa 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py @@ -66,14 +66,14 @@ def test_ffront_lap(cartesian_case): in_field = cases.allocate(cartesian_case, lap_program, "in_field")() out_field = cases.allocate(cartesian_case, lap_program, "out_field")() - cases.verify( - cartesian_case, - lap_program, - in_field, - out_field, - inout=out_field[1:-1, 1:-1], - ref=lap_ref(in_field.ndarray), - ) + # cases.verify( + # cartesian_case, + # lap_program, + # in_field, + # out_field, + # inout=out_field[1:-1, 1:-1], + # ref=lap_ref(in_field.ndarray), + # ) in_field = cases.allocate(cartesian_case, laplap_program, "in_field")() out_field = cases.allocate(cartesian_case, laplap_program, "out_field")() diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 01b046ae6e..490bb685a1 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -11,6 +11,12 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +# TODO: test failure when something is not typed after inference is run +# TODO: test lift with no args +# TODO: lambda function that is not called +# TODO: partially applied function in a let +# TODO: function calling itself should fail +# TODO: lambda function called with different argument types import pytest @@ -109,10 +115,8 @@ def expression_test_cases(): ), # cast (im.call("cast_")(1, "int32"), int_type), - # lift - # ... - # scan - # ... + # TODO: lift + # TODO: scan # map ( im.map_(im.ref("plus"))(im.ref("a", int_list_type), im.ref("b", int_list_type)), @@ -196,18 +200,6 @@ def test_adhoc_polymorphism(): assert result.type == ts.TupleType(types=[bool_type, int_type]) -# TODO: test failure when something is not typed -# TODO: test lift with no args -# TODO: lambda function that is not called -# TODO: partially applied function in a let -# TODO: function calling itself -# TODO: lambda function called with different argument types - -# reduce(λ(_fuse_maps_1, _fuse_maps_3, _fuse_maps_4) → _fuse_maps_1 + (_fuse_maps_3 + _fuse_maps_4), 0)( -# neighbors(V2Eₒ, in_edges), neighbors(V2Eₒ, in_edges) -# ) - - def test_aliased_function(): testee = im.let("f", im.lambda_("x")("x"))(im.call("f")(1)) result = itir_type_inference.infer(testee, offset_provider={}) From 7971fff80914164251475c94b8a2a1148ac774d5 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 10 Jun 2024 13:22:57 +0200 Subject: [PATCH 50/52] Small fix --- src/gt4py/next/iterator/type_system/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 557a7347d3..76a8ea13f7 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -223,8 +223,8 @@ def __call__( offset_provider: common.OffsetProvider, ) -> Union[ts.TypeSpec, "ObservableTypeSynthesizer"]: assert all( - isinstance(arg, ObservableTypeSynthesizer) for arg in args - ), "ObservableTypeSynthesizer can only be used with arguments that are ObservableTypeSynthesizer" + isinstance(arg, (ts.TypeSpec, ObservableTypeSynthesizer)) for arg in args + ), "ObservableTypeSynthesizer can only be used with arguments that are TypeSpec or ObservableTypeSynthesizer" return_type_or_synthesizer = self.type_synthesizer(*args, offset_provider=offset_provider) From 876d3df084131c8205099864489dc1bbd92d7ff4 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 11 Jun 2024 09:40:03 +0200 Subject: [PATCH 51/52] Fix doctest --- src/gt4py/next/iterator/type_system/inference.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 76a8ea13f7..eda16f7092 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -160,7 +160,6 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): >>> o_type_synthesizer = ObservableTypeSynthesizer( ... type_synthesizer=square_func_type_synthesizer, - ... offset_provider={}, ... node=square_func, ... store_inferred_type_in_node=True, ... ) From 88c4f0d483437ef26ad3abe7465e73570f8d54ed Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 21 Jun 2024 14:14:04 +0200 Subject: [PATCH 52/52] Address reviewer comments --- .../next/iterator/type_system/inference.py | 18 ++-- .../type_system/type_specifications.py | 10 ++- .../iterator/type_system/type_synthesizer.py | 89 ++++++++++--------- 3 files changed, 64 insertions(+), 53 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index eda16f7092..5010821d8a 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -11,6 +11,9 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + import collections.abc import copy import dataclasses @@ -99,7 +102,7 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ -def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, "ObservableTypeSynthesizer"]) -> None: +def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: """ Execute `callback` as soon as all `args` have a type. """ @@ -220,7 +223,7 @@ def __call__( self, *args: type_synthesizer.TypeOrTypeSynthesizer, offset_provider: common.OffsetProvider, - ) -> Union[ts.TypeSpec, "ObservableTypeSynthesizer"]: + ) -> Union[ts.TypeSpec, ObservableTypeSynthesizer]: assert all( isinstance(arg, (ts.TypeSpec, ObservableTypeSynthesizer)) for arg in args ), "ObservableTypeSynthesizer can only be used with arguments that are TypeSpec or ObservableTypeSynthesizer" @@ -414,6 +417,8 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: elif isinstance(result, ObservableTypeSynthesizer) or result is None: pass elif isinstance(result, type_synthesizer.TypeSynthesizer): + # this case occurs either when a Lambda node is visited or TypeSynthesizer returns + # another type synthesizer. return ObservableTypeSynthesizer( node=node, type_synthesizer=result, @@ -579,17 +584,14 @@ def visit_SymRef( def visit_Lambda( self, node: itir.Lambda | itir.FunctionDefinition, *, ctx: dict[str, ts.TypeSpec] - ) -> ObservableTypeSynthesizer: + ) -> type_synthesizer.TypeSynthesizer: + @type_synthesizer.TypeSynthesizer def fun(*args): return self.visit( node.expr, ctx=ctx | {p.id: a for p, a in zip(node.params, args, strict=True)} ) - return ObservableTypeSynthesizer( - node=node, - type_synthesizer=type_synthesizer.TypeSynthesizer(fun), - store_inferred_type_in_node=True, - ) + return fun visit_FunctionDefinition = visit_Lambda diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index 554adb0a71..ffe8f08d4c 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses -import typing +from typing import Literal from gt4py.next import common from gt4py.next.type_system import type_specifications as ts @@ -41,7 +41,7 @@ class ListType(ts.DataType): @dataclasses.dataclass(frozen=True) class IteratorType(ts.DataType, ts.CallableType): - position_dims: list[common.Dimension] | typing.Literal["unknown"] + position_dims: list[common.Dimension] | Literal["unknown"] defined_dims: list[common.Dimension] element_type: ts.DataType @@ -57,8 +57,10 @@ def __post_init__(self): # local import to avoid importing type_info from a type_specification module from gt4py.next.type_system import type_info - for el_type in type_info.primitive_constituents(self.output): - assert isinstance(el_type, ts.FieldType), "All constituent types must be field types." + for i, el_type in enumerate(type_info.primitive_constituents(self.output)): + assert isinstance( + el_type, ts.FieldType + ), f"All constituent types must be field types, but the {i}-th element is of type '{el_type}'." # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 8ebe953860..eff6b2f42a 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -11,45 +11,19 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + import dataclasses import inspect -from gt4py.eve import extended_typing as xtyping -from gt4py.eve.extended_typing import Iterable, Optional, Union +from gt4py.eve.extended_typing import Callable, Iterable, Optional, Union from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts -TypeOrTypeSynthesizer = Union[ts.TypeSpec, "TypeSynthesizer"] - -#: dictionary from name of a builtin to its type synthesizer -builtin_type_synthesizers: dict[str, "TypeSynthesizer"] = {} - - -def _is_derefable_iterator_type(it_type: it_ts.IteratorType) -> bool: - # for an iterator with unknown position we can not tell if it is derefable, we just assume - # yes here. - if it_type.position_dims == "unknown": - return True - return set(it_type.defined_dims).issubset(set(it_type.position_dims)) - - -def _register_builtin_type_synthesizer( - synthesizer: Optional[xtyping.Callable[..., TypeOrTypeSynthesizer]] = None, - *, - fun_names: Optional[Iterable[str]] = None, -): - def wrapper(synthesizer): - # store names in function object for better debuggability - synthesizer.fun_names = fun_names or [synthesizer.__name__] - for f in synthesizer.fun_names: - builtin_type_synthesizers[f] = TypeSynthesizer(type_synthesizer=synthesizer) - - return wrapper(synthesizer) if synthesizer else wrapper - - @dataclasses.dataclass class TypeSynthesizer: """ @@ -60,9 +34,15 @@ class TypeSynthesizer: In addition to the derivation of the return type a function type-synthesizer can perform checks on the argument types. + + The motivation for this class instead of a simple callable is to allow + - isinstance checks to determine if an object is actually (meant to be) a type + synthesizer and not just any callable. + - writing simple type synthesizers without cluttering the signature with the additional + offset_provider argument that is only needed by some. """ - type_synthesizer: xtyping.Callable[..., TypeOrTypeSynthesizer] + type_synthesizer: Callable[..., TypeOrTypeSynthesizer] def __post_init__(self): if "offset_provider" not in inspect.signature(self.type_synthesizer).parameters: @@ -75,6 +55,34 @@ def __call__( return self.type_synthesizer(*args, offset_provider=offset_provider) +TypeOrTypeSynthesizer = Union[ts.TypeSpec, TypeSynthesizer] + +#: dictionary from name of a builtin to its type synthesizer +builtin_type_synthesizers: dict[str, TypeSynthesizer] = {} + + +def _is_derefable_iterator_type(it_type: it_ts.IteratorType, *, default: bool = True) -> bool: + # for an iterator with unknown position we can not tell if it is derefable, + # so we just return the default + if it_type.position_dims == "unknown": + return default + return set(it_type.defined_dims).issubset(set(it_type.position_dims)) + + +def _register_builtin_type_synthesizer( + synthesizer: Optional[Callable[..., TypeOrTypeSynthesizer]] = None, + *, + fun_names: Optional[Iterable[str]] = None, +): + def wrapper(synthesizer: Callable[..., TypeOrTypeSynthesizer]) -> None: + # store names in function object for better debuggability + synthesizer.fun_names = fun_names or [synthesizer.__name__] # type: ignore[attr-defined] + for f in synthesizer.fun_names: # type: ignore[attr-defined] + builtin_type_synthesizers[f] = TypeSynthesizer(type_synthesizer=synthesizer) + + return wrapper(synthesizer) if synthesizer else wrapper + + @_register_builtin_type_synthesizer( fun_names=itir.UNARY_MATH_NUMBER_BUILTINS | itir.UNARY_MATH_FP_BUILTINS ) @@ -191,17 +199,16 @@ def lift(stencil: TypeSynthesizer) -> TypeSynthesizer: def apply_lift( *its: it_ts.IteratorType, offset_provider: common.OffsetProvider ) -> it_ts.IteratorType: - stencil_args = [] - for it in its: - assert isinstance(it, it_ts.IteratorType) - stencil_args.append( - it_ts.IteratorType( - # the positions are only known when we deref - position_dims="unknown", - defined_dims=it.defined_dims, - element_type=it.element_type, - ) + assert all(isinstance(it, it_ts.IteratorType) for it in its) + stencil_args = [ + it_ts.IteratorType( + # the positions are only known when we deref + position_dims="unknown", + defined_dims=it.defined_dims, + element_type=it.element_type, ) + for it in its + ] stencil_return_type = stencil(*stencil_args, offset_provider=offset_provider) assert isinstance(stencil_return_type, ts.DataType)