Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33882][ML] Add a vectorized BLAS implementation #30810

Closed
wants to merge 20 commits into from

Conversation

luhenry
Copy link
Contributor

@luhenry luhenry commented Dec 16, 2020

What changes were proposed in this pull request?

This patch introduces a VectorizedBLAS class which implements such hardware-accelerated BLAS operations. This feature is hidden behind the "jvm-vectorized" profile that you can enable by passing "-Pjvm-vectorized" to sbt or maven.

The Vector API has been introduced in JDK 16. Following discussion on the mailing list, this API is introduced transparently and needs to be enabled explicitely.

Why are the changes needed?

Whenever a native BLAS implementation isn't available on the system, Spark automatically falls back onto a Java implementation. With the recent release of the Vector API in the OpenJDK [1], we can use hardware acceleration for such operations.

This change was also discussed on the mailing list. [2]

Does this PR introduce any user-facing change?

It introduces a build-time profile called jvm-vectorized. You can pass it to sbt and mvn with -Pjvm-vectorized. There is no change to the end-user of Spark and it should only impact Spark developpers. It is also disabled by default.

How was this patch tested?

It passes build/sbt mllib-local/test with and without -Pjvm-vectorized with JDK 16. This patch also introduces benchmarks for BLAS.

The benchmark results are as follows:

[info] daxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  37             37           0        271.5           3.7       1.0X
[info] vector                                               24             25           4        416.1           2.4       1.5X
[info] 
[info] ddot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  70             70           0        143.2           7.0       1.0X
[info] vector                                               35             35           2        288.7           3.5       2.0X
[info] 
[info] sdot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  50             51           1        199.8           5.0       1.0X
[info] vector                                               15             15           0        648.7           1.5       3.2X
[info] 
[info] dscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  34             34           0        295.6           3.4       1.0X
[info] vector                                               19             19           0        531.2           1.9       1.8X
[info] 
[info] sscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  25             25           1        399.0           2.5       1.0X
[info] vector                                                8              9           1       1177.3           0.8       3.0X
[info] 
[info] dgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  27             27           0          0.0       26651.5       1.0X
[info] vector                                               21             21           0          0.0       20646.3       1.3X
[info] 
[info] dgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  36             36           0          0.0       35501.4       1.0X
[info] vector                                               22             22           0          0.0       21930.3       1.6X
[info] 
[info] sgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  20             20           0          0.0       20283.3       1.0X
[info] vector                                                9              9           0          0.1        8657.7       2.3X
[info] 
[info] sgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  30             30           0          0.0       29845.8       1.0X
[info] vector                                               10             10           1          0.1        9695.4       3.1X
[info] 
[info] dgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 182            182           0          0.5        1820.0       1.0X
[info] vector                                              160            160           1          0.6        1597.6       1.1X
[info] 
[info] dgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 211            211           1          0.5        2106.2       1.0X
[info] vector                                              156            157           0          0.6        1564.4       1.3X
[info] 
[info] dgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 276            276           0          0.4        2757.8       1.0X
[info] vector                                              137            137           0          0.7        1365.1       2.0X

/cc @srowen @xkrogen

[1] https://openjdk.java.net/jeps/338
[2] https://mail-archives.apache.org/mod_mbox/spark-dev/202012.mbox/%3cDM5PR2101MB11106162BB3AF32AD29C6C79B0C69@DM5PR2101MB1110.namprd21.prod.outlook.com%3e

Whenever a native BLAS implementation isn't available on the system, Spark
automatically falls back onto a Java implementation. With the recent release
of the Vector API in the OpenJDK [1], we can use hardware acceleration for such
operations.

This patch introduces a VectorizedBLAS class which implements such
hardware-accelerated BLAS operations. This feature is hidden behind the
"vectorized" profile that you can enable by passing "-Pvectorized" to sbt or
maven.

The Vector API has been introduced in JDK 16. Following discussion on the
mailing list, this API is introduced transparently and needs to be enabled
explicitely.

[1] https://openjdk.java.net/jeps/338
Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the right general approach - do we yet have benchmarks from the benchmark test here, vs OpenBLAS for example?

mllib-local/pom.xml Outdated Show resolved Hide resolved
@srowen
Copy link
Member

srowen commented Dec 16, 2020

This will eventually need a JIRA, but I'm going to mark it WIP for now

@srowen srowen marked this pull request as draft December 16, 2020 17:31
@srowen srowen changed the title [ML] Add a vectorized BLAS implementation [WIP][ML] Add a vectorized BLAS implementation Dec 16, 2020
@luhenry
Copy link
Contributor Author

luhenry commented Dec 17, 2020

I've run some benchmarks overnight comparing the implementation of the native, f2j, and vectorized implementations. You can find the results at https://gist.github.com/luhenry/2cda93cb40f3edef76cb499c896608a9

Some things I noted which are noteworthy or which I need to investigate further:

  1. for daxpy and ddot, the vectorized version is equivalent or faster than the native implementation
  2. for dscal, dspr, dsyr, and sdot, the vectorized version is faster than native on small vectors/matrices but slower on large ones; that is something I need to investigate in the Vector API implementation itself
  3. for dgemm and dgemv, the performance is at best equivalent to f2j and native; I'm focusing on that now since dgemm is a major component of neural net training and logistical regressions; the problem is with my naive implementation of the algorithms and not necessarily with the Vector API itself.

All implmentations beat f2j on microbenchmarks on x86 (w/ AVX-2).

See https://github.com/luhenry/vectorizedblas/releases/tag/v0.1.5 for more details
It simplifies the build process, and will allow for faster iterations on my end
We still use it for the benchmarks but it should go away at some point
@luhenry
Copy link
Contributor Author

luhenry commented Dec 20, 2020

I updated the PR to depend on the package dev.ludovic.vectorizedblas-blas instead as it makes it a lot easier for me to go and evolve the algorithms independently of Spark, and for the integration to the build system. Let me know if that compromise of good for you.

As for the latest results, it's looking much better:

[info] f2jBLAS    = com.github.fommil.netlib.F2jBLAS
[info] vectorBLAS = dev.ludovic.blas.VectorizedBLAS
[info] 
[info] daxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  45             45           1        223.7           4.5       1.0X
[info] vector                                               26             26           3        389.4           2.6       1.7X
[info] 
[info] sdot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  53             53           1        190.4           5.3       1.0X
[info] vector                                               16             17           1        607.4           1.6       3.2X
[info] 
[info] ddot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  73             73           1        137.3           7.3       1.0X
[info] vector                                               35             36           3        282.8           3.5       2.1X
[info] 
[info] dscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  36             36           1        279.5           3.6       1.0X
[info] vector                                               21             21           0        481.4           2.1       1.7X
[info] 
[info] dgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  35             36           0          0.0       35458.6       1.0X
[info] vector                                               23             23           0          0.0       23415.6       1.5X
[info]
[info] dgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 275            276           1          0.4        2753.6       1.0X
[info] vector                                              149            166         169          0.7        1488.5       1.8X

For much more detailed performance numbers on x86 (w/ AVX-2), I'm currently running a JMH benchmark covering more cases. I'll link to it as soon as it finishes (by tomorrow morning CET).

