From 48c84f77feda27fb0be96f21dbb95a04b82f1c1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5kon=20H=2E=20Hitland?= Date: Tue, 18 Apr 2023 23:07:03 +0200 Subject: [PATCH] q4_0c: AVX512 vec_dot and quantize impl --- ggml.c | 141 +++++++++++++++++++++++++++++++++++++++++++++++------- llama.cpp | 2 +- 2 files changed, 125 insertions(+), 18 deletions(-) diff --git a/ggml.c b/ggml.c index 73bcbe7668daa4..72d1e6d22bfbb4 100644 --- a/ggml.c +++ b/ggml.c @@ -1437,8 +1437,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int // reference implementation for deterministic creation of model files static void quantize_row_q8_0c_reference(const float * restrict x, void * restrict y, int k) { - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; + assert(k % QK8_0C == 0); + const int nb = k / QK8_0C; uint8_t * restrict qs = y; float * restrict ds = (float *) ((uint8_t *) y + QK8_0C * nb); @@ -1446,8 +1446,8 @@ static void quantize_row_q8_0c_reference(const float * restrict x, void * restri for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max - for (int l = 0; l < QK8_0; l++) { - const float v = x[i*QK8_0 + l]; + for (int l = 0; l < QK8_0C; l++) { + const float v = x[i*QK8_0C + l]; amax = MAX(amax, fabsf(v)); } @@ -1456,17 +1456,46 @@ static void quantize_row_q8_0c_reference(const float * restrict x, void * restri ds[i] = d; - for (int l = 0; l < QK8_0; ++l) { - const float v = x[i*QK8_0 + l]*id; - qs[i*QK8_0 + l] = roundf(v); + for (int l = 0; l < QK8_0C; ++l) { + const float v = x[i*QK8_0C + l]*id; + qs[i*QK8_0C + l] = roundf(v); } } } static void quantize_row_q8_0c(const float * restrict x, void * restrict vy, int k) { - assert(k % QK8_0 == 0); + assert(k % QK8_0C == 0); + const int nb = k / QK8_0C; + + int8_t * restrict qs = vy; + float * restrict ds = (float *) ((uint8_t *) vy + nb*QK8_0C); + +#if __AVX512F__ + for (int i = 0; i < nb; i++) { + const __m512 x0 = _mm512_loadu_ps( x + i*QK8_0C ); + const __m512 x1 = _mm512_loadu_ps( x + i*QK8_0C + QK8_0C/2); + + // Find absolute max + const __m512 x0abs = _mm512_abs_ps(x0); + const __m512 x1abs = _mm512_abs_ps(x1); + const float amax = _mm512_reduce_max_ps(_mm512_max_ps(x0abs, x1abs)); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + ds[i] = d; + const __m512 mul = _mm512_set1_ps( id ); + const __m512i x0q = _mm512_cvt_roundps_epi32(_mm512_mul_ps(x0, mul), (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + const __m512i x1q = _mm512_cvt_roundps_epi32(_mm512_mul_ps(x1, mul), (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + _mm512_mask_cvtepi32_storeu_epi8(qs + i*QK8_0C, 0xffff, x0q); + _mm512_mask_cvtepi32_storeu_epi8(qs + i*QK8_0C + QK8_0C/2, 0xffff, x1q); + } +#else + // scalar quantize_row_q8_0c_reference(x, vy, k); +#endif } static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) { @@ -2364,6 +2393,73 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float *s = sumf; } +#if __AVX512F__ && QK4_0 == 32 + +// Dot product of four blocks of q4_0c with four blocks of q8_0c +static inline __m512 dot_q4_0c_fourblocks_avx512( + __m512 acc, + const uint8_t * restrict xqs, + const float * restrict xds, + const int8_t * restrict yqs, + const float * restrict yds +) { + // load quantized bytes + // TODO: change back to aligned loads + const __m512i xqs0123 = _mm512_loadu_epi64( xqs ); + const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf ); + const __m512i xqs01 = _mm512_and_si512( low_nibble_mask, xqs0123 ); + // TODO: try srlv/i? + const __m512i xqs23 = _mm512_and_si512( low_nibble_mask, _mm512_srli_epi32( xqs0123, 4 ) ); + const __m512i yqs01 = _mm512_loadu_epi64( yqs ); + const __m512i yqs23 = _mm512_loadu_epi64( yqs + 2*QK8_0C ); + + // load scales + const __m512i scale_mask0 = _mm512_set_epi32(1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0); + const __m512i scale_mask1 = _mm512_set_epi32(3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2); + const __m128 xyds = _mm_mul_ps(_mm_load_ps(xds), _mm_load_ps(yds)); + const __m512 xyds0123 = _mm512_broadcast_f32x4(xyds); + const __m512 xyds01 = _mm512_permutevar_ps(xyds0123, scale_mask0); + const __m512 xyds23 = _mm512_permutevar_ps(xyds0123, scale_mask1); + + // take dot product of x and y bytes + const __m512i plus_8 = _mm512_set1_epi8( 8 ); +#ifdef __AVX512VNNI__ + // We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch: + // the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8 + // from each nibble, so they can be negative. So, instead of `(xqs01 - 8) * yqs01`, + // we compute `xqs01 * yqs01 - 8 * yqks`. + const __m512i zero = _mm512_setzero_epi32(); + const __m512i yqs01_mul8 = _mm512_dpbusds_epi32( zero, plus_8, yqs01 ); + const __m512i yqs23_mul8 = _mm512_dpbusds_epi32( zero, plus_8, yqs23 ); + const __m512i xy01 = _mm512_dpbusds_epi32( zero, xqs01, yqs01 ); + const __m512i xy23 = _mm512_dpbusds_epi32( zero, xqs23, yqs23 ); + const __m512i res0_int = _mm512_sub_epi32( xy01, yqs01_mul8 ); + const __m512i res1_int = _mm512_sub_epi32( xy23, yqs23_mul8 ); +#else + // As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones. + // It has the same catch as VPDPBUSDS: the left operand should be unsigned. + // This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me + // ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119 + const __m512i one = _mm512_set1_epi16( 1 ); + const __m512i prod_0 = _mm512_maddubs_epi16( xqs01, yqs01 ); + const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, yqs01 ); + const __m512i prod_2 = _mm512_maddubs_epi16( xqs23, yqs23 ); + const __m512i prod_3 = _mm512_maddubs_epi16( plus_8, yqs23 ); + const __m512i diff0 = _mm512_sub_epi16( prod_0, prod_1 ); + const __m512i diff1 = _mm512_sub_epi16( prod_2, prod_3 ); + const __m512i res0_int = _mm512_madd_epi16( diff0, one ); + const __m512i res1_int = _mm512_madd_epi16( diff1, one ); +#endif + + // Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate. + const __m512 res0_float = _mm512_cvtepi32_ps( res0_int ); + const __m512 res1_float = _mm512_cvtepi32_ps( res1_int ); + + return _mm512_fmadd_ps( xyds23, res1_float, + _mm512_fmadd_ps( xyds01, res0_float, acc )); +} +#endif + inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { ggml_float sumf = 0.0; @@ -2610,6 +2706,15 @@ static void ggml_vec_dot_q4_0c_q8_0c(const int n, float * restrict s, const void float sumf = 0.0; +#if __AVX512F__ + // Initialize accumulator with zeros + __m512 acc = _mm512_setzero_ps(); + for (int i = 0; i < nb; i += 4) { + acc = dot_q4_0c_fourblocks_avx512(acc, xqs + i*QK4_0/2, xds + i, yqs + i*QK8_0, yds + i); + } + // Horizontal sum of all lanes of the accumulator + sumf = _mm512_reduce_add_ps( acc ); +#else // scalar for (int i = 0; i < nb/2; i++) { const int dst0 = i + i/2*2; // 0, 1, 4, 5, 8, 9, ... @@ -2620,23 +2725,25 @@ static void ggml_vec_dot_q4_0c_q8_0c(const int n, float * restrict s, const void const float dy0 = yds[dst0]; const float dy1 = yds[dst1]; - int sumi0 = 0; - int sumi1 = 0; + // NOTE: having these as plain int triggers a bug with AVX512 on GCC 12.2 + int64_t sumi0 = 0; + int64_t sumi1 = 0; for (int l = 0; l < QK4_0; l++) { - const uint8_t v0 = xqs[i*QK4_0 + l]; + const uint8_t v0 = xqs[i*QK4_0 + l]; - const int i0 = (int8_t) (v0 & 0xf) - 8; - const int i1 = (int8_t) (v0 >> 4) - 8; + const int i0 = (int) (v0 & 0xf) - 8; + const int i1 = (int) (v0 >> 4) - 8; - const int i2 = yqs[dst0*QK4_0 + l]; - const int i3 = yqs[dst1*QK4_0 + l]; + const int i2 = yqs[dst0*QK4_0 + l]; + const int i3 = yqs[dst1*QK4_0 + l]; - sumi0 += i0*i2; - sumi1 += i1*i3; + sumi0 += i0*i2; + sumi1 += i1*i3; } sumf += dx0*dy0*sumi0 + dx1*dy1*sumi1; } +#endif *s = sumf; } diff --git a/llama.cpp b/llama.cpp index 9fd93f4a271228..d995f626d3194f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -839,7 +839,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_MOSTLY_F16: return "mostly F16"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "mostly Q4_0"; - case LLAMA_FTYPE_MOSTLY_Q4_0C: return "mostly Q4_1C"; + case LLAMA_FTYPE_MOSTLY_Q4_0C: return "mostly Q4_0C"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "mostly Q4_1"; case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: return "mostly Q4_1, some F16";