Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug[next]: fix lowering of astype on tuples containing scalars #1642

Merged
merged 3 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,15 +319,19 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
assert len(node.args) == 2 and isinstance(node.args[1], foast.Name)
obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id

def create_cast(expr: itir.Expr) -> itir.FunCall:
return im.as_fieldop(
im.lambda_("__val")(im.call("cast_")(im.deref("__val"), str(new_type)))
)(expr)
def create_cast(expr: itir.Expr, t: ts.TypeSpec) -> itir.FunCall:
if isinstance(t, ts.FieldType):
return im.as_fieldop(
im.lambda_("__val")(im.call("cast_")(im.deref("__val"), str(new_type)))
)(expr)
else:
assert isinstance(t, ts.ScalarType)
return im.call("cast_")(expr, str(new_type))

if not isinstance(node.type, ts.TupleType): # to keep the IR simpler
return create_cast(obj)
return create_cast(obj, node.type)

return lowering_utils.process_elements(create_cast, obj, node.type)
return lowering_utils.process_elements(create_cast, obj, node.type, with_type=True)

def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
if not isinstance(node.type, ts.TupleType): # to keep the IR simpler
Expand Down
13 changes: 11 additions & 2 deletions src/gt4py/next/ffront/lowering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def process_elements(
process_func: Callable[..., itir.Expr],
objs: itir.Expr | Iterable[itir.Expr],
current_el_type: ts.TypeSpec,
with_type: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update to the docstring is missing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

) -> itir.FunCall:
"""
Recursively applies a processing function to all primitive constituents of a tuple.
Expand All @@ -118,7 +119,10 @@ def process_elements(

let_ids = tuple(f"__val_{eve_utils.content_hash(obj)}" for obj in objs)
body = _process_elements_impl(
process_func, tuple(im.ref(let_id) for let_id in let_ids), current_el_type
process_func,
tuple(im.ref(let_id) for let_id in let_ids),
current_el_type,
with_type=with_type,
)

return im.let(*(zip(let_ids, objs, strict=True)))(body)
Expand All @@ -131,6 +135,7 @@ def _process_elements_impl(
process_func: Callable[..., itir.Expr],
_current_el_exprs: Iterable[T],
current_el_type: ts.TypeSpec,
with_type: bool,
) -> itir.Expr:
if isinstance(current_el_type, ts.TupleType):
result = im.make_tuple(
Expand All @@ -141,13 +146,17 @@ def _process_elements_impl(
im.tuple_get(i, current_el_expr) for current_el_expr in _current_el_exprs
),
current_el_type.types[i],
with_type=with_type,
)
for i in range(len(current_el_type.types))
)
)
elif type_info.contains_local_field(current_el_type):
raise NotImplementedError("Processing fields with local dimension is not implemented.")
else:
result = process_func(*_current_el_exprs)
if with_type:
result = process_func(*_current_el_exprs, current_el_type)
else:
result = process_func(*_current_el_exprs)

return result
32 changes: 31 additions & 1 deletion tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from gt4py.next import (
astype,
broadcast,
common,
float32,
float64,
int32,
Expand All @@ -27,7 +28,6 @@
min_over,
neighbor_sum,
where,
common,
)
from gt4py.next.ffront import type_specifications as ts_ffront
from gt4py.next.ffront.ast_passes import single_static_assign as ssa
Expand Down Expand Up @@ -291,6 +291,18 @@ def foo(a: gtx.Field[[TDim], float64]):
assert lowered.expr == reference


def test_astype_scalar():
def foo(a: float64):
return astype(a, int32)

parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)

reference = im.call("cast_")("a", "int32")

assert lowered.expr == reference


def test_astype_tuple():
def foo(a: tuple[gtx.Field[[TDim], float64], gtx.Field[[TDim], float64]]):
return astype(a, int32)
Expand All @@ -311,6 +323,24 @@ def foo(a: tuple[gtx.Field[[TDim], float64], gtx.Field[[TDim], float64]]):
assert lowered_inlined.expr == reference


def test_astype_tuple_scalar_and_field():
def foo(a: tuple[gtx.Field[[TDim], float64], float64]):
return astype(a, int32)

parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)
lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered)

reference = im.make_tuple(
im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))(
im.tuple_get(0, "a")
),
im.call("cast_")(im.tuple_get(1, "a"), "int32"),
)

assert lowered_inlined.expr == reference


def test_astype_nested_tuple():
def foo(
a: tuple[
Expand Down
Loading