diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f6426d4bad168..cac029b480b7a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -10,6 +10,7 @@ #include #include #include +#include "rocblas/rocblas.h" #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT @@ -2531,6 +2532,10 @@ void ggml_init_cublas() { static bool initialized = false; if (!initialized) { +#ifdef GGML_USE_HIPBLAS + rocblas_initialize(); + hipDeviceSynchronize(); +#endif CUDA_CHECK(cudaGetDeviceCount(&g_device_count)); GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0;