EDIT: I've added the results of the JMH benchmark suite at https://github.com/luhenry/vectorizedblas/releases/tag/v0.1.7

@github-actions github-actions bot added the INFRA label Dec 22, 2020
Brings acceleration for sscal, dgemm[N,N], dgemm[N,T], dgemv[N], sgemv[N] and sgemv[T]
@luhenry
Copy link
Contributor Author

luhenry commented Dec 22, 2020

The latest BLASBenchmark results with 0.1.9:

[info] daxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  37             37           0        271.5           3.7       1.0X
[info] vector                                               24             25           4        416.1           2.4       1.5X
[info] 
[info] ddot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  70             70           0        143.2           7.0       1.0X
[info] vector                                               35             35           2        288.7           3.5       2.0X
[info] 
[info] sdot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  50             51           1        199.8           5.0       1.0X
[info] vector                                               15             15           0        648.7           1.5       3.2X
[info] 
[info] dscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  34             34           0        295.6           3.4       1.0X
[info] vector                                               19             19           0        531.2           1.9       1.8X
[info] 
[info] sscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  25             25           1        399.0           2.5       1.0X
[info] vector                                                8              9           1       1177.3           0.8       3.0X
[info] 
[info] dgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  27             27           0          0.0       26651.5       1.0X
[info] vector                                               21             21           0          0.0       20646.3       1.3X
[info] 
[info] dgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  36             36           0          0.0       35501.4       1.0X
[info] vector                                               22             22           0          0.0       21930.3       1.6X
[info] 
[info] sgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  20             20           0          0.0       20283.3       1.0X
[info] vector                                                9              9           0          0.1        8657.7       2.3X
[info] 
[info] sgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  30             30           0          0.0       29845.8       1.0X
[info] vector                                               10             10           1          0.1        9695.4       3.1X
[info] 
[info] dgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 182            182           0          0.5        1820.0       1.0X
[info] vector                                              160            160           1          0.6        1597.6       1.1X
[info] 
[info] dgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 211            211           1          0.5        2106.2       1.0X
[info] vector                                              156            157           0          0.6        1564.4       1.3X
[info] 
[info] dgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 276            276           0          0.4        2757.8       1.0X
[info] vector                                              137            137           0          0.7        1365.1       2.0X

I'll now run a LogisticRegression benchmark that should benefit from these accelerated operations. I'll also later run a MultilayerPerceptronClassifier benchmark which should equally benefit from accelerated dgemm operations.

@SparkQA
Copy link

SparkQA commented Apr 12, 2021

Kubernetes integration test status failure
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/41802/

@SparkQA
Copy link

SparkQA commented Apr 12, 2021

Test build #137222 has finished for PR 30810 at commit d93b274.

  • This patch passes all tests.
  • This patch does not merge cleanly.
  • This patch adds the following public classes (experimental):
  • public class VectorizedBLAS extends F2jBLAS

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm good with it. Looks like a clean speedup over f2j. This also changes in a few cases where native vs Java BLAS is invoked, but I think it's probably a good rationalization of those calls.

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a moment, you might also do one last check through the source for other occurrences of BLAS that may have arisen after you began the change. @zhengruifeng has been applying BLAS in more places.

@luhenry
Copy link
Contributor Author

luhenry commented Apr 13, 2021

This also changes in a few cases where native vs Java BLAS is invoked, but I think it's probably a good rationalization of those calls.

I couldn't find such a case. Changes from blas to BLAS.nativeBLAS should lead to using native libraries just the same. It's only in the case that the native library is not available, then it will try to fallback to VectorizedBLAS instead of F2jBLAS.

If you have a moment, you might also do one last check through the source for other occurrences of BLAS that may have arisen after you began the change. @zhengruifeng has been applying BLAS in more places.

I couldn't find any other occurrence.

Thank you again :)

@srowen
Copy link
Member

srowen commented Apr 13, 2021

OK sounds good. @zhengruifeng it may only be your PRs in flight that might need to adjust.

@srowen
Copy link
Member

srowen commented Apr 13, 2021

Jenkins retest this please

@SparkQA
Copy link

SparkQA commented Apr 14, 2021

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/41893/

@SparkQA
Copy link

SparkQA commented Apr 14, 2021

Kubernetes integration test status failure
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/41893/

@zhengruifeng
Copy link
Contributor

@srowen It is OK, I will wait for this PR and then adjust my one.

@SparkQA
Copy link

SparkQA commented Apr 14, 2021

