Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33882][ML] Add a vectorized BLAS implementation #30810

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions mllib-local/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-tags_${scala.binary.version}</artifactId>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
luhenry marked this conversation as resolved.
Show resolved Hide resolved
<version>${project.version}</version>
</dependency>

<!--
This spark-tags test-dep is needed even though it isn't used in this module, otherwise testing-cmds that exclude
Expand All @@ -68,6 +73,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
srowen marked this conversation as resolved.
Show resolved Hide resolved
<scope>test</scope>
</dependency>
</dependencies>
<profiles>
<profile>
Expand All @@ -81,6 +93,34 @@
</dependency>
</dependencies>
</profile>
<profile>
<id>vectorized</id>
<properties>
<extra.source.dir>src/vectorized/java</extra.source.dir>
luhenry marked this conversation as resolved.
Show resolved Hide resolved
</properties>
<build>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<executions>
<execution>
<id>add-vectorized-sources</id>
<phase>generate-sources</phase>
<goals>
<goal>add-source</goal>
</goals>
<configuration>
<sources>
<source>${extra.source.dir}</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
</profiles>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
Expand Down
46 changes: 28 additions & 18 deletions mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,46 @@
package org.apache.spark.ml.linalg

import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS}
import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}
import scala.util.Try

import org.apache.spark.util.Utils

/**
* BLAS routines for MLlib's vectors and matrices.
*/
private[spark] object BLAS extends Serializable {

@transient private var _f2jBLAS: NetlibBLAS = _
@transient private var _javaBLAS: NetlibBLAS = _
@transient private var _nativeBLAS: NetlibBLAS = _
private val nativeL1Threshold: Int = 256
luhenry marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to the performance test, I think we can increase nativeL1Threshold to 512?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would go even as far as using nativeBLAS exclusively for level-3 operations, and never for level-1 and level-2. The cost of copying the data from managed memory to native memory (necessary to pass the array to native code) is too great relative to the small speed up of native for the level-1 and level-2 routines.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

netlib-java does not copy memory when using native backend, it uses memory pinning (which has its own problems). Please provide benchmarks to show any degradation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"small speed up of native for the level-1 and level-2 routines." I think you need to do some more analysis on this. Native can be 10x faster than JVM for reasonable sized matrices. However, as shown in https://github.com/fommil/matrix-toolkits-java the EJML and common-math project are faster for matrices of 10x10 or smaller. If you want to heavily optimise for those usecases, then swap to using EJML which is heavily optimised for that usecase (not just "something on the JVM")


// For level-1 function dspmv, use f2jBLAS for better performance.
private[ml] def f2jBLAS: NetlibBLAS = {
if (_f2jBLAS == null) {
_f2jBLAS = new F2jBLAS
// For level-1 function dspmv, use javaBLAS for better performance.
private[ml] def javaBLAS: NetlibBLAS = {
if (_javaBLAS == null) {
_javaBLAS = Try(
Utils.classForName("org.apache.spark.ml.linalg.VectorizedBLAS")
.newInstance()
.asInstanceOf[NetlibBLAS]).getOrElse(new F2jBLAS)
}
_javaBLAS
}

// For level-3 routines, we use the native BLAS.
private[ml] def nativeBLAS: NetlibBLAS = {
if (_nativeBLAS == null) {
_nativeBLAS =
if (NetlibBLAS.getInstance.isInstanceOf[F2jBLAS]) {
javaBLAS
} else {
NetlibBLAS.getInstance
luhenry marked this conversation as resolved.
Show resolved Hide resolved
}
}
_f2jBLAS
_nativeBLAS
}

private[ml] def getBLAS(vectorSize: Int): NetlibBLAS = {
if (vectorSize < nativeL1Threshold) {
f2jBLAS
javaBLAS
luhenry marked this conversation as resolved.
Show resolved Hide resolved
} else {
nativeBLAS
}
Expand Down Expand Up @@ -235,14 +253,6 @@ private[spark] object BLAS extends Serializable {
}
}

// For level-3 routines, we use the native BLAS.
private[ml] def nativeBLAS: NetlibBLAS = {
if (_nativeBLAS == null) {
_nativeBLAS = NativeBLAS
}
_nativeBLAS
}

/**
* Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR.
*
Expand All @@ -267,7 +277,7 @@ private[spark] object BLAS extends Serializable {
x: DenseVector,
beta: Double,
y: DenseVector): Unit = {
f2jBLAS.dspmv("U", n, alpha, A.values, x.values, 1, beta, y.values, 1)
javaBLAS.dspmv("U", n, alpha, A.values, x.values, 1, beta, y.values, 1)
srowen marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand All @@ -279,7 +289,7 @@ private[spark] object BLAS extends Serializable {
val n = v.size
v match {
case DenseVector(values) =>
NativeBLAS.dspr("U", n, alpha, values, 1, U)
nativeBLAS.dspr("U", n, alpha, values, 1, U)
case SparseVector(size, indices, values) =>
val nnz = indices.length
var colStartIdx = 0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.linalg

import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS}
import scala.util.Try

import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
import org.apache.spark.util.Utils

/**
* Serialization benchmark for BLAS.
* To run this benchmark:
* {{{
* 1. without sbt: bin/spark-submit --class <this class> <spark mllib test jar>
* 2. build/sbt "mllib/test:runMain <this class>"
luhenry marked this conversation as resolved.
Show resolved Hide resolved
* 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "mllib/test:runMain <this class>"
* Results will be written to "benchmarks/UDTSerializationBenchmark-results.txt".
luhenry marked this conversation as resolved.
Show resolved Hide resolved
* }}}
*/
object BLASBenchmark extends BenchmarkBase {
luhenry marked this conversation as resolved.
Show resolved Hide resolved

override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {

val iters = 1e2.toInt
val rnd = new scala.util.Random(0)

val f2jBLAS = new F2jBLAS
val nativeBLAS = NetlibBLAS.getInstance
val vectorBLAS = Try(
Utils.classForName("org.apache.spark.ml.linalg.VectorizedBLAS")
.newInstance()
.asInstanceOf[NetlibBLAS]).getOrElse(new F2jBLAS)

// scalastyle:off println
println("nativeBLAS = " + nativeBLAS.getClass.getName)
println("f2jBLAS = " + f2jBLAS.getClass.getName)
println("vectorBLAS = " + vectorBLAS.getClass.getName)
// scalastyle:on println

runBenchmark("daxpy") {
val n = 1e7.toInt
val a = rnd.nextDouble
val x = Array.fill(n) { rnd.nextDouble }
val y = Array.fill(n) { rnd.nextDouble }

val benchmark = new Benchmark("daxpy", n, iters, output = output)

benchmark.addCase("f2j") { _ =>
f2jBLAS.daxpy(n, a, x, 1, y, 1)
}

if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) {
benchmark.addCase("native") { _ =>
nativeBLAS.daxpy(n, a, x, 1, y, 1)
}
}

if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) {
benchmark.addCase("vector") { _ =>
vectorBLAS.daxpy(n, a, x, 1, y, 1)
}
}

benchmark.run()
}

runBenchmark("sdot") {
val n = 1e7.toInt
val x = Array.fill(n) { rnd.nextFloat }
val y = Array.fill(n) { rnd.nextFloat }

val benchmark = new Benchmark("sdot", n, iters, output = output)

benchmark.addCase("f2j") { _ =>
f2jBLAS.sdot(n, x, 1, y, 1)
}

if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) {
benchmark.addCase("native") { _ =>
nativeBLAS.sdot(n, x, 1, y, 1)
}
}

if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) {
benchmark.addCase("vector") { _ =>
vectorBLAS.sdot(n, x, 1, y, 1)
}
}

benchmark.run()
}

runBenchmark("ddot") {
val n = 1e7.toInt
val x = Array.fill(n) { rnd.nextDouble }
val y = Array.fill(n) { rnd.nextDouble }

val benchmark = new Benchmark("ddot", n, iters, output = output)

benchmark.addCase("f2j") { _ =>
f2jBLAS.ddot(n, x, 1, y, 1)
}

if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) {
benchmark.addCase("native") { _ =>
nativeBLAS.ddot(n, x, 1, y, 1)
}
}

if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) {
benchmark.addCase("vector") { _ =>
vectorBLAS.ddot(n, x, 1, y, 1)
}
}

