Skip to content

Commit

Permalink
Adds compute follows data check to KernelDispatcher
Browse files Browse the repository at this point in the history
    - Added checks for compute follows data compliance to
      kernel compilation.
    - Removed support for __get_item__ in KernelDispatcher
    - Address review comments.
  • Loading branch information
diptorupd committed Oct 26, 2023
1 parent 0cdd1b1 commit 6656dcd
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 43 deletions.
6 changes: 6 additions & 0 deletions numba_dpex/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ def __init__(self, kernel_name, *, usmarray_argnum_list) -> None:
f"usm_ndarray arguments {usmarray_args} were not allocated "
"on the same queue."
)
else:
self.message = (
f'Execution queue for kernel "{kernel_name}" could '
"be deduced using compute follows data programming model. The "
"kernel has no USMNdArray argument."
)
super().__init__(self.message)


Expand Down
90 changes: 47 additions & 43 deletions numba_dpex/experimental/kernel_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,12 @@
from numba_dpex import config, spirv_generator
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.exceptions import (
InvalidKernelLaunchArgsError,
ExecutionQueueInferenceError,
UnsupportedKernelArgumentError,
)
from numba_dpex.core.kernel_interface.indexers import NdRange, Range
from numba_dpex.core.pipelines import kernel_compiler
from numba_dpex.core.types import DpnpNdArray

_KernelLauncherLowerResult = namedtuple(
"_KernelLauncherLowerResult",
["sig", "fndesc", "library", "call_helper"],
)

_KernelModule = namedtuple("_KernelModule", ["kernel_name", "kernel_bitcode"])

