From ba353d3bf7c5352906d34bf4ab770341401d8938 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 19 Feb 2024 15:31:12 +0100 Subject: [PATCH] feat[next][dace]: Remove offsets in connectivity arrays (#1460) Remove generation of offset symbols for connectivity arrays. --- .../runners/dace_iterator/__init__.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 5a5df5ce14..aba5656192 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -140,16 +140,23 @@ def get_shape_args( return shape_args -def get_offset_args( - sdfg: dace.SDFG, - args: Sequence[Any], -) -> Mapping[str, int]: +def get_offset_args(sdfg: dace.SDFG, args: Sequence[Any]) -> Mapping[str, int]: sdfg_arrays: Mapping[str, dace.data.Array] = sdfg.arrays sdfg_params: Sequence[str] = sdfg.arg_names + field_args = {param: arg for param, arg in zip(sdfg_params, args) if common.is_field(arg)} + + # assume that arrays for connectivity tables do not use offset + assert all( + drange.start == 0 + for sdfg_param, arg in field_args.items() + if sdfg_param.startswith("__connectivity") + for drange in arg.domain.ranges + ) + return { str(sym): -drange.start - for sdfg_param, arg in zip(sdfg_params, args) - if common.is_field(arg) + for sdfg_param, arg in field_args.items() + if not sdfg_param.startswith("__connectivity") for sym, drange in zip(sdfg_arrays[sdfg_param].offset, get_sorted_dim_ranges(arg.domain)) } @@ -331,6 +338,8 @@ def build_sdfg_from_itir( symbols: dict[str, int] = {} device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) + elif on_gpu: + autoopt.apply_gpu_storage(sdfg) if on_gpu: sdfg.apply_gpu_transformations()