Skip to content

Commit

Permalink
bug[next]: Fix reverse operators in embedded (#1467)
Browse files Browse the repository at this point in the history
The reverse operators where not swapping their arguments.

E.g. `1.0 - field` did `field - 1.0`.
  • Loading branch information
havogt committed Feb 26, 2024
1 parent d9004cd commit f59dc2c
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 34 deletions.
29 changes: 21 additions & 8 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,16 @@
jnp: Optional[ModuleType] = None # type:ignore[no-redef]


def _make_builtin(builtin_name: str, array_builtin_name: str) -> Callable[..., NdArrayField]:
def _make_builtin(
builtin_name: str, array_builtin_name: str, reverse=False
) -> Callable[..., NdArrayField]:
def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
first = fields[0]
assert isinstance(first, NdArrayField)
first = None
for f in fields:
if isinstance(f, NdArrayField):
first = f
break
assert first is not None
xp = first.__class__.array_ns
op = getattr(xp, array_builtin_name)

Expand All @@ -67,7 +73,8 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
else:
assert core_defs.is_scalar_type(f)
transformed.append(f)

if reverse:
transformed.reverse()
new_data = op(*transformed)
return first.__class__.from_array(new_data, domain=domain_intersection)

Expand Down Expand Up @@ -248,17 +255,21 @@ def __setitem__(

__pos__ = _make_builtin("pos", "positive")

__sub__ = __rsub__ = _make_builtin("sub", "subtract")
__sub__ = _make_builtin("sub", "subtract")
__rsub__ = _make_builtin("sub", "subtract", reverse=True)

__mul__ = __rmul__ = _make_builtin("mul", "multiply")

__truediv__ = __rtruediv__ = _make_builtin("div", "divide")
__truediv__ = _make_builtin("div", "divide")
__rtruediv__ = _make_builtin("div", "divide", reverse=True)

__floordiv__ = __rfloordiv__ = _make_builtin("floordiv", "floor_divide")
__floordiv__ = _make_builtin("floordiv", "floor_divide")
__rfloordiv__ = _make_builtin("floordiv", "floor_divide", reverse=True)

__pow__ = _make_builtin("pow", "power")

__mod__ = __rmod__ = _make_builtin("mod", "mod")
__mod__ = _make_builtin("mod", "mod")
__rmod__ = _make_builtin("mod", "mod", reverse=True)

__ne__ = _make_builtin("not_equal", "not_equal") # type: ignore # mypy wants return `bool`

Expand Down Expand Up @@ -620,6 +631,8 @@ def __setitem__(


def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field:
if field.domain.dims == new_dimensions:
return field
domain_slice: list[slice | None] = []
named_ranges = []
for dim in new_dimensions:
Expand Down
89 changes: 63 additions & 26 deletions tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import pytest

from gt4py._core import definitions as core_defs
from gt4py.next import common
from gt4py.next.common import Dimension, Domain, UnitRange
from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field
Expand Down Expand Up @@ -69,9 +70,14 @@ def unary_logical_op(request):
yield request.param


def _make_field(lst: Iterable, nd_array_implementation, *, domain=None, dtype=None):
def _make_field_or_scalar(
lst: Iterable | core_defs.Scalar, nd_array_implementation, *, domain=None, dtype=None
):
"""Creates a field from an Iterable or returns a scalar."""
if not dtype:
dtype = nd_array_implementation.float32
dtype = np.float32
if isinstance(lst, core_defs.SCALAR_TYPES):
return dtype(lst)
buffer = nd_array_implementation.asarray(lst, dtype=dtype)
if domain is None:
domain = tuple(
Expand All @@ -83,6 +89,18 @@ def _make_field(lst: Iterable, nd_array_implementation, *, domain=None, dtype=No
)


def _np_asarray_or_scalar(value: Iterable | core_defs.Scalar, dtype=None):
"""Creates a numpy array from an Iterable or returns a scalar."""
if not dtype:
dtype = np.float32

return (
dtype(value)
if isinstance(value, core_defs.SCALAR_TYPES)
else np.asarray(value, dtype=dtype)
)


@pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data())
def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementation):
if builtin_name == "gamma":
Expand All @@ -94,7 +112,7 @@ def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementati

expected = ref_impl(*[np.asarray(inp, dtype=np.float32) for inp in inputs])

field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs]
field_inputs = [_make_field_or_scalar(inp, nd_array_implementation) for inp in inputs]

builtin = getattr(fbuiltins, builtin_name)
result = builtin(*field_inputs)
Expand All @@ -107,7 +125,9 @@ def test_where_builtin(nd_array_implementation):
true_ = np.asarray([1.0, 2.0], dtype=np.float32)
false_ = np.asarray([3.0, 4.0], dtype=np.float32)

field_inputs = [_make_field(inp, nd_array_implementation) for inp in [cond, true_, false_]]
field_inputs = [
_make_field_or_scalar(inp, nd_array_implementation) for inp in [cond, true_, false_]
]
expected = np.where(cond, true_, false_)

result = fbuiltins.where(*field_inputs)
Expand Down Expand Up @@ -147,37 +167,54 @@ def test_where_builtin_with_tuple(nd_array_implementation):
expected0 = np.where(cond, true0, false0)
expected1 = np.where(cond, true1, false1)

cond_field = _make_field(cond, nd_array_implementation, dtype=bool)
field_true = tuple(_make_field(inp, nd_array_implementation) for inp in [true0, true1])
field_false = tuple(_make_field(inp, nd_array_implementation) for inp in [false0, false1])
cond_field = _make_field_or_scalar(cond, nd_array_implementation, dtype=bool)
field_true = tuple(
_make_field_or_scalar(inp, nd_array_implementation) for inp in [true0, true1]
)
field_false = tuple(
_make_field_or_scalar(inp, nd_array_implementation) for inp in [false0, false1]
)

result = fbuiltins.where(cond_field, field_true, field_false)
assert np.allclose(result[0].ndarray, expected0)
assert np.allclose(result[1].ndarray, expected1)


def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation):
inp_a = [-1.0, 4.2, 42]
inp_b = [2.0, 3.0, -3.0]
inputs = [inp_a, inp_b]
@pytest.mark.parametrize(
"lhs, rhs",
[
([-1.0, 4.2, 42], [2.0, 3.0, -3.0]),
(2.0, [2.0, 3.0, -3.0]), # scalar with field, tests reverse operators
],
)
def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation, lhs, rhs):
inputs = [lhs, rhs]

