Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compact with the function select by cuGetProcAddress #32

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

VincentLeeMax
Copy link
Contributor

when cuda==11.3, running whith https://github.com/NVIDIA/cuda-samples/tree/v11.3/Samples/reduction, cudaMalloc will meet cudaErrorDeviceUninitialized error. Update the corresponding cuda_library_entry function to the function returned by cuGetProcAddress will fix it.
image

@mYmNeo
Copy link
Contributor

mYmNeo commented Nov 16, 2022

The problem may be the missing entry cuGetProcAddress of cuda_entry_enum_t in cuda-helper.h. The load_necessary_data didn't load real cuGetProcAddress, so caused the panic in #20.

@VincentLeeMax
Copy link
Contributor Author

VincentLeeMax commented Nov 16, 2022

The problem may be the missing entry cuGetProcAddress of cuda_entry_enum_t in cuda-helper.h. The load_necessary_data didn't load real cuGetProcAddress, so caused the panic in #20.

Thanks for replaying.

I think it's not the same problem with #20. Calling cudaMalloc still got cudaErrorDeviceUninitialized after initialization, see reduction.log(commit 72e0115).

After adding some log to cuGetProcAddress, I found that cuGetProcAddress return a different cuda version(3020) cudaMalloc comparing with the one stored in cuda_library_entry(2000).

CUresult cuGetProcAddress(const char *symbol, void **pfn, int cudaVersion,
                          cuuint64_t flags) {
  CUresult ret;
  int i;

  load_necessary_data();
  if (!is_custom_config_path()) {
    pthread_once(&g_register_set, register_to_remote);
  }
  pthread_once(&g_init_set, initialization);
  if (!strcmp(symbol, "cuMemAlloc")) {
      LOGGER(1, "%s call version: %d.", symbol, cudaVersion);
      ret = CUDA_ENTRY_CALL(cuda_library_entry, cuGetProcAddress, symbol, pfn,
                            3020, flags);
      LOGGER(1, "cudaVersion 3020, cudaMalloc function ptr: %d.", *pfn);
      ret = CUDA_ENTRY_CALL(cuda_library_entry, cuGetProcAddress, symbol, pfn,
                            2000, flags);
      LOGGER(1, "cudaVersion 2000, cudaMalloc function ptr: %d.", *pfn);
      entry_t entry = cuda_library_entry[CUDA_ENTRY_ENUM(cuMemAlloc)];
      LOGGER(1, "in cuda_library_entry, %s function ptr: %d.", entry.name, entry.fn_ptr);
  }
  ret = CUDA_ENTRY_CALL(cuda_library_entry, cuGetProcAddress, symbol, pfn,
                        cudaVersion, flags);
  if (ret == CUDA_SUCCESS) {
    for (i = 0; i < cuda_hook_nums; i++) {
      if (!strcmp(symbol, cuda_hooks_entry[i].name)) {
        LOGGER(5, "Match hook %s", symbol);
        LOGGER(1, "%s call version: %d.", symbol, cudaVersion);
        *pfn = cuda_hooks_entry[i].fn_ptr;
        break;
      }
    }
  }

  return ret;
}

image

Since we redirect the function in cuda_hooks_entry and cuGetProcAddress may request a function of different cuda version, eg cuLaunchKernel, we should update the cuda_library_entry for the next call in the redirect function.

image

@hzliangbin
Copy link
Contributor

@VincentLeeMax @mYmNeo According to nvdia docs,

  • The base name of the driver API function to look for. As an example, for the driver API cuMemAlloc_v2, symbol would be cuMemAlloc and cudaVersion would be the ABI compatible CUDA version for the _v2 variant.

cuGetProcAddress should add some logic to deal with version, if version comes with 3020, it should selecte cuMemAlloc_v2 instead cuMemAlloc, although symbol is still the cuMemAlloc.

@@ -640,6 +639,7 @@ CUresult cuGetProcAddress(const char *symbol, void **pfn, int cudaVersion,
for (i = 0; i < cuda_hook_nums; i++) {
if (!strcmp(symbol, cuda_hooks_entry[i].name)) {
LOGGER(5, "Match hook %s", symbol);
cuda_library_entry[cuda_hooks_entry[i].library_index].fn_ptr = *pfn;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'd better to match the function entry using both symbol and version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants