Skip to content

Commit

Permalink
feat[next][dace]: Remove offsets in connectivity arrays (#1460)
Browse files Browse the repository at this point in the history
Remove generation of offset symbols for connectivity arrays.
  • Loading branch information
edopao committed Feb 19, 2024
1 parent e631c7f commit ba353d3
Showing 1 changed file with 15 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,23 @@ def get_shape_args(
return shape_args


def get_offset_args(
sdfg: dace.SDFG,
args: Sequence[Any],
) -> Mapping[str, int]:
def get_offset_args(sdfg: dace.SDFG, args: Sequence[Any]) -> Mapping[str, int]:
sdfg_arrays: Mapping[str, dace.data.Array] = sdfg.arrays
sdfg_params: Sequence[str] = sdfg.arg_names
field_args = {param: arg for param, arg in zip(sdfg_params, args) if common.is_field(arg)}

# assume that arrays for connectivity tables do not use offset
assert all(
drange.start == 0
for sdfg_param, arg in field_args.items()
if sdfg_param.startswith("__connectivity")
for drange in arg.domain.ranges
)

return {
str(sym): -drange.start
for sdfg_param, arg in zip(sdfg_params, args)
if common.is_field(arg)
for sdfg_param, arg in field_args.items()
if not sdfg_param.startswith("__connectivity")
for sym, drange in zip(sdfg_arrays[sdfg_param].offset, get_sorted_dim_ranges(arg.domain))
}

Expand Down Expand Up @@ -331,6 +338,8 @@ def build_sdfg_from_itir(
symbols: dict[str, int] = {}
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU
sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu)
elif on_gpu:
autoopt.apply_gpu_storage(sdfg)

if on_gpu:
sdfg.apply_gpu_transformations()
Expand Down

0 comments on commit ba353d3

Please sign in to comment.