expected = binary_arithmetic_op(*[np.asarray(inp, dtype=np.float32) for inp in inputs])
expected = binary_arithmetic_op(*[_np_asarray_or_scalar(inp) for inp in inputs])

field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs]
field_inputs = [_make_field_or_scalar(inp, nd_array_implementation) for inp in inputs]

result = binary_arithmetic_op(*field_inputs)

assert np.allclose(result.ndarray, expected)


def test_binary_logical_ops(binary_logical_op, nd_array_implementation):
inp_a = [True, True, False, False]
inp_b = [True, False, True, False]
inputs = [inp_a, inp_b]
@pytest.mark.parametrize(
"lhs, rhs",
[
([True, True, False, False], [True, False, True, False]),
(True, [True, False]),
(False, [True, False]),
],
)
def test_binary_logical_ops(binary_logical_op, nd_array_implementation, lhs, rhs):
inputs = [lhs, rhs]

expected = binary_logical_op(*[np.asarray(inp) for inp in inputs])
expected = binary_logical_op(*[_np_asarray_or_scalar(inp, dtype=bool) for inp in inputs])

field_inputs = [_make_field(inp, nd_array_implementation, dtype=bool) for inp in inputs]
field_inputs = [
_make_field_or_scalar(inp, nd_array_implementation, dtype=bool) for inp in inputs
]

result = binary_logical_op(*field_inputs)

Expand All @@ -192,7 +229,7 @@ def test_unary_logical_ops(unary_logical_op, nd_array_implementation):

expected = unary_logical_op(np.asarray(inp))

field_input = _make_field(inp, nd_array_implementation, dtype=bool)
field_input = _make_field_or_scalar(inp, nd_array_implementation, dtype=bool)

result = unary_logical_op(field_input)

Expand All @@ -204,7 +241,7 @@ def test_unary_arithmetic_ops(unary_arithmetic_op, nd_array_implementation):

expected = unary_arithmetic_op(np.asarray(inp, dtype=np.float32))

field_input = _make_field(inp, nd_array_implementation)
field_input = _make_field_or_scalar(inp, nd_array_implementation)

result = unary_arithmetic_op(field_input)

Expand Down Expand Up @@ -255,8 +292,8 @@ def test_mixed_fields(product_nd_array_implementation):

expected = np.asarray(inp_a) + np.asarray(inp_b)

field_inp_a = _make_field(inp_a, first_impl)
field_inp_b = _make_field(inp_b, second_impl)
field_inp_a = _make_field_or_scalar(inp_a, first_impl)
field_inp_b = _make_field_or_scalar(inp_b, second_impl)

result = field_inp_a + field_inp_b
assert np.allclose(result.ndarray, expected)
Expand All @@ -273,9 +310,9 @@ def fma(a: common.Field, b: common.Field, c: common.Field, /) -> common.Field:

expected = np.asarray(inp_a) * np.asarray(inp_b) + np.asarray(inp_c)

field_inp_a = _make_field(inp_a, np)
field_inp_b = _make_field(inp_b, np)
field_inp_c = _make_field(inp_c, np)
field_inp_a = _make_field_or_scalar(inp_a, np)
field_inp_b = _make_field_or_scalar(inp_b, np)
field_inp_c = _make_field_or_scalar(inp_c, np)

result = fma(field_inp_a, field_inp_b, field_inp_c)
assert np.allclose(result.ndarray, expected)
Expand Down

0 comments on commit f59dc2c

Please sign in to comment.