Skip to content

Commit

Permalink
Merge pull request #13441 from rapidsai/branch-23.06
Browse files Browse the repository at this point in the history
Forward-merge branch-23.06 to branch-23.08
  • Loading branch information
GPUtester authored May 25, 2023
2 parents 0f0ebfd + 799c3f9 commit 0536a3a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
46 changes: 46 additions & 0 deletions python/cudf/cudf/tests/test_numba_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
import subprocess
import sys

import pytest

IS_CUDA_11 = False
try:
from ptxcompiler.patch import NO_DRIVER, safe_get_versions

versions = safe_get_versions()
if versions != NO_DRIVER:
driver_version, runtime_version = versions
if driver_version < (12, 0):
IS_CUDA_11 = True
except ModuleNotFoundError:
pass

TEST_NUMBA_MVC_ENABLED = """
import numba.cuda
import cudf
from cudf.utils._numba import _CUDFNumbaConfig, _patch_numba_mvc
_patch_numba_mvc()
@numba.cuda.jit
def test_kernel(x):
id = numba.cuda.grid(1)
if id < len(x):
x[id] += 1
s = cudf.Series([1, 2, 3])
with _CUDFNumbaConfig():
test_kernel.forall(len(s))(s)
"""


@pytest.mark.skipif(
not IS_CUDA_11, reason="Minor Version Compatibility test for CUDA 11"
)
def test_numba_mvc_enabled_cuda_11():
cp = subprocess.run(
[sys.executable, "-c", TEST_NUMBA_MVC_ENABLED], capture_output=True
)
assert cp.returncode == 0
30 changes: 25 additions & 5 deletions python/cudf/cudf/utils/_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import glob
import os
import sys
import warnings

from numba import config
from numba import config as numba_config

CC_60_PTX_FILE = os.path.join(
os.path.dirname(__file__), "../core/udf/shim_60.ptx"
Expand Down Expand Up @@ -64,6 +65,25 @@ def _get_ptx_file(path, prefix):
return regular_result[1]


def _patch_numba_mvc():
# Enable the config option for minor version compatibility
numba_config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY = 1

if "numba.cuda" in sys.modules:
# Patch numba for version 0.57.0 MVC support, which must know the
# config value at import time. We cannot guarantee the order of imports
# between cudf and numba.cuda so we patch numba to ensure it has these
# names available.
# See https://github.com/numba/numba/issues/8977 for details.
import numba.cuda
from cubinlinker import CubinLinker, CubinLinkerError
from ptxcompiler import compile_ptx

numba.cuda.cudadrv.driver.compile_ptx = compile_ptx
numba.cuda.cudadrv.driver.CubinLinker = CubinLinker
numba.cuda.cudadrv.driver.CubinLinkerError = CubinLinkerError


def _setup_numba():
"""
Configure the numba linker for use with cuDF. This consists of
Expand Down Expand Up @@ -108,7 +128,7 @@ def _setup_numba():
if (driver_version < ptx_toolkit_version) or (
driver_version < runtime_version
):
config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY = 1
_patch_numba_mvc()


def _get_cuda_version_from_ptx_file(path):
Expand Down Expand Up @@ -164,8 +184,8 @@ def _get_cuda_version_from_ptx_file(path):

class _CUDFNumbaConfig:
def __enter__(self):
self.enter_val = config.CUDA_LOW_OCCUPANCY_WARNINGS
config.CUDA_LOW_OCCUPANCY_WARNINGS = 0
self.enter_val = numba_config.CUDA_LOW_OCCUPANCY_WARNINGS
numba_config.CUDA_LOW_OCCUPANCY_WARNINGS = 0

def __exit__(self, exc_type, exc_value, traceback):
config.CUDA_LOW_OCCUPANCY_WARNINGS = self.enter_val
numba_config.CUDA_LOW_OCCUPANCY_WARNINGS = self.enter_val

0 comments on commit 0536a3a

Please sign in to comment.