Skip to content

Commit

Permalink
[Runtime] Dynamically load cuTensorMapEncodeTiled (#4330)
Browse files Browse the repository at this point in the history
That is only present in CUDA-12 compatible drivers, and is missing in
CUDA-11 ones

Spiritual follow up after
#2771 allows for dynamic query
of the symbol and if run on an older driver, it will return an error.
Also, fix `occupancyMaxActiveClusters` behavior when symbol is not found
(before this change it would crash with null pointer deref, now it
should return a structured exception)
  • Loading branch information
malfet authored Jul 16, 2024
1 parent 7c42f6b commit f9f2960
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions third_party/nvidia/backend/driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
int *numClusters, CUfunction func, const CUlaunchConfig *config);

typedef CUresult (*cuTensorMapEncodeTiled_t)(
CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim,
const cuuint64_t *globalStrides, const cuuint32_t *boxDim,
const cuuint32_t *elementStrides, CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
CUtensorMapFloatOOBfill oobFill);

#define defineGetFunctionHandle(name, symbolName) \
static symbolName##_t name() { \
/* Open the shared library */ \
Expand All @@ -168,6 +176,9 @@ typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
cuOccupancyMaxActiveClusters);

defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
cuTensorMapEncodeTiled);

static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1,
maxActiveClusters = -1;
Expand Down Expand Up @@ -206,6 +217,9 @@ static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL;
if (cuOccupancyMaxActiveClusters == NULL) {
cuOccupancyMaxActiveClusters = getCuOccupancyMaxActiveClustersHandle();
if (cuOccupancyMaxActiveClusters == NULL) {
return NULL;
}
}

Py_BEGIN_ALLOW_THREADS;
Expand Down Expand Up @@ -288,6 +302,13 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) {
}
assert((elementSize * tensorDim) >= 32 && "block size too small.");
int rank = 1;
static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
if (cuTensorMapEncodeTiled == NULL) {
cuTensorMapEncodeTiled = getCuTensorMapEncodeTiledHandle();
if (cuTensorMapEncodeTiled == NULL) {
return NULL;
}
}
CUresult result = cuTensorMapEncodeTiled(
(CUtensorMap *)desc_address, type, rank, (void *)global_address, dims,
globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
Expand Down

0 comments on commit f9f2960

Please sign in to comment.