diff --git a/libc/utils/gpu/loader/amdgpu/amdhsa-loader.cpp b/libc/utils/gpu/loader/amdgpu/amdhsa-loader.cpp index ca13414519d4ca..cb81a866622f93 100644 --- a/libc/utils/gpu/loader/amdgpu/amdhsa-loader.cpp +++ b/libc/utils/gpu/loader/amdgpu/amdhsa-loader.cpp @@ -28,9 +28,11 @@ #include "hsa/hsa_ext_amd.h" #endif +#include #include #include #include +#include #include #include @@ -289,18 +291,26 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable, __atomic_store_n((uint32_t *)&packet->header, header_word, __ATOMIC_RELEASE); hsa_signal_store_relaxed(queue->doorbell_signal, packet_id); + std::atomic finished = false; + std::thread server( + [](std::atomic *finished, rpc_device_t device) { + while (!*finished) { + if (rpc_status_t err = rpc_handle_server(device)) + handle_error(err); + } + }, + &finished, device); + // Wait until the kernel has completed execution on the device. Periodically // check the RPC client for work to be performed on the server. - while (hsa_signal_wait_scacquire( - packet->completion_signal, HSA_SIGNAL_CONDITION_EQ, 0, - /*timeout_hint=*/1024, HSA_WAIT_STATE_ACTIVE) != 0) - if (rpc_status_t err = rpc_handle_server(device)) - handle_error(err); + while (hsa_signal_wait_scacquire(packet->completion_signal, + HSA_SIGNAL_CONDITION_EQ, 0, UINT64_MAX, + HSA_WAIT_STATE_BLOCKED) != 0) + ; - // Handle the server one more time in case the kernel exited with a pending - // send still in flight. - if (rpc_status_t err = rpc_handle_server(device)) - handle_error(err); + finished = true; + if (server.joinable()) + server.join(); // Destroy the resources acquired to launch the kernel and return. if (hsa_status_t err = hsa_amd_memory_pool_free(args)) diff --git a/libc/utils/gpu/loader/nvptx/nvptx-loader.cpp b/libc/utils/gpu/loader/nvptx/nvptx-loader.cpp index 1b210b9e7a896f..58e5e5f04d0a70 100644 --- a/libc/utils/gpu/loader/nvptx/nvptx-loader.cpp +++ b/libc/utils/gpu/loader/nvptx/nvptx-loader.cpp @@ -20,10 +20,12 @@ #include "llvm/Object/ELF.h" #include "llvm/Object/ELFObjectFile.h" +#include #include #include #include #include +#include #include using namespace llvm; @@ -224,6 +226,16 @@ CUresult launch_kernel(CUmodule binary, CUstream stream, if (print_resource_usage) print_kernel_resources(binary, kernel_name); + std::atomic finished = false; + std::thread server( + [](std::atomic *finished, rpc_device_t device) { + while (!*finished) { + if (rpc_status_t err = rpc_handle_server(device)) + handle_error(err); + } + }, + &finished, rpc_device); + // Call the kernel with the given arguments. if (CUresult err = cuLaunchKernel( function, params.num_blocks_x, params.num_blocks_y, @@ -231,17 +243,13 @@ CUresult launch_kernel(CUmodule binary, CUstream stream, params.num_threads_z, 0, stream, nullptr, args_config)) handle_error(err); - // Wait until the kernel has completed execution on the device. Periodically - // check the RPC client for work to be performed on the server. - while (cuStreamQuery(stream) == CUDA_ERROR_NOT_READY) - if (rpc_status_t err = rpc_handle_server(rpc_device)) - handle_error(err); - - // Handle the server one more time in case the kernel exited with a pending - // send still in flight. - if (rpc_status_t err = rpc_handle_server(rpc_device)) + if (CUresult err = cuStreamSynchronize(stream)) handle_error(err); + finished = true; + if (server.joinable()) + server.join(); + return CUDA_SUCCESS; }