Skip to content

Commit

Permalink
Check if have bfloat
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 2, 2024
1 parent ef12e27 commit a2fa2ec
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions candle-metal-kernels/src/quantized.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,7 @@ kernel void kernel_mul_mv_f16_f32(
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
}

#if defined(__HAVE_BFLOAT__)
void kernel_mul_mv_bf16_f32_1row_impl(
device const char * src0,
device const char * src1,
Expand Down Expand Up @@ -1578,9 +1579,11 @@ kernel void kernel_mul_mv_bf16_f32_1row(
uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_bf16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
}
#endif

#define N_F16_F32 4
#define N_BF16_F32 4

#if defined(__HAVE_BFLOAT__)
void kernel_mul_mv_bf16_f32_impl(
device const char * src0,
device const char * src1,
Expand All @@ -1605,7 +1608,7 @@ void kernel_mul_mv_bf16_f32_impl(
uint tiisg[[thread_index_in_simdgroup]]) {

const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F16_F32;
const int64_t rb = tgpig.y*N_BF16_F32;
const int64_t im = tgpig.z;

const uint i12 = im%ne12;
Expand All @@ -1616,7 +1619,7 @@ void kernel_mul_mv_bf16_f32_impl(
device const bfloat * x = (device const bfloat *) (src0 + offset0);

if (ne00 < 128) {
for (int row = 0; row < N_F16_F32; ++row) {
for (int row = 0; row < N_BF16_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
Expand All @@ -1636,7 +1639,7 @@ void kernel_mul_mv_bf16_f32_impl(
}
} else {
device const bfloat4 * x4 = (device const bfloat4 *)x;
for (int row = 0; row < N_F16_F32; ++row) {
for (int row = 0; row < N_BF16_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
Expand Down Expand Up @@ -1684,7 +1687,9 @@ kernel void kernel_mul_mv_bf16_f32(
uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_bf16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
}
#endif

#if defined(__HAVE_BFLOAT__)
// Assumes row size (ne00) is a multiple of 4
kernel void kernel_mul_mv_bf16_f32_l4(
device const char * src0,
Expand Down Expand Up @@ -1734,6 +1739,7 @@ kernel void kernel_mul_mv_bf16_f32_l4(
}
}
}
#endif

kernel void kernel_alibi_f32(
device const float * src0,
Expand Down

0 comments on commit a2fa2ec

Please sign in to comment.