From 85357d82e0a5b67b0b9ca044072e00199df431ba Mon Sep 17 00:00:00 2001 From: Ger Hobbelt Date: Tue, 13 Jul 2021 09:56:11 +0200 Subject: [PATCH] bugfixing the AVX2 Extract8+16 codes, where there's lines like `__m256d scale01234567 = _mm256_loadu_ps(scales)`, i.e. loading float vectors into double vector types. Extract from #3490. --- src/arch/intsimdmatrixavx2.cpp | 49 ++++++++++++++-------------------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/src/arch/intsimdmatrixavx2.cpp b/src/arch/intsimdmatrixavx2.cpp index 8f671f08a9..c87ef414a7 100644 --- a/src/arch/intsimdmatrixavx2.cpp +++ b/src/arch/intsimdmatrixavx2.cpp @@ -132,54 +132,45 @@ static inline __m128i load64_to_128(const int8_t *wi_) { } #if defined(FAST_FLOAT) -static inline void ExtractResults8(__m256i result, const int8_t *wi, const float *scales, - float *v) { - __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg + +static inline void ExtractResults8(__m256i result, const int8_t *wi, + const float *scales, float *v) { + __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); - __m256d scale01234567 = _mm256_loadu_ps(scales); - //~ __m256d scale4567 = _mm256_loadu_ps(scales + 8); + __m256 scale01234567 = _mm256_loadu_ps(scales); w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 result = _mm256_add_epi32(result, w256); // result += bias * 127 - __m256 res01234567 = _mm256_cvtepi32_ps(_mm256_castsi256_si128(result)); + __m256 res01234567 = _mm256_cvtepi32_ps(result); result = _mm256_permute4x64_epi64(result, 2 + (3 << 2)); - __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result)); - res01234567 = _mm256_mul_pd(res01234567, scale01234567); - //~ res4567 = _mm256_mul_pd(res4567, scale4567); + res01234567 = _mm256_mul_ps(res01234567, scale01234567); _mm256_storeu_ps(v, res01234567); - //~ _mm256_storeu_pd(v + 4, res4567); } -static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi, - const float *&scales, float *&v) { +static inline void ExtractResults16(__m256i result0, __m256i result1, + const int8_t *&wi, const float *&scales, + float *&v) { __m128i w8 = _mm_loadu_si128(reinterpret_cast(wi)); // 8x8bit vals in bottom of 128bit reg - const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); + const __m256i bias_scale = + _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg - __m256d scale0123 = _mm256_loadu_ps(scales); - __m256d scale4567 = _mm256_loadu_ps(scales + 8); + __m256 scale01234567 = _mm256_loadu_ps(scales); w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 result0 = _mm256_add_epi32(result0, w256); // result += bias * 127 - __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0)); + __m256 res01234567 = _mm256_cvtepi32_ps(result0); result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2)); - __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0)); - res0123 = _mm256_mul_pd(res0123, scale0123); - res4567 = _mm256_mul_pd(res4567, scale4567); - _mm256_storeu_ps(v, res0123); - _mm256_storeu_ps(v + 8, res4567); + res01234567 = _mm256_mul_ps(res01234567, scale01234567); + _mm256_storeu_ps(v, res01234567); w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2)); w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg - scale0123 = _mm256_loadu_ps(scales + 16); - scale4567 = _mm256_loadu_ps(scales + 24); + scale01234567 = _mm256_loadu_ps(scales + 8); w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 result1 = _mm256_add_epi32(result1, w256); // result += bias * 127 - res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1)); + res01234567 = _mm256_cvtepi32_ps(result1); result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2)); - res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1)); - res0123 = _mm256_mul_pd(res0123, scale0123); - res4567 = _mm256_mul_pd(res4567, scale4567); - _mm256_storeu_ps(v + 16, res0123); - _mm256_storeu_ps(v + 24, res4567); + res01234567 = _mm256_mul_ps(res01234567, scale01234567); + _mm256_storeu_ps(v + 8, res01234567); wi += 16; scales += 16; v += 16;