Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature[next]: Nested scalars args & cleanup #1540

Merged
merged 79 commits into from
Sep 16, 2024

Conversation

tehrengruber
Copy link
Contributor

@tehrengruber tehrengruber commented May 2, 2024

  • Cleanup handling of tuple arguments in GTFN backend. This enables using scalar arguments in both the domain expression of a program and as an argument to a field operator simultaneously. Scalar arguments to an applied as_fieldop are detected using the new type inference.
  • Adds support for nested scalar args in the GTFN backend
  • Adds support for nested scalar & field args in GTFN and Embedded

" tuple) need to have the same shape and dimensions."
)
size_args.extend(shape if shape else [None] * len(dims))
if shapes_and_dims: # scalar or zero-dim field otherwise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean we allow writing to a scalar?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the implementation _field_constituents_shape_and_dims, it seems the comment is wrong as 0d fields probably returns a tuple of empty tuples.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean we allow writing to a scalar?

Not sure what motivated this question. This code block only extracts size arguments, the change here merely avoids errors for cases like (scalar, (field1, field2)). I've added a comment that explains this a little more.

Looking at the implementation _field_constituents_shape_and_dims, it seems the comment is wrong as 0d fields probably returns a tuple of empty tuples.

The comment was alright though not helpful^^, I've fixed _field_constituents_shape_and_dims which was wrong and added a test test_zero_dim_tuple_arg.

.filter(lambda dims: len(dims) > 0)
.to_list()
)
if len(fields_dims) > 0: # param has no field constituent otherwise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is a "field constituent"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A constituent of a composite with type field, e.g. here (scalar, (field, scalar)) the composite is a tuple and field is a field constituent. I have introduced the nomenclature composite and constituent as a generalization of tuples and structs where the constituents are the elements and members respectively. Given the amount of question marks this caused we should probably revisit this naming.

@@ -676,11 +676,23 @@ def _is_concrete_position(pos: Position) -> TypeGuard[ConcretePosition]:

def _get_axes(
field_or_tuple: LocatedField | tuple,
*,
ignore_zero_dims=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the PR it's not obvious to me how this is used. Please add an itir test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added doctests here. The motivation is similar as in the frontend for (scalar, (field1, field2)) we sometimes only want to check that field1 and field2 have the same dimensions, but don't care about scalar.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the changes in this file related to the newly added tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they are required for "mixed" tuple args, e.g. an argument with type (scalar, (field, scalar)). This is the code path in make_in_iterator where this function is called.

src/gt4py/next/otf/binding/nanobind.py Show resolved Hide resolved
for input_ in inputs:
lowered_input = self.visit(input_, **kwargs)

def _convert_input(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • unclear why this should be a closure
  • _convert_input -> better name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored this section and the convert_output case below to use a common utility function _process_elements.

raise ValueError("Expected 'SymRef' or 'make_tuple' in output argument.")
lowered_output = self.visit(node)

def _convert_output(el_type: ts.ScalarType | ts.FieldType, path: tuple[int, ...]) -> Expr:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment at _convert_input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's discuss this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above.

@tehrengruber tehrengruber changed the base branch from itir_type_inference to main July 31, 2024 07:00
@@ -61,7 +61,7 @@ def env_flag_to_bool(name: str, default: bool) -> bool:
#: Master debug flag
#: Changes defaults for all the other options to be as helpful for debugging as possible.
#: Does not override values set in environment variables.
DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=False)
DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What motivated you to change this here instead of setting the env var? Anyway, this is a reminder to revert.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't figured out how to have default env vars for test execution in a project in PyCharm. Reverted.

Comment on lines 107 to 113
if dims:
assert hasattr(arg, "shape") and len(arg.shape) == len(dims)
yield (arg.shape, dims)
else:
yield (tuple(), dims)
pass
case ts.ScalarType():
yield (tuple(), [])
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This constructs looks weird to me, I guess it's equivalent to replacing pass by return which might be slightly more expressive, but still ugly. Alternatives could be yield from []

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used yield from [] now.

src/gt4py/next/ffront/past_to_itir.py Outdated Show resolved Hide resolved
@@ -676,11 +676,23 @@ def _is_concrete_position(pos: Position) -> TypeGuard[ConcretePosition]:

def _get_axes(
field_or_tuple: LocatedField | tuple,
*,
ignore_zero_dims=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the changes in this file related to the newly added tests?

src/gt4py/next/otf/binding/nanobind.py Show resolved Hide resolved
@@ -895,7 +935,7 @@ def deref(self) -> Any:

assert self.pos is not None
shifted_pos = self.pos.copy()
axes = _get_axes(self.field)
axes = _get_axes(self.field, ignore_zero_dims=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO:

  1. write iterator tests
  2. check if (vertex_field, vertex_k_field) should also be valid here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I've added two tests.
  2. I've added two test named test_tuple_arg_with_unpromotable_dims and test_scalar_arg_with_field. The implementation is beyond the scope of this PR.

Copy link
Contributor

@havogt havogt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 remarks

Comment on lines +100 to +107
if (
isinstance(expr, FunCall)
and isinstance(expr.fun, SymRef)
and expr.fun.id == "tuple_get"
and len(expr.args) == 2
and _is_ref_or_tuple_expr_of_ref(expr.args[1])
):
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this case tested?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is indeed tested, the _process_elements creates such gtfn_ir expressions for example in the test_multicopy case.

tests/next_tests/integration_tests/cases.py Outdated Show resolved Hide resolved
@tehrengruber tehrengruber merged commit 15baa37 into GridTools:main Sep 16, 2024
31 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants