From 25b41a3d9acd9cadb463fe334005744e1f1d38ef Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 22 Apr 2023 10:37:19 +0300 Subject: [PATCH] ggml : slight improvement of Q4_3 - no need for loop unrolling --- ggml.c | 48 ++++++++++++------------------------------------ 1 file changed, 12 insertions(+), 36 deletions(-) diff --git a/ggml.c b/ggml.c index e7a4f9a914d6ae..488da236978ffc 100644 --- a/ggml.c +++ b/ggml.c @@ -2978,77 +2978,53 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); - float summs = 0.0f; + float summs0 = 0.0f; + float summs1 = 0.0f; - for (int i = 0; i < nb; i += 2) { + for (int i = 0; i < nb; ++i) { const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0]; const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1]; - const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0]; - const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1]; const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; - summs += GGML_FP16_TO_FP32(x0_0->m) * y0->s0 + GGML_FP16_TO_FP32(x0_1->m) * y0->s1; - summs += GGML_FP16_TO_FP32(x1_0->m) * y1->s0 + GGML_FP16_TO_FP32(x1_1->m) * y1->s1; - - const uint8x16_t m4b = vdupq_n_u8(0xf); - - const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); - const float x0_1d = GGML_FP16_TO_FP32(x0_1->d); - const float x1_0d = GGML_FP16_TO_FP32(x1_0->d); - const float x1_1d = GGML_FP16_TO_FP32(x1_1->d); + summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0; + summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1; const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); - const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs)); // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0xf))); const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); // interleave const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h); const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h); - const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h); - const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h); // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); + const float x0_1d = GGML_FP16_TO_FP32(x0_1->d); #if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h)); - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d); #endif } - sumf = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs; + sumf = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1; #else // scalar for (int i = 0; i < nb; i++) {