From 59af8dd93b2b0cc1a5b8b26d63ff25c43515516e Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Wed, 30 Jun 2021 07:35:36 +0200 Subject: [PATCH 01/11] Add TFloat data type for neural network Up to now Tesseract used double for training and recognition with "best" models. This commit replaces double by a new data type TFloat which is double by default, but float if FAST_FLOAT is defined. Ideally this should allow faster training. Signed-off-by: Stefan Weil --- Makefile.am | 6 + src/arch/dotproduct.cpp | 2 +- src/arch/dotproduct.h | 10 +- src/arch/dotproductavx.cpp | 42 ++++++ src/arch/dotproductfma.cpp | 10 ++ src/arch/dotproductsse.cpp | 10 ++ src/arch/intsimdmatrix.cpp | 2 +- src/arch/intsimdmatrix.h | 16 +- src/arch/intsimdmatrixavx2.cpp | 15 +- src/arch/intsimdmatrixneon.cpp | 7 +- src/arch/intsimdmatrixsse.cpp | 15 +- src/arch/simddetect.cpp | 20 +-- src/arch/simddetect.h | 3 +- src/ccutil/tfloat.h | 10 ++ src/lstm/fullyconnected.cpp | 18 +-- src/lstm/fullyconnected.h | 13 +- src/lstm/functions.cpp | 4 +- src/lstm/functions.h | 71 ++++----- src/lstm/lstm.cpp | 44 +++--- src/lstm/lstm.h | 2 +- src/lstm/network.h | 2 +- src/lstm/networkio.cpp | 20 +-- src/lstm/networkio.h | 24 +-- src/lstm/networkscratch.h | 14 +- src/lstm/plumbing.cpp | 2 +- src/lstm/plumbing.h | 2 +- src/lstm/weightmatrix.cpp | 193 ++++++++++++++++-------- src/lstm/weightmatrix.h | 37 +++-- src/training/unicharset/lstmtrainer.cpp | 16 +- src/training/unicharset/lstmtrainer.h | 2 +- unittest/intsimdmatrix_test.cc | 20 ++- 31 files changed, 409 insertions(+), 243 deletions(-) create mode 100644 src/ccutil/tfloat.h diff --git a/Makefile.am b/Makefile.am index e95f18ea17..afa75b311c 100644 --- a/Makefile.am +++ b/Makefile.am @@ -146,11 +146,13 @@ noinst_LTLIBRARIES += libtesseract_native.la libtesseract_native_la_CXXFLAGS = -O3 -ffast-math if MARCH_NATIVE_OPT libtesseract_native_la_CXXFLAGS += -march=native -mtune=native +libtesseract_native_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil endif libtesseract_native_la_SOURCES = src/arch/dotproduct.cpp if HAVE_AVX libtesseract_avx_la_CXXFLAGS = -mavx +libtesseract_avx_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil libtesseract_avx_la_SOURCES = src/arch/dotproductavx.cpp libtesseract_la_LIBADD += libtesseract_avx.la noinst_LTLIBRARIES += libtesseract_avx.la @@ -158,6 +160,7 @@ endif if HAVE_AVX2 libtesseract_avx2_la_CXXFLAGS = -mavx2 +libtesseract_avx2_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil libtesseract_avx2_la_SOURCES = src/arch/intsimdmatrixavx2.cpp libtesseract_la_LIBADD += libtesseract_avx2.la noinst_LTLIBRARIES += libtesseract_avx2.la @@ -165,6 +168,7 @@ endif if HAVE_FMA libtesseract_fma_la_CXXFLAGS = -mfma +libtesseract_fma_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil libtesseract_fma_la_SOURCES = src/arch/dotproductfma.cpp libtesseract_la_LIBADD += libtesseract_fma.la noinst_LTLIBRARIES += libtesseract_fma.la @@ -172,6 +176,7 @@ endif if HAVE_SSE4_1 libtesseract_sse_la_CXXFLAGS = -msse4.1 +libtesseract_sse_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil libtesseract_sse_la_SOURCES = src/arch/dotproductsse.cpp src/arch/intsimdmatrixsse.cpp libtesseract_la_LIBADD += libtesseract_sse.la noinst_LTLIBRARIES += libtesseract_sse.la @@ -179,6 +184,7 @@ endif if HAVE_NEON libtesseract_neon_la_CXXFLAGS = $(NEON_CXXFLAGS) +libtesseract_neon_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil libtesseract_neon_la_SOURCES = src/arch/intsimdmatrixneon.cpp libtesseract_la_LIBADD += libtesseract_neon.la noinst_LTLIBRARIES += libtesseract_neon.la diff --git a/src/arch/dotproduct.cpp b/src/arch/dotproduct.cpp index 62bcc00ce2..d0be2717fa 100644 --- a/src/arch/dotproduct.cpp +++ b/src/arch/dotproduct.cpp @@ -19,7 +19,7 @@ namespace tesseract { // Computes and returns the dot product of the two n-vectors u and v. -double DotProductNative(const double *u, const double *v, int n) { +TFloat DotProductNative(const TFloat *u, const TFloat *v, int n) { double total = 0.0; for (int k = 0; k < n; ++k) { total += u[k] * v[k]; diff --git a/src/arch/dotproduct.h b/src/arch/dotproduct.h index bbdf6df9bb..c64765597e 100644 --- a/src/arch/dotproduct.h +++ b/src/arch/dotproduct.h @@ -17,19 +17,21 @@ #ifndef TESSERACT_ARCH_DOTPRODUCT_H_ #define TESSERACT_ARCH_DOTPRODUCT_H_ +#include "tfloat.h" + namespace tesseract { // Computes and returns the dot product of the n-vectors u and v. -double DotProductNative(const double *u, const double *v, int n); +TFloat DotProductNative(const TFloat *u, const TFloat *v, int n); // Uses Intel AVX intrinsics to access the SIMD instruction set. -double DotProductAVX(const double *u, const double *v, int n); +TFloat DotProductAVX(const TFloat *u, const TFloat *v, int n); // Use Intel FMA. -double DotProductFMA(const double *u, const double *v, int n); +TFloat DotProductFMA(const TFloat *u, const TFloat *v, int n); // Uses Intel SSE intrinsics to access the SIMD instruction set. -double DotProductSSE(const double *u, const double *v, int n); +TFloat DotProductSSE(const TFloat *u, const TFloat *v, int n); } // namespace tesseract. diff --git a/src/arch/dotproductavx.cpp b/src/arch/dotproductavx.cpp index 3f243906db..f937b5d8c1 100644 --- a/src/arch/dotproductavx.cpp +++ b/src/arch/dotproductavx.cpp @@ -29,6 +29,47 @@ namespace tesseract { // Computes and returns the dot product of the n-vectors u and v. // Uses Intel AVX intrinsics to access the SIMD instruction set. +#if defined(FAST_FLOAT) +float DotProductAVX(const float *u, const float *v, int n) { +#ifndef FAST_FLOAT16 + const unsigned quot = n / 8; + const unsigned rem = n % 8; +#else + const unsigned quot = n / 16; + const unsigned rem = n % 16; +#endif + __m256 t0 = _mm256_setzero_ps(); +#ifdef FAST_FLOAT16 + __m256 t1 = _mm256_setzero_ps(); +#endif + for (unsigned k = 0; k < quot; k++) { + __m256 f0 = _mm256_loadu_ps(u); + __m256 f1 = _mm256_loadu_ps(v); + f0 = _mm256_mul_ps(f0, f1); + t0 = _mm256_add_ps(t0, f0); + u += 8; + v += 8; +#ifdef FAST_FLOAT16 + __m256 f2 = _mm256_loadu_ps(u); + __m256 f3 = _mm256_loadu_ps(v); + f2 = _mm256_mul_ps(f2, f3); + t1 = _mm256_add_ps(t1, f2); + u += 8; + v += 8; +#endif + } +#ifdef FAST_FLOAT16 + t0 = _mm256_hadd_ps(t0, t1); +#endif + alignas(32) float tmp[8]; + _mm256_store_ps(tmp, t0); + float result = tmp[0] + tmp[1] + tmp[2] + tmp[3] + tmp[4] + tmp[5] + tmp[6] + tmp[7]; + for (unsigned k = 0; k < rem; k++) { + result += *u++ * *v++; + } + return result; +} +#else double DotProductAVX(const double *u, const double *v, int n) { const unsigned quot = n / 8; const unsigned rem = n % 8; @@ -57,6 +98,7 @@ double DotProductAVX(const double *u, const double *v, int n) { } return result; } +#endif } // namespace tesseract. diff --git a/src/arch/dotproductfma.cpp b/src/arch/dotproductfma.cpp index ede46298e8..8ce7ae32d4 100644 --- a/src/arch/dotproductfma.cpp +++ b/src/arch/dotproductfma.cpp @@ -29,6 +29,15 @@ namespace tesseract { // Computes and returns the dot product of the n-vectors u and v. // Uses Intel FMA intrinsics to access the SIMD instruction set. +#if defined(FAST_FLOAT) +TFloat DotProductFMA(const TFloat *u, const TFloat *v, int n) { + TFloat total = 0.0; + for (int k = 0; k < n; ++k) { + total += u[k] * v[k]; + } + return total; +} +#else double DotProductFMA(const double *u, const double *v, int n) { const unsigned quot = n / 8; const unsigned rem = n % 8; @@ -55,6 +64,7 @@ double DotProductFMA(const double *u, const double *v, int n) { } return result; } +#endif } // namespace tesseract. diff --git a/src/arch/dotproductsse.cpp b/src/arch/dotproductsse.cpp index 1dbd18fb8e..ec94f50341 100644 --- a/src/arch/dotproductsse.cpp +++ b/src/arch/dotproductsse.cpp @@ -30,6 +30,15 @@ namespace tesseract { // Computes and returns the dot product of the n-vectors u and v. // Uses Intel SSE intrinsics to access the SIMD instruction set. +#if defined(FAST_FLOAT) +TFloat DotProductSSE(const TFloat *u, const TFloat *v, int n) { + TFloat total = 0.0; + for (int k = 0; k < n; ++k) { + total += u[k] * v[k]; + } + return total; +} +#else double DotProductSSE(const double *u, const double *v, int n) { int max_offset = n - 2; int offset = 0; @@ -78,6 +87,7 @@ double DotProductSSE(const double *u, const double *v, int n) { } return result; } +#endif } // namespace tesseract. diff --git a/src/arch/intsimdmatrix.cpp b/src/arch/intsimdmatrix.cpp index 5d113542cd..fa4afa7c8a 100644 --- a/src/arch/intsimdmatrix.cpp +++ b/src/arch/intsimdmatrix.cpp @@ -76,7 +76,7 @@ void IntSimdMatrix::Init(const GENERIC_2D_ARRAY &w, std::vector // u is imagined to have an extra element at the end with value 1, to // implement the bias, but it doesn't actually have it. void IntSimdMatrix::MatrixDotVector(const GENERIC_2D_ARRAY &w, - const std::vector &scales, const int8_t *u, double *v) { + const std::vector &scales, const int8_t *u, TFloat *v) { int num_out = w.dim1(); int num_in = w.dim2() - 1; // Base implementation. diff --git a/src/arch/intsimdmatrix.h b/src/arch/intsimdmatrix.h index c2947b06f5..aa05a450ee 100644 --- a/src/arch/intsimdmatrix.h +++ b/src/arch/intsimdmatrix.h @@ -23,6 +23,8 @@ #include #include +#include "tfloat.h" + namespace tesseract { template @@ -78,8 +80,8 @@ struct TESS_API IntSimdMatrix { // u is imagined to have an extra element at the end with value 1, to // implement the bias, but it doesn't actually have it. // Computes the base C++ implementation. - static void MatrixDotVector(const GENERIC_2D_ARRAY &w, const std::vector &scales, - const int8_t *u, double *v); + static void MatrixDotVector(const GENERIC_2D_ARRAY &w, const std::vector &scales, + const int8_t *u, TFloat *v); // Rounds the input up to a multiple of the given factor. static int Roundup(int input, int factor) { @@ -95,8 +97,8 @@ struct TESS_API IntSimdMatrix { // RoundInputs above. // The input will be over-read to the extent of the padding. There are no // alignment requirements. - using MatrixDotVectorFunction = void (*)(int, int, const int8_t *, const double *, const int8_t *, - double *); + using MatrixDotVectorFunction = void (*)(int, int, const int8_t *, const TFloat *, const int8_t *, + TFloat *); MatrixDotVectorFunction matrixDotVectorFunction; // Number of 32 bit outputs held in each register. @@ -112,10 +114,10 @@ struct TESS_API IntSimdMatrix { static const IntSimdMatrix *intSimdMatrix; // Only available with NEON. - static const IntSimdMatrix intSimdMatrixNEON; + static const IntSimdMatrix *intSimdMatrixNEON; // Only available with AVX2 / SSE. - static const IntSimdMatrix intSimdMatrixAVX2; - static const IntSimdMatrix intSimdMatrixSSE; + static const IntSimdMatrix *intSimdMatrixAVX2; + static const IntSimdMatrix *intSimdMatrixSSE; }; } // namespace tesseract diff --git a/src/arch/intsimdmatrixavx2.cpp b/src/arch/intsimdmatrixavx2.cpp index ce5d8ea9fe..d417869115 100644 --- a/src/arch/intsimdmatrixavx2.cpp +++ b/src/arch/intsimdmatrixavx2.cpp @@ -15,14 +15,18 @@ // limitations under the License. /////////////////////////////////////////////////////////////////////// +#include "intsimdmatrix.h" + #if !defined(__AVX2__) # if defined(__i686__) || defined(__x86_64__) # error Implementation only for AVX2 capable architectures # endif +#elif defined(FAST_FLOAT) +namespace tesseract { +const IntSimdMatrix *IntSimdMatrix::intSimdMatrixAVX2 = nullptr; +} #else -# include "intsimdmatrix.h" - # include # include # include @@ -331,7 +335,7 @@ static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double * } } -const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = { +static const IntSimdMatrix simdMatrix = { // Function. matrixDotVector, // Number of 32 bit outputs held in each register. @@ -341,7 +345,10 @@ const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = { // Number of 8 bit inputs in the inputs register. kNumInputsPerRegister, // Number of inputs in each weight group. - kNumInputsPerGroup}; + kNumInputsPerGroup +}; + +const IntSimdMatrix *IntSimdMatrix::intSimdMatrixAVX2 = &simdMatrix; } // namespace tesseract. diff --git a/src/arch/intsimdmatrixneon.cpp b/src/arch/intsimdmatrixneon.cpp index cd44c639d7..ae6608133d 100644 --- a/src/arch/intsimdmatrixneon.cpp +++ b/src/arch/intsimdmatrixneon.cpp @@ -186,7 +186,7 @@ static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double * num_out & (kNumOutputsPerRegister - 1)); } -const IntSimdMatrix IntSimdMatrix::intSimdMatrixNEON = { +static const IntSimdMatrix intSimdMatrix = { // Function. matrixDotVector, // Number of 32 bit outputs held in each register. @@ -196,7 +196,10 @@ const IntSimdMatrix IntSimdMatrix::intSimdMatrixNEON = { // Number of 8 bit inputs in the inputs register. kNumInputsPerRegister, // Number of inputs in each weight group. - kNumInputsPerGroup}; + kNumInputsPerGroup +}; + +const IntSimdMatrix *IntSimdMatrix::intSimdMatrixNEON = &intSimdMatrix; } // namespace tesseract. diff --git a/src/arch/intsimdmatrixsse.cpp b/src/arch/intsimdmatrixsse.cpp index 7af6f81be7..7407f6f5a1 100644 --- a/src/arch/intsimdmatrixsse.cpp +++ b/src/arch/intsimdmatrixsse.cpp @@ -15,14 +15,18 @@ // limitations under the License. /////////////////////////////////////////////////////////////////////// +#include "intsimdmatrix.h" + #if !defined(__SSE4_1__) # if defined(__i686__) || defined(__x86_64__) # error Implementation only for SSE 4.1 capable architectures # endif +#elif defined(FAST_FLOAT) +namespace tesseract { +const IntSimdMatrix *IntSimdMatrix::intSimdMatrixSSE = nullptr; +} #else -# include "intsimdmatrix.h" - # include # include # include @@ -90,7 +94,7 @@ static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double * } } -const IntSimdMatrix IntSimdMatrix::intSimdMatrixSSE = { +static const IntSimdMatrix simdMatrix = { matrixDotVector, // Number of 32 bit outputs held in each register. 1, @@ -99,7 +103,10 @@ const IntSimdMatrix IntSimdMatrix::intSimdMatrixSSE = { // Number of 8 bit inputs in the inputs register. 1, // Number of inputs in each weight group. - 1}; + 1 +}; + +const IntSimdMatrix *IntSimdMatrix::intSimdMatrixSSE = &simdMatrix; } // namespace tesseract. diff --git a/src/arch/simddetect.cpp b/src/arch/simddetect.cpp index a14bd19ac4..5b3edc60be 100644 --- a/src/arch/simddetect.cpp +++ b/src/arch/simddetect.cpp @@ -84,8 +84,8 @@ bool SIMDDetect::sse_available_; #endif // Computes and returns the dot product of the two n-vectors u and v. -static double DotProductGeneric(const double *u, const double *v, int n) { - double total = 0.0; +static TFloat DotProductGeneric(const TFloat *u, const TFloat *v, int n) { + TFloat total = 0.0; for (int k = 0; k < n; ++k) { total += u[k] * v[k]; } @@ -93,7 +93,7 @@ static double DotProductGeneric(const double *u, const double *v, int n) { } // Compute dot product using std::inner_product. -static double DotProductStdInnerProduct(const double *u, const double *v, int n) { +static TFloat DotProductStdInnerProduct(const TFloat *u, const TFloat *v, int n) { return std::inner_product(u, u + n, v, 0.0); } @@ -200,22 +200,22 @@ SIMDDetect::SIMDDetect() { #if defined(HAVE_AVX2) } else if (avx2_available_) { // AVX2 detected. - SetDotProduct(DotProductAVX, &IntSimdMatrix::intSimdMatrixAVX2); + SetDotProduct(DotProductAVX, IntSimdMatrix::intSimdMatrixAVX2); #endif #if defined(HAVE_AVX) } else if (avx_available_) { // AVX detected. - SetDotProduct(DotProductAVX, &IntSimdMatrix::intSimdMatrixSSE); + SetDotProduct(DotProductAVX, IntSimdMatrix::intSimdMatrixSSE); #endif #if defined(HAVE_SSE4_1) } else if (sse_available_) { // SSE detected. - SetDotProduct(DotProductSSE, &IntSimdMatrix::intSimdMatrixSSE); + SetDotProduct(DotProductSSE, IntSimdMatrix::intSimdMatrixSSE); #endif #if defined(HAVE_NEON) || defined(__aarch64__) } else if (neon_available_) { // NEON detected. - SetDotProduct(DotProduct, &IntSimdMatrix::intSimdMatrixNEON); + SetDotProduct(DotProduct, IntSimdMatrix::intSimdMatrixNEON); #endif } } @@ -237,13 +237,13 @@ void SIMDDetect::Update() { #if defined(HAVE_AVX2) } else if (!strcmp(dotproduct.c_str(), "avx2")) { // AVX2 selected by config variable. - SetDotProduct(DotProductAVX, &IntSimdMatrix::intSimdMatrixAVX2); + SetDotProduct(DotProductAVX, IntSimdMatrix::intSimdMatrixAVX2); dotproduct_method = "avx2"; #endif #if defined(HAVE_AVX) } else if (!strcmp(dotproduct.c_str(), "avx")) { // AVX selected by config variable. - SetDotProduct(DotProductAVX, &IntSimdMatrix::intSimdMatrixSSE); + SetDotProduct(DotProductAVX, IntSimdMatrix::intSimdMatrixSSE); dotproduct_method = "avx"; #endif #if defined(HAVE_FMA) @@ -255,7 +255,7 @@ void SIMDDetect::Update() { #if defined(HAVE_SSE4_1) } else if (!strcmp(dotproduct.c_str(), "sse")) { // SSE selected by config variable. - SetDotProduct(DotProductSSE, &IntSimdMatrix::intSimdMatrixSSE); + SetDotProduct(DotProductSSE, IntSimdMatrix::intSimdMatrixSSE); dotproduct_method = "sse"; #endif } else if (!strcmp(dotproduct.c_str(), "std::inner_product")) { diff --git a/src/arch/simddetect.h b/src/arch/simddetect.h index e986a1ecaa..1e070fe53c 100644 --- a/src/arch/simddetect.h +++ b/src/arch/simddetect.h @@ -18,11 +18,12 @@ #define TESSERACT_ARCH_SIMDDETECT_H_ #include +#include "tfloat.h" namespace tesseract { // Function pointer for best calculation of dot product. -using DotProductFunction = double (*)(const double *, const double *, int); +using DotProductFunction = TFloat (*)(const TFloat *, const TFloat *, int); extern DotProductFunction DotProduct; // Architecture detector. Add code here to detect any other architectures for diff --git a/src/ccutil/tfloat.h b/src/ccutil/tfloat.h new file mode 100644 index 0000000000..9e7cddc01d --- /dev/null +++ b/src/ccutil/tfloat.h @@ -0,0 +1,10 @@ +#ifndef TESSERACT_LSTM_TFLOAT_H +#define TESSERACT_LSTM_TFLOAT_H + +#ifdef FAST_FLOAT +typedef float TFloat; +#else +typedef double TFloat; +#endif + +#endif diff --git a/src/lstm/fullyconnected.cpp b/src/lstm/fullyconnected.cpp index 80f7f2a5ef..85989f407e 100644 --- a/src/lstm/fullyconnected.cpp +++ b/src/lstm/fullyconnected.cpp @@ -156,7 +156,7 @@ void FullyConnected::Forward(bool debug, const NetworkIO &input, // Thread-local pointer to temporary storage. int thread_id = 0; #endif - double *temp_line = temp_lines[thread_id]; + TFloat *temp_line = temp_lines[thread_id]; if (input.int_mode()) { ForwardTimeStep(input.i(t), t, temp_line); } else { @@ -200,7 +200,7 @@ void FullyConnected::SetupForward(const NetworkIO &input, const TransposedArray } } -void FullyConnected::ForwardTimeStep(int t, double *output_line) { +void FullyConnected::ForwardTimeStep(int t, TFloat *output_line) { if (type_ == NT_TANH) { FuncInplace(no_, output_line); } else if (type_ == NT_LOGISTIC) { @@ -218,7 +218,7 @@ void FullyConnected::ForwardTimeStep(int t, double *output_line) { } } -void FullyConnected::ForwardTimeStep(const double *d_input, int t, double *output_line) { +void FullyConnected::ForwardTimeStep(const TFloat *d_input, int t, TFloat *output_line) { // input is copied to source_ line-by-line for cache coherency. if (IsTraining() && external_source_ == nullptr) { source_t_.WriteStrided(t, d_input); @@ -227,7 +227,7 @@ void FullyConnected::ForwardTimeStep(const double *d_input, int t, double *outpu ForwardTimeStep(t, output_line); } -void FullyConnected::ForwardTimeStep(const int8_t *i_input, int t, double *output_line) { +void FullyConnected::ForwardTimeStep(const int8_t *i_input, int t, TFloat *output_line) { // input is copied to source_ line-by-line for cache coherency. weights_.MatrixDotVector(i_input, output_line); ForwardTimeStep(t, output_line); @@ -265,11 +265,11 @@ bool FullyConnected::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkSc for (int t = 0; t < width; ++t) { int thread_id = 0; #endif - double *backprop = nullptr; + TFloat *backprop = nullptr; if (needs_to_backprop_) { backprop = temp_backprops[thread_id]; } - double *curr_errors = errors[thread_id]; + TFloat *curr_errors = errors[thread_id]; BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop); if (backprop != nullptr) { back_deltas->WriteTimeStep(t, backprop); @@ -287,8 +287,8 @@ bool FullyConnected::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkSc return false; // No point going further back. } -void FullyConnected::BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, - TransposedArray *errors_t, double *backprop) { +void FullyConnected::BackwardTimeStep(const NetworkIO &fwd_deltas, int t, TFloat *curr_errors, + TransposedArray *errors_t, TFloat *backprop) { if (type_ == NT_TANH) { acts_.FuncMultiply(fwd_deltas, t, curr_errors); } else if (type_ == NT_LOGISTIC) { @@ -328,7 +328,7 @@ void FullyConnected::Update(float learning_rate, float momentum, float adam_beta // Sums the products of weight updates in *this and other, splitting into // positive (same direction) in *same and negative (different direction) in // *changed. -void FullyConnected::CountAlternators(const Network &other, double *same, double *changed) const { +void FullyConnected::CountAlternators(const Network &other, TFloat *same, TFloat *changed) const { ASSERT_HOST(other.type() == type_); const auto *fc = static_cast(&other); weights_.CountAlternators(fc->weights_, same, changed); diff --git a/src/lstm/fullyconnected.h b/src/lstm/fullyconnected.h index 95f27f1ab8..1971a0a8ec 100644 --- a/src/lstm/fullyconnected.h +++ b/src/lstm/fullyconnected.h @@ -20,6 +20,7 @@ #include "network.h" #include "networkscratch.h" +#include "tfloat.h" namespace tesseract { @@ -90,17 +91,17 @@ class FullyConnected : public Network { NetworkScratch *scratch, NetworkIO *output) override; // Components of Forward so FullyConnected can be reused inside LSTM. void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose); - void ForwardTimeStep(int t, double *output_line); - void ForwardTimeStep(const double *d_input, int t, double *output_line); - void ForwardTimeStep(const int8_t *i_input, int t, double *output_line); + void ForwardTimeStep(int t, TFloat *output_line); + void ForwardTimeStep(const TFloat *d_input, int t, TFloat *output_line); + void ForwardTimeStep(const int8_t *i_input, int t, TFloat *output_line); // Runs backward propagation of errors on the deltas line. // See Network for a detailed discussion of the arguments. bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override; // Components of Backward so FullyConnected can be reused inside LSTM. - void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, - TransposedArray *errors_t, double *backprop); + void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, TFloat *curr_errors, + TransposedArray *errors_t, TFloat *backprop); void FinishBackward(const TransposedArray &errors_t); // Updates the weights using the given learning rate, momentum and adam_beta. @@ -109,7 +110,7 @@ class FullyConnected : public Network { // Sums the products of weight updates in *this and other, splitting into // positive (same direction) in *same and negative (different direction) in // *changed. - void CountAlternators(const Network &other, double *same, double *changed) const override; + void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override; protected: // Weight arrays of size [no, ni + 1]. diff --git a/src/lstm/functions.cpp b/src/lstm/functions.cpp index 46e1392c19..5eaf0252d0 100644 --- a/src/lstm/functions.cpp +++ b/src/lstm/functions.cpp @@ -1,7 +1,7 @@ // Generated code with lookup tables #include "functions.h" namespace tesseract { -const double TanhTable[] = { +const TFloat TanhTable[] = { 0.0, 0.00390623013190634, 0.007812341058161014, @@ -4099,7 +4099,7 @@ const double TanhTable[] = { 0.9999999999999742, 0.9999999999999745, }; -const double LogisticTable[] = { +const TFloat LogisticTable[] = { 0.5, 0.5009765612582384, 0.5019531150659532, diff --git a/src/lstm/functions.h b/src/lstm/functions.h index 65b4d33453..00bd784402 100644 --- a/src/lstm/functions.h +++ b/src/lstm/functions.h @@ -19,6 +19,7 @@ #define TESSERACT_LSTM_FUNCTIONS_H_ #include "helpers.h" +#include "tfloat.h" // Setting this to 1 or more causes massive dumps of debug data: weights, // updates, internal calculations etc, and reduces the number of test iterations @@ -33,14 +34,14 @@ namespace tesseract { // Size of static tables. constexpr int kTableSize = 4096; // Scale factor for float arg to int index. -constexpr double kScaleFactor = 256.0; +constexpr TFloat kScaleFactor = 256.0; // Generated lookup tables. -extern const double TanhTable[]; -extern const double LogisticTable[]; +extern const TFloat TanhTable[]; +extern const TFloat LogisticTable[]; // Non-linearity (sigmoid) functions with cache tables and clipping. -inline double Tanh(double x) { +inline TFloat Tanh(TFloat x) { if (x < 0.0) { return -Tanh(-x); } @@ -49,13 +50,13 @@ inline double Tanh(double x) { if (index >= (kTableSize - 1)) { return 1.0; } - double tanh_i0 = TanhTable[index]; - double tanh_i1 = TanhTable[index + 1]; + TFloat tanh_i0 = TanhTable[index]; + TFloat tanh_i1 = TanhTable[index + 1]; // Linear interpolation. return tanh_i0 + (tanh_i1 - tanh_i0) * (x - index); } -inline double Logistic(double x) { +inline TFloat Logistic(TFloat x) { if (x < 0.0) { return 1.0 - Logistic(-x); } @@ -64,25 +65,25 @@ inline double Logistic(double x) { if (index >= (kTableSize - 1)) { return 1.0; } - double l0 = LogisticTable[index]; - double l1 = LogisticTable[index + 1]; + TFloat l0 = LogisticTable[index]; + TFloat l1 = LogisticTable[index + 1]; // Linear interpolation. return l0 + (l1 - l0) * (x - index); } // Non-linearity (sigmoid) functions and their derivatives. struct FFunc { - inline double operator()(double x) const { + inline TFloat operator()(TFloat x) const { return Logistic(x); } }; struct FPrime { - inline double operator()(double y) const { + inline TFloat operator()(TFloat y) const { return y * (1.0 - y); } }; struct ClipFFunc { - inline double operator()(double x) const { + inline TFloat operator()(TFloat x) const { if (x <= 0.0) { return 0.0; } @@ -93,12 +94,12 @@ struct ClipFFunc { } }; struct ClipFPrime { - inline double operator()(double y) const { + inline TFloat operator()(TFloat y) const { return 0.0 < y && y < 1.0 ? 1.0 : 0.0; } }; struct Relu { - inline double operator()(double x) const { + inline TFloat operator()(TFloat x) const { if (x <= 0.0) { return 0.0; } @@ -106,22 +107,22 @@ struct Relu { } }; struct ReluPrime { - inline double operator()(double y) const { + inline TFloat operator()(TFloat y) const { return 0.0 < y ? 1.0 : 0.0; } }; struct GFunc { - inline double operator()(double x) const { + inline TFloat operator()(TFloat x) const { return Tanh(x); } }; struct GPrime { - inline double operator()(double y) const { + inline TFloat operator()(TFloat y) const { return 1.0 - y * y; } }; struct ClipGFunc { - inline double operator()(double x) const { + inline TFloat operator()(TFloat x) const { if (x <= -1.0) { return -1.0; } @@ -132,35 +133,35 @@ struct ClipGFunc { } }; struct ClipGPrime { - inline double operator()(double y) const { + inline TFloat operator()(TFloat y) const { return -1.0 < y && y < 1.0 ? 1.0 : 0.0; } }; struct HFunc { - inline double operator()(double x) const { + inline TFloat operator()(TFloat x) const { return Tanh(x); } }; struct HPrime { - inline double operator()(double y) const { - double u = Tanh(y); - return 1.0 - u * u; + inline TFloat operator()(TFloat y) const { + TFloat u = Tanh(y); + return 1 - u * u; } }; struct UnityFunc { - inline double operator()(double /*x*/) const { + inline TFloat operator()(TFloat /*x*/) const { return 1.0; } }; struct IdentityFunc { - inline double operator()(double x) const { + inline TFloat operator()(TFloat x) const { return x; } }; // Applies Func in-place to inout, of size n. template -inline void FuncInplace(int n, double *inout) { +inline void FuncInplace(int n, TFloat *inout) { Func f; for (int i = 0; i < n; ++i) { inout[i] = f(inout[i]); @@ -169,7 +170,7 @@ inline void FuncInplace(int n, double *inout) { // Applies Func to u and multiplies the result by v component-wise, // putting the product in out, all of size n. template -inline void FuncMultiply(const double *u, const double *v, int n, double *out) { +inline void FuncMultiply(const TFloat *u, const TFloat *v, int n, TFloat *out) { Func f; for (int i = 0; i < n; ++i) { out[i] = f(u[i]) * v[i]; @@ -206,34 +207,34 @@ inline void SoftmaxInPlace(int n, T *inout) { } // Copies n values of the given src vector to dest. -inline void CopyVector(int n, const double *src, double *dest) { +inline void CopyVector(int n, const TFloat *src, TFloat *dest) { memcpy(dest, src, n * sizeof(dest[0])); } // Adds n values of the given src vector to dest. -inline void AccumulateVector(int n, const double *src, double *dest) { +inline void AccumulateVector(int n, const TFloat *src, TFloat *dest) { for (int i = 0; i < n; ++i) { dest[i] += src[i]; } } // Multiplies n values of inout in-place element-wise by the given src vector. -inline void MultiplyVectorsInPlace(int n, const double *src, double *inout) { +inline void MultiplyVectorsInPlace(int n, const TFloat *src, TFloat *inout) { for (int i = 0; i < n; ++i) { inout[i] *= src[i]; } } // Multiplies n values of u by v, element-wise, accumulating to out. -inline void MultiplyAccumulate(int n, const double *u, const double *v, double *out) { +inline void MultiplyAccumulate(int n, const TFloat *u, const TFloat *v, TFloat *out) { for (int i = 0; i < n; i++) { out[i] += u[i] * v[i]; } } // Sums the given 5 n-vectors putting the result into sum. -inline void SumVectors(int n, const double *v1, const double *v2, const double *v3, - const double *v4, const double *v5, double *sum) { +inline void SumVectors(int n, const TFloat *v1, const TFloat *v2, const TFloat *v3, + const TFloat *v4, const TFloat *v5, TFloat *sum) { for (int i = 0; i < n; ++i) { sum[i] = v1[i] + v2[i] + v3[i] + v4[i] + v5[i]; } @@ -255,12 +256,12 @@ inline void ClipVector(int n, T lower, T upper, T *vec) { // Converts the given n-vector to a binary encoding of the maximum value, // encoded as vector of nf binary values. -inline void CodeInBinary(int n, int nf, double *vec) { +inline void CodeInBinary(int n, int nf, TFloat *vec) { if (nf <= 0 || n < nf) { return; } int index = 0; - double best_score = vec[0]; + TFloat best_score = vec[0]; for (int i = 1; i < n; ++i) { if (vec[i] > best_score) { best_score = vec[i]; diff --git a/src/lstm/lstm.cpp b/src/lstm/lstm.cpp index 9a8ab2cfe7..352b92f387 100644 --- a/src/lstm/lstm.cpp +++ b/src/lstm/lstm.cpp @@ -68,9 +68,9 @@ namespace tesseract { // Max absolute value of state_. It is reasonably high to enable the state // to count things. -const double kStateClip = 100.0; +const TFloat kStateClip = 100.0; // Max absolute value of gate_errors (the gradients). -const double kErrClip = 1.0f; +const TFloat kErrClip = 1.0f; // Calculate ceil(log2(n)). static inline uint32_t ceil_log2(uint32_t n) { @@ -312,9 +312,9 @@ void LSTM::Forward(bool debug, const NetworkIO &input, const TransposedArray *in // Single timestep buffers for the current/recurrent output and state. NetworkScratch::FloatVec curr_state, curr_output; curr_state.Init(ns_, scratch); - ZeroVector(ns_, curr_state); + ZeroVector(ns_, curr_state); curr_output.Init(ns_, scratch); - ZeroVector(ns_, curr_output); + ZeroVector(ns_, curr_output); // Rotating buffers of width buf_width allow storage of the state and output // for the other dimension, used only when working in true 2D mode. The width // is enough to hold an entire strip of the major direction. @@ -325,9 +325,9 @@ void LSTM::Forward(bool debug, const NetworkIO &input, const TransposedArray *in outputs.resize(buf_width); for (int i = 0; i < buf_width; ++i) { states[i].Init(ns_, scratch); - ZeroVector(ns_, states[i]); + ZeroVector(ns_, states[i]); outputs[i].Init(ns_, scratch); - ZeroVector(ns_, outputs[i]); + ZeroVector(ns_, outputs[i]); } } // Used only if a softmax LSTM. @@ -335,7 +335,7 @@ void LSTM::Forward(bool debug, const NetworkIO &input, const TransposedArray *in NetworkScratch::IO int_output; if (softmax_ != nullptr) { softmax_output.Init(no_, scratch); - ZeroVector(no_, softmax_output); + ZeroVector(no_, softmax_output); int rounded_softmax_inputs = gate_weights_[CI].RoundInputs(ns_); if (input.int_mode()) { int_output.Resize2d(true, 1, rounded_softmax_inputs, scratch); @@ -429,7 +429,7 @@ void LSTM::Forward(bool debug, const NetworkIO &input, const TransposedArray *in int8_t *which_fg_col = which_fg_[t]; memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0])); if (valid_2d) { - const double *stepped_state = states[mod_t]; + const TFloat *stepped_state = states[mod_t]; for (int i = 0; i < ns_; ++i) { if (temp_lines[GF1][i] < temp_lines[GFS][i]) { curr_state[i] = temp_lines[GFS][i] * stepped_state[i]; @@ -440,7 +440,7 @@ void LSTM::Forward(bool debug, const NetworkIO &input, const TransposedArray *in } MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state); // Clip curr_state to a sane range. - ClipVector(ns_, -kStateClip, kStateClip, curr_state); + ClipVector(ns_, -kStateClip, kStateClip, curr_state); if (IsTraining()) { // Save the gate node values. node_values_[CI].WriteTimeStep(t, temp_lines[CI]); @@ -483,8 +483,8 @@ void LSTM::Forward(bool debug, const NetworkIO &input, const TransposedArray *in // Always zero the states at the end of every row, but only for the major // direction. The 2-D state remains intact. if (src_index.IsLast(FD_WIDTH)) { - ZeroVector(ns_, curr_state); - ZeroVector(ns_, curr_output); + ZeroVector(ns_, curr_state); + ZeroVector(ns_, curr_output); } } while (src_index.Increment()); #if DEBUG_DETAIL > 0 @@ -520,8 +520,8 @@ bool LSTM::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scr NetworkScratch::FloatVec curr_stateerr, curr_sourceerr; curr_stateerr.Init(ns_, scratch); curr_sourceerr.Init(na_, scratch); - ZeroVector(ns_, curr_stateerr); - ZeroVector(na_, curr_sourceerr); + ZeroVector(ns_, curr_stateerr); + ZeroVector(na_, curr_sourceerr); // Errors in the gates. NetworkScratch::FloatVec gate_errors[WT_COUNT]; for (auto &gate_error : gate_errors) { @@ -537,8 +537,8 @@ bool LSTM::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scr for (int t = 0; t < buf_width; ++t) { stateerr[t].Init(ns_, scratch); sourceerr[t].Init(na_, scratch); - ZeroVector(ns_, stateerr[t]); - ZeroVector(na_, sourceerr[t]); + ZeroVector(ns_, stateerr[t]); + ZeroVector(na_, sourceerr[t]); } } // Parallel-generated sourceerr from each of the gates. @@ -559,7 +559,7 @@ bool LSTM::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scr softmax_errors.Init(no_, scratch); softmax_errors_t.Init(no_, width, scratch); } - double state_clip = Is2D() ? 9.0 : 4.0; + TFloat state_clip = Is2D() ? 9.0 : 4.0; #if DEBUG_DETAIL > 1 tprintf("fwd_deltas:%s\n", name_.c_str()); fwd_deltas.Print(10); @@ -594,8 +594,8 @@ bool LSTM::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scr int mod_t = Modulo(t, buf_width); // Current timestep. // Zero the state in the major direction only at the end of every row. if (at_last_x) { - ZeroVector(na_, curr_sourceerr); - ZeroVector(ns_, curr_stateerr); + ZeroVector(na_, curr_sourceerr); + ZeroVector(ns_, curr_stateerr); } // Setup the outputerr. if (type_ == NT_LSTM_SUMMARY) { @@ -603,7 +603,7 @@ bool LSTM::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scr fwd_deltas.ReadTimeStep(src_index.t(), outputerr); src_index.Decrement(); } else { - ZeroVector(ns_, outputerr); + ZeroVector(ns_, outputerr); } } else if (softmax_ == nullptr) { fwd_deltas.ReadTimeStep(t, outputerr); @@ -631,7 +631,7 @@ bool LSTM::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scr } if (down_pos >= 0) { const float *right_node_gfs = node_values_[GFS].f(down_pos); - const double *right_stateerr = stateerr[mod_t]; + const TFloat *right_stateerr = stateerr[mod_t]; for (int i = 0; i < ns_; ++i) { if (which_fg_[down_pos][i] == 2) { curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i]; @@ -641,7 +641,7 @@ bool LSTM::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scr } state_.FuncMultiply3Add(node_values_[GO], t, outputerr, curr_stateerr); // Clip stateerr_ to a sane range. - ClipVector(ns_, -state_clip, state_clip, curr_stateerr); + ClipVector(ns_, -state_clip, state_clip, curr_stateerr); #if DEBUG_DETAIL > 1 if (t + 10 > width) { tprintf("t=%d, stateerr=", t); @@ -758,7 +758,7 @@ void LSTM::Update(float learning_rate, float momentum, float adam_beta, int num_ // Sums the products of weight updates in *this and other, splitting into // positive (same direction) in *same and negative (different direction) in // *changed. -void LSTM::CountAlternators(const Network &other, double *same, double *changed) const { +void LSTM::CountAlternators(const Network &other, TFloat *same, TFloat *changed) const { ASSERT_HOST(other.type() == type_); const LSTM *lstm = static_cast(&other); for (int w = 0; w < WT_COUNT; ++w) { diff --git a/src/lstm/lstm.h b/src/lstm/lstm.h index 4d399b1dbc..7f7cace1b2 100644 --- a/src/lstm/lstm.h +++ b/src/lstm/lstm.h @@ -109,7 +109,7 @@ class LSTM : public Network { // Sums the products of weight updates in *this and other, splitting into // positive (same direction) in *same and negative (different direction) in // *changed. - void CountAlternators(const Network &other, double *same, double *changed) const override; + void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override; // Prints the weights for debug purposes. void PrintW(); // Prints the weight deltas for debug purposes. diff --git a/src/lstm/network.h b/src/lstm/network.h index 4faac88de4..3de7bae99f 100644 --- a/src/lstm/network.h +++ b/src/lstm/network.h @@ -235,7 +235,7 @@ class TESS_API Network { // Sums the products of weight updates in *this and other, splitting into // positive (same direction) in *same and negative (different direction) in // *changed. - virtual void CountAlternators(const Network &other, double *same, double *changed) const {} + virtual void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const {} // Reads from the given file. Returns nullptr in case of error. // Determines the type of the serialized class and calls its DeSerialize diff --git a/src/lstm/networkio.cpp b/src/lstm/networkio.cpp index 72f33efcb4..ce4d63651b 100644 --- a/src/lstm/networkio.cpp +++ b/src/lstm/networkio.cpp @@ -531,7 +531,7 @@ int NetworkIO::PositionOfBestMatch(const std::vector &labels, int start, in int best_start = -1; double best_score = 0.0; for (int s = start; s <= last_start; ++s) { - double score = ScoreOfLabels(labels, s); + TFloat score = ScoreOfLabels(labels, s); if (score > best_score || best_start < 0) { best_score = score; best_start = s; @@ -542,9 +542,9 @@ int NetworkIO::PositionOfBestMatch(const std::vector &labels, int start, in // Returns the cumulative score of the given labels starting at start, and // using one label per time-step. -double NetworkIO::ScoreOfLabels(const std::vector &labels, int start) const { +TFloat NetworkIO::ScoreOfLabels(const std::vector &labels, int start) const { int length = labels.size(); - double score = 0.0; + TFloat score = 0.0; for (int i = 0; i < length; ++i) { score += f_(start + i, labels[i]); } @@ -615,27 +615,27 @@ bool NetworkIO::AnySuspiciousTruth(float confidence_thr) const { } // Reads a single timestep to floats in the range [-1, 1]. -void NetworkIO::ReadTimeStep(int t, double *output) const { +void NetworkIO::ReadTimeStep(int t, TFloat *output) const { if (int_mode_) { const int8_t *line = i_[t]; for (int i = 0; i < i_.dim2(); ++i) { - output[i] = static_cast(line[i]) / INT8_MAX; + output[i] = static_cast(line[i]) / INT8_MAX; } } else { const float *line = f_[t]; for (int i = 0; i < f_.dim2(); ++i) { - output[i] = static_cast(line[i]); + output[i] = static_cast(line[i]); } } } // Adds a single timestep to floats. -void NetworkIO::AddTimeStep(int t, double *inout) const { +void NetworkIO::AddTimeStep(int t, TFloat *inout) const { int num_features = NumFeatures(); if (int_mode_) { const int8_t *line = i_[t]; for (int i = 0; i < num_features; ++i) { - inout[i] += static_cast(line[i]) / INT8_MAX; + inout[i] += static_cast(line[i]) / INT8_MAX; } } else { const float *line = f_[t]; @@ -661,13 +661,13 @@ void NetworkIO::AddTimeStepPart(int t, int offset, int num_features, float *inou } // Writes a single timestep from floats in the range [-1, 1]. -void NetworkIO::WriteTimeStep(int t, const double *input) { +void NetworkIO::WriteTimeStep(int t, const TFloat *input) { WriteTimeStepPart(t, 0, NumFeatures(), input); } // Writes a single timestep from floats in the range [-1, 1] writing only // num_features elements of input to (*this)[t], starting at offset. -void NetworkIO::WriteTimeStepPart(int t, int offset, int num_features, const double *input) { +void NetworkIO::WriteTimeStepPart(int t, int offset, int num_features, const TFloat *input) { if (int_mode_) { int8_t *line = i_[t] + offset; for (int i = 0; i < num_features; ++i) { diff --git a/src/lstm/networkio.h b/src/lstm/networkio.h index e170bc8537..b1caadcb75 100644 --- a/src/lstm/networkio.h +++ b/src/lstm/networkio.h @@ -172,7 +172,7 @@ class TESS_API NetworkIO { int PositionOfBestMatch(const std::vector &labels, int start, int end) const; // Returns the cumulative score of the given labels starting at start, and // using one label per time-step. - double ScoreOfLabels(const std::vector &labels, int start) const; + TFloat ScoreOfLabels(const std::vector &labels, int start) const; // Helper function sets all the outputs for a single timestep, such that // label has value ok_score, and the other labels share 1 - ok_score. // Assumes float mode. @@ -193,16 +193,16 @@ class TESS_API NetworkIO { bool AnySuspiciousTruth(float confidence_thr) const; // Reads a single timestep to floats in the range [-1, 1]. - void ReadTimeStep(int t, double *output) const; + void ReadTimeStep(int t, TFloat *output) const; // Adds a single timestep to floats. - void AddTimeStep(int t, double *inout) const; + void AddTimeStep(int t, TFloat *inout) const; // Adds part of a single timestep to floats. void AddTimeStepPart(int t, int offset, int num_features, float *inout) const; // Writes a single timestep from floats in the range [-1, 1]. - void WriteTimeStep(int t, const double *input); + void WriteTimeStep(int t, const TFloat *input); // Writes a single timestep from floats in the range [-1, 1] writing only // num_features elements of input to (*this)[t], starting at offset. - void WriteTimeStepPart(int t, int offset, int num_features, const double *input); + void WriteTimeStepPart(int t, int offset, int num_features, const TFloat *input); // Maxpools a single time step from src. void MaxpoolTimeStep(int dest_t, const NetworkIO &src, int src_t, int *max_line); // Runs maxpool backward, using maxes to index timesteps in *this. @@ -253,9 +253,9 @@ class TESS_API NetworkIO { // Applies Func to timestep t of *this (u) and multiplies the result by v // component-wise, putting the product in *product. - // *this and v may be int or float, but must match. The outputs are double. + // *this and v may be int or float, but must match. The outputs are TFloat. template - void FuncMultiply(const NetworkIO &v_io, int t, double *product) { + void FuncMultiply(const NetworkIO &v_io, int t, TFloat *product) { Func f; ASSERT_HOST(!int_mode_); ASSERT_HOST(!v_io.int_mode_); @@ -264,7 +264,7 @@ class TESS_API NetworkIO { const int8_t *u = i_[t]; const int8_t *v = v_io.i_[t]; for (int i = 0; i < dim; ++i) { - product[i] = f(u[i] / static_cast(INT8_MAX)) * v[i] / static_cast(INT8_MAX); + product[i] = f(u[i] / static_cast(INT8_MAX)) * v[i] / INT8_MAX; } } else { const float *u = f_[t]; @@ -278,8 +278,8 @@ class TESS_API NetworkIO { // component-wise, putting the product in *product. // All NetworkIOs are assumed to be float. template - void FuncMultiply3(int u_t, const NetworkIO &v_io, int v_t, const double *w, - double *product) const { + void FuncMultiply3(int u_t, const NetworkIO &v_io, int v_t, const TFloat *w, + TFloat *product) const { ASSERT_HOST(!int_mode_); ASSERT_HOST(!v_io.int_mode_); Func f; @@ -294,7 +294,7 @@ class TESS_API NetworkIO { // component-wise, adding the product to *product. // All NetworkIOs are assumed to be float. template - void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const { + void FuncMultiply3Add(const NetworkIO &v_io, int t, const TFloat *w, TFloat *product) const { ASSERT_HOST(!int_mode_); ASSERT_HOST(!v_io.int_mode_); Func f; @@ -309,7 +309,7 @@ class TESS_API NetworkIO { // component-wise, putting the product in product, all at timestep t, except // w, which is a simple array. All NetworkIOs are assumed to be float. template - void Func2Multiply3(const NetworkIO &v_io, int t, const double *w, double *product) const { + void Func2Multiply3(const NetworkIO &v_io, int t, const TFloat *w, TFloat *product) const { ASSERT_HOST(!int_mode_); ASSERT_HOST(!v_io.int_mode_); Func1 f; diff --git a/src/lstm/networkscratch.h b/src/lstm/networkscratch.h index 703a9b97aa..869560e1b6 100644 --- a/src/lstm/networkscratch.h +++ b/src/lstm/networkscratch.h @@ -156,25 +156,25 @@ class NetworkScratch { } // Use the cast operator instead of operator[] so the FloatVec can be used - // as a double* argument to a function call. - operator double *() const { + // as a TFloat* argument to a function call. + operator TFloat *() const { return data_; } - double *get() { + TFloat *get() { return data_; } private: // Vector borrowed from the scratch space. Use Return to free it. - std::vector *vec_; + std::vector *vec_; // Short-cut pointer to the underlying array. - double *data_; + TFloat *data_; // The source scratch_space_. Borrowed pointer, used to free the // vector. Don't delete! NetworkScratch *scratch_space_; }; // class FloatVec - // Class that acts like a 2-D array of double, yet actually uses space + // Class that acts like a 2-D array of TFloat, yet actually uses space // from the source NetworkScratch, and knows how to unstack the borrowed // array on destruction. class GradientStore { @@ -270,7 +270,7 @@ class NetworkScratch { // deleted until the NetworkScratch is deleted. Stack int_stack_; Stack float_stack_; - Stack> vec_stack_; + Stack> vec_stack_; Stack array_stack_; }; diff --git a/src/lstm/plumbing.cpp b/src/lstm/plumbing.cpp index 98ec64ef28..34f9be8f55 100644 --- a/src/lstm/plumbing.cpp +++ b/src/lstm/plumbing.cpp @@ -255,7 +255,7 @@ void Plumbing::Update(float learning_rate, float momentum, float adam_beta, int // Sums the products of weight updates in *this and other, splitting into // positive (same direction) in *same and negative (different direction) in // *changed. -void Plumbing::CountAlternators(const Network &other, double *same, double *changed) const { +void Plumbing::CountAlternators(const Network &other, TFloat *same, TFloat *changed) const { ASSERT_HOST(other.type() == type_); const auto *plumbing = static_cast(&other); ASSERT_HOST(plumbing->stack_.size() == stack_.size()); diff --git a/src/lstm/plumbing.h b/src/lstm/plumbing.h index fe0f499b21..c1ecc2f23e 100644 --- a/src/lstm/plumbing.h +++ b/src/lstm/plumbing.h @@ -143,7 +143,7 @@ class Plumbing : public Network { // Sums the products of weight updates in *this and other, splitting into // positive (same direction) in *same and negative (different direction) in // *changed. - void CountAlternators(const Network &other, double *same, double *changed) const override; + void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override; protected: // The networks. diff --git a/src/lstm/weightmatrix.cpp b/src/lstm/weightmatrix.cpp index e24f95f083..ba7b8653e7 100644 --- a/src/lstm/weightmatrix.cpp +++ b/src/lstm/weightmatrix.cpp @@ -21,12 +21,12 @@ #include "intsimdmatrix.h" #include "simddetect.h" // for DotProduct #include "statistc.h" -#include "tprintf.h" +#include "tprintf.h" // forTFloat namespace tesseract { #if defined(ANDROID) -static inline double log2(double n) { +static inline TFloat log2(TFloat n) { return log(n) / log(2.0); } #endif // ANDROID @@ -34,7 +34,59 @@ static inline double log2(double n) { // Number of iterations after which the correction effectively becomes unity. const int kAdamCorrectionIterations = 200000; // Epsilon in Adam to prevent division by zero. -const double kAdamEpsilon = 1e-8; +const TFloat kAdamEpsilon = 1e-8; + +// Utility functions convert between double and float arrays. +#ifdef FAST_FLOAT +static void DoubleToFloat(const GENERIC_2D_ARRAY &src, GENERIC_2D_ARRAY &dst) { + const auto dim1 = src.dim1(); + const auto dim2 = src.dim2(); + dst.ResizeNoInit(dim1, dim2); + for (int i = 0; i < dim1; ++i) { + const auto *src_i = src[i]; + auto *dst_i = dst[i]; + for (int j = 0; j < dim2; ++j) { + dst_i[j] = static_cast(src_i[j]); + } + } +} +#endif + +static void FloatToDouble(const GENERIC_2D_ARRAY &src, GENERIC_2D_ARRAY &dst) { + const auto dim1 = src.dim1(); + const auto dim2 = src.dim2(); + dst.ResizeNoInit(dim1, dim2); + for (int i = 0; i < dim1; ++i) { + const auto *src_i = src[i]; + auto *dst_i = dst[i]; + for (int j = 0; j < dim2; ++j) { + dst_i[j] = static_cast(src_i[j]); + } + } +} + +static bool DeSerialize(TFile *fp, GENERIC_2D_ARRAY &tfloat_array) { +#ifdef FAST_FLOAT + GENERIC_2D_ARRAY double_array; + if (!double_array.DeSerialize(fp)) { + return false; + } + DoubleToFloat(double_array, tfloat_array); + return true; +#else + return tfloat_array.DeSerialize(fp); +#endif +} + +static bool Serialize(TFile *fp, const GENERIC_2D_ARRAY &tfloat_array) { +#ifdef FAST_FLOAT + GENERIC_2D_ARRAY double_array; + FloatToDouble(tfloat_array, double_array); + return double_array.Serialize(fp); +#else + return tfloat_array.Serialize(fp); +#endif +} // Computes matrix.vector v = Wu. // u is of size W.dim2() - add_bias_fwd and the output v is of size @@ -44,13 +96,13 @@ const double kAdamEpsilon = 1e-8; // If skip_bias_back, we are actually performing the backwards product on a // transposed matrix, so we need to drop the v output corresponding to the last // element in dim1. -static inline void MatrixDotVectorInternal(const GENERIC_2D_ARRAY &w, bool add_bias_fwd, - bool skip_bias_back, const double *u, double *v) { +static inline void MatrixDotVectorInternal(const GENERIC_2D_ARRAY &w, bool add_bias_fwd, + bool skip_bias_back, const TFloat *u, TFloat *v) { int num_results = w.dim1() - skip_bias_back; int extent = w.dim2() - add_bias_fwd; for (int i = 0; i < num_results; ++i) { - const double *wi = w[i]; - double total = DotProduct(wi, u, extent); + const TFloat *wi = w[i]; + TFloat total = DotProduct(wi, u, extent); if (add_bias_fwd) { total += wi[extent]; // The bias value. } @@ -58,8 +110,8 @@ static inline void MatrixDotVectorInternal(const GENERIC_2D_ARRAY &w, bo } } -// Copies the whole input transposed, converted to double, into *this. -void TransposedArray::Transpose(const GENERIC_2D_ARRAY &input) { +// Copies the whole input transposed, converted to TFloat, into *this. +void TransposedArray::Transpose(const GENERIC_2D_ARRAY &input) { int width = input.dim1(); int num_features = input.dim2(); ResizeNoInit(num_features, width); @@ -97,25 +149,25 @@ int WeightMatrix::InitWeightsFloat(int no, int ni, bool use_adam, float weight_r // for all outputs with negative code_map entries. Returns the new number of // weights. int WeightMatrix::RemapOutputs(const std::vector &code_map) { - GENERIC_2D_ARRAY old_wf(wf_); + GENERIC_2D_ARRAY old_wf(wf_); int old_no = wf_.dim1(); int new_no = code_map.size(); int ni = wf_.dim2(); - std::vector means(ni, 0.0); + std::vector means(ni, 0.0); for (int c = 0; c < old_no; ++c) { - const double *weights = wf_[c]; + const TFloat *weights = wf_[c]; for (int i = 0; i < ni; ++i) { means[i] += weights[i]; } } - for (double &mean : means) { + for (TFloat &mean : means) { mean /= old_no; } wf_.Resize(new_no, ni, 0.0); InitBackward(); for (int dest = 0; dest < new_no; ++dest) { int src = code_map[dest]; - const double *src_data = src >= 0 ? old_wf[src] : means.data(); + const TFloat *src_data = src >= 0 ? old_wf[src] : means.data(); memcpy(wf_[dest], src_data, ni * sizeof(*src_data)); } return ni * new_no; @@ -126,23 +178,23 @@ int WeightMatrix::RemapOutputs(const std::vector &code_map) { // Compute the max absolute value of the weight set. // Scale so the max absolute value becomes INT8_MAX. // Round to integer. -// Store a multiplicative scale factor (as a double) that will reproduce +// Store a multiplicative scale factor (as a TFloat) that will reproduce // the original value, subject to rounding errors. void WeightMatrix::ConvertToInt() { wi_.ResizeNoInit(wf_.dim1(), wf_.dim2()); scales_.reserve(wi_.dim1()); int dim2 = wi_.dim2(); for (int t = 0; t < wi_.dim1(); ++t) { - double *f_line = wf_[t]; + TFloat *f_line = wf_[t]; int8_t *i_line = wi_[t]; - double max_abs = 0.0; + TFloat max_abs = 0.0; for (int f = 0; f < dim2; ++f) { - double abs_val = fabs(f_line[f]); + TFloat abs_val = fabs(f_line[f]); if (abs_val > max_abs) { max_abs = abs_val; } } - double scale = max_abs / INT8_MAX; + TFloat scale = max_abs / INT8_MAX; scales_.push_back(scale / INT8_MAX); if (scale == 0.0) { scale = 1.0; @@ -177,9 +229,9 @@ void WeightMatrix::InitBackward() { const int kInt8Flag = 1; // Flag on mode to indicate that this weightmatrix uses adam. const int kAdamFlag = 4; -// Flag on mode to indicate that this weightmatrix uses double. Set +// Flag on mode to indicate that this weightmatrix uses TFloat. Set // independently of kInt8Flag as even in int mode the scales can -// be float or double. +// be float or TFloat. const int kDoubleFlag = 128; // Writes to the given file. Returns false in case of error. @@ -205,18 +257,25 @@ bool WeightMatrix::Serialize(bool training, TFile *fp) const { if (!fp->Serialize(&size)) { return false; } +#ifdef FAST_FLOAT + assert(!"not implemented"); + return false; +#else if (!fp->Serialize(&scales[0], size)) { return false; } +#endif } else { - if (!wf_.Serialize(fp)) { - return false; - } - if (training && !updates_.Serialize(fp)) { + if (!tesseract::Serialize(fp, wf_)) { return false; } - if (training && use_adam_ && !dw_sq_sum_.Serialize(fp)) { - return false; + if (training) { + if (!tesseract::Serialize(fp, updates_)) { + return false; + } + if (use_adam_ && !tesseract::Serialize(fp, dw_sq_sum_)) { + return false; + } } } return true; @@ -242,6 +301,16 @@ bool WeightMatrix::DeSerialize(bool training, TFile *fp) { if (!fp->DeSerialize(&size)) { return false; } +#ifdef FAST_FLOAT + scales_.reserve(size); + for (auto n = size; n > 0; n--) { + double val; + if (!fp->DeSerialize(&val)) { + return false; + } + scales_.push_back(val / INT8_MAX); + } +#else scales_.resize(size); if (!fp->DeSerialize(&scales_[0], size)) { return false; @@ -249,22 +318,25 @@ bool WeightMatrix::DeSerialize(bool training, TFile *fp) { for (auto &scale : scales_) { scale /= INT8_MAX; } +#endif if (IntSimdMatrix::intSimdMatrix) { int32_t rounded_num_out; IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_, rounded_num_out); scales_.resize(rounded_num_out); } } else { - if (!wf_.DeSerialize(fp)) { + if (!tesseract::DeSerialize(fp, wf_)) { return false; } if (training) { InitBackward(); - if (!updates_.DeSerialize(fp)) { + if (!tesseract::DeSerialize(fp, updates_)) { return false; } - if (use_adam_ && !dw_sq_sum_.DeSerialize(fp)) { - return false; + if (use_adam_) { + if (!tesseract::DeSerialize(fp, dw_sq_sum_)) { + return false; + } } } } @@ -274,7 +346,11 @@ bool WeightMatrix::DeSerialize(bool training, TFile *fp) { // As DeSerialize, but reads an old (float) format WeightMatrix for // backward compatibility. bool WeightMatrix::DeSerializeOld(bool training, TFile *fp) { - GENERIC_2D_ARRAY float_array; +#ifdef FAST_FLOAT + // Not implemented. + assert(!"not implemented"); + return false; +#else if (int_mode_) { if (!wi_.DeSerialize(fp)) { return false; @@ -288,23 +364,26 @@ bool WeightMatrix::DeSerializeOld(bool training, TFile *fp) { scales_.push_back(old_scale); } } else { + GENERIC_2D_ARRAY float_array; if (!float_array.DeSerialize(fp)) { return false; } - FloatToDouble(float_array, &wf_); + FloatToDouble(float_array, wf_); } if (training) { InitBackward(); + GENERIC_2D_ARRAY float_array; if (!float_array.DeSerialize(fp)) { return false; } - FloatToDouble(float_array, &updates_); + FloatToDouble(float_array, updates_); // Errs was only used in int training, which is now dead. if (!float_array.DeSerialize(fp)) { return false; } } return true; +#endif } // Computes matrix.vector v = Wu. @@ -312,12 +391,12 @@ bool WeightMatrix::DeSerializeOld(bool training, TFile *fp) { // u is imagined to have an extra element at the end with value 1, to // implement the bias, but it doesn't actually have it. // Asserts that the call matches what we have. -void WeightMatrix::MatrixDotVector(const double *u, double *v) const { +void WeightMatrix::MatrixDotVector(const TFloat *u, TFloat *v) const { assert(!int_mode_); MatrixDotVectorInternal(wf_, true, false, u, v); } -void WeightMatrix::MatrixDotVector(const int8_t *u, double *v) const { +void WeightMatrix::MatrixDotVector(const int8_t *u, TFloat *v) const { assert(int_mode_); if (IntSimdMatrix::intSimdMatrix) { IntSimdMatrix::intSimdMatrix->matrixDotVectorFunction(wi_.dim1(), wi_.dim2(), &shaped_w_[0], @@ -329,11 +408,11 @@ void WeightMatrix::MatrixDotVector(const int8_t *u, double *v) const { // MatrixDotVector for peep weights, MultiplyAccumulate adds the // component-wise products of *this[0] and v to inout. -void WeightMatrix::MultiplyAccumulate(const double *v, double *inout) { +void WeightMatrix::MultiplyAccumulate(const TFloat *v, TFloat *inout) { assert(!int_mode_); assert(wf_.dim1() == 1); int n = wf_.dim2(); - const double *u = wf_[0]; + const TFloat *u = wf_[0]; for (int i = 0; i < n; ++i) { inout[i] += u[i] * v[i]; } @@ -343,7 +422,7 @@ void WeightMatrix::MultiplyAccumulate(const double *v, double *inout) { // u is of size W.dim1() and the output v is of size W.dim2() - 1. // The last result is discarded, as v is assumed to have an imaginary // last value of 1, as with MatrixDotVector. -void WeightMatrix::VectorDotMatrix(const double *u, double *v) const { +void WeightMatrix::VectorDotMatrix(const TFloat *u, TFloat *v) const { assert(!int_mode_); MatrixDotVectorInternal(wf_t_, false, true, u, v); } @@ -367,13 +446,13 @@ void WeightMatrix::SumOuterTransposed(const TransposedArray &u, const Transposed # pragma omp parallel for num_threads(4) if (in_parallel) #endif for (int i = 0; i < num_outputs; ++i) { - double *dwi = dw_[i]; - const double *ui = u[i]; + TFloat *dwi = dw_[i]; + const TFloat *ui = u[i]; for (int j = 0; j < num_inputs; ++j) { dwi[j] = DotProduct(ui, v[j], num_samples); } // The last element of v is missing, presumed 1.0f. - double total = 0.0; + TFloat total = 0.0; for (int k = 0; k < num_samples; ++k) { total += ui[k]; } @@ -419,17 +498,17 @@ void WeightMatrix::AddDeltas(const WeightMatrix &other) { // Sums the products of weight updates in *this and other, splitting into // positive (same direction) in *same and negative (different direction) in // *changed. -void WeightMatrix::CountAlternators(const WeightMatrix &other, double *same, - double *changed) const { +void WeightMatrix::CountAlternators(const WeightMatrix &other, TFloat *same, + TFloat *changed) const { int num_outputs = updates_.dim1(); int num_inputs = updates_.dim2(); assert(num_outputs == other.updates_.dim1()); assert(num_inputs == other.updates_.dim2()); for (int i = 0; i < num_outputs; ++i) { - const double *this_i = updates_[i]; - const double *other_i = other.updates_[i]; + const TFloat *this_i = updates_[i]; + const TFloat *other_i = other.updates_[i]; for (int j = 0; j < num_inputs; ++j) { - double product = this_i[j] * other_i[j]; + TFloat product = this_i[j] * other_i[j]; if (product < 0.0) { *changed -= product; } else { @@ -442,10 +521,10 @@ void WeightMatrix::CountAlternators(const WeightMatrix &other, double *same, // Helper computes an integer histogram bucket for a weight and adds it // to the histogram. const int kHistogramBuckets = 16; -static void HistogramWeight(double weight, STATS *histogram) { +static void HistogramWeight(TFloat weight, STATS *histogram) { int bucket = kHistogramBuckets - 1; if (weight != 0.0) { - double logval = -log2(fabs(weight)); + TFloat logval = -log2(fabs(weight)); bucket = ClipToRange(IntCastRounded(logval), 0, kHistogramBuckets - 1); } histogram->add(bucket, 1); @@ -470,20 +549,4 @@ void WeightMatrix::Debug2D(const char *msg) { histogram.print(); } -// Utility function converts an array of float to the corresponding array -// of double. -/* static */ -void WeightMatrix::FloatToDouble(const GENERIC_2D_ARRAY &wf, GENERIC_2D_ARRAY *wd) { - int dim1 = wf.dim1(); - int dim2 = wf.dim2(); - wd->ResizeNoInit(dim1, dim2); - for (int i = 0; i < dim1; ++i) { - const float *wfi = wf[i]; - double *wdi = (*wd)[i]; - for (int j = 0; j < dim2; ++j) { - wdi[j] = static_cast(wfi[j]); - } - } -} - } // namespace tesseract. diff --git a/src/lstm/weightmatrix.h b/src/lstm/weightmatrix.h index bdcdc948ab..9336de4324 100644 --- a/src/lstm/weightmatrix.h +++ b/src/lstm/weightmatrix.h @@ -23,16 +23,17 @@ #include "intsimdmatrix.h" #include "matrix.h" #include "tprintf.h" +#include "tfloat.h" namespace tesseract { -// Convenience instantiation of GENERIC_2D_ARRAY with additional +// Convenience instantiation of GENERIC_2D_ARRAY with additional // operations to write a strided vector, so the transposed form of the input // is memory-contiguous. -class TransposedArray : public GENERIC_2D_ARRAY { +class TransposedArray : public GENERIC_2D_ARRAY { public: - // Copies the whole input transposed, converted to double, into *this. - void Transpose(const GENERIC_2D_ARRAY &input); + // Copies the whole input transposed, converted to TFloat, into *this. + void Transpose(const GENERIC_2D_ARRAY &input); // Writes a vector of data representing a timestep (gradients or sources). // The data is assumed to be of size1 in size (the strided dimension). ~TransposedArray() override; @@ -107,11 +108,11 @@ class WeightMatrix { return int_mode_ ? wi_.dim1() : wf_.dim1(); } // Provides one set of weights. Only used by peep weight maxpool. - const double *GetWeights(int index) const { + const TFloat *GetWeights(int index) const { return wf_[index]; } // Provides access to the deltas (dw_). - double GetDW(int i, int j) const { + TFloat GetDW(int i, int j) const { return dw_(i, j); } @@ -132,16 +133,16 @@ class WeightMatrix { // u is imagined to have an extra element at the end with value 1, to // implement the bias, but it doesn't actually have it. // Asserts that the call matches what we have. - void MatrixDotVector(const double *u, double *v) const; - void MatrixDotVector(const int8_t *u, double *v) const; + void MatrixDotVector(const TFloat *u, TFloat *v) const; + void MatrixDotVector(const int8_t *u, TFloat *v) const; // MatrixDotVector for peep weights, MultiplyAccumulate adds the // component-wise products of *this[0] and v to inout. - void MultiplyAccumulate(const double *v, double *inout); + void MultiplyAccumulate(const TFloat *v, TFloat *inout); // Computes vector.matrix v = uW. // u is of size W.dim1() and the output v is of size W.dim2() - 1. // The last result is discarded, as v is assumed to have an imaginary // last value of 1, as with MatrixDotVector. - void VectorDotMatrix(const double *u, double *v) const; + void VectorDotMatrix(const TFloat *u, TFloat *v) const; // Fills dw_[i][j] with the dot product u[i][] . v[j][], using elements // from u and v, starting with u[i][offset] and v[j][offset]. // Note that (matching MatrixDotVector) v[last][] is missing, presumed 1.0. @@ -155,17 +156,13 @@ class WeightMatrix { // Sums the products of weight updates in *this and other, splitting into // positive (same direction) in *same and negative (different direction) in // *changed. - void CountAlternators(const WeightMatrix &other, double *same, double *changed) const; + void CountAlternators(const WeightMatrix &other, TFloat *same, TFloat *changed) const; void Debug2D(const char *msg); - // Utility function converts an array of float to the corresponding array - // of double. - static void FloatToDouble(const GENERIC_2D_ARRAY &wf, GENERIC_2D_ARRAY *wd); - private: // Choice between float and 8 bit int implementations. - GENERIC_2D_ARRAY wf_; + GENERIC_2D_ARRAY wf_; GENERIC_2D_ARRAY wi_; // Transposed copy of wf_, used only for Backward, and set with each Update. TransposedArray wf_t_; @@ -175,14 +172,14 @@ class WeightMatrix { bool use_adam_; // If we are using wi_, then scales_ is a factor to restore the row product // with a vector to the correct range. - std::vector scales_; + std::vector scales_; // Weight deltas. dw_ is the new delta, and updates_ the momentum-decaying // amount to be added to wf_/wi_. - GENERIC_2D_ARRAY dw_; - GENERIC_2D_ARRAY updates_; + GENERIC_2D_ARRAY dw_; + GENERIC_2D_ARRAY updates_; // Iff use_adam_, the sum of squares of dw_. The number of samples is // given to Update(). Serialized iff use_adam_. - GENERIC_2D_ARRAY dw_sq_sum_; + GENERIC_2D_ARRAY dw_sq_sum_; // The weights matrix reorganized in whatever way suits this instance. std::vector shaped_w_; }; diff --git a/src/training/unicharset/lstmtrainer.cpp b/src/training/unicharset/lstmtrainer.cpp index f21c0d3d2d..1f8d638ae5 100644 --- a/src/training/unicharset/lstmtrainer.cpp +++ b/src/training/unicharset/lstmtrainer.cpp @@ -661,7 +661,7 @@ void LSTMTrainer::ReduceLearningRates(LSTMTrainer *samples_trainer, std::string // Even if it looks like all weights should remain the same, an adjustment // will be made to guarantee a different result when reverting to an old best. // Returns the number of layer learning rates that were reduced. -int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples, +int LSTMTrainer::ReduceLayerLearningRates(TFloat factor, int num_samples, LSTMTrainer *samples_trainer) { enum WhichWay { LR_DOWN, // Learning rate will go down by factor. @@ -671,13 +671,13 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples, std::vector layers = EnumerateLayers(); int num_layers = layers.size(); std::vector num_weights(num_layers); - std::vector bad_sums[LR_COUNT]; - std::vector ok_sums[LR_COUNT]; + std::vector bad_sums[LR_COUNT]; + std::vector ok_sums[LR_COUNT]; for (int i = 0; i < LR_COUNT; ++i) { bad_sums[i].resize(num_layers, 0.0); ok_sums[i].resize(num_layers, 0.0); } - double momentum_factor = 1.0 / (1.0 - momentum_); + TFloat momentum_factor = 1.0 / (1.0 - momentum_); std::vector orig_trainer; samples_trainer->SaveTrainingDump(LIGHT, *this, &orig_trainer); for (int i = 0; i < num_layers; ++i) { @@ -748,10 +748,10 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples, } Network *layer = GetLayer(layers[i]); float lr = GetLayerLearningRate(layers[i]); - double total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i]; - double total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i]; - double frac_down = bad_sums[LR_DOWN][i] / total_down; - double frac_same = bad_sums[LR_SAME][i] / total_same; + TFloat total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i]; + TFloat total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i]; + TFloat frac_down = bad_sums[LR_DOWN][i] / total_down; + TFloat frac_same = bad_sums[LR_SAME][i] / total_same; tprintf("Layer %d=%s: lr %g->%g%%, lr %g->%g%%", i, layer->name().c_str(), lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same); if (frac_down < frac_same * kImprovementFraction) { diff --git a/src/training/unicharset/lstmtrainer.h b/src/training/unicharset/lstmtrainer.h index a57c819083..cdebf9897d 100644 --- a/src/training/unicharset/lstmtrainer.h +++ b/src/training/unicharset/lstmtrainer.h @@ -237,7 +237,7 @@ class TESS_UNICHARSET_TRAINING_API LSTMTrainer : public LSTMRecognizer { // Even if it looks like all weights should remain the same, an adjustment // will be made to guarantee a different result when reverting to an old best. // Returns the number of layer learning rates that were reduced. - int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer); + int ReduceLayerLearningRates(TFloat factor, int num_samples, LSTMTrainer *samples_trainer); // Converts the string to integer class labels, with appropriate null_char_s // in between if not in SimpleTextOutput mode. Returns false on failure. diff --git a/unittest/intsimdmatrix_test.cc b/unittest/intsimdmatrix_test.cc index a1411a0a10..6ed2bd2ed5 100644 --- a/unittest/intsimdmatrix_test.cc +++ b/unittest/intsimdmatrix_test.cc @@ -52,8 +52,8 @@ class IntSimdMatrixTest : public ::testing::Test { return v; } // Makes a random scales vector of the given size. - std::vector RandomScales(int size) { - std::vector v(size); + std::vector RandomScales(int size) { + std::vector v(size); for (int i = 0; i < size; ++i) { v[i] = (1.0 + random_.SignedRand(1.0)) / INT8_MAX; } @@ -61,19 +61,19 @@ class IntSimdMatrixTest : public ::testing::Test { } // Tests a range of sizes and compares the results against the generic version. void ExpectEqualResults(const IntSimdMatrix &matrix) { - double total = 0.0; + TFloat total = 0.0; for (int num_out = 1; num_out < 130; ++num_out) { for (int num_in = 1; num_in < 130; ++num_in) { GENERIC_2D_ARRAY w = InitRandom(num_out, num_in + 1); std::vector u = RandomVector(num_in, matrix); - std::vector scales = RandomScales(num_out); + std::vector scales = RandomScales(num_out); int ro = num_out; if (IntSimdMatrix::intSimdMatrix) { ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro); } - std::vector base_result(num_out); + std::vector base_result(num_out); IntSimdMatrix::MatrixDotVector(w, scales, u.data(), base_result.data()); - std::vector test_result(ro); + std::vector test_result(ro); std::vector shaped_wi; int32_t rounded_num_out; matrix.Init(w, shaped_wi, rounded_num_out); @@ -91,7 +91,11 @@ class IntSimdMatrixTest : public ::testing::Test { } } // Compare sum of all results with expected value. +#ifdef FAST_FLOAT + EXPECT_FLOAT_EQ(total, 337852.16f); +#else EXPECT_FLOAT_EQ(total, 337849.39354684710); +#endif } TRand random_; @@ -110,7 +114,7 @@ TEST_F(IntSimdMatrixTest, SSE) { GTEST_LOG_(INFO) << "No SSE found! Not tested!"; GTEST_SKIP(); } - ExpectEqualResults(IntSimdMatrix::intSimdMatrixSSE); + ExpectEqualResults(*IntSimdMatrix::intSimdMatrixSSE); #else GTEST_LOG_(INFO) << "SSE unsupported! Not tested!"; GTEST_SKIP(); @@ -124,7 +128,7 @@ TEST_F(IntSimdMatrixTest, AVX2) { GTEST_LOG_(INFO) << "No AVX2 found! Not tested!"; GTEST_SKIP(); } - ExpectEqualResults(IntSimdMatrix::intSimdMatrixAVX2); + ExpectEqualResults(*IntSimdMatrix::intSimdMatrixAVX2); #else GTEST_LOG_(INFO) << "AVX2 unsupported! Not tested!"; GTEST_SKIP(); From c64ab2e0580bf2e40e184435b28b7e10960cee6a Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Wed, 7 Jul 2021 21:14:12 +0200 Subject: [PATCH 02/11] Fix some compiler warnings Signed-off-by: Stefan Weil --- src/ccstruct/blobbox.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/ccstruct/blobbox.h b/src/ccstruct/blobbox.h index f2b935e2e1..4784c09396 100644 --- a/src/ccstruct/blobbox.h +++ b/src/ccstruct/blobbox.h @@ -740,8 +740,11 @@ class TESS_API TO_BLOCK : public ELIST_LINK { TO_ROW_IT row_it = &row_list; for (row_it.mark_cycle_pt(); !row_it.cycled_list(); row_it.forward()) { auto row = row_it.data(); - tprintf("Row range (%g,%g), para_c=%g, blobcount=%" PRId32 "\n", row->min_y(), row->max_y(), - row->parallel_c(), row->blob_list()->length()); + tprintf("Row range (%g,%g), para_c=%g, blobcount=%" PRId32 "\n", + static_cast(row->min_y()), + static_cast(row->max_y()), + static_cast(row->parallel_c()), + row->blob_list()->length()); } } From 78871a9adff2e69deb136bacfc7d463d72560af2 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Wed, 7 Jul 2021 21:15:48 +0200 Subject: [PATCH 03/11] Optimize DotProductStdInnerProduct for float Signed-off-by: Stefan Weil --- src/arch/simddetect.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/arch/simddetect.cpp b/src/arch/simddetect.cpp index 5b3edc60be..af52f6091f 100644 --- a/src/arch/simddetect.cpp +++ b/src/arch/simddetect.cpp @@ -94,7 +94,7 @@ static TFloat DotProductGeneric(const TFloat *u, const TFloat *v, int n) { // Compute dot product using std::inner_product. static TFloat DotProductStdInnerProduct(const TFloat *u, const TFloat *v, int n) { - return std::inner_product(u, u + n, v, 0.0); + return std::inner_product(u, u + n, v, static_cast(0)); } static void SetDotProduct(DotProductFunction f, const IntSimdMatrix *m = nullptr) { @@ -110,6 +110,12 @@ static void SetDotProduct(DotProductFunction f, const IntSimdMatrix *m = nullptr SIMDDetect::SIMDDetect() { // The fallback is a generic dot product calculation. SetDotProduct(DotProductGeneric); + const char *env = getenv("dotproduct"); + if (env) { + dotproduct = env; + Update(); + return; + } #if defined(HAS_CPUID) # if defined(__GNUC__) @@ -239,6 +245,9 @@ void SIMDDetect::Update() { // AVX2 selected by config variable. SetDotProduct(DotProductAVX, IntSimdMatrix::intSimdMatrixAVX2); dotproduct_method = "avx2"; + } else if (dotproduct == "avx-1") { + SetDotProduct(DotProductAVX1, IntSimdMatrix::intSimdMatrixAVX2); + dotproduct_method = "avx-1"; #endif #if defined(HAVE_AVX) } else if (!strcmp(dotproduct.c_str(), "avx")) { From 1b9e4629017cac905f391a289327242c1c616012 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Tue, 13 Jul 2021 07:12:11 +0200 Subject: [PATCH 04/11] Avoid double / float conversion Signed-off-by: Stefan Weil --- src/ccstruct/matrix.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ccstruct/matrix.h b/src/ccstruct/matrix.h index 0741967127..a97912ad78 100644 --- a/src/ccstruct/matrix.h +++ b/src/ccstruct/matrix.h @@ -417,7 +417,7 @@ class GENERIC_2D_ARRAY { // Accumulates the element-wise sums of squares of src into *this. void SumSquares(const GENERIC_2D_ARRAY &src, const T &decay_factor) { - T update_factor = 1.0 - decay_factor; + T update_factor = 1 - decay_factor; int size = num_elements(); for (int i = 0; i < size; ++i) { array_[i] = array_[i] * decay_factor + update_factor * src.array_[i] * src.array_[i]; From 93e90220dcdf5e7afc52fab4aa9ce5be8b38f7f6 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Tue, 13 Jul 2021 07:14:17 +0200 Subject: [PATCH 05/11] Implement TFloat for IntSimdMatrix Signed-off-by: Stefan Weil --- src/arch/intsimdmatrixavx2.cpp | 292 ++++++++++++++++++++++++++++++++- 1 file changed, 290 insertions(+), 2 deletions(-) diff --git a/src/arch/intsimdmatrixavx2.cpp b/src/arch/intsimdmatrixavx2.cpp index d417869115..8f671f08a9 100644 --- a/src/arch/intsimdmatrixavx2.cpp +++ b/src/arch/intsimdmatrixavx2.cpp @@ -23,10 +23,51 @@ # endif #elif defined(FAST_FLOAT) namespace tesseract { -const IntSimdMatrix *IntSimdMatrix::intSimdMatrixAVX2 = nullptr; + +static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales, + const int8_t *u, float *v) { + const int num_out = dim1; + const int num_in = dim2 - 1; + for (int i = 0; i < num_out; ++i) { + for (int j = 0; j < num_in; ++j) { + } + } } -#else +#if 0 +void IntSimdMatrix::MatrixDotVector(const GENERIC_2D_ARRAY &w, + const std::vector &scales, const int8_t *u, TFloat *v) { + int num_out = w.dim1(); + int num_in = w.dim2() - 1; + // Base implementation. + for (int i = 0; i < num_out; ++i) { + const int8_t *wi = w[i]; + int total = 0; + for (int j = 0; j < num_in; ++j) { + total += wi[j] * u[j]; + } + // Add in the bias and correct for integer values. + v[i] = (total + wi[num_in] * INT8_MAX) * scales[i]; + } +} +#endif + +static const IntSimdMatrix simdMatrix = { + // Function. + matrixDotVector, + // Number of 32 bit outputs held in each register. + 1, + // Maximum number of registers that we will use to hold outputs. + 1, + // Number of 8 bit inputs in the inputs register. + 1, + // Number of inputs in each weight group. + 1 +}; + +const IntSimdMatrix *IntSimdMatrix::intSimdMatrixAVX2 = &simdMatrix; +} +#else # include # include # include @@ -90,6 +131,252 @@ static inline __m128i load64_to_128(const int8_t *wi_) { return _mm_set_epi64x(0, wi[0]); } +#if defined(FAST_FLOAT) +static inline void ExtractResults8(__m256i result, const int8_t *wi, const float *scales, + float *v) { + __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg + __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg + __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); + __m256d scale01234567 = _mm256_loadu_ps(scales); + //~ __m256d scale4567 = _mm256_loadu_ps(scales + 8); + w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 + result = _mm256_add_epi32(result, w256); // result += bias * 127 + __m256 res01234567 = _mm256_cvtepi32_ps(_mm256_castsi256_si128(result)); + result = _mm256_permute4x64_epi64(result, 2 + (3 << 2)); + __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result)); + res01234567 = _mm256_mul_pd(res01234567, scale01234567); + //~ res4567 = _mm256_mul_pd(res4567, scale4567); + _mm256_storeu_ps(v, res01234567); + //~ _mm256_storeu_pd(v + 4, res4567); +} + +static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi, + const float *&scales, float *&v) { + __m128i w8 = _mm_loadu_si128(reinterpret_cast(wi)); + // 8x8bit vals in bottom of 128bit reg + const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); + __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg + __m256d scale0123 = _mm256_loadu_ps(scales); + __m256d scale4567 = _mm256_loadu_ps(scales + 8); + w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 + result0 = _mm256_add_epi32(result0, w256); // result += bias * 127 + __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0)); + result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2)); + __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0)); + res0123 = _mm256_mul_pd(res0123, scale0123); + res4567 = _mm256_mul_pd(res4567, scale4567); + _mm256_storeu_ps(v, res0123); + _mm256_storeu_ps(v + 8, res4567); + w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2)); + w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg + scale0123 = _mm256_loadu_ps(scales + 16); + scale4567 = _mm256_loadu_ps(scales + 24); + w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 + result1 = _mm256_add_epi32(result1, w256); // result += bias * 127 + res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1)); + result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2)); + res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1)); + res0123 = _mm256_mul_pd(res0123, scale0123); + res4567 = _mm256_mul_pd(res4567, scale4567); + _mm256_storeu_ps(v + 16, res0123); + _mm256_storeu_ps(v + 24, res4567); + wi += 16; + scales += 16; + v += 16; +} + +// Computes part of matrix.vector v = Wu. Computes N=64 results. +// The weights *must* be arranged so that consecutive reads from wi +// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of +// (kNumInputsPerGroup inputs))). After that there must be N consecutive +// bias weights, before continuing with any more weights. +// u must be padded out with zeros to +// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements. +static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u, + int num_in, float *v) { + // Register containing 16-bit ones for horizontal add with 16->32 bit + // conversion. + __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); + __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); + // Initialize all the results to 0. + __m256i result0 = _mm256_setzero_si256(); + __m256i result1 = _mm256_setzero_si256(); + __m256i result2 = _mm256_setzero_si256(); + __m256i result3 = _mm256_setzero_si256(); + __m256i result4 = _mm256_setzero_si256(); + __m256i result5 = _mm256_setzero_si256(); + __m256i result6 = _mm256_setzero_si256(); + __m256i result7 = _mm256_setzero_si256(); + // Iterate over the input (u), one registerful at a time. + for (int j = 0; j < num_in;) { + __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); + // Inputs are processed in groups of kNumInputsPerGroup, replicated + // kNumInputGroups times. + for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { + // Replicate the low 32 bits (4 inputs) 8 times. + __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); + // Rotate the inputs in groups of 4, so the next 4 inputs are ready. + inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); + __m256i weights, reps; + // Mul-add, with horizontal add of the 4 inputs to each of the results. + MultiplyGroup(rep_input, ones, wi, weights, reps, result0); + MultiplyGroup(rep_input, ones, wi, weights, reps, result1); + MultiplyGroup(rep_input, ones, wi, weights, reps, result2); + MultiplyGroup(rep_input, ones, wi, weights, reps, result3); + MultiplyGroup(rep_input, ones, wi, weights, reps, result4); + MultiplyGroup(rep_input, ones, wi, weights, reps, result5); + MultiplyGroup(rep_input, ones, wi, weights, reps, result6); + MultiplyGroup(rep_input, ones, wi, weights, reps, result7); + } + } + ExtractResults16(result0, result1, wi, scales, v); + ExtractResults16(result2, result3, wi, scales, v); + ExtractResults16(result4, result5, wi, scales, v); + ExtractResults16(result6, result7, wi, scales, v); +} + +// Computes part of matrix.vector v = Wu. Computes N=32 results. +// For details see PartialMatrixDotVector64 with N=32. +static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u, + int num_in, float *v) { + // Register containing 16-bit ones for horizontal add with 16->32 bit + // conversion. + __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); + __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); + // Initialize all the results to 0. + __m256i result0 = _mm256_setzero_si256(); + __m256i result1 = _mm256_setzero_si256(); + __m256i result2 = _mm256_setzero_si256(); + __m256i result3 = _mm256_setzero_si256(); + // Iterate over the input (u), one registerful at a time. + for (int j = 0; j < num_in;) { + __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); + // Inputs are processed in groups of kNumInputsPerGroup, replicated + // kNumInputGroups times. + for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { + // Replicate the low 32 bits (4 inputs) 8 times. + __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); + // Rotate the inputs in groups of 4, so the next 4 inputs are ready. + inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); + __m256i weights, reps; + // Mul-add, with horizontal add of the 4 inputs to each of the results. + MultiplyGroup(rep_input, ones, wi, weights, reps, result0); + MultiplyGroup(rep_input, ones, wi, weights, reps, result1); + MultiplyGroup(rep_input, ones, wi, weights, reps, result2); + MultiplyGroup(rep_input, ones, wi, weights, reps, result3); + } + } + ExtractResults16(result0, result1, wi, scales, v); + ExtractResults16(result2, result3, wi, scales, v); +} + +// Computes part of matrix.vector v = Wu. Computes N=16 results. +// For details see PartialMatrixDotVector64 with N=16. +static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u, + int num_in, float *v) { + // Register containing 16-bit ones for horizontal add with 16->32 bit + // conversion. + __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); + __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); + // Initialize all the results to 0. + __m256i result0 = _mm256_setzero_si256(); + __m256i result1 = _mm256_setzero_si256(); + // Iterate over the input (u), one registerful at a time. + for (int j = 0; j < num_in;) { + __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); + // Inputs are processed in groups of kNumInputsPerGroup, replicated + // kNumInputGroups times. + for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { + // Replicate the low 32 bits (4 inputs) 8 times. + __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); + // Rotate the inputs in groups of 4, so the next 4 inputs are ready. + inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); + __m256i weights, reps; + // Mul-add, with horizontal add of the 4 inputs to each of the results. + MultiplyGroup(rep_input, ones, wi, weights, reps, result0); + MultiplyGroup(rep_input, ones, wi, weights, reps, result1); + } + } + ExtractResults16(result0, result1, wi, scales, v); +} + +// Computes part of matrix.vector v = Wu. Computes N=8 results. +// For details see PartialMatrixDotVector64 with N=8. +static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u, + int num_in, float *v) { + // Register containing 16-bit ones for horizontal add with 16->32 bit + // conversion. + __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); + __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); + // Initialize all the results to 0. + __m256i result0 = _mm256_setzero_si256(); + // Iterate over the input (u), one registerful at a time. + for (int j = 0; j < num_in;) { + __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); + // Inputs are processed in groups of kNumInputsPerGroup, replicated + // kNumInputGroups times. + for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { + // Replicate the low 32 bits (4 inputs) 8 times. + __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); + // Rotate the inputs in groups of 4, so the next 4 inputs are ready. + inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); + __m256i weights, reps; + // Mul-add, with horizontal add of the 4 inputs to each of the results. + MultiplyGroup(rep_input, ones, wi, weights, reps, result0); + } + } + ExtractResults8(result0, wi, scales, v); +} + +static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales, + const int8_t *u, float *v) { + const int num_out = dim1; + const int num_in = dim2 - 1; + // Each call to a partial_func_ produces group_size outputs, except the + // last one, which can produce less. + const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup); + const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister); + int group_size = kNumOutputsPerRegister * kMaxOutputRegisters; + int output = 0; + + int w_step = (rounded_num_in + 1) * group_size; + + // Run with this group size, until it would produce too much output, then + // switch to a smaller size. + for (; output + group_size <= rounded_num_out; output += group_size) { + PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v); + wi += w_step; + scales += group_size; + v += group_size; + } + group_size /= 2; + w_step /= 2; + + if (output + group_size <= rounded_num_out) { + PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v); + wi += w_step; + scales += group_size; + v += group_size; + output += group_size; + } + group_size /= 2; + w_step /= 2; + + if (output + group_size <= rounded_num_out) { + PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v); + wi += w_step; + scales += group_size; + v += group_size; + output += group_size; + } + group_size /= 2; + w_step /= 2; + + if (output + group_size <= rounded_num_out) { + PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v); + } +} +#else static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales, double *v) { __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg @@ -334,6 +621,7 @@ static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double * PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v); } } +#endif static const IntSimdMatrix simdMatrix = { // Function. From 00e428391799ef5eab559a28b063f6fef1aefb08 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Tue, 13 Jul 2021 07:15:03 +0200 Subject: [PATCH 06/11] Test more implementations of DotProduct Signed-off-by: Stefan Weil --- src/arch/dotproduct.h | 4 +++ src/arch/dotproductavx.cpp | 64 ++++++++++++++++++++++++++++---------- src/arch/dotproductfma.cpp | 27 +++++++++++++--- 3 files changed, 75 insertions(+), 20 deletions(-) diff --git a/src/arch/dotproduct.h b/src/arch/dotproduct.h index c64765597e..918d1c4bab 100644 --- a/src/arch/dotproduct.h +++ b/src/arch/dotproduct.h @@ -26,6 +26,10 @@ TFloat DotProductNative(const TFloat *u, const TFloat *v, int n); // Uses Intel AVX intrinsics to access the SIMD instruction set. TFloat DotProductAVX(const TFloat *u, const TFloat *v, int n); +TFloat DotProductAVX1(const TFloat *u, const TFloat *v, int n); +TFloat DotProductAVX2(const TFloat *u, const TFloat *v, int n); +TFloat DotProductAVX3(const TFloat *u, const TFloat *v, int n); +TFloat DotProductAVX4(const TFloat *u, const TFloat *v, int n); // Use Intel FMA. TFloat DotProductFMA(const TFloat *u, const TFloat *v, int n); diff --git a/src/arch/dotproductavx.cpp b/src/arch/dotproductavx.cpp index f937b5d8c1..4c49e9e4a3 100644 --- a/src/arch/dotproductavx.cpp +++ b/src/arch/dotproductavx.cpp @@ -31,17 +31,9 @@ namespace tesseract { // Uses Intel AVX intrinsics to access the SIMD instruction set. #if defined(FAST_FLOAT) float DotProductAVX(const float *u, const float *v, int n) { -#ifndef FAST_FLOAT16 const unsigned quot = n / 8; const unsigned rem = n % 8; -#else - const unsigned quot = n / 16; - const unsigned rem = n % 16; -#endif __m256 t0 = _mm256_setzero_ps(); -#ifdef FAST_FLOAT16 - __m256 t1 = _mm256_setzero_ps(); -#endif for (unsigned k = 0; k < quot; k++) { __m256 f0 = _mm256_loadu_ps(u); __m256 f1 = _mm256_loadu_ps(v); @@ -49,18 +41,33 @@ float DotProductAVX(const float *u, const float *v, int n) { t0 = _mm256_add_ps(t0, f0); u += 8; v += 8; -#ifdef FAST_FLOAT16 - __m256 f2 = _mm256_loadu_ps(u); - __m256 f3 = _mm256_loadu_ps(v); + } + alignas(32) float tmp[8]; + _mm256_store_ps(tmp, t0); + float result = tmp[0] + tmp[1] + tmp[2] + tmp[3] + tmp[4] + tmp[5] + tmp[6] + tmp[7]; + for (unsigned k = 0; k < rem; k++) { + result += *u++ * *v++; + } + return result; +} +float DotProductAVX1(const float *u, const float *v, int n) { + const unsigned quot = n / 16; + const unsigned rem = n % 16; + __m256 t0 = _mm256_setzero_ps(); + __m256 t1 = _mm256_setzero_ps(); + for (unsigned k = 0; k < quot; k++) { + __m256 f0 = _mm256_loadu_ps(u); + __m256 f1 = _mm256_loadu_ps(v); + __m256 f2 = _mm256_loadu_ps(u + 8); + __m256 f3 = _mm256_loadu_ps(v + 8); + f0 = _mm256_mul_ps(f0, f1); f2 = _mm256_mul_ps(f2, f3); + t0 = _mm256_add_ps(t0, f0); t1 = _mm256_add_ps(t1, f2); - u += 8; - v += 8; -#endif + u += 16; + v += 16; } -#ifdef FAST_FLOAT16 t0 = _mm256_hadd_ps(t0, t1); -#endif alignas(32) float tmp[8]; _mm256_store_ps(tmp, t0); float result = tmp[0] + tmp[1] + tmp[2] + tmp[3] + tmp[4] + tmp[5] + tmp[6] + tmp[7]; @@ -70,6 +77,31 @@ float DotProductAVX(const float *u, const float *v, int n) { return result; } #else +double DotProductAVX1(const double *u, const double *v, int n) { + __m256d t0 = _mm256_setzero_pd(); + __m256d t1 = _mm256_setzero_pd(); + for (unsigned quot = n / 8; quot > 0; quot--) { + __m256d f0 = _mm256_loadu_pd(u); + __m256d f1 = _mm256_loadu_pd(v); + __m256d f2 = _mm256_loadu_pd(u + 4); + __m256d f3 = _mm256_loadu_pd(v + 4); + f0 = _mm256_mul_pd(f0, f1); + f2 = _mm256_mul_pd(f2, f3); + t0 = _mm256_add_pd(t0, f0); + t1 = _mm256_add_pd(t1, f2); + u += 8; + v += 8; + } + t0 = _mm256_hadd_pd(t0, t1); + alignas(32) double tmp[4]; + _mm256_store_pd(tmp, t0); + double result = tmp[0] + tmp[1] + tmp[2] + tmp[3]; + for (unsigned rem = n % 8; rem > 0; rem--) { + result += *u++ * *v++; + } + return result; +} + double DotProductAVX(const double *u, const double *v, int n) { const unsigned quot = n / 8; const unsigned rem = n % 8; diff --git a/src/arch/dotproductfma.cpp b/src/arch/dotproductfma.cpp index 8ce7ae32d4..32154283ae 100644 --- a/src/arch/dotproductfma.cpp +++ b/src/arch/dotproductfma.cpp @@ -31,11 +31,30 @@ namespace tesseract { // Uses Intel FMA intrinsics to access the SIMD instruction set. #if defined(FAST_FLOAT) TFloat DotProductFMA(const TFloat *u, const TFloat *v, int n) { - TFloat total = 0.0; - for (int k = 0; k < n; ++k) { - total += u[k] * v[k]; + const unsigned quot = n / 8; + const unsigned rem = n % 8; + __m256 t0 = _mm256_setzero_ps(); + __m256 t1 = _mm256_setzero_ps(); + for (unsigned k = 0; k < quot; k++) { + __m256 f0 = _mm256_loadu_ps(u); + __m256 f1 = _mm256_loadu_ps(v); + t0 = _mm256_fmadd_ps(f0, f1, t0); + u += 4; + v += 4; + __m256 f2 = _mm256_loadu_ps(u); + __m256 f3 = _mm256_loadu_ps(v); + t1 = _mm256_fmadd_ps(f2, f3, t1); + u += 4; + v += 4; } - return total; + t0 = _mm256_hadd_ps(t0, t1); + alignas(32) float tmp[4]; + _mm256_store_ps(tmp, t0); + float result = tmp[0] + tmp[1] + tmp[2] + tmp[3]; + for (unsigned k = 0; k < rem; k++) { + result += *u++ * *v++; + } + return result; } #else double DotProductFMA(const double *u, const double *v, int n) { From e2529ddb40c6e2bc57afe9677cb0162161c38079 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Sat, 10 Jul 2021 16:27:21 +0200 Subject: [PATCH 07/11] Add unittest for dotproduct Signed-off-by: Stefan Weil --- Makefile.am | 11 ++++ unittest/dotproduct_test.cc | 121 ++++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+) create mode 100644 unittest/dotproduct_test.cc diff --git a/Makefile.am b/Makefile.am index afa75b311c..515585055a 100644 --- a/Makefile.am +++ b/Makefile.am @@ -1236,6 +1236,7 @@ check_PROGRAMS += commandlineflags_test check_PROGRAMS += dawg_test endif # ENABLE_TRAINING check_PROGRAMS += denorm_test +check_PROGRAMS += dotproduct_test if !DISABLED_LEGACY_ENGINE check_PROGRAMS += equationdetect_test endif # !DISABLED_LEGACY_ENGINE @@ -1362,6 +1363,16 @@ denorm_test_SOURCES = unittest/denorm_test.cc denorm_test_CPPFLAGS = $(unittest_CPPFLAGS) denorm_test_LDADD = $(TESS_LIBS) +dotproduct_test_SOURCES = unittest/dotproduct_test.cc +dotproduct_test_CPPFLAGS = $(unittest_CPPFLAGS) +if HAVE_AVX2 +dotproduct_test_CPPFLAGS += -DHAVE_AVX2 +endif +if HAVE_SSE4_1 +dotproduct_test_CPPFLAGS += -DHAVE_SSE4_1 +endif +dotproduct_test_LDADD = $(TESS_LIBS) + if !DISABLED_LEGACY_ENGINE equationdetect_test_SOURCES = unittest/equationdetect_test.cc equationdetect_test_CPPFLAGS = $(unittest_CPPFLAGS) diff --git a/unittest/dotproduct_test.cc b/unittest/dotproduct_test.cc new file mode 100644 index 0000000000..36b0fbd97e --- /dev/null +++ b/unittest/dotproduct_test.cc @@ -0,0 +1,121 @@ +/////////////////////////////////////////////////////////////////////// +// File: dotproduct_test.cc +// Author: Stefan Weil +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include "include_gunit.h" +#include "dotproduct.h" +#include "matrix.h" +#include "simddetect.h" +#include "tprintf.h" + +namespace tesseract { +class DotProductTest : public ::testing::Test { +protected: + void SetUp() override { + std::locale::global(std::locale("")); + } + void RunTest(TFloat (*f)(const TFloat *u, const TFloat *v, int n)); + static const size_t multiplications = 500000000; + static const size_t n = 40; + //static const size_t n = 1000000; + TFloat u[n]; + TFloat v[n]; +}; + +void DotProductTest::RunTest(TFloat (*f)(const TFloat *u, const TFloat *v, int n)) { + for (auto i = multiplications / n; i > 0; i--) { + f(u, v, n); + } +} + +TFloat DotProductGeneric(const TFloat *u, const TFloat *v, int n); +TFloat DotProductGeneric(const TFloat *u, const TFloat *v, int n) { + TFloat total = 0; +#pragma omp simd reduction(+:total) + for (int k = 0; k < n; ++k) { + total += u[k] * v[k]; + } + return total; +} + +// Test the C++ implementation without SIMD. +TEST_F(DotProductTest, C) { + RunTest(DotProductGeneric); +} + +TEST_F(DotProductTest, Native) { + RunTest(DotProductNative); +} + +// Tests that the SSE implementation gets the same result as the vanilla. +TEST_F(DotProductTest, SSE) { +#if defined(HAVE_SSE4_1) + if (!SIMDDetect::IsSSEAvailable()) { + GTEST_LOG_(INFO) << "No SSE found! Not tested!"; + GTEST_SKIP(); + } + RunTest(DotProductSSE); +#else + GTEST_LOG_(INFO) << "SSE unsupported! Not tested!"; + GTEST_SKIP(); +#endif +} + +// Tests that the AVX implementation gets the same result as the vanilla. +TEST_F(DotProductTest, AVX) { +#if defined(HAVE_AVX2) + if (!SIMDDetect::IsAVX2Available()) { + GTEST_LOG_(INFO) << "No AVX2 found! Not tested!"; + GTEST_SKIP(); + } + RunTest(DotProductAVX); +#else + GTEST_LOG_(INFO) << "AVX2 unsupported! Not tested!"; + GTEST_SKIP(); +#endif +} + +// Tests that the AVX1 implementation gets the same result as the vanilla. +TEST_F(DotProductTest, AVX1) { +#if defined(HAVE_AVX2) + if (!SIMDDetect::IsAVX2Available()) { + GTEST_LOG_(INFO) << "No AVX2 found! Not tested!"; + GTEST_SKIP(); + } + RunTest(DotProductAVX1); +#else + GTEST_LOG_(INFO) << "AVX2 unsupported! Not tested!"; + GTEST_SKIP(); +#endif +} + +// Tests that the FMA implementation gets the same result as the vanilla. +TEST_F(DotProductTest, FMA) { +#if defined(HAVE_FMA) + if (!SIMDDetect::IsFMAAvailable()) { + GTEST_LOG_(INFO) << "No FMA found! Not tested!"; + GTEST_SKIP(); + } + RunTest(DotProductFMA); +#else + GTEST_LOG_(INFO) << "FMA unsupported! Not tested!"; + GTEST_SKIP(); +#endif +} + +} // namespace tesseract From 01ae69ed95d24e89af96dad49a66562c60ceb26c Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Sun, 28 Feb 2021 12:04:17 +0100 Subject: [PATCH 08/11] Support Apple Accelerate framework for training and best models Signed-off-by: Stefan Weil --- configure.ac | 13 ++++++----- src/arch/simddetect.cpp | 51 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 8 deletions(-) diff --git a/configure.ac b/configure.ac index 37a6bf4520..9c946986ef 100644 --- a/configure.ac +++ b/configure.ac @@ -284,7 +284,7 @@ m4_define([MY_CHECK_FRAMEWORK], ]) if test "$my_cv_framework_$1"="yes"; then AC_DEFINE(AS_TR_CPP([HAVE_FRAMEWORK_$1]), 1, - [Define if you have the $1 framework]) + [Define if you have the $1 framework]) AS_TR_CPP([FRAMEWORK_$1])="-framework $1" AC_SUBST(AS_TR_CPP([FRAMEWORK_$1])) fi] @@ -295,13 +295,14 @@ OPENCL_CPPFLAGS='' OPENCL_LDFLAGS='' case "${host_os}" in *darwin* | *-macos10*) - echo "checking for OpenCL framework" - MY_CHECK_FRAMEWORK([OpenCL]) - if test $my_cv_framework_OpenCL = yes; then - have_opencl_lib=true + MY_CHECK_FRAMEWORK([Accelerate]) + if test $my_cv_framework_Accelerate = yes; then + AM_CPPFLAGS="-DHAVE_FRAMEWORK_ACCELERATE $AM_CPPFLAGS" + LDFLAGS="$LDFLAGS -framework Accelerate" fi + MY_CHECK_FRAMEWORK([OpenCL]) if test "$enable_opencl" = "yes"; then - if !($have_opencl_lib); then + if test $my_cv_framework_OpenCL = no; then AC_MSG_ERROR([Required OpenCL library not found!]) fi AM_CPPFLAGS="-DUSE_OPENCL $AM_CPPFLAGS" diff --git a/src/arch/simddetect.cpp b/src/arch/simddetect.cpp index af52f6091f..f5fd25b264 100644 --- a/src/arch/simddetect.cpp +++ b/src/arch/simddetect.cpp @@ -25,6 +25,23 @@ #include "simddetect.h" #include "tprintf.h" // for tprintf +#if defined(HAVE_FRAMEWORK_ACCELERATE) + +// Use Apple Accelerate framework. +// https://developer.apple.com/documentation/accelerate/simd + +// Comparison of execution time with different dot product implementations. +// time DOTPRODUCT=accelerate lstm_squashed_test +// Results for Apple M1: +// DotProductGeneric 64 s +// DotProduct 60 s +// DotProductAccelerate 33 s +// DotProductNative 30 s + +#include + +#endif + #if defined(HAVE_AVX) || defined(HAVE_AVX2) || defined(HAVE_FMA) || defined(HAVE_SSE4_1) # define HAS_CPUID #endif @@ -83,6 +100,15 @@ bool SIMDDetect::fma_available_; bool SIMDDetect::sse_available_; #endif +#if defined(HAVE_FRAMEWORK_ACCELERATE) +static double DotProductAccelerate(const double* u, const double* v, int n) { + double total = 0.0; + const int stride = 1; + vDSP_dotprD(u, stride, v, stride, &total, n); + return total; +} +#endif + // Computes and returns the dot product of the two n-vectors u and v. static TFloat DotProductGeneric(const TFloat *u, const TFloat *v, int n) { TFloat total = 0.0; @@ -110,10 +136,17 @@ static void SetDotProduct(DotProductFunction f, const IntSimdMatrix *m = nullptr SIMDDetect::SIMDDetect() { // The fallback is a generic dot product calculation. SetDotProduct(DotProductGeneric); - const char *env = getenv("dotproduct"); - if (env) { + const char* dotproduct_env = getenv("DOTPRODUCT"); + if (dotproduct_env != nullptr) { dotproduct = env; Update(); + if (strcmp(dotproduct_env, "native") == 0) { + SetDotProduct(DotProductNative); +#if defined(HAVE_FRAMEWORK_ACCELERATE) + } else if (strcmp(dotproduct_env, "accelerate") == 0) { + SetDotProduct(DotProductAccelerate); +#endif + } return; } @@ -240,6 +273,11 @@ void SIMDDetect::Update() { // Native optimized code selected by config variable. SetDotProduct(DotProductNative); dotproduct_method = "native"; +#if defined(HAVE_FRAMEWORK_ACCELERATE) + } else if (dotproduct == "accelerate") { + SetDotProduct(DotProductAccelerate); + dotproduct_method = "accelerate"; +#endif #if defined(HAVE_AVX2) } else if (!strcmp(dotproduct.c_str(), "avx2")) { // AVX2 selected by config variable. @@ -277,9 +315,18 @@ void SIMDDetect::Update() { dotproduct.c_str()); tprintf( "Support values for dotproduct: auto generic native" +#if defined(HAVE_FRAMEWORK_ACCELERATE) + " accelerate" +#endif +#if defined(HAVE_AVX2) + " avx2" +#endif #if defined(HAVE_AVX) " avx" #endif +#if defined(HAVE_FMA) + " fma" +#endif #if defined(HAVE_SSE4_1) " sse" #endif From a09531a1a8afe282919f38da00587c8ed4437007 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Tue, 13 Jul 2021 09:18:27 +0200 Subject: [PATCH 09/11] Fix TFloat builds for Apple M1 Signed-off-by: Stefan Weil --- Makefile.am | 2 +- src/arch/dotproduct.h | 1 + src/arch/intsimdmatrixneon.cpp | 13 +++++++++++-- src/arch/simddetect.cpp | 10 +++++++--- unittest/dotproduct_test.cc | 24 +++++++++++++++++++++++- 5 files changed, 43 insertions(+), 7 deletions(-) diff --git a/Makefile.am b/Makefile.am index 515585055a..486de36eb4 100644 --- a/Makefile.am +++ b/Makefile.am @@ -146,8 +146,8 @@ noinst_LTLIBRARIES += libtesseract_native.la libtesseract_native_la_CXXFLAGS = -O3 -ffast-math if MARCH_NATIVE_OPT libtesseract_native_la_CXXFLAGS += -march=native -mtune=native -libtesseract_native_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil endif +libtesseract_native_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil libtesseract_native_la_SOURCES = src/arch/dotproduct.cpp if HAVE_AVX diff --git a/src/arch/dotproduct.h b/src/arch/dotproduct.h index 918d1c4bab..c9b2756e2c 100644 --- a/src/arch/dotproduct.h +++ b/src/arch/dotproduct.h @@ -37,6 +37,7 @@ TFloat DotProductFMA(const TFloat *u, const TFloat *v, int n); // Uses Intel SSE intrinsics to access the SIMD instruction set. TFloat DotProductSSE(const TFloat *u, const TFloat *v, int n); +TFloat DotProductAccelerate(const TFloat *u, const TFloat *v, int n); } // namespace tesseract. #endif // TESSERACT_ARCH_DOTPRODUCT_H_ diff --git a/src/arch/intsimdmatrixneon.cpp b/src/arch/intsimdmatrixneon.cpp index ae6608133d..260f747d48 100644 --- a/src/arch/intsimdmatrixneon.cpp +++ b/src/arch/intsimdmatrixneon.cpp @@ -19,6 +19,7 @@ #if defined(__ARM_NEON) # include "intsimdmatrix.h" +# include "tfloat.h" # include # include @@ -27,6 +28,12 @@ namespace tesseract { +#if defined(FAST_FLOAT) + +const IntSimdMatrix *IntSimdMatrix::intSimdMatrixNEON = nullptr; + +#else + // Number of outputs held in each register. (Actually, we use a // pair of 4x32 registers, so 8 x 32 bit ints). constexpr int kNumOutputsPerRegister = 8; @@ -186,7 +193,7 @@ static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double * num_out & (kNumOutputsPerRegister - 1)); } -static const IntSimdMatrix intSimdMatrix = { +static const IntSimdMatrix simdMatrix = { // Function. matrixDotVector, // Number of 32 bit outputs held in each register. @@ -199,7 +206,9 @@ static const IntSimdMatrix intSimdMatrix = { kNumInputsPerGroup }; -const IntSimdMatrix *IntSimdMatrix::intSimdMatrixNEON = &intSimdMatrix; +const IntSimdMatrix *IntSimdMatrix::intSimdMatrixNEON = &simdMatrix; + +#endif // FAST_FLOAT } // namespace tesseract. diff --git a/src/arch/simddetect.cpp b/src/arch/simddetect.cpp index f5fd25b264..6c7f822239 100644 --- a/src/arch/simddetect.cpp +++ b/src/arch/simddetect.cpp @@ -101,10 +101,14 @@ bool SIMDDetect::sse_available_; #endif #if defined(HAVE_FRAMEWORK_ACCELERATE) -static double DotProductAccelerate(const double* u, const double* v, int n) { - double total = 0.0; +TFloat DotProductAccelerate(const TFloat* u, const TFloat* v, int n) { + TFloat total = 0; const int stride = 1; +#if defined(FAST_FLOAT) + vDSP_dotpr(u, stride, v, stride, &total, n); +#else vDSP_dotprD(u, stride, v, stride, &total, n); +#endif return total; } #endif @@ -138,7 +142,7 @@ SIMDDetect::SIMDDetect() { SetDotProduct(DotProductGeneric); const char* dotproduct_env = getenv("DOTPRODUCT"); if (dotproduct_env != nullptr) { - dotproduct = env; + dotproduct = dotproduct_env; Update(); if (strcmp(dotproduct_env, "native") == 0) { SetDotProduct(DotProductNative); diff --git a/unittest/dotproduct_test.cc b/unittest/dotproduct_test.cc index 36b0fbd97e..2682243da5 100644 --- a/unittest/dotproduct_test.cc +++ b/unittest/dotproduct_test.cc @@ -30,7 +30,7 @@ class DotProductTest : public ::testing::Test { std::locale::global(std::locale("")); } void RunTest(TFloat (*f)(const TFloat *u, const TFloat *v, int n)); - static const size_t multiplications = 500000000; + static const size_t multiplications = 5000000000U; static const size_t n = 40; //static const size_t n = 1000000; TFloat u[n]; @@ -118,4 +118,26 @@ TEST_F(DotProductTest, FMA) { #endif } +#if defined(HAVE_FRAMEWORK_ACCELERATE) +TEST_F(DotProductTest, Accelerate) { + RunTest(DotProductAccelerate); +} +#endif + +#if 0 +// Tests that the NEON implementation gets the same result as the vanilla. +TEST_F(DotProductTest, NEON) { +#if defined(HAVE_NEON) + if (!SIMDDetect::IsNEONAvailable()) { + GTEST_LOG_(INFO) << "No NEON found! Not tested!"; + GTEST_SKIP(); + } + RunTest(DotProductNEON); +#else + GTEST_LOG_(INFO) << "NEON unsupported! Not tested!"; + GTEST_SKIP(); +#endif +} +#endif + } // namespace tesseract From 1a59b6f22b49c212beff0dd10d4e76fe8d8b18ee Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Tue, 13 Jul 2021 09:21:02 +0200 Subject: [PATCH 10/11] Fix DotProductNative for TFloat Signed-off-by: Stefan Weil --- src/arch/dotproduct.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/arch/dotproduct.cpp b/src/arch/dotproduct.cpp index d0be2717fa..b8764ec16d 100644 --- a/src/arch/dotproduct.cpp +++ b/src/arch/dotproduct.cpp @@ -20,8 +20,9 @@ namespace tesseract { // Computes and returns the dot product of the two n-vectors u and v. TFloat DotProductNative(const TFloat *u, const TFloat *v, int n) { - double total = 0.0; - for (int k = 0; k < n; ++k) { + TFloat total = 0; +#pragma omp simdi reduction(+:total) + for (int k = 0; k < n; k++) { total += u[k] * v[k]; } return total; From d2eb7bdf4d35f90b14ef99f8dfd20db2363d8a0f Mon Sep 17 00:00:00 2001 From: Ger Hobbelt Date: Tue, 13 Jul 2021 10:04:59 +0200 Subject: [PATCH 11/11] same as patch-4 (#3494) but now with reduced code duplication: for TFloat to work, we don't need to duplicate the integer work functions as it's only the ExtractResults16[8,16] functions that need different implementations for float vs. double. These are therefor common to both implementations: ``` static void PartialMatrixDotVector64(const int8_t *wi, const TFloat *scales, const int8_t *u, int num_in, TFloat *v) { static void PartialMatrixDotVector32(const int8_t *wi, const TFloat *scales, const int8_t *u, int num_in, TFloat *v) { static void PartialMatrixDotVector16(const int8_t *wi, const TFloat *scales, const int8_t *u, int num_in, TFloat *v) { static inline void PartialMatrixDotVector8(const int8_t *wi, const TFloat *scales, const int8_t *u, int num_in, TFloat *v) { static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const TFloat *scales, const int8_t *u, TFloat *v) { ``` --- src/arch/intsimdmatrixavx2.cpp | 307 ++++++--------------------------- 1 file changed, 54 insertions(+), 253 deletions(-) diff --git a/src/arch/intsimdmatrixavx2.cpp b/src/arch/intsimdmatrixavx2.cpp index 8f671f08a9..a463408597 100644 --- a/src/arch/intsimdmatrixavx2.cpp +++ b/src/arch/intsimdmatrixavx2.cpp @@ -132,253 +132,52 @@ static inline __m128i load64_to_128(const int8_t *wi_) { } #if defined(FAST_FLOAT) -static inline void ExtractResults8(__m256i result, const int8_t *wi, const float *scales, - float *v) { - __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg - __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg - __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); - __m256d scale01234567 = _mm256_loadu_ps(scales); - //~ __m256d scale4567 = _mm256_loadu_ps(scales + 8); - w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 - result = _mm256_add_epi32(result, w256); // result += bias * 127 - __m256 res01234567 = _mm256_cvtepi32_ps(_mm256_castsi256_si128(result)); - result = _mm256_permute4x64_epi64(result, 2 + (3 << 2)); - __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result)); - res01234567 = _mm256_mul_pd(res01234567, scale01234567); - //~ res4567 = _mm256_mul_pd(res4567, scale4567); - _mm256_storeu_ps(v, res01234567); - //~ _mm256_storeu_pd(v + 4, res4567); -} - -static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi, - const float *&scales, float *&v) { - __m128i w8 = _mm_loadu_si128(reinterpret_cast(wi)); - // 8x8bit vals in bottom of 128bit reg - const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); - __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg - __m256d scale0123 = _mm256_loadu_ps(scales); - __m256d scale4567 = _mm256_loadu_ps(scales + 8); - w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 - result0 = _mm256_add_epi32(result0, w256); // result += bias * 127 - __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0)); - result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2)); - __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0)); - res0123 = _mm256_mul_pd(res0123, scale0123); - res4567 = _mm256_mul_pd(res4567, scale4567); - _mm256_storeu_ps(v, res0123); - _mm256_storeu_ps(v + 8, res4567); - w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2)); - w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg - scale0123 = _mm256_loadu_ps(scales + 16); - scale4567 = _mm256_loadu_ps(scales + 24); - w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 - result1 = _mm256_add_epi32(result1, w256); // result += bias * 127 - res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1)); - result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2)); - res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1)); - res0123 = _mm256_mul_pd(res0123, scale0123); - res4567 = _mm256_mul_pd(res4567, scale4567); - _mm256_storeu_ps(v + 16, res0123); - _mm256_storeu_ps(v + 24, res4567); - wi += 16; - scales += 16; - v += 16; -} - -// Computes part of matrix.vector v = Wu. Computes N=64 results. -// The weights *must* be arranged so that consecutive reads from wi -// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of -// (kNumInputsPerGroup inputs))). After that there must be N consecutive -// bias weights, before continuing with any more weights. -// u must be padded out with zeros to -// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements. -static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u, - int num_in, float *v) { - // Register containing 16-bit ones for horizontal add with 16->32 bit - // conversion. - __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); - __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); - // Initialize all the results to 0. - __m256i result0 = _mm256_setzero_si256(); - __m256i result1 = _mm256_setzero_si256(); - __m256i result2 = _mm256_setzero_si256(); - __m256i result3 = _mm256_setzero_si256(); - __m256i result4 = _mm256_setzero_si256(); - __m256i result5 = _mm256_setzero_si256(); - __m256i result6 = _mm256_setzero_si256(); - __m256i result7 = _mm256_setzero_si256(); - // Iterate over the input (u), one registerful at a time. - for (int j = 0; j < num_in;) { - __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); - // Inputs are processed in groups of kNumInputsPerGroup, replicated - // kNumInputGroups times. - for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { - // Replicate the low 32 bits (4 inputs) 8 times. - __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); - // Rotate the inputs in groups of 4, so the next 4 inputs are ready. - inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); - __m256i weights, reps; - // Mul-add, with horizontal add of the 4 inputs to each of the results. - MultiplyGroup(rep_input, ones, wi, weights, reps, result0); - MultiplyGroup(rep_input, ones, wi, weights, reps, result1); - MultiplyGroup(rep_input, ones, wi, weights, reps, result2); - MultiplyGroup(rep_input, ones, wi, weights, reps, result3); - MultiplyGroup(rep_input, ones, wi, weights, reps, result4); - MultiplyGroup(rep_input, ones, wi, weights, reps, result5); - MultiplyGroup(rep_input, ones, wi, weights, reps, result6); - MultiplyGroup(rep_input, ones, wi, weights, reps, result7); - } - } - ExtractResults16(result0, result1, wi, scales, v); - ExtractResults16(result2, result3, wi, scales, v); - ExtractResults16(result4, result5, wi, scales, v); - ExtractResults16(result6, result7, wi, scales, v); -} - -// Computes part of matrix.vector v = Wu. Computes N=32 results. -// For details see PartialMatrixDotVector64 with N=32. -static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u, - int num_in, float *v) { - // Register containing 16-bit ones for horizontal add with 16->32 bit - // conversion. - __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); - __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); - // Initialize all the results to 0. - __m256i result0 = _mm256_setzero_si256(); - __m256i result1 = _mm256_setzero_si256(); - __m256i result2 = _mm256_setzero_si256(); - __m256i result3 = _mm256_setzero_si256(); - // Iterate over the input (u), one registerful at a time. - for (int j = 0; j < num_in;) { - __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); - // Inputs are processed in groups of kNumInputsPerGroup, replicated - // kNumInputGroups times. - for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { - // Replicate the low 32 bits (4 inputs) 8 times. - __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); - // Rotate the inputs in groups of 4, so the next 4 inputs are ready. - inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); - __m256i weights, reps; - // Mul-add, with horizontal add of the 4 inputs to each of the results. - MultiplyGroup(rep_input, ones, wi, weights, reps, result0); - MultiplyGroup(rep_input, ones, wi, weights, reps, result1); - MultiplyGroup(rep_input, ones, wi, weights, reps, result2); - MultiplyGroup(rep_input, ones, wi, weights, reps, result3); - } - } - ExtractResults16(result0, result1, wi, scales, v); - ExtractResults16(result2, result3, wi, scales, v); -} -// Computes part of matrix.vector v = Wu. Computes N=16 results. -// For details see PartialMatrixDotVector64 with N=16. -static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u, - int num_in, float *v) { - // Register containing 16-bit ones for horizontal add with 16->32 bit - // conversion. - __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); - __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); - // Initialize all the results to 0. - __m256i result0 = _mm256_setzero_si256(); - __m256i result1 = _mm256_setzero_si256(); - // Iterate over the input (u), one registerful at a time. - for (int j = 0; j < num_in;) { - __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); - // Inputs are processed in groups of kNumInputsPerGroup, replicated - // kNumInputGroups times. - for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { - // Replicate the low 32 bits (4 inputs) 8 times. - __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); - // Rotate the inputs in groups of 4, so the next 4 inputs are ready. - inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); - __m256i weights, reps; - // Mul-add, with horizontal add of the 4 inputs to each of the results. - MultiplyGroup(rep_input, ones, wi, weights, reps, result0); - MultiplyGroup(rep_input, ones, wi, weights, reps, result1); - } - } - ExtractResults16(result0, result1, wi, scales, v); +static inline void ExtractResults8(__m256i result, const int8_t* wi, const TFloat* scales, + TFloat* v) { + __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg + __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg + __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); + __m256 scale01234567 = _mm256_loadu_ps(scales); + w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 + result = _mm256_add_epi32(result, w256); // result += bias * 127 + __m256 res01234567 = _mm256_cvtepi32_ps(result); + result = _mm256_permute4x64_epi64(result, 2 + (3 << 2)); + res01234567 = _mm256_mul_ps(res01234567, scale01234567); + _mm256_storeu_ps(v, res01234567); } -// Computes part of matrix.vector v = Wu. Computes N=8 results. -// For details see PartialMatrixDotVector64 with N=8. -static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u, - int num_in, float *v) { - // Register containing 16-bit ones for horizontal add with 16->32 bit - // conversion. - __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); - __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); - // Initialize all the results to 0. - __m256i result0 = _mm256_setzero_si256(); - // Iterate over the input (u), one registerful at a time. - for (int j = 0; j < num_in;) { - __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); - // Inputs are processed in groups of kNumInputsPerGroup, replicated - // kNumInputGroups times. - for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { - // Replicate the low 32 bits (4 inputs) 8 times. - __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); - // Rotate the inputs in groups of 4, so the next 4 inputs are ready. - inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); - __m256i weights, reps; - // Mul-add, with horizontal add of the 4 inputs to each of the results. - MultiplyGroup(rep_input, ones, wi, weights, reps, result0); - } - } - ExtractResults8(result0, wi, scales, v); +static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t*& wi, + const TFloat*& scales, TFloat*& v) { + __m128i w8 = _mm_loadu_si128(reinterpret_cast(wi)); + // 8x8bit vals in bottom of 128bit reg + const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); + __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg + __m256 scale01234567 = _mm256_loadu_ps(scales); + w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 + result0 = _mm256_add_epi32(result0, w256); // result += bias * 127 + __m256 res01234567 = _mm256_cvtepi32_ps(result0); + result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2)); + res01234567 = _mm256_mul_ps(res01234567, scale01234567); + _mm256_storeu_ps(v, res01234567); + w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2)); + w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg + scale01234567 = _mm256_loadu_ps(scales + 8); + w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 + result1 = _mm256_add_epi32(result1, w256); // result += bias * 127 + res01234567 = _mm256_cvtepi32_ps(result1); + result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2)); + res01234567 = _mm256_mul_ps(res01234567, scale01234567); + _mm256_storeu_ps(v + 8, res01234567); + wi += 16; + scales += 16; + v += 16; } -static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales, - const int8_t *u, float *v) { - const int num_out = dim1; - const int num_in = dim2 - 1; - // Each call to a partial_func_ produces group_size outputs, except the - // last one, which can produce less. - const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup); - const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister); - int group_size = kNumOutputsPerRegister * kMaxOutputRegisters; - int output = 0; - - int w_step = (rounded_num_in + 1) * group_size; - - // Run with this group size, until it would produce too much output, then - // switch to a smaller size. - for (; output + group_size <= rounded_num_out; output += group_size) { - PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v); - wi += w_step; - scales += group_size; - v += group_size; - } - group_size /= 2; - w_step /= 2; - - if (output + group_size <= rounded_num_out) { - PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v); - wi += w_step; - scales += group_size; - v += group_size; - output += group_size; - } - group_size /= 2; - w_step /= 2; - - if (output + group_size <= rounded_num_out) { - PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v); - wi += w_step; - scales += group_size; - v += group_size; - output += group_size; - } - group_size /= 2; - w_step /= 2; - - if (output + group_size <= rounded_num_out) { - PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v); - } -} #else -static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales, - double *v) { + +static inline void ExtractResults8(__m256i result, const int8_t *wi, const TFloat *scales, + TFloat *v) { __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); @@ -396,7 +195,7 @@ static inline void ExtractResults8(__m256i result, const int8_t *wi, const doubl } static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi, - const double *&scales, double *&v) { + const TFloat *&scales, TFloat *&v) { __m128i w8 = _mm_loadu_si128(reinterpret_cast(wi)); // 8x8bit vals in bottom of 128bit reg const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); @@ -430,6 +229,8 @@ static inline void ExtractResults16(__m256i result0, __m256i result1, const int8 v += 16; } +#endif + // Computes part of matrix.vector v = Wu. Computes N=64 results. // The weights *must* be arranged so that consecutive reads from wi // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of @@ -437,8 +238,8 @@ static inline void ExtractResults16(__m256i result0, __m256i result1, const int8 // bias weights, before continuing with any more weights. // u must be padded out with zeros to // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements. -static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, const int8_t *u, - int num_in, double *v) { +static void PartialMatrixDotVector64(const int8_t *wi, const TFloat *scales, const int8_t *u, + int num_in, TFloat *v) { // Register containing 16-bit ones for horizontal add with 16->32 bit // conversion. __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); @@ -482,8 +283,8 @@ static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, con // Computes part of matrix.vector v = Wu. Computes N=32 results. // For details see PartialMatrixDotVector64 with N=32. -static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, const int8_t *u, - int num_in, double *v) { +static void PartialMatrixDotVector32(const int8_t *wi, const TFloat *scales, const int8_t *u, + int num_in, TFloat *v) { // Register containing 16-bit ones for horizontal add with 16->32 bit // conversion. __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); @@ -517,8 +318,8 @@ static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, con // Computes part of matrix.vector v = Wu. Computes N=16 results. // For details see PartialMatrixDotVector64 with N=16. -static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, const int8_t *u, - int num_in, double *v) { +static void PartialMatrixDotVector16(const int8_t *wi, const TFloat *scales, const int8_t *u, + int num_in, TFloat *v) { // Register containing 16-bit ones for horizontal add with 16->32 bit // conversion. __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); @@ -547,8 +348,8 @@ static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, con // Computes part of matrix.vector v = Wu. Computes N=8 results. // For details see PartialMatrixDotVector64 with N=8. -static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scales, const int8_t *u, - int num_in, double *v) { +static inline void PartialMatrixDotVector8(const int8_t *wi, const TFloat *scales, const int8_t *u, + int num_in, TFloat *v) { // Register containing 16-bit ones for horizontal add with 16->32 bit // conversion. __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); @@ -573,8 +374,8 @@ static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scale ExtractResults8(result0, wi, scales, v); } -static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales, - const int8_t *u, double *v) { +static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const TFloat *scales, + const int8_t *u, TFloat *v) { const int num_out = dim1; const int num_in = dim2 - 1; // Each call to a partial_func_ produces group_size outputs, except the @@ -621,7 +422,7 @@ static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double * PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v); } } -#endif + static const IntSimdMatrix simdMatrix = { // Function.