Skip to content

Commit

Permalink
same as patch-4 (tesseract-ocr#3494) but now with reduced code duplic…
Browse files Browse the repository at this point in the history
…ation: for TFloat to work, we don't need to duplicate the integer work functions as it's only the ExtractResults16[8,16] functions that need different implementations for float vs. double. These are therefor common to both implementations:

```
static void PartialMatrixDotVector64(const int8_t *wi, const TFloat *scales, const int8_t *u,
                                     int num_in, TFloat *v) {

static void PartialMatrixDotVector32(const int8_t *wi, const TFloat *scales, const int8_t *u,
                                     int num_in, TFloat *v) {

static void PartialMatrixDotVector16(const int8_t *wi, const TFloat *scales, const int8_t *u,
                                     int num_in, TFloat *v) {

static inline void PartialMatrixDotVector8(const int8_t *wi, const TFloat *scales, const int8_t *u,
                                           int num_in, TFloat *v) {

static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const TFloat *scales,
                            const int8_t *u, TFloat *v) {
```
  • Loading branch information
GerHobbelt committed Jul 13, 2021
1 parent 1a59b6f commit d2eb7bd
Showing 1 changed file with 54 additions and 253 deletions.
307 changes: 54 additions & 253 deletions src/arch/intsimdmatrixavx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,253 +132,52 @@ static inline __m128i load64_to_128(const int8_t *wi_) {
}

#if defined(FAST_FLOAT)
static inline void ExtractResults8(__m256i result, const int8_t *wi, const float *scales,
float *v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
__m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256d scale01234567 = _mm256_loadu_ps(scales);
//~ __m256d scale4567 = _mm256_loadu_ps(scales + 8);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result = _mm256_add_epi32(result, w256); // result += bias * 127
__m256 res01234567 = _mm256_cvtepi32_ps(_mm256_castsi256_si128(result));
result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
__m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
res01234567 = _mm256_mul_pd(res01234567, scale01234567);
//~ res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_ps(v, res01234567);
//~ _mm256_storeu_pd(v + 4, res4567);
}

static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi,
const float *&scales, float *&v) {
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
// 8x8bit vals in bottom of 128bit reg
const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
__m256d scale0123 = _mm256_loadu_ps(scales);
__m256d scale4567 = _mm256_loadu_ps(scales + 8);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
__m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
__m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_ps(v, res0123);
_mm256_storeu_ps(v + 8, res4567);
w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
scale0123 = _mm256_loadu_ps(scales + 16);
scale4567 = _mm256_loadu_ps(scales + 24);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_ps(v + 16, res0123);
_mm256_storeu_ps(v + 24, res4567);
wi += 16;
scales += 16;
v += 16;
}

// Computes part of matrix.vector v = Wu. Computes N=64 results.
// The weights *must* be arranged so that consecutive reads from wi
// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
// (kNumInputsPerGroup inputs))). After that there must be N consecutive
// bias weights, before continuing with any more weights.
// u must be padded out with zeros to
// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u,
int num_in, float *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
// Initialize all the results to 0.
__m256i result0 = _mm256_setzero_si256();
__m256i result1 = _mm256_setzero_si256();
__m256i result2 = _mm256_setzero_si256();
__m256i result3 = _mm256_setzero_si256();
__m256i result4 = _mm256_setzero_si256();
__m256i result5 = _mm256_setzero_si256();
__m256i result6 = _mm256_setzero_si256();
__m256i result7 = _mm256_setzero_si256();
// Iterate over the input (u), one registerful at a time.
for (int j = 0; j < num_in;) {
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
// Inputs are processed in groups of kNumInputsPerGroup, replicated
// kNumInputGroups times.
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
// Replicate the low 32 bits (4 inputs) 8 times.
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
__m256i weights, reps;
// Mul-add, with horizontal add of the 4 inputs to each of the results.
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
}
}
ExtractResults16(result0, result1, wi, scales, v);
ExtractResults16(result2, result3, wi, scales, v);
ExtractResults16(result4, result5, wi, scales, v);
ExtractResults16(result6, result7, wi, scales, v);
}

// Computes part of matrix.vector v = Wu. Computes N=32 results.
// For details see PartialMatrixDotVector64 with N=32.
static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u,
int num_in, float *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
// Initialize all the results to 0.
__m256i result0 = _mm256_setzero_si256();
__m256i result1 = _mm256_setzero_si256();
__m256i result2 = _mm256_setzero_si256();
__m256i result3 = _mm256_setzero_si256();
// Iterate over the input (u), one registerful at a time.
for (int j = 0; j < num_in;) {
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
// Inputs are processed in groups of kNumInputsPerGroup, replicated
// kNumInputGroups times.
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
// Replicate the low 32 bits (4 inputs) 8 times.
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
__m256i weights, reps;
// Mul-add, with horizontal add of the 4 inputs to each of the results.
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
}
}
ExtractResults16(result0, result1, wi, scales, v);
ExtractResults16(result2, result3, wi, scales, v);
}

// Computes part of matrix.vector v = Wu. Computes N=16 results.
// For details see PartialMatrixDotVector64 with N=16.
static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u,
int num_in, float *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
// Initialize all the results to 0.
__m256i result0 = _mm256_setzero_si256();
__m256i result1 = _mm256_setzero_si256();
// Iterate over the input (u), one registerful at a time.
for (int j = 0; j < num_in;) {
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
// Inputs are processed in groups of kNumInputsPerGroup, replicated
// kNumInputGroups times.
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
// Replicate the low 32 bits (4 inputs) 8 times.
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
__m256i weights, reps;
// Mul-add, with horizontal add of the 4 inputs to each of the results.
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
}
}
ExtractResults16(result0, result1, wi, scales, v);
static inline void ExtractResults8(__m256i result, const int8_t* wi, const TFloat* scales,
TFloat* v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
__m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256 scale01234567 = _mm256_loadu_ps(scales);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result = _mm256_add_epi32(result, w256); // result += bias * 127
__m256 res01234567 = _mm256_cvtepi32_ps(result);
result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v, res01234567);
}