benchmark.run()
}

runBenchmark("dscal") {
val n = 1e7.toInt
val a = rnd.nextDouble
val x = Array.fill(n) { rnd.nextDouble }

val benchmark = new Benchmark("dscal", n, iters, output = output)

benchmark.addCase("f2j") { _ =>
f2jBLAS.dscal(n, a, x, 1)
}

if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) {
benchmark.addCase("native") { _ =>
nativeBLAS.dscal(n, a, x, 1)
}
}

if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) {
benchmark.addCase("vector") { _ =>
vectorBLAS.dscal(n, a, x, 1)
}
}

benchmark.run()
}

runBenchmark("dgemv[T]") {
val m = 1e4.toInt
val n = 1e3.toInt
val alpha = rnd.nextDouble
val a = Array.fill(n * m) { rnd.nextDouble }
val lda = m
val x = Array.fill(m) { rnd.nextDouble }
val beta = rnd.nextDouble
val y = Array.fill(n) { rnd.nextDouble }

val benchmark = new Benchmark("dgemv[T]", n, iters, output = output)

benchmark.addCase("f2j") { _ =>
f2jBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1)
}

if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) {
benchmark.addCase("native") { _ =>
nativeBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1)
}
}

if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) {
benchmark.addCase("vector") { _ =>
vectorBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1)
}
}

benchmark.run()
}

runBenchmark("dgemm[T,N]") {
val m = 1e3.toInt
val n = 1e2.toInt
val k = 1e3.toInt
val alpha = rnd.nextDouble
val a = Array.fill(k * m) { rnd.nextDouble }
val lda = k
val b = Array.fill(n * k) { rnd.nextDouble }
val ldb = k
val beta = rnd.nextDouble
val c = Array.fill(m * n) { rnd.nextDouble }
var ldc = m

val benchmark = new Benchmark("dgemm[T,N]", m*n, iters, output = output)

benchmark.addCase("f2j") { _ =>
f2jBLAS.dgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
}

if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) {
benchmark.addCase("native") { _ =>
nativeBLAS.dgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
}
}

if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) {
benchmark.addCase("vector") { _ =>
vectorBLAS.dgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
}
}

benchmark.run()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.ml.util.TestingUtils._
class BLASSuite extends SparkMLFunSuite {

test("nativeL1Threshold") {
assert(getBLAS(128) == BLAS.f2jBLAS)
assert(getBLAS(128) == BLAS.javaBLAS)
assert(getBLAS(256) == BLAS.nativeBLAS)
assert(getBLAS(512) == BLAS.nativeBLAS)
}
Expand Down
Loading