diff --git a/ggml/include/ggml-cuda.h b/ggml/include/ggml-cuda.h index d7903c666cebfc..71bb6dcf07975e 100644 --- a/ggml/include/ggml-cuda.h +++ b/ggml/include/ggml-cuda.h @@ -6,6 +6,9 @@ #ifdef GGML_USE_HIPBLAS #define GGML_CUDA_NAME "ROCm" #define GGML_CUBLAS_NAME "hipBLAS" +#elif defined(GGML_USE_MUSA) +#define GGML_CUDA_NAME "MUSA" +#define GGML_CUBLAS_NAME "muBLAS" #else #define GGML_CUDA_NAME "CUDA" #define GGML_CUBLAS_NAME "cuBLAS" diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index fafd5fa7ae000d..e40057632fc5aa 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -19,7 +19,11 @@ typedef half2 ggml_half2; #define GGML_COMMON_DECL #elif defined(GGML_COMMON_DECL_CUDA) +#if defined(GGML_COMMON_DECL_MUSA) +#include +#else #include +#endif #include typedef half ggml_half; @@ -415,7 +419,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_ #define GGML_TABLE_END() }; #define GGML_COMMON_IMPL -#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) +#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA) #include #define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = { diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index ed784ea1c20246..091c675e3d4203 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1341,7 +1341,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { static cudaError_t ggml_cuda_Memcpy2DPeerAsync( void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) { -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices cudaMemcpy3DPeerParms p = {}; p.dstDevice = dstDevice; @@ -1355,7 +1355,7 @@ static cudaError_t ggml_cuda_Memcpy2DPeerAsync( GGML_UNUSED(dstDevice); GGML_UNUSED(srcDevice); return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream); -#endif // !defined(GGML_USE_HIPBLAS) +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) } static void ggml_cuda_op_mul_mat( @@ -1821,6 +1821,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } } #else +#ifdef GGML_USE_MUSA + GGML_ASSERT(false); +#else // !GGML_USE_MUSA if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx @@ -1863,6 +1866,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } +#endif // GGML_USE_MUSA #endif if (dst->op_params[0] == GGML_PREC_DEFAULT) { @@ -1902,11 +1906,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor const int cc = ggml_cuda_info().devices[id].cc; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); +#ifdef GGML_USE_MUSA + use_mul_mat_vec_q = false; +#endif // GGML_USE_MUSA } } else { const int cc = ggml_cuda_info().devices[ctx.device].cc; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); +#ifdef GGML_USE_MUSA + use_mul_mat_vec_q = false; +#endif // GGML_USE_MUSA } // debug helpers @@ -3019,7 +3029,7 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size return false; } -#if CUDART_VERSION >= 11100 +#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); if (err != cudaSuccess) { // clear the error diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 26d9412a23eb6f..2b89850c668d47 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -12,6 +12,10 @@ #else #define GGML_COMMON_DECL_CUDA #define GGML_COMMON_IMPL_CUDA +#if defined(GGML_USE_MUSA) +#define GGML_COMMON_DECL_MUSA +#define GGML_COMMON_IMPL_MUSA +#endif #endif #include "ggml-common.h" @@ -114,6 +118,151 @@ #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED +#elif defined(GGML_USE_MUSA) +#include +#include +#include +#include +// XXX: Keep the following order the same as hipBLAS +// #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F +// #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F +#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F +#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N MUBLAS_OP_N +#define CUBLAS_OP_T MUBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS +// #define CUBLAS_TF32_TENSOR_OP_MATH 0 +#define CUDA_R_16F MUSA_R_16F +#define CUDA_R_32F MUSA_R_32F +// #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +// #define cublasComputeType_t mublasComputeType_t +#define cublasCreate mublasCreate +#define cublasDestroy mublasDestroy +#define cublasGemmEx mublasGemmEx +#define cublasGemmBatchedEx mublasGemmBatchedEx +#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx +#define cublasHandle_t mublasHandle_t +// #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS +#define cublasSetMathMode mublasSetMathMode +#define cublasSetStream mublasSetStream +#define cublasSgemm mublasSgemm +#define cublasStatus_t mublasStatus_t +#define cudaDataType_t musaDataType_t //deprecated, new hipblasDatatype not in 5.6 +#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess +#define cudaDeviceProp musaDeviceProp +#define cudaDeviceSynchronize musaDeviceSynchronize +#define cudaError_t musaError_t +#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled +#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled +#define cudaEventCreateWithFlags musaEventCreateWithFlags +#define cudaEventDisableTiming musaEventDisableTiming +#define cudaEventRecord musaEventRecord +#define cudaEventSynchronize musaEventSynchronize +#define cudaEvent_t musaEvent_t +#define cudaEventDestroy musaEventDestroy +#define cudaFree musaFree +#define cudaFreeHost musaFreeHost +#define cudaGetDevice musaGetDevice +#define cudaGetDeviceCount musaGetDeviceCount +#define cudaGetDeviceProperties musaGetDeviceProperties +#define cudaGetErrorString musaGetErrorString +#define cudaGetLastError musaGetLastError +#define cudaHostRegister musaHostRegister +#define cudaHostRegisterPortable musaHostRegisterPortable +#define cudaHostRegisterReadOnly musaHostRegisterReadOnly +#define cudaHostUnregister musaHostUnregister +#define cudaLaunchHostFunc musaLaunchHostFunc +#define cudaMalloc musaMalloc +#define cudaMallocHost musaMallocHost +#define cudaMemcpy musaMemcpy +#define cudaMemcpyAsync musaMemcpyAsync +#define cudaMemcpyPeerAsync musaMemcpyPeerAsync +#define cudaMemcpy2DAsync musaMemcpy2DAsync +#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost +#define cudaMemcpyHostToDevice musaMemcpyHostToDevice +#define cudaMemcpyKind musaMemcpyKind +#define cudaMemset musaMemset +#define cudaMemsetAsync musaMemsetAsync +#define cudaMemGetInfo musaMemGetInfo +#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize +#define cudaSetDevice musaSetDevice +#define cudaStreamCreateWithFlags musaStreamCreateWithFlags +#define cudaStreamDestroy musaStreamDestroy +#define cudaStreamFireAndForget musaStreamFireAndForget +#define cudaStreamNonBlocking musaStreamNonBlocking +#define cudaStreamPerThread musaStreamPerThread +#define cudaStreamSynchronize musaStreamSynchronize +#define cudaStreamWaitEvent musaStreamWaitEvent +#define cudaStream_t musaStream_t +#define cudaSuccess musaSuccess + +// XXX: Other CUDA => MUSA mapping +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED +#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED +#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE +#define CUdevice MUdevice +#define CUdeviceptr MUdeviceptr +#define CUmemAccessDesc MUmemAccessDesc +#define CUmemAllocationProp MUmemAllocationProp +#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle +#define cuDeviceGet muDeviceGet +#define cuDeviceGetAttribute muDeviceGetAttribute +#define cuMemAddressFree muMemAddressFree +#define cuMemAddressReserve muMemAddressReserve +#define cuMemCreate muMemCreate +#define cuMemGetAllocationGranularity muMemGetAllocationGranularity +#define cuMemMap muMemMap +#define cuMemRelease muMemRelease +#define cuMemSetAccess muMemSetAccess +#define cuMemUnmap muMemUnmap +#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize +#define cudaFuncSetAttribute musaFuncSetAttribute +#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms +#define make_cudaExtent make_musaExtent +#define make_cudaPitchedPtr make_musaPitchedPtr + +// XXX: USE_CUDA_GRAPH +#define CUDA_SUCCESS MUSA_SUCCESS +#define CUresult MUresult +#define cuGetErrorString muGetErrorString +#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure +#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction +#define cudaGraphDestroy musaGraphDestroy +#define cudaGraphExecDestroy musaGraphExecDestroy +#define cudaGraphExec_t musaGraphExec_t +#define cudaGraphExecUpdate musaGraphExecUpdate +#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult +#define cudaGraphGetNodes musaGraphGetNodes +#define cudaGraphInstantiate musaGraphInstantiate +#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams +#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams +#define cudaGraphLaunch musaGraphLaunch +#define cudaGraphNodeGetType musaGraphNodeGetType +#define cudaGraphNode_t musaGraphNode_t +#define cudaGraphNodeType musaGraphNodeType +#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel +#define cudaGraph_t musaGraph_t +#define cudaKernelNodeParams musaKernelNodeParams +#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed +#define cudaStreamEndCapture musaStreamEndCapture + +// XXX: cuBLAS => muBLAS mapping +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED +#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT +#define CUBLAS_COMPUTE_16F CUDA_R_16F +#define CUBLAS_COMPUTE_32F CUDA_R_32F +#define cublasComputeType_t cudaDataType_t + +// XXX: Clang builtins mapping +#define __vsubss4 __vsubss4_musa +#define __vsub4 __vsub4_musa +#define __vcmpeq4 __vcmpeq4_musa +#define __vcmpne4 __vcmpne4_musa #else #include #include @@ -168,9 +317,13 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) -#if CUDART_VERSION >= 12000 - static const char * cublas_get_error_str(const cublasStatus_t err) { +#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA) + static const char * cublas_get_error_str(const mublasStatus_t err) { +#ifndef GGML_USE_MUSA return cublasGetStatusString(err); +#else + return mublasStatus_to_string(err); +#endif // GGML_USE_MUSA } #else static const char * cublas_get_error_str(const cublasStatus_t err) { @@ -200,7 +353,7 @@ static const char * cu_get_error_str(CUresult err) { #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #endif -#if CUDART_VERSION >= 11100 +#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) #define GGML_CUDA_ASSUME(x) __builtin_assume(x) #else #define GGML_CUDA_ASSUME(x) @@ -214,6 +367,62 @@ typedef float dfloat; // dequantize float typedef float2 dfloat2; #endif //GGML_CUDA_F16 +#if defined(GGML_USE_MUSA) +#ifndef __has_builtin + #define __has_builtin(x) 0 +#endif + +typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); +typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); +static __device__ __forceinline__ int __vsubss4_musa(const int a, const int b) { + const int8x4_t va = reinterpret_cast(a); + const int8x4_t vb = reinterpret_cast(b); +#if __has_builtin(__builtin_elementwise_sub_sat) + const int8x4_t c = __builtin_elementwise_sub_sat(va, vb); + return reinterpret_cast(c); +#else + int8x4_t c; + int16_t tmp; +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp = va[i] - vb[i]; + if(tmp > std::numeric_limits::max()) tmp = std::numeric_limits::max(); + if(tmp < std::numeric_limits::min()) tmp = std::numeric_limits::min(); + c[i] = tmp; + } + return reinterpret_cast(c); +#endif // __has_builtin(__builtin_elementwise_sub_sat) +} + +static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) { + return __vsubss4_musa(a, b); +} + +static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) { + const uint8x4_t& va = reinterpret_cast(a); + const uint8x4_t& vb = reinterpret_cast(b); + unsigned int c; + uint8x4_t& vc = reinterpret_cast(c); +#pragma unroll + for (int i = 0; i < 4; ++i) { + vc[i] = va[i] == vb[i] ? 0xff : 0x00; + } + return c; +} + +static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) { + const uint8x4_t& va = reinterpret_cast(a); + const uint8x4_t& vb = reinterpret_cast(b); + unsigned int c; + uint8x4_t& vc = reinterpret_cast(c); +#pragma unroll + for (int i = 0; i < 4; ++i) { + vc[i] = va[i] == vb[i] ? 0x00 : 0xff; + } + return c; +} +#endif // defined(GGML_USE_MUSA) + #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 @@ -455,7 +664,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b))); return mask_low | mask_high; } -#endif // CUDART_VERSION < 12000 +#endif // CUDART_VERSION < CUDART_HMASK static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) { #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)