diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index 71e66ca771..7063faee16 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -73,20 +73,9 @@ def _ensure_is_on_device( return connectivity_arg -def _get_connectivity_args( - neighbor_tables: Mapping[str, gtx_common.NeighborTable], device: dace.dtypes.DeviceType -) -> dict[str, Any]: - return { - dace_util.connectivity_identifier(offset): _ensure_is_on_device( - offset_provider.table, device - ) - for offset, offset_provider in neighbor_tables.items() - } - - def _get_shape_args( - arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any] -) -> Mapping[str, int]: + arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray] +) -> dict[str, int]: shape_args: dict[str, int] = {} for name, value in args.items(): for sym, size in zip(arrays[name].shape, value.shape, strict=True): @@ -101,8 +90,8 @@ def _get_shape_args( def _get_stride_args( - arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any] -) -> Mapping[str, int]: + arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray] +) -> dict[str, int]: stride_args = {} for name, value in args.items(): for sym, stride_size in zip(arrays[name].strides, value.strides, strict=True): @@ -121,6 +110,27 @@ def _get_stride_args( return stride_args +def get_sdfg_conn_args( + sdfg: dace.SDFG, + offset_provider: dict[str, Any], + on_gpu: bool, +) -> dict[str, np.typing.NDArray]: + """ + Extracts the connectivity tables that are used in the sdfg and ensures + that the memory buffers are allocated for the target device. + """ + device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU + + connectivity_args = {} + for offset, connectivity in dace_util.filter_connectivities(offset_provider).items(): + assert isinstance(connectivity, gtx_common.NeighborTable) + param = dace_util.connectivity_identifier(offset) + if param in sdfg.arrays: + connectivity_args[param] = _ensure_is_on_device(connectivity.table, device) + + return connectivity_args + + def get_sdfg_args( sdfg: dace.SDFG, *args: Any, @@ -138,17 +148,9 @@ def get_sdfg_args( """ offset_provider = kwargs["offset_provider"] - neighbor_tables: dict[str, gtx_common.NeighborTable] = {} - for offset, connectivity in dace_util.filter_connectivities(offset_provider).items(): - assert isinstance(connectivity, gtx_common.NeighborTable) - neighbor_tables[offset] = connectivity - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU - dace_args = _get_args(sdfg, args, use_field_canonical_representation) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} - dace_conn_args = _get_connectivity_args(neighbor_tables, device) - # keep only connectivity tables that are used in the sdfg - dace_conn_args = {n: v for n, v in dace_conn_args.items() if n in sdfg.arrays} + dace_conn_args = get_sdfg_conn_args(sdfg, offset_provider, on_gpu) dace_shapes = _get_shape_args(sdfg.arrays, dace_field_args) dace_conn_shapes = _get_shape_args(sdfg.arrays, dace_conn_args) dace_strides = _get_stride_args(sdfg.arrays, dace_field_args) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index c8f9e37a6b..a892040303 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -8,7 +8,8 @@ from __future__ import annotations -from typing import Any, Mapping, Optional, Sequence +import re +from typing import Any, Final, Mapping, Optional, Sequence import dace @@ -17,6 +18,10 @@ from gt4py.next.type_system import type_specifications as ts +# regex to match the symbols for field shape and strides +FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile("__.+_(size|stride)_\d+") + + def as_scalar_type(typestr: str) -> ts.ScalarType: """Obtain GT4Py scalar type from generic numpy string representation.""" try: @@ -38,6 +43,10 @@ def field_stride_symbol_name(field_name: str, axis: int) -> str: return f"__{field_name}_stride_{axis}" +def is_field_symbol(name: str) -> bool: + return FIELD_SYMBOL_RE.match(name) is not None + + def debug_info( node: gtir.Node, *, default: Optional[dace.dtypes.DebugInfo] = None ) -> Optional[dace.dtypes.DebugInfo]: diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py index dbe2b70ff8..a772ed0745 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_common/workflow.py @@ -10,7 +10,7 @@ import ctypes import dataclasses -from typing import Any, Optional +from typing import Any import dace import factory @@ -20,26 +20,23 @@ from gt4py.next import common, config from gt4py.next.otf import arguments, languages, stages, step_types, workflow from gt4py.next.otf.compilation import cache -from gt4py.next.program_processors.runners.dace_common import dace_backend +from gt4py.next.program_processors.runners.dace_common import dace_backend, utility as dace_utils class CompiledDaceProgram(stages.CompiledProgram): sdfg_program: dace.CompiledSDFG - # Map SDFG argument to its position in program ABI; scalar arguments that are not used in the SDFG will not be present. - sdfg_arg_position: list[Optional[int]] - def __init__(self, program: dace.CompiledSDFG): - # extract position of arguments in program ABI - sdfg_arglist = program.sdfg.signature_arglist(with_types=False) - sdfg_arg_pos_mapping = {param: pos for pos, param in enumerate(sdfg_arglist)} - sdfg_used_symbols = program.sdfg.used_symbols(all_symbols=False) + # Sorted list of SDFG arguments as they appear in program ABI and corresponding data type; + # scalar arguments that are not used in the SDFG will not be present. + sdfg_arglist: list[tuple[str, dace.dtypes.Data]] + def __init__(self, program: dace.CompiledSDFG): self.sdfg_program = program - self.sdfg_arg_position = [ - sdfg_arg_pos_mapping[param] - if param in program.sdfg.arrays or param in sdfg_used_symbols - else None - for param in program.sdfg.arg_names + # `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument + # name to its data type, in the same order as arguments appear in the program ABI. + # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. + self.sdfg_arglist = [ + (arg_name, arg_type) for arg_name, arg_type in program.sdfg.arglist().items() ] def __call__(self, *args: Any, **kwargs: Any) -> None: @@ -94,13 +91,6 @@ class Meta: model = DaCeCompiler -def _get_ctype_value(arg: Any, dtype: dace.dtypes.dataclass) -> Any: - if not isinstance(arg, (ctypes._SimpleCData, ctypes._Pointer)): - actype = dtype.as_ctypes() - return actype(arg) - return arg - - def convert_args( inp: CompiledDaceProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU, @@ -119,28 +109,36 @@ def decorated_program( args = (*args, out) if len(sdfg.arg_names) > len(args): args = (*args, *arguments.iter_size_args(args)) + if sdfg_program._lastargs: - # The scalar arguments should be replaced with the actual value; for field arguments, - # the data pointer should remain the same otherwise fast-call cannot be used and - # the args list needs to be reconstructed. + kwargs = dict(zip(sdfg.arg_names, args, strict=True)) + kwargs.update(dace_backend.get_sdfg_conn_args(sdfg, offset_provider, on_gpu)) + use_fast_call = True - for arg, param, pos in zip(args, sdfg.arg_names, inp.sdfg_arg_position, strict=True): - if isinstance(arg, common.Field): - desc = sdfg.arrays[param] - assert isinstance(desc, dace.data.Array) - assert isinstance(sdfg_program._lastargs[0][pos], ctypes.c_void_p) - if sdfg_program._lastargs[0][pos].value != get_array_interface_ptr( - arg.ndarray, desc.storage - ): + last_call_args = sdfg_program._lastargs[0] + # The scalar arguments should be overridden with the new value; for field arguments, + # the data pointer should remain the same otherwise fast_call cannot be used and + # the arguments list has to be reconstructed. + for i, (arg_name, arg_type) in enumerate(inp.sdfg_arglist): + if isinstance(arg_type, dace.data.Array): + assert arg_name in kwargs, f"Argument '{arg_name}' not found." + data_ptr = get_array_interface_ptr(kwargs[arg_name], arg_type.storage) + assert isinstance(last_call_args[i], ctypes.c_void_p) + if last_call_args[i].value != data_ptr: use_fast_call = False break - elif param in sdfg.arrays: - desc = sdfg.arrays[param] - assert isinstance(desc, dace.data.Scalar) - sdfg_program._lastargs[0][pos] = _get_ctype_value(arg, desc.dtype) - elif pos: - sym_dtype = sdfg.symbols[param] - sdfg_program._lastargs[0][pos] = _get_ctype_value(arg, sym_dtype) + else: + assert isinstance(arg_type, dace.data.Scalar) + assert isinstance(last_call_args[i], ctypes._SimpleCData) + if arg_name in kwargs: + # override the scalar value used in previous program call + actype = arg_type.dtype.as_ctypes() + last_call_args[i] = actype(kwargs[arg_name]) + else: + # shape and strides of arrays are supposed not to change, and can therefore be omitted + assert dace_utils.is_field_symbol( + arg_name + ), f"Argument '{arg_name}' not found." if use_fast_call: return sdfg_program.fast_call(*sdfg_program._lastargs) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index dcbab29efc..953491bde3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -10,52 +10,81 @@ import ctypes import unittest -from typing import Any import numpy as np import pytest +import gt4py._core.definitions as core_defs import gt4py.next as gtx -from gt4py.next import int32 from gt4py.next.ffront.fbuiltins import where from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( + E2V, cartesian_case, + unstructured_case, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, + mesh_descriptor, ) +from unittest.mock import patch from . import pytestmark dace = pytest.importorskip("dace") -def get_scalar_values_from_sdfg_args( - args: tuple[list[ctypes._SimpleCData], list[ctypes._SimpleCData]], -) -> list[Any]: - runtime_args, init_args = args - return [ - arg.value for arg in [*runtime_args, *init_args] if not isinstance(arg, ctypes.c_void_p) - ] +def make_mocks(monkeypatch): + # Wrap `compiled_sdfg.CompiledSDFG.fast_call` with mock object + mock_fast_call = unittest.mock.MagicMock() + dace_fast_call = dace.codegen.compiled_sdfg.CompiledSDFG.fast_call + + def mocked_fast_call(self, *args, **kwargs): + mock_fast_call.__call__(*args, **kwargs) + fast_call_result = dace_fast_call(self, *args, **kwargs) + # invalidate all scalar positional arguments to ensure that they are properly set + # next time the SDFG is executed before fast_call + positional_args = set(self.sdfg.arg_names) + sdfg_arglist = self.sdfg.arglist() + for i, (arg_name, arg_type) in enumerate(sdfg_arglist.items()): + if arg_name in positional_args and isinstance(arg_type, dace.data.Scalar): + assert isinstance(self._lastargs[0][i], ctypes.c_int) + self._lastargs[0][i].value = -1 + return fast_call_result + + monkeypatch.setattr(dace.codegen.compiled_sdfg.CompiledSDFG, "fast_call", mocked_fast_call) + + # Wrap `compiled_sdfg.CompiledSDFG._construct_args` with mock object + mock_construct_args = unittest.mock.MagicMock() + dace_construct_args = dace.codegen.compiled_sdfg.CompiledSDFG._construct_args + + def mocked_construct_args(self, *args, **kwargs): + mock_construct_args.__call__(*args, **kwargs) + return dace_construct_args(self, *args, **kwargs) + + monkeypatch.setattr( + dace.codegen.compiled_sdfg.CompiledSDFG, "_construct_args", mocked_construct_args + ) + + return mock_fast_call, mock_construct_args def test_dace_fastcall(cartesian_case, monkeypatch): """Test reuse of SDFG arguments between program calls by means of SDFG fastcall API.""" if not cartesian_case.executor or "dace" not in cartesian_case.executor.__name__: - pytest.skip("DaCe-specific testcase.") + pytest.skip("requires dace backend") @gtx.field_operator def testee( a: cases.IField, a_idx: cases.IField, unused_field: cases.IField, - a0: int32, - a1: int32, - a2: int32, - unused_scalar: int32, + a0: gtx.int32, + a1: gtx.int32, + a2: gtx.int32, + unused_scalar: gtx.int32, ) -> cases.IField: t0 = where(a_idx == 0, a + a0, a) t1 = where(a_idx == 1, t0 + a1, t0) @@ -70,27 +99,7 @@ def testee( unused_field = cases.allocate(cartesian_case, testee, "unused_field")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() - # Wrap `compiled_sdfg.CompiledSDFG.fast_call` with mock object - mock_fast_call = unittest.mock.MagicMock() - mock_fast_call_attr = dace.codegen.compiled_sdfg.CompiledSDFG.fast_call - - def mocked_fast_call(self, *args, **kwargs): - mock_fast_call.__call__(*args, **kwargs) - return mock_fast_call_attr(self, *args, **kwargs) - - monkeypatch.setattr(dace.codegen.compiled_sdfg.CompiledSDFG, "fast_call", mocked_fast_call) - - # Wrap `compiled_sdfg.CompiledSDFG._construct_args` with mock object - mock_construct_args = unittest.mock.MagicMock() - mock_construct_args_attr = dace.codegen.compiled_sdfg.CompiledSDFG._construct_args - - def mocked_construct_args(self, *args, **kwargs): - mock_construct_args.__call__(*args, **kwargs) - return mock_construct_args_attr(self, *args, **kwargs) - - monkeypatch.setattr( - dace.codegen.compiled_sdfg.CompiledSDFG, "_construct_args", mocked_construct_args - ) + mock_fast_call, mock_construct_args = make_mocks(monkeypatch) # Reset mock objects and run/verify GT4Py program def verify_testee(): @@ -111,54 +120,88 @@ def verify_testee(): # On first run, the SDFG arguments will have to be constructed verify_testee() mock_construct_args.assert_called_once() - # here we store the reference to the tuple of arguments passed to `fast_call` on first run and compare on successive runs - fast_call_args = mock_fast_call.call_args.args - # and the scalar values in the order they appear in the program ABI - fast_call_scalar_values = get_scalar_values_from_sdfg_args(fast_call_args) - - def check_one_scalar_arg_changed(prev_scalar_args): - new_scalar_args = get_scalar_values_from_sdfg_args(mock_fast_call.call_args.args) - diff = np.array(new_scalar_args) - np.array(prev_scalar_args) - assert np.count_nonzero(diff) == 1 - - def check_scalar_args_all_same(prev_scalar_args): - new_scalar_args = get_scalar_values_from_sdfg_args(mock_fast_call.call_args.args) - diff = np.array(new_scalar_args) - np.array(prev_scalar_args) - assert np.count_nonzero(diff) == 0 - - def check_pointer_args_all_same(): - for arg, prev in zip(mock_fast_call.call_args.args, fast_call_args, strict=True): - if isinstance(arg, ctypes._Pointer): - assert arg == prev # Now modify the scalar arguments, used and unused ones: reuse previous SDFG arguments for i in range(4): a_offset[i] += 1 verify_testee() mock_construct_args.assert_not_called() - assert mock_fast_call.call_args.args == fast_call_args - check_pointer_args_all_same() - if i < 3: - # same arguments tuple object but one scalar value is changed - check_one_scalar_arg_changed(fast_call_scalar_values) - # update reference scalar values - fast_call_scalar_values = get_scalar_values_from_sdfg_args(fast_call_args) - else: - # unused scalar argument: the symbol is removed from the SDFG arglist and therefore no change - check_scalar_args_all_same(fast_call_scalar_values) # Modify content of current buffer: reuse previous SDFG arguments for buff in (a, unused_field): buff[0] += 1 verify_testee() mock_construct_args.assert_not_called() - # same arguments tuple object and same content - assert mock_fast_call.call_args.args == fast_call_args - check_pointer_args_all_same() - check_scalar_args_all_same(fast_call_scalar_values) # Pass a new buffer, which should trigger reconstruct of SDFG arguments: fastcall API will not be used a = cases.allocate(cartesian_case, testee, "a")() verify_testee() mock_construct_args.assert_called_once() - assert mock_fast_call.call_args.args != fast_call_args + + +def test_dace_fastcall_with_connectivity(unstructured_case, monkeypatch): + """Test reuse of SDFG arguments between program calls by means of SDFG fastcall API.""" + + if not unstructured_case.executor or "dace" not in unstructured_case.executor.__name__: + pytest.skip("requires dace backend") + + connectivity_E2V = unstructured_case.offset_provider["E2V"] + assert isinstance(connectivity_E2V, gtx.common.NeighborTable) + + # check that test connectivities are allocated on host memory + # this is an assumption to test that fast_call cannot be used for gpu tests + assert isinstance(connectivity_E2V.table, np.ndarray) + + @gtx.field_operator + def testee(a: cases.VField) -> cases.EField: + return a(E2V[0]) + + (a,), kwfields = cases.get_default_data(unstructured_case, testee) + numpy_ref = lambda a: a[connectivity_E2V.table[:, 0]] + + mock_fast_call, mock_construct_args = make_mocks(monkeypatch) + + # Reset mock objects and run/verify GT4Py program + def verify_testee(offset_provider): + mock_construct_args.reset_mock() + mock_fast_call.reset_mock() + cases.verify( + unstructured_case, + testee, + a, + **kwfields, + offset_provider=offset_provider, + ref=numpy_ref(a.asnumpy()), + ) + mock_fast_call.assert_called_once() + + if gtx.allocators.is_field_allocator_for( + unstructured_case.executor.allocator, core_defs.DeviceType.CPU + ): + offset_provider = unstructured_case.offset_provider + else: + assert gtx.allocators.is_field_allocator_for( + unstructured_case.executor.allocator, gtx.allocators.CUPY_DEVICE + ) + + import cupy as cp + + # The test connectivities are numpy arrays, by default, and they are copied + # to gpu memory at each program call (see `dace_backend._ensure_is_on_device`), + # therefore fast_call cannot be used (unless cupy reuses the same cupy array + # from the its memory pool, but this behavior is random and unpredictable). + # Here we copy the connectivity to gpu memory, and resuse the same cupy array + # on multiple program calls, in order to ensure that fast_call is used. + offset_provider = { + "E2V": gtx.NeighborTableOffsetProvider( + table=cp.asarray(connectivity_E2V.table), + origin_axis=connectivity_E2V.origin_axis, + neighbor_axis=connectivity_E2V.neighbor_axis, + max_neighbors=connectivity_E2V.max_neighbors, + has_skip_values=connectivity_E2V.has_skip_values, + ) + } + + verify_testee(offset_provider) + verify_testee(offset_provider) + mock_construct_args.assert_not_called()