From a2fa2ecb99986f7db893f34672cbb94e5cc1f050 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 2 Sep 2024 10:30:48 -0400 Subject: [PATCH] Check if have bfloat --- candle-metal-kernels/src/quantized.metal | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal index 5a1bccdc4..162b7a2d1 100644 --- a/candle-metal-kernels/src/quantized.metal +++ b/candle-metal-kernels/src/quantized.metal @@ -1495,6 +1495,7 @@ kernel void kernel_mul_mv_f16_f32( kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); } +#if defined(__HAVE_BFLOAT__) void kernel_mul_mv_bf16_f32_1row_impl( device const char * src0, device const char * src1, @@ -1578,9 +1579,11 @@ kernel void kernel_mul_mv_bf16_f32_1row( uint tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_bf16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); } +#endif -#define N_F16_F32 4 +#define N_BF16_F32 4 +#if defined(__HAVE_BFLOAT__) void kernel_mul_mv_bf16_f32_impl( device const char * src0, device const char * src1, @@ -1605,7 +1608,7 @@ void kernel_mul_mv_bf16_f32_impl( uint tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F16_F32; + const int64_t rb = tgpig.y*N_BF16_F32; const int64_t im = tgpig.z; const uint i12 = im%ne12; @@ -1616,7 +1619,7 @@ void kernel_mul_mv_bf16_f32_impl( device const bfloat * x = (device const bfloat *) (src0 + offset0); if (ne00 < 128) { - for (int row = 0; row < N_F16_F32; ++row) { + for (int row = 0; row < N_BF16_F32; ++row) { int r1 = rb + row; if (r1 >= ne11) { break; @@ -1636,7 +1639,7 @@ void kernel_mul_mv_bf16_f32_impl( } } else { device const bfloat4 * x4 = (device const bfloat4 *)x; - for (int row = 0; row < N_F16_F32; ++row) { + for (int row = 0; row < N_BF16_F32; ++row) { int r1 = rb + row; if (r1 >= ne11) { break; @@ -1684,7 +1687,9 @@ kernel void kernel_mul_mv_bf16_f32( uint tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_bf16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); } +#endif +#if defined(__HAVE_BFLOAT__) // Assumes row size (ne00) is a multiple of 4 kernel void kernel_mul_mv_bf16_f32_l4( device const char * src0, @@ -1734,6 +1739,7 @@ kernel void kernel_mul_mv_bf16_f32_l4( } } } +#endif kernel void kernel_alibi_f32( device const float * src0,