// Computes part of matrix.vector v = Wu. Computes N=8 results.
// For details see PartialMatrixDotVector64 with N=8.
static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u,
int num_in, float *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
// Initialize all the results to 0.
__m256i result0 = _mm256_setzero_si256();
// Iterate over the input (u), one registerful at a time.
for (int j = 0; j < num_in;) {
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
// Inputs are processed in groups of kNumInputsPerGroup, replicated
// kNumInputGroups times.
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
// Replicate the low 32 bits (4 inputs) 8 times.
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
__m256i weights, reps;
// Mul-add, with horizontal add of the 4 inputs to each of the results.
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
}
}
ExtractResults8(result0, wi, scales, v);
static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t*& wi,
const TFloat*& scales, TFloat*& v) {
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(wi));
// 8x8bit vals in bottom of 128bit reg
const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
__m256 scale01234567 = _mm256_loadu_ps(scales);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
__m256 res01234567 = _mm256_cvtepi32_ps(result0);
result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v, res01234567);
w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
scale01234567 = _mm256_loadu_ps(scales + 8);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
res01234567 = _mm256_cvtepi32_ps(result1);
result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v + 8, res01234567);
wi += 16;
scales += 16;
v += 16;
}

static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales,
const int8_t *u, float *v) {
const int num_out = dim1;
const int num_in = dim2 - 1;
// Each call to a partial_func_ produces group_size outputs, except the
// last one, which can produce less.
const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
int output = 0;

int w_step = (rounded_num_in + 1) * group_size;

// Run with this group size, until it would produce too much output, then
// switch to a smaller size.
for (; output + group_size <= rounded_num_out; output += group_size) {
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
wi += w_step;
scales += group_size;
v += group_size;
}
group_size /= 2;
w_step /= 2;

if (output + group_size <= rounded_num_out) {
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
wi += w_step;
scales += group_size;
v += group_size;
output += group_size;
}
group_size /= 2;
w_step /= 2;

if (output + group_size <= rounded_num_out) {
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
wi += w_step;
scales += group_size;
v += group_size;
output += group_size;
}
group_size /= 2;
w_step /= 2;

if (output + group_size <= rounded_num_out) {
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
}
}
#else
static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales,
double *v) {

static inline void ExtractResults8(__m256i result, const int8_t *wi, const TFloat *scales,
TFloat *v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
__m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
Expand All @@ -396,7 +195,7 @@ static inline void ExtractResults8(__m256i result, const int8_t *wi, const doubl
}

static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi,
const double *&scales, double *&v) {
const TFloat *&scales, TFloat *&v) {
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
// 8x8bit vals in bottom of 128bit reg
const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
Expand Down Expand Up @@ -430,15 +229,17 @@ static inline void ExtractResults16(__m256i result0, __m256i result1, const int8
v += 16;
}

#endif

// Computes part of matrix.vector v = Wu. Computes N=64 results.
// The weights *must* be arranged so that consecutive reads from wi
// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
// (kNumInputsPerGroup inputs))). After that there must be N consecutive
// bias weights, before continuing with any more weights.
// u must be padded out with zeros to
// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, const int8_t *u,
int num_in, double *v) {
static void PartialMatrixDotVector64(const int8_t *wi, const TFloat *scales, const int8_t *u,
int num_in, TFloat *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
Expand Down Expand Up @@ -482,8 +283,8 @@ static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, con

// Computes part of matrix.vector v = Wu. Computes N=32 results.
// For details see PartialMatrixDotVector64 with N=32.
static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, const int8_t *u,
int num_in, double *v) {
static void PartialMatrixDotVector32(const int8_t *wi, const TFloat *scales, const int8_t *u,
int num_in, TFloat *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
Expand Down Expand Up @@ -517,8 +318,8 @@ static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, con

// Computes part of matrix.vector v = Wu. Computes N=16 results.
// For details see PartialMatrixDotVector64 with N=16.
static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, const int8_t *u,
int num_in, double *v) {
static void PartialMatrixDotVector16(const int8_t *wi, const TFloat *scales, const int8_t *u,
int num_in, TFloat *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
Expand Down Expand Up @@ -547,8 +348,8 @@ static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, con

// Computes part of matrix.vector v = Wu. Computes N=8 results.
// For details see PartialMatrixDotVector64 with N=8.
static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scales, const int8_t *u,
int num_in, double *v) {
static inline void PartialMatrixDotVector8(const int8_t *wi, const TFloat *scales, const int8_t *u,
int num_in, TFloat *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
Expand All @@ -573,8 +374,8 @@ static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scale
ExtractResults8(result0, wi, scales, v);
}

static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales,
const int8_t *u, double *v) {
static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const TFloat *scales,
const int8_t *u, TFloat *v) {
const int num_out = dim1;
const int num_in = dim2 - 1;
// Each call to a partial_func_ produces group_size outputs, except the
Expand Down Expand Up @@ -621,7 +422,7 @@ static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
}
}
#endif


static const IntSimdMatrix simdMatrix = {
// Function.
Expand Down

0 comments on commit d2eb7bd

Please sign in to comment.