Skip to content

Commit

Permalink
Support Apple Accelerate framework for training and best models
Browse files Browse the repository at this point in the history
Signed-off-by: Stefan Weil <[email protected]>
  • Loading branch information
stweil committed Jul 13, 2021
1 parent e2529dd commit 01ae69e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 8 deletions.
13 changes: 7 additions & 6 deletions configure.ac
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ m4_define([MY_CHECK_FRAMEWORK],
])
if test "$my_cv_framework_$1"="yes"; then
AC_DEFINE(AS_TR_CPP([HAVE_FRAMEWORK_$1]), 1,
[Define if you have the $1 framework])
[Define if you have the $1 framework])
AS_TR_CPP([FRAMEWORK_$1])="-framework $1"
AC_SUBST(AS_TR_CPP([FRAMEWORK_$1]))
fi]
Expand All @@ -295,13 +295,14 @@ OPENCL_CPPFLAGS=''
OPENCL_LDFLAGS=''
case "${host_os}" in
*darwin* | *-macos10*)
echo "checking for OpenCL framework"
MY_CHECK_FRAMEWORK([OpenCL])
if test $my_cv_framework_OpenCL = yes; then
have_opencl_lib=true
MY_CHECK_FRAMEWORK([Accelerate])
if test $my_cv_framework_Accelerate = yes; then
AM_CPPFLAGS="-DHAVE_FRAMEWORK_ACCELERATE $AM_CPPFLAGS"
LDFLAGS="$LDFLAGS -framework Accelerate"
fi
MY_CHECK_FRAMEWORK([OpenCL])
if test "$enable_opencl" = "yes"; then
if !($have_opencl_lib); then
if test $my_cv_framework_OpenCL = no; then
AC_MSG_ERROR([Required OpenCL library not found!])
fi
AM_CPPFLAGS="-DUSE_OPENCL $AM_CPPFLAGS"
Expand Down
51 changes: 49 additions & 2 deletions src/arch/simddetect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,23 @@
#include "simddetect.h"
#include "tprintf.h" // for tprintf

#if defined(HAVE_FRAMEWORK_ACCELERATE)

// Use Apple Accelerate framework.
// https://developer.apple.com/documentation/accelerate/simd

// Comparison of execution time with different dot product implementations.
// time DOTPRODUCT=accelerate lstm_squashed_test
// Results for Apple M1:
// DotProductGeneric 64 s
// DotProduct 60 s
// DotProductAccelerate 33 s
// DotProductNative 30 s

#include <Accelerate/Accelerate.h>

#endif

#if defined(HAVE_AVX) || defined(HAVE_AVX2) || defined(HAVE_FMA) || defined(HAVE_SSE4_1)
# define HAS_CPUID
#endif
Expand Down Expand Up @@ -83,6 +100,15 @@ bool SIMDDetect::fma_available_;
bool SIMDDetect::sse_available_;
#endif

#if defined(HAVE_FRAMEWORK_ACCELERATE)
static double DotProductAccelerate(const double* u, const double* v, int n) {
double total = 0.0;
const int stride = 1;
vDSP_dotprD(u, stride, v, stride, &total, n);
return total;
}
#endif

// Computes and returns the dot product of the two n-vectors u and v.
static TFloat DotProductGeneric(const TFloat *u, const TFloat *v, int n) {
TFloat total = 0.0;
Expand Down Expand Up @@ -110,10 +136,17 @@ static void SetDotProduct(DotProductFunction f, const IntSimdMatrix *m = nullptr
SIMDDetect::SIMDDetect() {
// The fallback is a generic dot product calculation.
SetDotProduct(DotProductGeneric);
const char *env = getenv("dotproduct");
if (env) {
const char* dotproduct_env = getenv("DOTPRODUCT");
if (dotproduct_env != nullptr) {
dotproduct = env;
Update();
if (strcmp(dotproduct_env, "native") == 0) {
SetDotProduct(DotProductNative);
#if defined(HAVE_FRAMEWORK_ACCELERATE)
} else if (strcmp(dotproduct_env, "accelerate") == 0) {
SetDotProduct(DotProductAccelerate);
#endif
}
return;
}

Expand Down Expand Up @@ -240,6 +273,11 @@ void SIMDDetect::Update() {
// Native optimized code selected by config variable.
SetDotProduct(DotProductNative);
dotproduct_method = "native";
#if defined(HAVE_FRAMEWORK_ACCELERATE)
} else if (dotproduct == "accelerate") {
SetDotProduct(DotProductAccelerate);
dotproduct_method = "accelerate";
#endif
#if defined(HAVE_AVX2)
} else if (!strcmp(dotproduct.c_str(), "avx2")) {
// AVX2 selected by config variable.
Expand Down Expand Up @@ -277,9 +315,18 @@ void SIMDDetect::Update() {
dotproduct.c_str());
tprintf(
"Support values for dotproduct: auto generic native"
#if defined(HAVE_FRAMEWORK_ACCELERATE)
" accelerate"
#endif
#if defined(HAVE_AVX2)
" avx2"
#endif
#if defined(HAVE_AVX)
" avx"
#endif
#if defined(HAVE_FMA)
" fma"
#endif
#if defined(HAVE_SSE4_1)
" sse"
#endif
Expand Down

0 comments on commit 01ae69e

Please sign in to comment.