Test build #137313 has finished for PR 30810 at commit 10595df.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class SQLProcessor(object):
  • case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression
  • trait BaseGroupingSets extends Expression with CodegenFallback
  • case class Cube(
  • case class Acos(child: Expression) extends UnaryMathExpression(math.acos, \"ACOS\")
  • case class Asin(child: Expression) extends UnaryMathExpression(math.asin, \"ASIN\")
  • case class Atan(child: Expression) extends UnaryMathExpression(math.atan, \"ATAN\")
  • case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, \"CBRT\")
  • case class Cos(child: Expression) extends UnaryMathExpression(math.cos, \"COS\")
  • case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, \"COSH\")
  • case class Log10(child: Expression) extends UnaryLogExpression(StrictMath.log10, \"LOG10\")
  • case class Signum(child: Expression) extends UnaryMathExpression(math.signum, \"SIGNUM\")
  • case class Sin(child: Expression) extends UnaryMathExpression(math.sin, \"SIN\")
  • case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, \"SINH\")
  • case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, \"SQRT\")
  • case class Tan(child: Expression) extends UnaryMathExpression(math.tan, \"TAN\")
  • case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, \"TANH\")
  • case class DeleteAction(condition: Option[Expression]) extends MergeAction
  • case class UpdateStarAction(condition: Option[Expression]) extends MergeAction
  • case class InsertStarAction(condition: Option[Expression]) extends MergeAction
  • case class RefreshTable(child: LogicalPlan) extends UnaryCommand
  • case class CommentOnNamespace(child: LogicalPlan, comment: String) extends UnaryCommand
  • case class CommentOnTable(child: LogicalPlan, comment: String) extends UnaryCommand
  • case class RefreshFunction(child: LogicalPlan) extends UnaryCommand
  • case class DescribeFunction(child: LogicalPlan, isExtended: Boolean) extends UnaryCommand
  • case class RecoverPartitions(child: LogicalPlan) extends UnaryCommand
  • case class RuleId(id: Int)
  • abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with TreePatternBits
  • trait TreePatternBits
  • case class SetCommand(kv: Option[(String, Option[String])])
  • case class ResetCommand(config: Option[String]) extends LeafRunnableCommand with IgnoreCachedData
  • case class AddJarCommand(path: String) extends LeafRunnableCommand
  • case class AddFileCommand(path: String) extends LeafRunnableCommand
  • case class AddArchiveCommand(path: String) extends LeafRunnableCommand
  • case class ListFilesCommand(files: Seq[String] = Seq.empty[String]) extends LeafRunnableCommand
  • case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends LeafRunnableCommand
  • case class ListArchivesCommand(archives: Seq[String] = Seq.empty[String])
  • abstract class DescribeCommandBase extends LeafRunnableCommand
  • case class WriteToDataSourceV2(
  • case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec
  • case class WriteToMicroBatchDataSource(

@srowen srowen closed this in 9244066 Apr 14, 2021
@srowen
Copy link
Member

srowen commented Apr 14, 2021

Merged to master

@zhengruifeng
Copy link
Contributor

zhengruifeng commented Apr 15, 2021

Late LGTM.
BTW, I am not familiar with the new vector api, is it possible to apply it in BLAS for sparse vec/mat?


/**
* BLAS routines for MLlib's vectors and matrices.
*/
private[spark] object BLAS extends Serializable {

@transient private var _f2jBLAS: NetlibBLAS = _
@transient private var _javaBLAS: NetlibBLAS = _
@transient private var _nativeBLAS: NetlibBLAS = _
private val nativeL1Threshold: Int = 256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to the performance test, I think we can increase nativeL1Threshold to 512?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would go even as far as using nativeBLAS exclusively for level-3 operations, and never for level-1 and level-2. The cost of copying the data from managed memory to native memory (necessary to pass the array to native code) is too great relative to the small speed up of native for the level-1 and level-2 routines.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

netlib-java does not copy memory when using native backend, it uses memory pinning (which has its own problems). Please provide benchmarks to show any degradation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"small speed up of native for the level-1 and level-2 routines." I think you need to do some more analysis on this. Native can be 10x faster than JVM for reasonable sized matrices. However, as shown in https://github.com/fommil/matrix-toolkits-java the EJML and common-math project are faster for matrices of 10x10 or smaller. If you want to heavily optimise for those usecases, then swap to using EJML which is heavily optimised for that usecase (not just "something on the JVM")

@kiszk
Copy link
Member

kiszk commented Apr 16, 2021

@luhenry Could you please update the description using -Pjvm-vectorized instead of -Pvectorized?

@luhenry
Copy link
Contributor Author

luhenry commented Apr 20, 2021

Late LGTM.
BTW, I am not familiar with the new vector api, is it possible to apply it in BLAS for sparse vec/mat?

This vector API provides access to hardware acceleration. So as long as you can express the sparse vec/matrix operations with hardware vectors, you should be able to use the Vector API. However, from my cursory glance at the implementation I’m BLAS.Scala, using hardware acceleration doesn’t seem trivial.

@luhenry
Copy link
Contributor Author

luhenry commented Apr 20, 2021

Late LGTM.
BTW, I am not familiar with the new vector api, is it possible to apply it in BLAS for sparse vec/mat?

This vector API provides access to hardware acceleration. So as long as you can express the sparse vec/matrix operations with hardware vectors, you should be able to use the Vector API. However, from my cursory glance at the implementation I’m BLAS.Scala, using hardware acceleration doesn’t seem trivial.

@zhengruifeng I looked further into that today and what might be looking interesting is Intel MKL support for level-2 and level-3 operations on sparse vectors/matrices (https://software.intel.com/content/www/us/en/develop/documentation/onemkl-developer-reference-c/top/blas-and-sparse-blas-routines/inspector-executor-sparse-blas-routines.html). I'll research what's applicable to Spark, and how we could surface it. Feel free to reach out if you want to discuss it offline.

srowen pushed a commit that referenced this pull request Apr 27, 2021
### What changes were proposed in this pull request?

Following #30810, I've continued looking for ways to accelerate the usage of BLAS in Spark. With this PR, I integrate work done in the [`dev.ludovic.netlib`](https://github.com/luhenry/netlib/) Maven package.

The `dev.ludovic.netlib` library wraps the original `com.github.fommil.netlib` library and focus on accelerating the linear algebra routines in use in Spark. When running the `org.apache.spark.ml.linalg.BLASBenchmark` benchmarking suite, I get the results at [1] on an Intel machine. Moreover, this library is thoroughly tested to return the exact same results as the reference implementation.

Under the hood, it reimplements the necessary algorithms in pure autovectorization-friendly Java 8, as well as takes advantage of the Vector API and Foreign Linker API introduced in JDK 16 when available.

A table summarising which version gets loaded in which case:

```
|                       | BLAS.nativeBLAS                                    | BLAS.javaBLAS                                      |
| --------------------- | -------------------------------------------------- | -------------------------------------------------- |
| with -Pnetlib-lgpl    | 1. dev.ludovic.netlib.blas.NetlibNativeBLAS, a     | 1. dev.ludovic.netlib.blas.VectorizedBLAS          |
|                       |     wrapper for com.github.fommil:all              |    (JDK16+, relies on the Vector API, requires     |
|                       | 2. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+,    |     `--add-modules=jdk.incubator.vector` on JDK16) |
|                       |    relies on the Foreign Linker API, requires      | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+)     |
|                       |    `--add-modules=jdk.incubator.foreign            | 3. dev.ludovic.netlib.blas.JavaBLAS                |
|                       |     -Dforeign.restricted=warn`)                    | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a        |
|                       | 3. fails to load, falls back to BLAS.javaBLAS in   |     wrapper for com.github.fommil:core             |
|                       |     org.apache.spark.ml.linalg.BLAS                |                                                    |
| --------------------- | -------------------------------------------------- | -------------------------------------------------- |
| without -Pnetlib-lgpl | 1. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+,    | 1. dev.ludovic.netlib.blas.VectorizedBLAS          |
|                       |    relies on the Foreign Linker API, requires      |    (JDK16+, relies on the Vector API, requires     |
|                       |    `--add-modules=jdk.incubator.foreign            |     `--add-modules=jdk.incubator.vector` on JDK16) |
|                       |     -Dforeign.restricted=warn`)                    | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+)     |
|                       | 2. fails to load, falls back to BLAS.javaBLAS in   | 3. dev.ludovic.netlib.blas.JavaBLAS                |
|                       |     org.apache.spark.ml.linalg.BLAS                | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a        |
|                       |                                                    |     wrapper for com.github.fommil:core             |
| --------------------- | -------------------------------------------------- | -------------------------------------------------- |
```

### Why are the changes needed?

Accelerates linear algebra operations when the pure-java fallback method is in use. Transparently falls back to native implementation (OpenBLAS, MKL) when available.

### Does this PR introduce _any_ user-facing change?

No, all changes are transparent to the user.

### How was this patch tested?

The `dev.ludovic.netlib` library has its own test suite [2]. It has also been validated by running the Spark test suite and benchmarking suite.

[1] Results for `org.apache.spark.ml.linalg.BLASBenchmark`:
#### JDK8:
```
[info] OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.8.0-50-generic
[info] Intel(R) Xeon(R) E-2276G CPU  3.80GHz
[info]
[info] f2jBLAS    = dev.ludovic.netlib.blas.NetlibF2jBLAS
[info] javaBLAS   = dev.ludovic.netlib.blas.Java8BLAS
[info] nativeBLAS = dev.ludovic.netlib.blas.Java8BLAS
[info]
[info] daxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 223            232           8        448.0           2.2       1.0X
[info] java                                                221            228           7        453.0           2.2       1.0X
[info]
[info] saxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 122            128           4        821.2           1.2       1.0X
[info] java                                                122            128           4        822.3           1.2       1.0X
[info]
[info] ddot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 109            112           2        921.4           1.1       1.0X
[info] java                                                 70             74           3       1423.5           0.7       1.5X
[info]
[info] sdot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  96             98           2       1046.1           1.0       1.0X
[info] java                                                 47             49           2       2121.7           0.5       2.0X
[info]
[info] dscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 184            195           8        544.3           1.8       1.0X
[info] java                                                185            196           7        539.5           1.9       1.0X
[info]
[info] sscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  99            104           4       1011.9           1.0       1.0X
[info] java                                                 99            104           4       1010.4           1.0       1.0X
[info]
[info] dspmv[U]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        947.2           1.1       1.0X
[info] java                                                  0              0           0       1584.8           0.6       1.7X
[info]
[info] dspr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        867.4           1.2       1.0X
[info] java                                                  1              1           0        865.0           1.2       1.0X
[info]
[info] dsyr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        485.9           2.1       1.0X
[info] java                                                  1              1           0        486.8           2.1       1.0X
[info]
[info] dgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1843.0           0.5       1.0X
[info] java                                                  0              0           0       2690.6           0.4       1.5X
[info]
[info] dgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1214.7           0.8       1.0X
[info] java                                                  0              0           0       2536.8           0.4       2.1X
[info]
[info] sgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1895.9           0.5       1.0X
[info] java                                                  0              0           0       2961.1           0.3       1.6X
[info]
[info] sgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1223.4           0.8       1.0X
[info] java                                                  0              0           0       3091.4           0.3       2.5X
[info]
[info] dgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 560            575          20       1787.1           0.6       1.0X
[info] java                                                226            232           5       4432.4           0.2       2.5X
[info]
[info] dgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 570            586          23       1755.2           0.6       1.0X
[info] java                                                227            232           4       4410.1           0.2       2.5X
[info]
[info] dgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 863            879          17       1158.4           0.9       1.0X
[info] java                                                227            231           3       4407.9           0.2       3.8X
[info]
[info] dgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                1282           1305          23        780.0           1.3       1.0X
[info] java                                                227            232           4       4413.4           0.2       5.7X
[info]
[info] sgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 538            548           8       1858.6           0.5       1.0X
[info] java                                                221            226           3       4521.1           0.2       2.4X
[info]
[info] sgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 549            558          10       1819.9           0.5       1.0X
[info] java                                                222            229           7       4503.5           0.2       2.5X
[info]
[info] sgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 838            852          12       1193.0           0.8       1.0X
[info] java                                                222            229           5       4500.5           0.2       3.8X
[info]
[info] sgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 905            919          18       1104.8           0.9       1.0X
[info] java                                                221            228           5       4521.3           0.2       4.1X
```

#### JDK11:
```
[info] OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.8.0-50-generic
[info] Intel(R) Xeon(R) E-2276G CPU  3.80GHz
[info]
[info] f2jBLAS    = dev.ludovic.netlib.blas.NetlibF2jBLAS
[info] javaBLAS   = dev.ludovic.netlib.blas.Java11BLAS
[info] nativeBLAS = dev.ludovic.netlib.blas.Java11BLAS
[info]
[info] daxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 195            204          10        512.7           2.0       1.0X
[info] java                                                195            202           7        512.4           2.0       1.0X
[info]
[info] saxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 108            113           4        923.3           1.1       1.0X
[info] java                                                102            107           4        984.4           1.0       1.1X
[info]
[info] ddot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 107            110           3        938.1           1.1       1.0X
[info] java                                                 69             72           3       1447.1           0.7       1.5X
[info]
[info] sdot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  96             98           2       1046.5           1.0       1.0X
[info] java                                                 43             45           2       2317.1           0.4       2.2X
[info]
[info] dscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 155            168           8        644.2           1.6       1.0X
[info] java                                                158            169           8        632.8           1.6       1.0X
[info]
[info] sscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  85             90           4       1178.1           0.8       1.0X
[info] java                                                 86             90           4       1167.7           0.9       1.0X
[info]
[info] dspmv[U]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   0              0           0       1182.1           0.8       1.0X
[info] java                                                  0              0           0       1432.1           0.7       1.2X
[info]
[info] dspr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        898.7           1.1       1.0X
[info] java                                                  1              1           0        891.5           1.1       1.0X
[info]
[info] dsyr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        495.4           2.0       1.0X
[info] java                                                  1              1           0        495.7           2.0       1.0X
[info]
[info] dgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   0              0           0       2271.6           0.4       1.0X
[info] java                                                  0              0           0       3648.1           0.3       1.6X
[info]
[info] dgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1229.3           0.8       1.0X
[info] java                                                  0              0           0       2711.3           0.4       2.2X
[info]
[info] sgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   0              0           0       2677.5           0.4       1.0X
[info] java                                                  0              0           0       3288.2           0.3       1.2X
[info]
[info] sgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1233.0           0.8       1.0X
[info] java                                                  0              0           0       2766.3           0.4       2.2X
[info]
[info] dgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 520            536          16       1923.6           0.5       1.0X
[info] java                                                214            221           7       4669.5           0.2       2.4X
[info]
[info] dgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 593            612          17       1686.5           0.6       1.0X
[info] java                                                215            219           3       4643.3           0.2       2.8X
[info]
[info] dgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 853            870          16       1172.8           0.9       1.0X
[info] java                                                215            218           3       4659.7           0.2       4.0X
[info]
[info] dgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                1350           1370          23        740.8           1.3       1.0X
[info] java                                                215            219           4       4656.6           0.2       6.3X
[info]
[info] sgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 460            468           6       2173.2           0.5       1.0X
[info] java                                                210            213           2       4752.7           0.2       2.2X
[info]
[info] sgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 535            544           8       1869.3           0.5       1.0X
[info] java                                                210            215           5       4761.8           0.2       2.5X
[info]
[info] sgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 843            853          11       1186.8           0.8       1.0X
[info] java                                                209            214           4       4793.4           0.2       4.0X
[info]
[info] sgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 891            904          15       1122.0           0.9       1.0X
[info] java                                                209            214           4       4777.2           0.2       4.3X
```

#### JDK16:
```
[info] OpenJDK 64-Bit Server VM 16+36 on Linux 5.8.0-50-generic
[info] Intel(R) Xeon(R) E-2276G CPU  3.80GHz
[info]
[info] f2jBLAS    = dev.ludovic.netlib.blas.NetlibF2jBLAS
[info] javaBLAS   = dev.ludovic.netlib.blas.VectorizedBLAS
[info] nativeBLAS = dev.ludovic.netlib.blas.VectorizedBLAS
[info]
[info] daxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 194            199           7        515.7           1.9       1.0X
[info] java                                                181            186           3        551.1           1.8       1.1X
[info]
[info] saxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 109            115           4        915.0           1.1       1.0X
[info] java                                                 88             92           3       1138.8           0.9       1.2X
[info]
[info] ddot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 108            110           2        922.6           1.1       1.0X
[info] java                                                 54             56           2       1839.2           0.5       2.0X
[info]
[info] sdot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  96             97           2       1046.1           1.0       1.0X
[info] java                                                 29             30           1       3393.4           0.3       3.2X
[info]
[info] dscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 156            165           5        643.0           1.6       1.0X
[info] java                                                150            159           5        667.1           1.5       1.0X
[info]
[info] sscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  85             91           6       1171.0           0.9       1.0X
[info] java                                                 75             79           3       1340.6           0.7       1.1X
[info]
[info] dspmv[U]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        917.0           1.1       1.0X
[info] java                                                  0              0           0       8147.2           0.1       8.9X
[info]
[info] dspr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        859.3           1.2       1.0X
[info] java                                                  1              1           0        859.3           1.2       1.0X
[info]
[info] dsyr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        482.1           2.1       1.0X
[info] java                                                  1              1           0        482.6           2.1       1.0X
[info]
[info] dgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   0              0           0       2214.2           0.5       1.0X
[info] java                                                  0              0           0       7975.8           0.1       3.6X
[info]
[info] dgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1231.4           0.8       1.0X
[info] java                                                  0              0           0       8680.9           0.1       7.0X
[info]
[info] sgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   0              0           0       2684.3           0.4       1.0X
[info] java                                                  0              0           0      18527.1           0.1       6.9X
[info]
[info] sgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1235.4           0.8       1.0X
[info] java                                                  0              0           0      17347.9           0.1      14.0X
[info]
[info] dgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 530            552          18       1887.5           0.5       1.0X
[info] java                                                 58             64           3      17143.9           0.1       9.1X
[info]
[info] dgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 598            620          17       1671.1           0.6       1.0X
[info] java                                                 58             64           3      17196.6           0.1      10.3X
[info]
[info] dgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 834            847          14       1199.4           0.8       1.0X
[info] java                                                 57             63           4      17486.9           0.1      14.6X
[info]
[info] dgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                1338           1366          22        747.3           1.3       1.0X
[info] java                                                 58             63           3      17356.6           0.1      23.2X
[info]
[info] sgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 489            501           9       2045.5           0.5       1.0X
[info] java                                                 36             38           2      27721.9           0.0      13.6X
[info]
[info] sgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 478            488           9       2094.0           0.5       1.0X
[info] java                                                 36             38           2      27813.2           0.0      13.3X
[info]
[info] sgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 825            837          10       1211.6           0.8       1.0X
[info] java                                                 35             38           2      28433.1           0.0      23.5X
[info]
[info] sgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 900            918          15       1111.6           0.9       1.0X
[info] java                                                 36             38           2      28073.0           0.0      25.3X
```

[2] https://github.com/luhenry/netlib/tree/master/blas/src/test/java/dev/ludovic/netlib/blas

Closes #32253 from luhenry/master.

Authored-by: Ludovic Henry <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
a0x8o added a commit to a0x8o/spark that referenced this pull request Apr 27, 2021
### What changes were proposed in this pull request?

Following apache/spark#30810, I've continued looking for ways to accelerate the usage of BLAS in Spark. With this PR, I integrate work done in the [`dev.ludovic.netlib`](https://github.com/luhenry/netlib/) Maven package.

The `dev.ludovic.netlib` library wraps the original `com.github.fommil.netlib` library and focus on accelerating the linear algebra routines in use in Spark. When running the `org.apache.spark.ml.linalg.BLASBenchmark` benchmarking suite, I get the results at [1] on an Intel machine. Moreover, this library is thoroughly tested to return the exact same results as the reference implementation.

Under the hood, it reimplements the necessary algorithms in pure autovectorization-friendly Java 8, as well as takes advantage of the Vector API and Foreign Linker API introduced in JDK 16 when available.

A table summarising which version gets loaded in which case:

```
|                       | BLAS.nativeBLAS                                    | BLAS.javaBLAS                                      |
| --------------------- | -------------------------------------------------- | -------------------------------------------------- |
| with -Pnetlib-lgpl    | 1. dev.ludovic.netlib.blas.NetlibNativeBLAS, a     | 1. dev.ludovic.netlib.blas.VectorizedBLAS          |
|                       |     wrapper for com.github.fommil:all              |    (JDK16+, relies on the Vector API, requires     |
|                       | 2. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+,    |     `--add-modules=jdk.incubator.vector` on JDK16) |
|                       |    relies on the Foreign Linker API, requires      | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+)     |
|                       |    `--add-modules=jdk.incubator.foreign            | 3. dev.ludovic.netlib.blas.JavaBLAS                |
|                       |     -Dforeign.restricted=warn`)                    | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a        |
|                       | 3. fails to load, falls back to BLAS.javaBLAS in   |     wrapper for com.github.fommil:core             |
|                       |     org.apache.spark.ml.linalg.BLAS                |                                                    |
| --------------------- | -------------------------------------------------- | -------------------------------------------------- |
| without -Pnetlib-lgpl | 1. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+,    | 1. dev.ludovic.netlib.blas.VectorizedBLAS          |
|                       |    relies on the Foreign Linker API, requires      |    (JDK16+, relies on the Vector API, requires     |
|                       |    `--add-modules=jdk.incubator.foreign            |     `--add-modules=jdk.incubator.vector` on JDK16) |
|                       |     -Dforeign.restricted=warn`)                    | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+)     |
|                       | 2. fails to load, falls back to BLAS.javaBLAS in   | 3. dev.ludovic.netlib.blas.JavaBLAS                |
|                       |     org.apache.spark.ml.linalg.BLAS                | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a        |
|                       |                                                    |     wrapper for com.github.fommil:core             |
| --------------------- | -------------------------------------------------- | -------------------------------------------------- |
```

### Why are the changes needed?

Accelerates linear algebra operations when the pure-java fallback method is in use. Transparently falls back to native implementation (OpenBLAS, MKL) when available.

### Does this PR introduce _any_ user-facing change?

No, all changes are transparent to the user.

### How was this patch tested?

The `dev.ludovic.netlib` library has its own test suite [2]. It has also been validated by running the Spark test suite and benchmarking suite.

[1] Results for `org.apache.spark.ml.linalg.BLASBenchmark`:
#### JDK8:
```
[info] OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.8.0-50-generic
[info] Intel(R) Xeon(R) E-2276G CPU  3.80GHz
[info]
[info] f2jBLAS    = dev.ludovic.netlib.blas.NetlibF2jBLAS
[info] javaBLAS   = dev.ludovic.netlib.blas.Java8BLAS
[info] nativeBLAS = dev.ludovic.netlib.blas.Java8BLAS
[info]
[info] daxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 223            232           8        448.0           2.2       1.0X
[info] java                                                221            228           7        453.0           2.2       1.0X
[info]
[info] saxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 122            128           4        821.2           1.2       1.0X
[info] java                                                122            128           4        822.3           1.2       1.0X
[info]
[info] ddot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 109            112           2        921.4           1.1       1.0X
[info] java                                                 70             74           3       1423.5           0.7       1.5X
[info]
[info] sdot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  96             98           2       1046.1           1.0       1.0X
[info] java                                                 47             49           2       2121.7           0.5       2.0X
[info]
[info] dscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 184            195           8        544.3           1.8       1.0X
[info] java                                                185            196           7        539.5           1.9       1.0X
[info]
[info] sscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  99            104           4       1011.9           1.0       1.0X
[info] java                                                 99            104           4       1010.4           1.0       1.0X
[info]
[info] dspmv[U]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        947.2           1.1       1.0X
[info] java                                                  0              0           0       1584.8           0.6       1.7X
[info]
[info] dspr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        867.4           1.2       1.0X
[info] java                                                  1              1           0        865.0           1.2       1.0X
[info]
[info] dsyr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        485.9           2.1       1.0X
[info] java                                                  1              1           0        486.8           2.1       1.0X
[info]
[info] dgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1843.0           0.5       1.0X
[info] java                                                  0              0           0       2690.6           0.4       1.5X
[info]
[info] dgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1214.7           0.8       1.0X
[info] java                                                  0              0           0       2536.8           0.4       2.1X
[info]
[info] sgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1895.9           0.5       1.0X
[info] java                                                  0              0           0       2961.1           0.3       1.6X
[info]
[info] sgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1223.4           0.8       1.0X
[info] java                                                  0              0           0       3091.4           0.3       2.5X
[info]
[info] dgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 560            575          20       1787.1           0.6       1.0X
[info] java                                                226            232           5       4432.4           0.2       2.5X
[info]
[info] dgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 570            586          23       1755.2           0.6       1.0X
[info] java                                                227            232           4       4410.1           0.2       2.5X
[info]
[info] dgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 863            879          17       1158.4           0.9       1.0X
[info] java                                                227            231           3       4407.9           0.2       3.8X
[info]
[info] dgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                1282           1305          23        780.0           1.3       1.0X
[info] java                                                227            232           4       4413.4           0.2       5.7X
[info]
[info] sgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 538            548           8       1858.6           0.5       1.0X
[info] java                                                221            226           3       4521.1           0.2       2.4X
[info]
[info] sgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 549            558          10       1819.9           0.5       1.0X
[info] java                                                222            229           7       4503.5           0.2       2.5X
[info]
[info] sgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 838            852          12       1193.0           0.8       1.0X
[info] java                                                222            229           5       4500.5           0.2       3.8X
[info]
[info] sgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 905            919          18       1104.8           0.9       1.0X
[info] java                                                221            228           5       4521.3           0.2       4.1X
```

#### JDK11:
```
[info] OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.8.0-50-generic
[info] Intel(R) Xeon(R) E-2276G CPU  3.80GHz
[info]
[info] f2jBLAS    = dev.ludovic.netlib.blas.NetlibF2jBLAS
[info] javaBLAS   = dev.ludovic.netlib.blas.Java11BLAS
[info] nativeBLAS = dev.ludovic.netlib.blas.Java11BLAS
[info]
[info] daxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 195            204          10        512.7           2.0       1.0X
[info] java                                                195            202           7        512.4           2.0       1.0X
[info]
[info] saxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 108            113           4        923.3           1.1       1.0X
[info] java                                                102            107           4        984.4           1.0       1.1X
[info]
[info] ddot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 107            110           3        938.1           1.1       1.0X
[info] java                                                 69             72           3       1447.1           0.7       1.5X
[info]
[info] sdot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  96             98           2       1046.5           1.0       1.0X
[info] java                                                 43             45           2       2317.1           0.4       2.2X
[info]
[info] dscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 155            168           8        644.2           1.6       1.0X
[info] java                                                158            169           8        632.8           1.6       1.0X
[info]
[info] sscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  85             90           4       1178.1           0.8       1.0X
[info] java                                                 86             90           4       1167.7           0.9       1.0X
[info]
[info] dspmv[U]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   0              0           0       1182.1           0.8       1.0X
[info] java                                                  0              0           0       1432.1           0.7       1.2X
[info]
[info] dspr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        898.7           1.1       1.0X
[info] java                                                  1              1           0        891.5           1.1       1.0X
[info]
[info] dsyr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        495.4           2.0       1.0X
[info] java                                                  1              1           0        495.7           2.0       1.0X
[info]
[info] dgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   0              0           0       2271.6           0.4       1.0X
[info] java                                                  0              0           0       3648.1           0.3       1.6X
[info]
[info] dgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1229.3           0.8       1.0X
[info] java                                                  0              0           0       2711.3           0.4       2.2X
[info]
[info] sgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   0              0           0       2677.5           0.4       1.0X
[info] java                                                  0              0           0       3288.2           0.3       1.2X
[info]
[info] sgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1233.0           0.8       1.0X
[info] java                                                  0              0           0       2766.3           0.4       2.2X
[info]
[info] dgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 520            536          16       1923.6           0.5       1.0X
[info] java                                                214            221           7       4669.5           0.2       2.4X
[info]
[info] dgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 593            612          17       1686.5           0.6       1.0X
[info] java                                                215            219           3       4643.3           0.2       2.8X
[info]
[info] dgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 853            870          16       1172.8           0.9       1.0X
[info] java                                                215            218           3       4659.7           0.2       4.0X
[info]
[info] dgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                1350           1370          23        740.8           1.3       1.0X
[info] java                                                215            219           4       4656.6           0.2       6.3X
[info]
[info] sgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 460            468           6       2173.2           0.5       1.0X
[info] java                                                210            213           2       4752.7           0.2       2.2X
[info]
[info] sgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 535            544           8       1869.3           0.5       1.0X
[info] java                                                210            215           5       4761.8           0.2       2.5X
[info]
[info] sgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 843            853          11       1186.8           0.8       1.0X
[info] java                                                209            214           4       4793.4           0.2       4.0X
[info]
[info] sgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 891            904          15       1122.0           0.9       1.0X
[info] java                                                209            214           4       4777.2           0.2       4.3X
```

#### JDK16:
```
[info] OpenJDK 64-Bit Server VM 16+36 on Linux 5.8.0-50-generic
[info] Intel(R) Xeon(R) E-2276G CPU  3.80GHz
[info]
[info] f2jBLAS    = dev.ludovic.netlib.blas.NetlibF2jBLAS
[info] javaBLAS   = dev.ludovic.netlib.blas.VectorizedBLAS
[info] nativeBLAS = dev.ludovic.netlib.blas.VectorizedBLAS
[info]
[info] daxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 194            199           7        515.7           1.9       1.0X
[info] java                                                181            186           3        551.1           1.8       1.1X
[info]
[info] saxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 109            115           4        915.0           1.1       1.0X
[info] java                                                 88             92           3       1138.8           0.9       1.2X
[info]
[info] ddot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 108            110           2        922.6           1.1       1.0X
[info] java                                                 54             56           2       1839.2           0.5       2.0X
[info]
[info] sdot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  96             97           2       1046.1           1.0       1.0X
[info] java                                                 29             30           1       3393.4           0.3       3.2X
[info]
[info] dscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 156            165           5        643.0           1.6       1.0X
[info] java                                                150            159           5        667.1           1.5       1.0X
[info]
[info] sscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  85             91           6       1171.0           0.9       1.0X
[info] java                                                 75             79           3       1340.6           0.7       1.1X
[info]
[info] dspmv[U]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        917.0           1.1       1.0X
[info] java                                                  0              0           0       8147.2           0.1       8.9X
[info]
[info] dspr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        859.3           1.2       1.0X
[info] java                                                  1              1           0        859.3           1.2       1.0X
[info]
[info] dsyr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        482.1           2.1       1.0X
[info] java                                                  1              1           0        482.6           2.1       1.0X
[info]
[info] dgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   0              0           0       2214.2           0.5       1.0X
[info] java                                                  0              0           0       7975.8           0.1       3.6X
[info]
[info] dgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1231.4           0.8       1.0X
[info] java                                                  0              0           0       8680.9           0.1       7.0X
[info]
[info] sgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   0              0           0       2684.3           0.4       1.0X
[info] java                                                  0              0           0      18527.1           0.1       6.9X
[info]
[info] sgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1235.4           0.8       1.0X
[info] java                                                  0              0           0      17347.9           0.1      14.0X
[info]
[info] dgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 530            552          18       1887.5           0.5       1.0X
[info] java                                                 58             64           3      17143.9           0.1       9.1X
[info]
[info] dgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 598            620          17       1671.1           0.6       1.0X
[info] java                                                 58             64           3      17196.6           0.1      10.3X
[info]
[info] dgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 834            847          14       1199.4           0.8       1.0X
[info] java                                                 57             63           4      17486.9           0.1      14.6X
[info]
[info] dgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                1338           1366          22        747.3           1.3       1.0X
[info] java                                                 58             63           3      17356.6           0.1      23.2X
[info]
[info] sgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 489            501           9       2045.5           0.5       1.0X
[info] java                                                 36             38           2      27721.9           0.0      13.6X
[info]
[info] sgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 478            488           9       2094.0           0.5       1.0X
[info] java                                                 36             38           2      27813.2           0.0      13.3X
[info]
[info] sgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 825            837          10       1211.6           0.8       1.0X
[info] java                                                 35             38           2      28433.1           0.0      23.5X
[info]
[info] sgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 900            918          15       1111.6           0.9       1.0X
[info] java                                                 36             38           2      28073.0           0.0      25.3X
```

[2] https://github.com/luhenry/netlib/tree/master/blas/src/test/java/dev/ludovic/netlib/blas

Closes #32253 from luhenry/master.

Authored-by: Ludovic Henry <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
@fommil
Copy link
Contributor

fommil commented May 5, 2021

the f2j backend in spark is provided via netlib-java which can swap to using machine optimised binaries if they are present. There are reasons not to use netlib-java (and I no longer recommend it, preferring direct handcoded access to machine optimised libblas) so it's good to see alternatives being proposed.

However, it is strange that you're not providing benchmarks against machine-optimised MKL (or otherwise) backends as described in http://fommil.com/scalax14/#/ The f2j backend is just the fallback and could be replaced in the most part (e.g. dgemm) with 10 lines of java code.

Lots of benchmarks over at https://github.com/fommil/matrix-toolkits-java as another consumer of netlib-java that address your comments about "memory copying" (which is incorrect).

@luhenry
Copy link
Contributor Author

luhenry commented May 5, 2021

However, it is strange that you're not providing benchmarks against machine-optimised MKL (or otherwise) backends as described in http://fommil.com/scalax14/#/

You can find the comparison to machine-optimised MKL and OpenBLAS in #32415. There, the javaBLAS on JDK 16 is using the Vector API. You can see the performance is equivalent or better on most level-1 and level-2 operations but still lacks on level-3 operations (there is a write-up why at https://mail.openjdk.java.net/pipermail/panama-dev/2021-January/011930.html).

The f2j backend is just the fallback and could be replaced in the most part (e.g. dgemm) with 10 lines of java code.

Yep, agreed. That's what I did for some operations in javaBLAS. I get some decent speedups over f2j for some of the operations and even matching the native implementations in some cases.

Lots of benchmarks over at https://github.com/fommil/matrix-toolkits-java as another consumer of netlib-java that address your comments about "memory copying" (which is incorrect).

I'll definitely look into these higher-level benchmarks, thanks for the pointer!

And you're right, my comment about memory copying was based on the wrong assumption that JNI doesn't support passing java heap memory to native libraries without copying. By using Get/ReleasePrimitiveArrayCritical, you can do just that. That is however something that is not possible with the Foreign Linker API without using aMemorySegment backed by native memory, and which would be very invasive for Spark (and not compatible with Java 8 nor 11).

A future avenue of improvement will be to support Sparse Matrix/Vector operations in dev.ludovic.netlib directly so that we can take advantage of native libraries that do provide such routines, or fall back to a hand-optimized pure-java fallback.

dongjoon-hyun pushed a commit that referenced this pull request Apr 28, 2024
…`JavaModuleOptions`

### What changes were proposed in this pull request?
The pr aims to:
- add `--add-modules=jdk.incubator.vector` to `JavaModuleOptions`
- remove `jdk.incubator.foreign` and `-Dforeign.restricted=warn` from `SparkBuild.scala`

### Why are the changes needed?
1.`jdk.incubator.vector`
First introduction: #30810
https://github.com/apache/spark/pull/30810/files#diff-6f545c33f2fcc975200bf208c900a600a593ce6b170180f81e2f93b3efb6cb3e
<img width="1045" alt="image" src="https://github.com/apache/spark/assets/15246973/6ac7919a-5d82-475c-b8a2-7d9de71acacc">

Why should we add `--add-modules=jdk.incubator.vector` to `JavaModuleOptions`,
Because when we only add `--add-modules=jdk.incubator.vector` to `SparkBuild.scala`, it will only take effect when compiling, as follows:
```
build/sbt "mllib-local/Test/runMain org.apache.spark.ml.linalg.BLASBenchmark"
...
```
<img width="619" alt="image" src="https://github.com/apache/spark/assets/15246973/54d5f55f-cefe-4126-b255-69488f8699a6">

However, when we use `spark-submit`, it is as follows:
```
./bin/spark-submit --class org.apache.spark.ml.linalg.BLASBenchmark /Users/panbingkun/Developer/spark/spark-community/mllib-local/target/scala-2.13/spark-mllib-local_2.13-4.0.0-SNAPSHOT-tests.jar
```
<img width="1399" alt="image" src="https://github.com/apache/spark/assets/15246973/8e02fa93-fef4-4cdc-96bd-908b3e9baea1">

Obviously, `--add-modules=jdk.incubator.vector` does not take effect in the `Spark runtime`, so I propose adding `--add-modules=jdk.incubator.vector` to the `JavaModuleOptions`(`Spark runtime options`) so that we can improve `performance` by using `hardware-accelerated BLAS operations` by default.

After this patch(add `--add-modules=jdk.incubator.vector` to the `JavaModuleOptions`), as follows:
<img width="1399" alt="image" src="https://github.com/apache/spark/assets/15246973/da7aa494-0d3c-4c60-9991-e7cd29a1cec5">

2.`jdk.incubator.foreign` and `-Dforeign.restricted=warn`
A.First introduction: #32253
https://github.com/apache/spark/pull/32253/files#diff-6f545c33f2fcc975200bf208c900a600a593ce6b170180f81e2f93b3efb6cb3e
<img width="1041" alt="image" src="https://github.com/apache/spark/assets/15246973/3f526019-c389-4e60-ab2a-7777f8e99cfb">
Use `dev.ludovic.netlib:blas:1.3.2`, the class `ForeignLinkerBLAS` uses `jdk.incubator.foreign.*` in this version, so we need to add `jdk.incubator.foreign` and `-Dforeign.restricted=warn` to `SparkBuild.scala`
https://github.com/apache/spark/pull/32253/files#diff-9c5fb3d1b7e3b0f54bc5c4182965c4fe1f9023d449017cece3005d3f90e8e4d8
<img width="497" alt="image" src="https://github.com/apache/spark/assets/15246973/4fd35e96-0da2-4456-a3f6-6b57ad2e9b64">
https://github.com/luhenry/netlib/blob/v1.3.2/blas/src/main/java/dev/ludovic/netlib/blas/ForeignLinkerBLAS.java#L36
<img width="743" alt="image" src="https://github.com/apache/spark/assets/15246973/4b7e3bd1-4650-4c7d-bdb4-c1761d48d478">

However, with the iterative development of `dev.ludovic.netlib`, `ForeignLinkerBLAS` has experienced one `major` change, as following:
luhenry/netlib@48e923c
<img width="452" alt="image" src="https://github.com/apache/spark/assets/15246973/7ba30b19-00c7-4cc4-bea7-a6ab4b326ad8">
From now on (V3.0.0), `jdk.incubator.foreign.*` will not be used in `dev.ludovic.netlib`

Currently, Spark has used the `dev.ludovic.netlib` of version `v3.0.3`. In this version, `ForeignLinkerBLAS` has be removed.
https://github.com/apache/spark/blob/master/pom.xml#L191

Double check (`jdk.incubator.foreign` cannot be found in the `netlib` source code):
<img width="674" alt="image" src="https://github.com/apache/spark/assets/15246973/5c6c6d73-6a5d-427a-9fb4-f626f02335ca">

So we can completely remove options `jdk.incubator.foreign` and `-Dforeign.restricted=warn`.

B.For JDK 21
(PS: This is to explain the historical reasons for the differences between the current code logic and the initial ones)
(Just because `Spark` made changes to support `JDK 21`)
https://issues.apache.org/jira/browse/SPARK-44088
<img width="1350" alt="image" src="https://github.com/apache/spark/assets/15246973/34e7e7e8-4e72-470e-abc0-d79406ad25e5">

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
- Manually test
- Pass GA.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #46246 from panbingkun/test_spark_build.

Authored-by: panbingkun <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
JacobZheng0927 pushed a commit to JacobZheng0927/spark that referenced this pull request May 11, 2024
…`JavaModuleOptions`

### What changes were proposed in this pull request?
The pr aims to:
- add `--add-modules=jdk.incubator.vector` to `JavaModuleOptions`
- remove `jdk.incubator.foreign` and `-Dforeign.restricted=warn` from `SparkBuild.scala`

### Why are the changes needed?
1.`jdk.incubator.vector`
First introduction: apache#30810
https://github.com/apache/spark/pull/30810/files#diff-6f545c33f2fcc975200bf208c900a600a593ce6b170180f81e2f93b3efb6cb3e
<img width="1045" alt="image" src="https://github.com/apache/spark/assets/15246973/6ac7919a-5d82-475c-b8a2-7d9de71acacc">

Why should we add `--add-modules=jdk.incubator.vector` to `JavaModuleOptions`,
Because when we only add `--add-modules=jdk.incubator.vector` to `SparkBuild.scala`, it will only take effect when compiling, as follows:
```
build/sbt "mllib-local/Test/runMain org.apache.spark.ml.linalg.BLASBenchmark"
...
```
<img width="619" alt="image" src="https://github.com/apache/spark/assets/15246973/54d5f55f-cefe-4126-b255-69488f8699a6">

However, when we use `spark-submit`, it is as follows:
```
./bin/spark-submit --class org.apache.spark.ml.linalg.BLASBenchmark /Users/panbingkun/Developer/spark/spark-community/mllib-local/target/scala-2.13/spark-mllib-local_2.13-4.0.0-SNAPSHOT-tests.jar
```
<img width="1399" alt="image" src="https://github.com/apache/spark/assets/15246973/8e02fa93-fef4-4cdc-96bd-908b3e9baea1">

Obviously, `--add-modules=jdk.incubator.vector` does not take effect in the `Spark runtime`, so I propose adding `--add-modules=jdk.incubator.vector` to the `JavaModuleOptions`(`Spark runtime options`) so that we can improve `performance` by using `hardware-accelerated BLAS operations` by default.

After this patch(add `--add-modules=jdk.incubator.vector` to the `JavaModuleOptions`), as follows:
<img width="1399" alt="image" src="https://github.com/apache/spark/assets/15246973/da7aa494-0d3c-4c60-9991-e7cd29a1cec5">

2.`jdk.incubator.foreign` and `-Dforeign.restricted=warn`
A.First introduction: apache#32253
https://github.com/apache/spark/pull/32253/files#diff-6f545c33f2fcc975200bf208c900a600a593ce6b170180f81e2f93b3efb6cb3e
<img width="1041" alt="image" src="https://github.com/apache/spark/assets/15246973/3f526019-c389-4e60-ab2a-7777f8e99cfb">
Use `dev.ludovic.netlib:blas:1.3.2`, the class `ForeignLinkerBLAS` uses `jdk.incubator.foreign.*` in this version, so we need to add `jdk.incubator.foreign` and `-Dforeign.restricted=warn` to `SparkBuild.scala`
https://github.com/apache/spark/pull/32253/files#diff-9c5fb3d1b7e3b0f54bc5c4182965c4fe1f9023d449017cece3005d3f90e8e4d8
<img width="497" alt="image" src="https://github.com/apache/spark/assets/15246973/4fd35e96-0da2-4456-a3f6-6b57ad2e9b64">
https://github.com/luhenry/netlib/blob/v1.3.2/blas/src/main/java/dev/ludovic/netlib/blas/ForeignLinkerBLAS.java#L36
<img width="743" alt="image" src="https://github.com/apache/spark/assets/15246973/4b7e3bd1-4650-4c7d-bdb4-c1761d48d478">

However, with the iterative development of `dev.ludovic.netlib`, `ForeignLinkerBLAS` has experienced one `major` change, as following:
luhenry/netlib@48e923c
<img width="452" alt="image" src="https://github.com/apache/spark/assets/15246973/7ba30b19-00c7-4cc4-bea7-a6ab4b326ad8">
From now on (V3.0.0), `jdk.incubator.foreign.*` will not be used in `dev.ludovic.netlib`

Currently, Spark has used the `dev.ludovic.netlib` of version `v3.0.3`. In this version, `ForeignLinkerBLAS` has be removed.
https://github.com/apache/spark/blob/master/pom.xml#L191

Double check (`jdk.incubator.foreign` cannot be found in the `netlib` source code):
<img width="674" alt="image" src="https://github.com/apache/spark/assets/15246973/5c6c6d73-6a5d-427a-9fb4-f626f02335ca">

So we can completely remove options `jdk.incubator.foreign` and `-Dforeign.restricted=warn`.

B.For JDK 21
(PS: This is to explain the historical reasons for the differences between the current code logic and the initial ones)
(Just because `Spark` made changes to support `JDK 21`)
https://issues.apache.org/jira/browse/SPARK-44088
<img width="1350" alt="image" src="https://github.com/apache/spark/assets/15246973/34e7e7e8-4e72-470e-abc0-d79406ad25e5">

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
- Manually test
- Pass GA.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#46246 from panbingkun/test_spark_build.

Authored-by: panbingkun <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants