Skip to content

Commit

Permalink
[spark] Use batch predict in spark
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Apr 16, 2023
1 parent cef59a4 commit 412e906
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ abstract class BasePredictor[A, B](override val uid: String) extends Transformer
def this() = this(Identifiable.randomUID("BasePredictor"))

final val engine = new Param[String](this, "engine", "The engine")
final val batchSize = new Param[Int](this, "batchSize", "The batch size")
final val modelUrl = new Param[String](this, "modelUrl", "The model URL")
final val inputClass = new Param[Class[A]](this, "inputClass", "The input class")
final val outputClass = new Param[Class[B]](this, "outputClass", "The output class")
Expand Down Expand Up @@ -85,6 +86,7 @@ abstract class BasePredictor[A, B](override val uid: String) extends Transformer
def setTranslatorFactory(value: TranslatorFactory): this.type = set(translatorFactory, value)

setDefault(engine, null)
setDefault(batchSize, 10)
setDefault(modelUrl, null)

/** @inheritdoc */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
*/
package ai.djl.spark.task.audio

import ai.djl.modality.audio.AudioFactory
import ai.djl.modality.audio.{Audio, AudioFactory}
import ai.djl.modality.audio.translator.SpeechRecognitionTranslatorFactory
import org.apache.spark.ml.param.IntParam
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
Expand All @@ -21,6 +21,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.types.{BinaryType, StringType, StructField, StructType}

import java.io.ByteArrayInputStream
import scala.jdk.CollectionConverters.asScalaBufferConverter

/**
* SpeechRecognizer performs speech recognition on audio.
Expand Down Expand Up @@ -92,29 +93,30 @@ class SpeechRecognizer(override val uid: String) extends BaseAudioPredictor[Stri
super.transform(dataset)
}

/**
* Transforms the rows.
*
* @param iter the rows to transform
* @return the transformed rows
*/
/** @inheritdoc */
override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = {
val predictor = model.newPredictor()
iter.map(row => {
val data = row.getAs[Array[Byte]](inputColIndex)
val audioFactory = AudioFactory.newInstance
if (isDefined(channels)) {
audioFactory.setChannels($(channels))
}
if (isDefined(sampleRate)) {
audioFactory.setSampleRate($(sampleRate))
val audioFactory = AudioFactory.newInstance
if (isDefined(channels)) {
audioFactory.setChannels($(channels))
}
if (isDefined(sampleRate)) {
audioFactory.setSampleRate($(sampleRate))
}
if (isDefined(sampleFormat)) {
audioFactory.setSampleFormat($(sampleFormat))
}
iter.grouped($(batchSize)).flatMap { batch =>
val inputs = new java.util.ArrayList[Audio]()
batch.foreach { row =>
val data = row.getAs[Array[Byte]](inputColIndex)
inputs.add(audioFactory.fromInputStream(new ByteArrayInputStream(data)))
}
if (isDefined(sampleFormat)) {
audioFactory.setSampleFormat($(sampleFormat))
val output = predictor.batchPredict(inputs).asScala
batch.zip(output).map { case (row, out) =>
Row.fromSeq(row.toSeq :+ out)
}
val audio = audioFactory.fromInputStream(new ByteArrayInputStream(data))
Row.fromSeq(row.toSeq :+ predictor.predict(audio))
})
}
}

/** @inheritdoc */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}

import scala.jdk.CollectionConverters.asScalaBufferConverter

/**
* BinaryPredictor performs prediction on binary input.
*
Expand Down Expand Up @@ -61,17 +63,23 @@ class BinaryPredictor(override val uid: String) extends BasePredictor[Array[Byte

/** @inheritdoc */
override def transform(dataset: Dataset[_]): DataFrame = {
arguments.put("batchifier", $(batchifier))
inputColIndex = dataset.schema.fieldIndex($(inputCol))
super.transform(dataset)
}

/** @inheritdoc */
override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = {
val predictor = model.newPredictor()
iter.map(row => {
Row.fromSeq(row.toSeq :+ predictor.predict(row.getAs[Array[Byte]](inputColIndex)))
})
iter.grouped($(batchSize)).flatMap { batch =>
val inputs = new java.util.ArrayList[Array[Byte]]()
batch.foreach { row =>
inputs.add(row.getAs[Array[Byte]](inputColIndex))
}
val output = predictor.batchPredict(inputs).asScala
batch.zip(output).map { case (row, out) =>
Row.fromSeq(row.toSeq :+ out)
}
}
}

/** @inheritdoc */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ import org.apache.spark.ml.util.Identifiable
abstract class BaseTextPredictor[A, B](override val uid: String) extends BasePredictor[A, B] {

def this() = this(Identifiable.randomUID("BaseTextPredictor"))

setDefault(batchSize, 100)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row}
*
* @param uid An immutable unique ID for the object and its derivatives.
*/
class QuestionAnswerer(override val uid: String) extends BaseTextPredictor[QAInput, String]
class QuestionAnswerer(override val uid: String) extends BaseTextPredictor[Array[QAInput], Array[String]]
with HasInputCols with HasOutputCol {

def this() = this(Identifiable.randomUID("QuestionAnswerer"))
Expand All @@ -46,8 +46,8 @@ class QuestionAnswerer(override val uid: String) extends BaseTextPredictor[QAInp
*/
def setOutputCol(value: String): this.type = set(outputCol, value)

setDefault(inputClass, classOf[QAInput])
setDefault(outputClass, classOf[String])
setDefault(inputClass, classOf[Array[QAInput]])
setDefault(outputClass, classOf[Array[String]])
setDefault(translatorFactory, new QuestionAnsweringTranslatorFactory())

/**
Expand All @@ -70,10 +70,14 @@ class QuestionAnswerer(override val uid: String) extends BaseTextPredictor[QAInp
/** @inheritdoc */
override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = {
val predictor = model.newPredictor()
iter.map(row => {
Row.fromSeq(row.toSeq :+ predictor.predict(new QAInput(row.getString(inputColIndices(0)),
row.getString(inputColIndices(1)))))
})
iter.grouped($(batchSize)).flatMap { batch =>
val inputs = batch.map(row => new QAInput(row.getString(inputColIndices(0)),
row.getString(inputColIndices(1)))).toArray
val output = predictor.predict(inputs)
batch.zip(output).map { case (row, out) =>
Row.fromSeq(row.toSeq :+ out)
}
}
}

/** @inheritdoc */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,18 @@ package ai.djl.spark.task.text

import ai.djl.huggingface.translator.TextClassificationTranslatorFactory
import ai.djl.modality.Classifications
import ai.djl.modality.Classifications.Classification
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{ArrayType, DoubleType, MapType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}

import scala.collection.mutable

/**
* TextClassifier performs text classification on text.
*
* @param uid An immutable unique ID for the object and its derivatives.
*/
class TextClassifier(override val uid: String) extends BaseTextPredictor[String, Classifications]
class TextClassifier(override val uid: String) extends BaseTextPredictor[Array[String], Array[Classifications]]
with HasInputCol with HasOutputCol {

def this() = this(Identifiable.randomUID("TextClassifier"))
Expand Down Expand Up @@ -58,9 +55,8 @@ class TextClassifier(override val uid: String) extends BaseTextPredictor[String,
*/
def setTopK(value: Int): this.type = set(topK, value)

setDefault(inputClass, classOf[String])
setDefault(outputClass, classOf[Classifications])
setDefault(topK, 3)
setDefault(inputClass, classOf[Array[String]])
setDefault(outputClass, classOf[Array[Classifications]])
setDefault(translatorFactory, new TextClassificationTranslatorFactory())

/**
Expand All @@ -75,26 +71,24 @@ class TextClassifier(override val uid: String) extends BaseTextPredictor[String,

/** @inheritdoc */
override def transform(dataset: Dataset[_]): DataFrame = {
arguments.put("batchifier", $(batchifier))
arguments.put("topK", $(topK).toString)
if (isDefined(topK)) {
arguments.put("topK", $(topK).toString)
}
inputColIndex = dataset.schema.fieldIndex($(inputCol))
super.transform(dataset)
}

/** @inheritdoc */
override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = {
val predictor = model.newPredictor()
iter.map(row => {
val prediction: Classifications = predictor.predict(row.getString(inputColIndex))
val top = mutable.LinkedHashMap[String, Double]()
val it: java.util.Iterator[Classification] = prediction.topK($(topK)).iterator()
while (it.hasNext) {
val t = it.next()
top += (t.getClassName -> t.getProbability)
iter.grouped($(batchSize)).flatMap { batch =>
val inputs = batch.map(_.getString(inputColIndex)).toArray
val output = predictor.predict(inputs)
val top = output.map(_.topK[Classifications.Classification]().toString)
batch.zip(output).map { case (row, out) =>
Row.fromSeq(row.toSeq :+ Row(out.getClassNames.toArray, out.getProbabilities.toArray(), top))
}
Row.fromSeq(row.toSeq :+ Row(prediction.getClassNames.toArray,
prediction.getProbabilities.toArray, top))
})
}
}

/** @inheritdoc */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row}
*
* @param uid An immutable unique ID for the object and its derivatives.
*/
class TextEmbedder(override val uid: String) extends BaseTextPredictor[String, Array[Float]]
class TextEmbedder(override val uid: String) extends BaseTextPredictor[Array[String], Array[Array[Float]]]
with HasInputCol with HasOutputCol {

def this() = this(Identifiable.randomUID("TextEmbedder"))
Expand All @@ -44,8 +44,8 @@ class TextEmbedder(override val uid: String) extends BaseTextPredictor[String, A
*/
def setOutputCol(value: String): this.type = set(outputCol, value)

setDefault(inputClass, classOf[String])
setDefault(outputClass, classOf[Array[Float]])
setDefault(inputClass, classOf[Array[String]])
setDefault(outputClass, classOf[Array[Array[Float]]])
setDefault(translatorFactory, new TextEmbeddingTranslatorFactory())

/**
Expand All @@ -67,9 +67,13 @@ class TextEmbedder(override val uid: String) extends BaseTextPredictor[String, A
/** @inheritdoc */
override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = {
val predictor = model.newPredictor()
iter.map(row => {
Row.fromSeq(row.toSeq :+ predictor.predict(row.getString(inputColIndex)))
})
iter.grouped($(batchSize)).flatMap { batch =>
val inputs = batch.map(_.getString(inputColIndex)).toArray
val output = predictor.predict(inputs)
batch.zip(output).map { case (row, out) =>
Row.fromSeq(row.toSeq :+ out)
}
}
}

/** @inheritdoc */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
package ai.djl.spark.task.vision

import ai.djl.modality.Classifications
import ai.djl.modality.cv.ImageFactory
import ai.djl.modality.cv.{Image, ImageFactory}
import ai.djl.modality.cv.translator.ImageClassificationTranslatorFactory
import org.apache.spark.ml.image.ImageSchema
import org.apache.spark.ml.param.Param
Expand All @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructFiel
import org.apache.spark.sql.{DataFrame, Dataset, Row}

import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`
import scala.jdk.CollectionConverters.asJavaIterableConverter

/**
* ImageClassifier performs image classification on images.
Expand Down Expand Up @@ -61,8 +62,6 @@ class ImageClassifier(override val uid: String) extends BaseImagePredictor[Class

setDefault(outputClass, classOf[Classifications])
setDefault(translatorFactory, new ImageClassificationTranslatorFactory())
setDefault(applySoftmax, true)
setDefault(topK, 5)

/**
* Performs image classification on the provided dataset.
Expand All @@ -76,22 +75,30 @@ class ImageClassifier(override val uid: String) extends BaseImagePredictor[Class

/** @inheritdoc */
override def transform(dataset: Dataset[_]): DataFrame = {
arguments.put("applySoftmax", $(applySoftmax).toString)
arguments.put("topK", $(topK).toString)
if (isDefined(applySoftmax)) {
arguments.put("applySoftmax", $(applySoftmax).toString)
}
if (isDefined(topK)) {
arguments.put("topK", $(topK).toString)
}
super.transform(dataset)
}

/** @inheritdoc */
override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = {
val predictor = model.newPredictor()
iter.map(row => {
val image = ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)),
ImageSchema.getWidth(row), ImageSchema.getHeight(row))
val prediction = predictor.predict(image)
val top = prediction.topK[Classifications.Classification]($(topK)).map(item => item.toString)
Row.fromSeq(row.toSeq :+ Row(prediction.getClassNames.toArray,
prediction.getProbabilities.toArray, top))
})
iter.grouped($(batchSize)).flatMap { batch =>
val inputs = new java.util.ArrayList[Image]()
batch.foreach { row =>
inputs.add(ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)),
ImageSchema.getWidth(row), ImageSchema.getHeight(row)))
}
val output = predictor.batchPredict(inputs)
val top = output.map(_.topK[Classifications.Classification]().toString)
batch.zip(output).map { case (row, out) =>
Row.fromSeq(row.toSeq :+ Row(out.getClassNames.toArray, out.getProbabilities.toArray(), top))
}
}
}

/** @inheritdoc */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
*/
package ai.djl.spark.task.vision

import ai.djl.modality.cv.ImageFactory
import ai.djl.modality.Classifications
import ai.djl.modality.cv.{Image, ImageFactory}
import ai.djl.modality.cv.translator.ImageFeatureExtractorFactory
import org.apache.spark.ml.image.ImageSchema
import org.apache.spark.ml.param.shared.HasOutputCol
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{ArrayType, ByteType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}

import scala.jdk.CollectionConverters.asScalaBufferConverter

/**
* ImageEmbedder performs image embedding on images.
*
Expand Down Expand Up @@ -53,11 +56,17 @@ class ImageEmbedder(override val uid: String) extends BaseImagePredictor[Array[B
/** @inheritdoc */
override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = {
val predictor = model.newPredictor()
iter.map(row => {
val image = ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)),
ImageSchema.getWidth(row), ImageSchema.getHeight(row))
Row.fromSeq(row.toSeq :+ predictor.predict(image))
})
iter.grouped($(batchSize)).flatMap { batch =>
val inputs = new java.util.ArrayList[Image]()
batch.foreach { row =>
inputs.add(ImageFactory.getInstance().fromPixels(bgrToRgb(ImageSchema.getData(row)),
ImageSchema.getWidth(row), ImageSchema.getHeight(row)))
}
val output = predictor.batchPredict(inputs).asScala
batch.zip(output).map { case (row, out) =>
Row.fromSeq(row.toSeq :+ out)
}
}
}

/** @inheritdoc */
Expand Down
Loading

0 comments on commit 412e906

Please sign in to comment.