_KernelCompileResult = namedtuple(
Expand All @@ -38,6 +32,43 @@


class _KernelCompiler(_FunctionCompiler):
"""A special compiler class used to compile numba_dpex.kernel decorated
functions.
"""

def _check_queue_equivalence_of_args(
self, py_func_name: str, args: [types.Type, ...]
):
"""Evaluates if all DpnpNdArray arguments passed to a kernel function
has the same DpctlSyclQueue type.
Args:
py_func_name (str): Name of the kernel that is being evaluated
args (types.Type, ...]): List of numba inferred types for each
argument passed to the kernel
Raises:
ExecutionQueueInferenceError: If all DpnpNdArray were not allocated
on the same dpctl.SyclQueue
ExecutionQueueInferenceError: If there were not DpnpNdArray
arguments passed to the kernel.
"""
common_queue = None

for arg in args:
if isinstance(arg, DpnpNdArray):
if common_queue is None:
common_queue = arg.queue
elif common_queue != arg.queue:
raise ExecutionQueueInferenceError(
kernel_name=py_func_name, usmarray_argnum_list=[]
)

if common_queue is None:
raise ExecutionQueueInferenceError(
kernel_name=py_func_name, usmarray_argnum_list=None
)

def _compile_to_spirv(
self, kernel_library, kernel_fndesc, kernel_targetctx
):
Expand Down Expand Up @@ -156,9 +187,6 @@ def __init__(
targetoptions["experimental"] = True

self._kernel_name = pyfunc.__name__
self._range = None
self._ndrange = None

self.typingctx = self.targetdescr.typing_context
self.targetctx = self.targetdescr.target_context

Expand All @@ -185,7 +213,7 @@ def __init__(
self._cache = NullCache()
compiler_class = self._impl_kinds[impl_kind]
self._impl_kind = impl_kind
self._compiler = compiler_class(
self._compiler: _KernelCompiler = compiler_class(
pyfunc, self.targetdescr, targetoptions, locals, pipeline_class
)
self._cache_hits = Counter()
Expand Down Expand Up @@ -265,9 +293,14 @@ def cb_llvm(dur):
)
with ev.trigger_event("numba_dpex:compile", data=ev_details):
try:
self._compiler._check_queue_equivalence_of_args(
self._kernel_name, args
)
kcres: _KernelCompileResult = self._compiler.compile(
args, return_type
)
except ExecutionQueueInferenceError as eqie:
raise eqie
except errors.ForceLiteralArg as e:

def folded(args, kws):
Expand All @@ -283,40 +316,11 @@ def folded(args, kws):
return kcres.kernel_module

def __getitem__(self, args):
"""Square-bracket notation for configuring the global_range and
local_range settings when launching a kernel on a SYCL queue.
When a Python function decorated with the @kernel decorator,
is invoked it creates a KernelLauncher object. Calling the
KernelLauncher objects ``__getitem__`` function inturn clones the object
and sets the ``global_range`` and optionally the ``local_range``
attributes with the arguments passed to ``__getitem__``.
Args:
args (tuple): A tuple of tuples that specify the global and
optionally the local range for the kernel execution. If the
argument is a two-tuple of tuple, then it is assumed that both
global and local range options are specified. The first entry is
considered to be the global range and the second the local range.
If only a single tuple value is provided, then the kernel is
launched with only a global range and the local range configuration
is decided by the SYCL runtime.
Returns:
KernelLauncher: A clone of the KernelLauncher object, but with the
global_range and local_range attributes initialized.
"""Square-bracket notation for configuring launch arguments is not
supported.
"""

if isinstance(args, Range):
self._range = args
elif isinstance(args, NdRange):
self._ndrange = args
else:
# FIXME: Improve error message
raise InvalidKernelLaunchArgsError(kernel_name=self._kernel_name)

return self
raise NotImplementedError

def __call__(self, *args, **kw_args):
"""Functor to launch a kernel."""
Expand Down
90 changes: 90 additions & 0 deletions numba_dpex/tests/experimental/test_exec_queue_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0


import dpctl
import dpnp
import pytest

import numba_dpex.experimental as exp_dpex
from numba_dpex import Range
from numba_dpex.core.exceptions import ExecutionQueueInferenceError


@exp_dpex.kernel(
release_gil=False,
no_compile=True,
no_cpython_wrapper=True,
no_cfunc_wrapper=True,
)
def add(a, b, c):
c[0] = b[0] + a[0]


def test_successful_execution_queue_inference():
"""
Tests if KernelDispatcher successfully infers the execution queue for the
kernel.
"""

q = dpctl.SyclQueue()
a = dpnp.ones(100, sycl_queue=q)
b = dpnp.ones_like(a, sycl_queue=q)
c = dpnp.zeros_like(a, sycl_queue=q)
r = Range(100)

# FIXME: This test fails unexpectedly if the NUMBA_CAPTURED_ERRORS is set
# to "new_style"
try:
exp_dpex.call_kernel(add, r, a, b, c)
except:
pytest.fail("Unexpected error when calling kernel")

assert c[0] == b[0] + a[0]


def test_execution_queue_inference_error():
"""
Tests if KernelDispatcher successfully raised ExecutionQueueInferenceError
when dpnp.ndarray arguments do not share the same dpctl.SyclQueue
instance.
"""

q1 = dpctl.SyclQueue()
q2 = dpctl.SyclQueue()
a = dpnp.ones(100, sycl_queue=q1)
b = dpnp.ones_like(a, sycl_queue=q2)
c = dpnp.zeros_like(a, sycl_queue=q1)
r = Range(100)

from numba.core import config

current_captured_error_style = config.CAPTURED_ERRORS
config.CAPTURED_ERRORS = "new_style"

with pytest.raises(ExecutionQueueInferenceError):
exp_dpex.call_kernel(add, r, a, b, c)

config.CAPTURED_ERRORS = current_captured_error_style


def test_error_when_no_array_args():
"""
Tests if KernelDispatcher successfully raised ExecutionQueueInferenceError
when no dpnp.ndarray arguments were passed to a kernel.
"""
a = 1
b = 2
c = 3
r = Range(100)

from numba.core import config

current_captured_error_style = config.CAPTURED_ERRORS
config.CAPTURED_ERRORS = "new_style"

with pytest.raises(ExecutionQueueInferenceError):
exp_dpex.call_kernel(add, r, a, b, c)

config.CAPTURED_ERRORS = current_captured_error_style

0 comments on commit 6656dcd

Please sign in to comment.