From d2eb7bdf4d35f90b14ef99f8dfd20db2363d8a0f Mon Sep 17 00:00:00 2001 From: Ger Hobbelt Date: Tue, 13 Jul 2021 10:04:59 +0200 Subject: [PATCH] same as patch-4 (#3494) but now with reduced code duplication: for TFloat to work, we don't need to duplicate the integer work functions as it's only the ExtractResults16[8,16] functions that need different implementations for float vs. double. These are therefor common to both implementations: ``` static void PartialMatrixDotVector64(const int8_t *wi, const TFloat *scales, const int8_t *u, int num_in, TFloat *v) { static void PartialMatrixDotVector32(const int8_t *wi, const TFloat *scales, const int8_t *u, int num_in, TFloat *v) { static void PartialMatrixDotVector16(const int8_t *wi, const TFloat *scales, const int8_t *u, int num_in, TFloat *v) { static inline void PartialMatrixDotVector8(const int8_t *wi, const TFloat *scales, const int8_t *u, int num_in, TFloat *v) { static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const TFloat *scales, const int8_t *u, TFloat *v) { ``` --- src/arch/intsimdmatrixavx2.cpp | 307 ++++++--------------------------- 1 file changed, 54 insertions(+), 253 deletions(-) diff --git a/src/arch/intsimdmatrixavx2.cpp b/src/arch/intsimdmatrixavx2.cpp index 8f671f08a9..a463408597 100644 --- a/src/arch/intsimdmatrixavx2.cpp +++ b/src/arch/intsimdmatrixavx2.cpp @@ -132,253 +132,52 @@ static inline __m128i load64_to_128(const int8_t *wi_) { } #if defined(FAST_FLOAT) -static inline void ExtractResults8(__m256i result, const int8_t *wi, const float *scales, - float *v) { - __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg - __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg - __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); - __m256d scale01234567 = _mm256_loadu_ps(scales); - //~ __m256d scale4567 = _mm256_loadu_ps(scales + 8); - w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 - result = _mm256_add_epi32(result, w256); // result += bias * 127 - __m256 res01234567 = _mm256_cvtepi32_ps(_mm256_castsi256_si128(result)); - result = _mm256_permute4x64_epi64(result, 2 + (3 << 2)); - __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result)); - res01234567 = _mm256_mul_pd(res01234567, scale01234567); - //~ res4567 = _mm256_mul_pd(res4567, scale4567); - _mm256_storeu_ps(v, res01234567); - //~ _mm256_storeu_pd(v + 4, res4567); -} - -static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi, - const float *&scales, float *&v) { - __m128i w8 = _mm_loadu_si128(reinterpret_cast(wi)); - // 8x8bit vals in bottom of 128bit reg - const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); - __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg - __m256d scale0123 = _mm256_loadu_ps(scales); - __m256d scale4567 = _mm256_loadu_ps(scales + 8); - w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 - result0 = _mm256_add_epi32(result0, w256); // result += bias * 127 - __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0)); - result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2)); - __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0)); - res0123 = _mm256_mul_pd(res0123, scale0123); - res4567 = _mm256_mul_pd(res4567, scale4567); - _mm256_storeu_ps(v, res0123); - _mm256_storeu_ps(v + 8, res4567); - w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2)); - w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg - scale0123 = _mm256_loadu_ps(scales + 16); - scale4567 = _mm256_loadu_ps(scales + 24); - w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 - result1 = _mm256_add_epi32(result1, w256); // result += bias * 127 - res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1)); - result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2)); - res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1)); - res0123 = _mm256_mul_pd(res0123, scale0123); - res4567 = _mm256_mul_pd(res4567, scale4567); - _mm256_storeu_ps(v + 16, res0123); - _mm256_storeu_ps(v + 24, res4567); - wi += 16; - scales += 16; - v += 16; -} - -// Computes part of matrix.vector v = Wu. Computes N=64 results. -// The weights *must* be arranged so that consecutive reads from wi -// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of -// (kNumInputsPerGroup inputs))). After that there must be N consecutive -// bias weights, before continuing with any more weights. -// u must be padded out with zeros to -// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements. -static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u, - int num_in, float *v) { - // Register containing 16-bit ones for horizontal add with 16->32 bit - // conversion. - __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); - __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); - // Initialize all the results to 0. - __m256i result0 = _mm256_setzero_si256(); - __m256i result1 = _mm256_setzero_si256(); - __m256i result2 = _mm256_setzero_si256(); - __m256i result3 = _mm256_setzero_si256(); - __m256i result4 = _mm256_setzero_si256(); - __m256i result5 = _mm256_setzero_si256(); - __m256i result6 = _mm256_setzero_si256(); - __m256i result7 = _mm256_setzero_si256(); - // Iterate over the input (u), one registerful at a time. - for (int j = 0; j < num_in;) { - __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); - // Inputs are processed in groups of kNumInputsPerGroup, replicated - // kNumInputGroups times. - for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { - // Replicate the low 32 bits (4 inputs) 8 times. - __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); - // Rotate the inputs in groups of 4, so the next 4 inputs are ready. - inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); - __m256i weights, reps; - // Mul-add, with horizontal add of the 4 inputs to each of the results. - MultiplyGroup(rep_input, ones, wi, weights, reps, result0); - MultiplyGroup(rep_input, ones, wi, weights, reps, result1); - MultiplyGroup(rep_input, ones, wi, weights, reps, result2); - MultiplyGroup(rep_input, ones, wi, weights, reps, result3); - MultiplyGroup(rep_input, ones, wi, weights, reps, result4); - MultiplyGroup(rep_input, ones, wi, weights, reps, result5); - MultiplyGroup(rep_input, ones, wi, weights, reps, result6); - MultiplyGroup(rep_input, ones, wi, weights, reps, result7); - } - } - ExtractResults16(result0, result1, wi, scales, v); - ExtractResults16(result2, result3, wi, scales, v); - ExtractResults16(result4, result5, wi, scales, v); - ExtractResults16(result6, result7, wi, scales, v); -} - -// Computes part of matrix.vector v = Wu. Computes N=32 results. -// For details see PartialMatrixDotVector64 with N=32. -static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u, - int num_in, float *v) { - // Register containing 16-bit ones for horizontal add with 16->32 bit - // conversion. - __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); - __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); - // Initialize all the results to 0. - __m256i result0 = _mm256_setzero_si256(); - __m256i result1 = _mm256_setzero_si256(); - __m256i result2 = _mm256_setzero_si256(); - __m256i result3 = _mm256_setzero_si256(); - // Iterate over the input (u), one registerful at a time. - for (int j = 0; j < num_in;) { - __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); - // Inputs are processed in groups of kNumInputsPerGroup, replicated - // kNumInputGroups times. - for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { - // Replicate the low 32 bits (4 inputs) 8 times. - __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); - // Rotate the inputs in groups of 4, so the next 4 inputs are ready. - inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); - __m256i weights, reps; - // Mul-add, with horizontal add of the 4 inputs to each of the results. - MultiplyGroup(rep_input, ones, wi, weights, reps, result0); - MultiplyGroup(rep_input, ones, wi, weights, reps, result1); - MultiplyGroup(rep_input, ones, wi, weights, reps, result2); - MultiplyGroup(rep_input, ones, wi, weights, reps, result3); - } - } - ExtractResults16(result0, result1, wi, scales, v); - ExtractResults16(result2, result3, wi, scales, v); -} -// Computes part of matrix.vector v = Wu. Computes N=16 results. -// For details see PartialMatrixDotVector64 with N=16. -static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u, - int num_in, float *v) { - // Register containing 16-bit ones for horizontal add with 16->32 bit - // conversion. - __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); - __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); - // Initialize all the results to 0. - __m256i result0 = _mm256_setzero_si256(); - __m256i result1 = _mm256_setzero_si256(); - // Iterate over the input (u), one registerful at a time. - for (int j = 0; j < num_in;) { - __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); - // Inputs are processed in groups of kNumInputsPerGroup, replicated - // kNumInputGroups times. - for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { - // Replicate the low 32 bits (4 inputs) 8 times. - __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); - // Rotate the inputs in groups of 4, so the next 4 inputs are ready. - inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); - __m256i weights, reps; - // Mul-add, with horizontal add of the 4 inputs to each of the results. - MultiplyGroup(rep_input, ones, wi, weights, reps, result0); - MultiplyGroup(rep_input, ones, wi, weights, reps, result1); - } - } - ExtractResults16(result0, result1, wi, scales, v); +static inline void ExtractResults8(__m256i result, const int8_t* wi, const TFloat* scales, + TFloat* v) { + __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg + __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg + __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); + __m256 scale01234567 = _mm256_loadu_ps(scales); + w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 + result = _mm256_add_epi32(result, w256); // result += bias * 127 + __m256 res01234567 = _mm256_cvtepi32_ps(result); + result = _mm256_permute4x64_epi64(result, 2 + (3 << 2)); + res01234567 = _mm256_mul_ps(res01234567, scale01234567); + _mm256_storeu_ps(v, res01234567); } -// Computes part of matrix.vector v = Wu. Computes N=8 results. -// For details see PartialMatrixDotVector64 with N=8. -static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u, - int num_in, float *v) { - // Register containing 16-bit ones for horizontal add with 16->32 bit - // conversion. - __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); - __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); - // Initialize all the results to 0. - __m256i result0 = _mm256_setzero_si256(); - // Iterate over the input (u), one registerful at a time. - for (int j = 0; j < num_in;) { - __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); - // Inputs are processed in groups of kNumInputsPerGroup, replicated - // kNumInputGroups times. - for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { - // Replicate the low 32 bits (4 inputs) 8 times. - __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); - // Rotate the inputs in groups of 4, so the next 4 inputs are ready. - inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); - __m256i weights, reps; - // Mul-add, with horizontal add of the 4 inputs to each of the results. - MultiplyGroup(rep_input, ones, wi, weights, reps, result0); - } - } - ExtractResults8(result0, wi, scales, v); +static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t*& wi, + const TFloat*& scales, TFloat*& v) { + __m128i w8 = _mm_loadu_si128(reinterpret_cast(wi)); + // 8x8bit vals in bottom of 128bit reg + const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); + __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg + __m256 scale01234567 = _mm256_loadu_ps(scales); + w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 + result0 = _mm256_add_epi32(result0, w256); // result += bias * 127 + __m256 res01234567 = _mm256_cvtepi32_ps(result0); + result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2)); + res01234567 = _mm256_mul_ps(res01234567, scale01234567); + _mm256_storeu_ps(v, res01234567); + w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2)); + w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg + scale01234567 = _mm256_loadu_ps(scales + 8); + w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 + result1 = _mm256_add_epi32(result1, w256); // result += bias * 127 + res01234567 = _mm256_cvtepi32_ps(result1); + result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2)); + res01234567 = _mm256_mul_ps(res01234567, scale01234567); + _mm256_storeu_ps(v + 8, res01234567); + wi += 16; + scales += 16; + v += 16; } -static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales, - const int8_t *u, float *v) { - const int num_out = dim1; - const int num_in = dim2 - 1; - // Each call to a partial_func_ produces group_size outputs, except the - // last one, which can produce less. - const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup); - const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister); - int group_size = kNumOutputsPerRegister * kMaxOutputRegisters; - int output = 0; - - int w_step = (rounded_num_in + 1) * group_size; - - // Run with this group size, until it would produce too much output, then - // switch to a smaller size. - for (; output + group_size <= rounded_num_out; output += group_size) { - PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v); - wi += w_step; - scales += group_size; - v += group_size; - } - group_size /= 2; - w_step /= 2; - - if (output + group_size <= rounded_num_out) { - PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v); - wi += w_step; - scales += group_size; - v += group_size; - output += group_size; - } - group_size /= 2; - w_step /= 2; - - if (output + group_size <= rounded_num_out) { - PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v); - wi += w_step; - scales += group_size; - v += group_size; - output += group_size; - } - group_size /= 2; - w_step /= 2; - - if (output + group_size <= rounded_num_out) { - PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v); - } -} #else -static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales, - double *v) { + +static inline void ExtractResults8(__m256i result, const int8_t *wi, const TFloat *scales, + TFloat *v) { __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); @@ -396,7 +195,7 @@ static inline void ExtractResults8(__m256i result, const int8_t *wi, const doubl } static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi, - const double *&scales, double *&v) { + const TFloat *&scales, TFloat *&v) { __m128i w8 = _mm_loadu_si128(reinterpret_cast(wi)); // 8x8bit vals in bottom of 128bit reg const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); @@ -430,6 +229,8 @@ static inline void ExtractResults16(__m256i result0, __m256i result1, const int8 v += 16; } +#endif + // Computes part of matrix.vector v = Wu. Computes N=64 results. // The weights *must* be arranged so that consecutive reads from wi // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of @@ -437,8 +238,8 @@ static inline void ExtractResults16(__m256i result0, __m256i result1, const int8 // bias weights, before continuing with any more weights. // u must be padded out with zeros to // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements. -static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, const int8_t *u, - int num_in, double *v) { +static void PartialMatrixDotVector64(const int8_t *wi, const TFloat *scales, const int8_t *u, + int num_in, TFloat *v) { // Register containing 16-bit ones for horizontal add with 16->32 bit // conversion. __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); @@ -482,8 +283,8 @@ static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, con // Computes part of matrix.vector v = Wu. Computes N=32 results. // For details see PartialMatrixDotVector64 with N=32. -static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, const int8_t *u, - int num_in, double *v) { +static void PartialMatrixDotVector32(const int8_t *wi, const TFloat *scales, const int8_t *u, + int num_in, TFloat *v) { // Register containing 16-bit ones for horizontal add with 16->32 bit // conversion. __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); @@ -517,8 +318,8 @@ static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, con // Computes part of matrix.vector v = Wu. Computes N=16 results. // For details see PartialMatrixDotVector64 with N=16. -static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, const int8_t *u, - int num_in, double *v) { +static void PartialMatrixDotVector16(const int8_t *wi, const TFloat *scales, const int8_t *u, + int num_in, TFloat *v) { // Register containing 16-bit ones for horizontal add with 16->32 bit // conversion. __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); @@ -547,8 +348,8 @@ static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, con // Computes part of matrix.vector v = Wu. Computes N=8 results. // For details see PartialMatrixDotVector64 with N=8. -static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scales, const int8_t *u, - int num_in, double *v) { +static inline void PartialMatrixDotVector8(const int8_t *wi, const TFloat *scales, const int8_t *u, + int num_in, TFloat *v) { // Register containing 16-bit ones for horizontal add with 16->32 bit // conversion. __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); @@ -573,8 +374,8 @@ static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scale ExtractResults8(result0, wi, scales, v); } -static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales, - const int8_t *u, double *v) { +static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const TFloat *scales, + const int8_t *u, TFloat *v) { const int num_out = dim1; const int num_in = dim2 - 1; // Each call to a partial_func_ produces group_size outputs, except the @@ -621,7 +422,7 @@ static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double * PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v); } } -#endif + static const IntSimdMatrix simdMatrix = { // Function.