Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[next]: Fix usage of DaCe fast-call to SDFG #1656

Merged
merged 37 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
69c9ac4
Fix for DaCe fast_call API
edopao Sep 20, 2024
db97d2b
override all symbols in SDFG call except for connectivities
edopao Sep 23, 2024
087802d
Fix previous commit
edopao Sep 23, 2024
631039b
Fix previous commit (1)
edopao Sep 23, 2024
a51a6fb
Fix previous commit (2)
edopao Sep 23, 2024
9602aa5
Fix previous commit (3)
edopao Sep 23, 2024
468645d
Edit code comments
edopao Sep 23, 2024
681c985
Add test coverage
edopao Sep 23, 2024
b16be89
Revert "Add test coverage"
edopao Sep 23, 2024
357beb0
Revert "Edit code comments"
edopao Sep 23, 2024
07cec75
Revert "Fix previous commit (3)"
edopao Sep 23, 2024
68a2790
Revert "Fix previous commit (2)"
edopao Sep 23, 2024
06eb1f4
Revert "Fix previous commit (1)"
edopao Sep 23, 2024
e4d8a7c
Revert "Fix previous commit"
edopao Sep 23, 2024
ba7ab6b
Revert "override all symbols in SDFG call except for connectivities"
edopao Sep 23, 2024
7360579
Support both np and cp arrays
edopao Sep 23, 2024
e1486a4
code refactoring
edopao Sep 23, 2024
549c91b
Pre-commit
edopao Sep 23, 2024
903ff8b
Pre-commit
edopao Sep 23, 2024
a15613b
Make fast_call configurable
edopao Sep 23, 2024
ccc0775
Update testcase
edopao Sep 23, 2024
d76a6cd
cleanup
edopao Sep 23, 2024
76502f9
Ensure gpu arrays are cupy buffers
edopao Sep 23, 2024
a5cf7db
Merge remote-tracking branch 'origin/main' into dace-fix-fast_call
edopao Sep 24, 2024
3b0a890
Special handling of connectivity tables
edopao Sep 24, 2024
e4834da
Code refactoring
edopao Sep 24, 2024
d5f3d65
Loop through entire arglist, not only the positional arguments
edopao Sep 25, 2024
1ec0e92
remove extra change
edopao Sep 25, 2024
f8b616a
Add test coverage
edopao Sep 25, 2024
ad85f06
Fix gpu test case
edopao Sep 25, 2024
745fe5e
Fix gpu test case (1)
edopao Sep 25, 2024
e854fd6
Relax assertion on first call
edopao Sep 25, 2024
b837aa8
cleanup testcase
edopao Sep 25, 2024
70a1dd5
Address review comments
edopao Sep 26, 2024
f50d89e
Fix mypy error
edopao Sep 26, 2024
c0cbde5
remove unnecessary isinstance check on critical path
edopao Sep 26, 2024
3aea328
Merge remote-tracking branch 'origin/main' into dace-fix-fast_call
edopao Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,9 @@ def _ensure_is_on_device(
return connectivity_arg


def _get_connectivity_args(
neighbor_tables: Mapping[str, gtx_common.NeighborTable], device: dace.dtypes.DeviceType
) -> dict[str, Any]:
return {
dace_util.connectivity_identifier(offset): _ensure_is_on_device(
offset_provider.table, device
)
for offset, offset_provider in neighbor_tables.items()
}


def _get_shape_args(
arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any]
) -> Mapping[str, int]:
arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray]
) -> dict[str, int]:
shape_args: dict[str, int] = {}
for name, value in args.items():
for sym, size in zip(arrays[name].shape, value.shape, strict=True):
Expand All @@ -101,8 +90,8 @@ def _get_shape_args(


def _get_stride_args(
arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any]
) -> Mapping[str, int]:
arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray]
) -> dict[str, int]:
stride_args = {}
for name, value in args.items():
for sym, stride_size in zip(arrays[name].strides, value.strides, strict=True):
Expand All @@ -121,6 +110,27 @@ def _get_stride_args(
return stride_args


def get_sdfg_conn_args(
sdfg: dace.SDFG,
offset_provider: dict[str, Any],
on_gpu: bool,
) -> dict[str, np.typing.NDArray]:
"""
Extracts the connectivity tables that are used in the sdfg and ensures
that the memory buffers are allocated for the target device.
"""
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU

connectivity_args = {}
for offset, connectivity in dace_util.filter_connectivities(offset_provider).items():
assert isinstance(connectivity, gtx_common.NeighborTable)
param = dace_util.connectivity_identifier(offset)
if param in sdfg.arrays:
connectivity_args[param] = _ensure_is_on_device(connectivity.table, device)

return connectivity_args


def get_sdfg_args(
sdfg: dace.SDFG,
*args: Any,
Expand All @@ -138,17 +148,9 @@ def get_sdfg_args(
"""
offset_provider = kwargs["offset_provider"]

neighbor_tables: dict[str, gtx_common.NeighborTable] = {}
for offset, connectivity in dace_util.filter_connectivities(offset_provider).items():
assert isinstance(connectivity, gtx_common.NeighborTable)
neighbor_tables[offset] = connectivity
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU

dace_args = _get_args(sdfg, args, use_field_canonical_representation)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
dace_conn_args = _get_connectivity_args(neighbor_tables, device)
# keep only connectivity tables that are used in the sdfg
dace_conn_args = {n: v for n, v in dace_conn_args.items() if n in sdfg.arrays}
dace_conn_args = get_sdfg_conn_args(sdfg, offset_provider, on_gpu)
dace_shapes = _get_shape_args(sdfg.arrays, dace_field_args)
dace_conn_shapes = _get_shape_args(sdfg.arrays, dace_conn_args)
dace_strides = _get_stride_args(sdfg.arrays, dace_field_args)
Expand Down
11 changes: 10 additions & 1 deletion src/gt4py/next/program_processors/runners/dace_common/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from __future__ import annotations

from typing import Any, Mapping, Optional, Sequence
import re
from typing import Any, Final, Mapping, Optional, Sequence

import dace

Expand All @@ -17,6 +18,10 @@
from gt4py.next.type_system import type_specifications as ts


# regex to match the symbols for field shape and strides
FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile("__.+_(size|stride)_\d+")


def as_scalar_type(typestr: str) -> ts.ScalarType:
"""Obtain GT4Py scalar type from generic numpy string representation."""
try:
Expand All @@ -38,6 +43,10 @@ def field_stride_symbol_name(field_name: str, axis: int) -> str:
return f"__{field_name}_stride_{axis}"


def is_field_symbol(name: str) -> bool:
return FIELD_SYMBOL_RE.match(name) is not None


def debug_info(
node: gtir.Node, *, default: Optional[dace.dtypes.DebugInfo] = None
) -> Optional[dace.dtypes.DebugInfo]:
Expand Down
76 changes: 37 additions & 39 deletions src/gt4py/next/program_processors/runners/dace_common/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import ctypes
import dataclasses
from typing import Any, Optional
from typing import Any

import dace
import factory
Expand All @@ -20,26 +20,23 @@
from gt4py.next import common, config
from gt4py.next.otf import arguments, languages, stages, step_types, workflow
from gt4py.next.otf.compilation import cache
from gt4py.next.program_processors.runners.dace_common import dace_backend
from gt4py.next.program_processors.runners.dace_common import dace_backend, utility as dace_utils


class CompiledDaceProgram(stages.CompiledProgram):
sdfg_program: dace.CompiledSDFG
# Map SDFG argument to its position in program ABI; scalar arguments that are not used in the SDFG will not be present.
sdfg_arg_position: list[Optional[int]]

def __init__(self, program: dace.CompiledSDFG):
# extract position of arguments in program ABI
sdfg_arglist = program.sdfg.signature_arglist(with_types=False)
sdfg_arg_pos_mapping = {param: pos for pos, param in enumerate(sdfg_arglist)}
sdfg_used_symbols = program.sdfg.used_symbols(all_symbols=False)
# Sorted list of SDFG arguments as they appear in program ABI and corresponding data type;
# scalar arguments that are not used in the SDFG will not be present.
sdfg_arglist: list[tuple[str, dace.dtypes.Data]]

def __init__(self, program: dace.CompiledSDFG):
self.sdfg_program = program
self.sdfg_arg_position = [
sdfg_arg_pos_mapping[param]
if param in program.sdfg.arrays or param in sdfg_used_symbols
else None
for param in program.sdfg.arg_names
# `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument
# name to its data type, in the same order as arguments appear in the program ABI.
# This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`.
self.sdfg_arglist = [
(arg_name, arg_type) for arg_name, arg_type in program.sdfg.arglist().items()
]

def __call__(self, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -94,13 +91,6 @@ class Meta:
model = DaCeCompiler


def _get_ctype_value(arg: Any, dtype: dace.dtypes.dataclass) -> Any:
if not isinstance(arg, (ctypes._SimpleCData, ctypes._Pointer)):
actype = dtype.as_ctypes()
return actype(arg)
return arg


def convert_args(
inp: CompiledDaceProgram,
device: core_defs.DeviceType = core_defs.DeviceType.CPU,
Expand All @@ -119,28 +109,36 @@ def decorated_program(
args = (*args, out)
if len(sdfg.arg_names) > len(args):
args = (*args, *arguments.iter_size_args(args))

if sdfg_program._lastargs:
# The scalar arguments should be replaced with the actual value; for field arguments,
# the data pointer should remain the same otherwise fast-call cannot be used and
# the args list needs to be reconstructed.
kwargs = dict(zip(sdfg.arg_names, args, strict=True))
kwargs.update(dace_backend.get_sdfg_conn_args(sdfg, offset_provider, on_gpu))

use_fast_call = True
for arg, param, pos in zip(args, sdfg.arg_names, inp.sdfg_arg_position, strict=True):
if isinstance(arg, common.Field):
desc = sdfg.arrays[param]
assert isinstance(desc, dace.data.Array)
assert isinstance(sdfg_program._lastargs[0][pos], ctypes.c_void_p)
if sdfg_program._lastargs[0][pos].value != get_array_interface_ptr(
arg.ndarray, desc.storage
):
last_call_args = sdfg_program._lastargs[0]
# The scalar arguments should be overridden with the new value; for field arguments,
# the data pointer should remain the same otherwise fast_call cannot be used and
# the arguments list has to be reconstructed.
for i, (arg_name, arg_type) in enumerate(inp.sdfg_arglist):
if isinstance(arg_type, dace.data.Array):
assert arg_name in kwargs, f"Argument '{arg_name}' not found."
data_ptr = get_array_interface_ptr(kwargs[arg_name], arg_type.storage)
assert isinstance(last_call_args[i], ctypes.c_void_p)
if last_call_args[i].value != data_ptr:
use_fast_call = False
break
elif param in sdfg.arrays:
desc = sdfg.arrays[param]
assert isinstance(desc, dace.data.Scalar)
sdfg_program._lastargs[0][pos] = _get_ctype_value(arg, desc.dtype)
elif pos:
sym_dtype = sdfg.symbols[param]
sdfg_program._lastargs[0][pos] = _get_ctype_value(arg, sym_dtype)
else:
assert isinstance(arg_type, dace.data.Scalar)
assert isinstance(last_call_args[i], ctypes._SimpleCData)
if arg_name in kwargs:
edopao marked this conversation as resolved.
Show resolved Hide resolved
# override the scalar value used in previous program call
actype = arg_type.dtype.as_ctypes()
last_call_args[i] = actype(kwargs[arg_name])
else:
# shape and strides of arrays are supposed not to change, and can therefore be omitted
assert dace_utils.is_field_symbol(
arg_name
), f"Argument '{arg_name}' not found."

if use_fast_call:
return sdfg_program.fast_call(*sdfg_program._lastargs)
Expand Down
Loading
Loading