Skip to content

Commit

Permalink
[aot] Add stream_ variable for CUDAContext to use a specific CUDA str…
Browse files Browse the repository at this point in the history
…eam to launch CUDA kernel (#8579)

### Brief Summary

copilot:summary

### Walkthrough

copilot:walkthrough

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Routhleck and pre-commit-ci[bot] authored Aug 15, 2024
1 parent 45b3275 commit e9f19b8
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 3 deletions.
7 changes: 7 additions & 0 deletions c_api/include/taichi/taichi_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,17 @@ ti_export_cuda_memory(TiRuntime runtime,
TiMemory memory,
TiCudaMemoryInteropInfo *interop_info);

// Function `ti_import_cuda_memory`
TI_DLL_EXPORT TiMemory TI_API_CALL ti_import_cuda_memory(TiRuntime runtime,
void *ptr,
size_t memory_size);

// Function `ti_set_cuda_stream`
TI_DLL_EXPORT void TI_API_CALL ti_set_cuda_stream(void *stream);

// Function `ti_get_cuda_stream`
TI_DLL_EXPORT void TI_API_CALL ti_get_cuda_stream(void **stream);

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
21 changes: 21 additions & 0 deletions c_api/src/taichi_llvm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#ifdef TI_WITH_CUDA
#include "taichi/rhi/cuda/cuda_device.h"
#include "taichi/rhi/cuda/cuda_context.h"
#include "taichi/runtime/cuda/kernel_launcher.h"
#endif

Expand Down Expand Up @@ -242,4 +243,24 @@ TI_DLL_EXPORT TiMemory TI_API_CALL ti_import_cuda_memory(TiRuntime runtime,
#endif
}

// function.set_cuda_stream
TI_DLL_EXPORT void TI_API_CALL ti_set_cuda_stream(void *stream) {
#ifdef TI_WITH_CUDA
taichi::lang::CUDAContext::get_instance().set_stream(stream);

#else
TI_NOT_IMPLEMENTED;
#endif
}

// function.get_cuda_stream
TI_DLL_EXPORT void TI_API_CALL ti_get_cuda_stream(void **stream) {
#ifdef TI_WITH_CUDA
*stream = taichi::lang::CUDAContext::get_instance().get_stream();
#else
TI_NOT_IMPLEMENTED;

#endif
}

#endif // TI_WITH_LLVM
20 changes: 20 additions & 0 deletions c_api/tests/c_api_interop_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,23 @@ TEST_F(CapiTest, TestCUDAImport) {
EXPECT_EQ(data_out[3], 4.0);
}
#endif // TI_WITH_CUDA

#ifdef TI_WITH_CUDA
TEST_F(CapiTest, TestCUDAStreamSet) {
void *temp_stream = nullptr;

ti_get_cuda_stream(&temp_stream);
EXPECT_EQ(temp_stream, nullptr);

void *stream1 = reinterpret_cast<void *>(0x12345678);
void *stream2 = reinterpret_cast<void *>(0x87654321);

ti_set_cuda_stream(stream1);
ti_get_cuda_stream(&temp_stream);
EXPECT_EQ(temp_stream, stream1);

ti_set_cuda_stream(stream2);
ti_get_cuda_stream(&temp_stream);
EXPECT_EQ(temp_stream, stream2);
}
#endif
8 changes: 5 additions & 3 deletions taichi/rhi/cuda/cuda_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
namespace taichi::lang {

CUDAContext::CUDAContext()
: profiler_(nullptr), driver_(CUDADriver::get_instance_without_context()) {
: profiler_(nullptr),
driver_(CUDADriver::get_instance_without_context()),
stream_(nullptr) {
// CUDA initialization
dev_count_ = 0;
driver_.init(0);
Expand Down Expand Up @@ -156,14 +158,14 @@ void CUDAContext::launch(void *func,
dynamic_shared_mem_bytes);
}
driver_.launch_kernel(func, grid_dim, 1, 1, block_dim, 1, 1,
dynamic_shared_mem_bytes, nullptr,
dynamic_shared_mem_bytes, stream_,
arg_pointers.data(), nullptr);
}
if (profiler_)
profiler_->stop(task_handle);

if (debug_) {
driver_.stream_synchronize(nullptr);
driver_.stream_synchronize(stream_);
}
}

Expand Down
9 changes: 9 additions & 0 deletions taichi/rhi/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class CUDAContext {
int max_shared_memory_bytes_;
bool debug_;
bool supports_mem_pool_;
void *stream_;

public:
CUDAContext();
Expand Down Expand Up @@ -108,6 +109,14 @@ class CUDAContext {
}

static CUDAContext &get_instance();

void set_stream(void *stream) {
stream_ = stream;
}

void *get_stream() const {
return stream_;
}
};

} // namespace taichi::lang

0 comments on commit e9f19b8

Please sign in to comment.