Skip to content

Commit

Permalink
[spark] Use batch predict in spark
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed May 1, 2023
1 parent 2ecf3fb commit b53db96
Show file tree
Hide file tree
Showing 38 changed files with 983 additions and 236 deletions.
1 change: 1 addition & 0 deletions docker/spark/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,5 @@ RUN echo "export HUGGINGFACE_HUB_CACHE=/tmp" >> /opt/hadoop-config/spark-env.sh
RUN echo "export TRANSFORMERS_CACHE=/tmp" >> /opt/hadoop-config/spark-env.sh
RUN echo "spark.yarn.appMasterEnv.PYTORCH_PRECXX11 true" >> /opt/hadoop-config/spark-defaults.conf
RUN echo "spark.executorEnv.PYTORCH_PRECXX11 true" >> /opt/hadoop-config/spark-defaults.conf
RUN echo "spark.sql.execution.arrow.maxRecordsPerBatch 500" >> /opt/hadoop-config/spark-defaults.conf
RUN echo "spark.hadoop.fs.s3a.connection.maximum 1000" >> /opt/hadoop-config/spark-defaults.conf
13 changes: 10 additions & 3 deletions extensions/spark/setup/djl_spark/task/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,18 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

"""DJL Spark Tasks Text API."""
"""DJL Spark Tasks Audio API."""

from . import whisper_speech_recognizer
from . import (
speech_recognizer,
whisper_speech_recognizer,
)

SpeechRecognizer = speech_recognizer.SpeechRecognizer
WhisperSpeechRecognizer = whisper_speech_recognizer.WhisperSpeechRecognizer

# Remove unnecessary modules to avoid duplication in API.
del whisper_speech_recognizer
del (
speech_recognizer,
whisper_speech_recognizer,
)
81 changes: 81 additions & 0 deletions extensions/spark/setup/djl_spark/task/audio/speech_recognizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python
#
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

from pyspark import SparkContext
from pyspark.sql import DataFrame
from typing import Optional


class SpeechRecognizer:

def __init__(self,
input_col: str,
output_col: str,
model_url: str,
engine: Optional[str] = None,
batch_size: Optional[int] = None,
translator_factory=None,
batchifier: Optional[str] = None,
channels: Optional[int] = None,
sample_rate: Optional[int] = None,
sample_format: Optional[int] = None):
"""
Initializes the SpeechRecognizer.
:param input_col: The input column
:param output_col: The output column
:param model_url: The model URL
:param engine (optional): The engine
:param batch_size (optional): The batch size
:param translator_factory (optional): The translator factory.
Default is SpeechRecognitionTranslatorFactory.
:param batchifier (optional): The batchifier. Valid values include "none" (default),
"stack", and "padding".
:param channels (optional): The number of channels
:param sample_rate (optional): The audio sample rate
:param sample_format (optional): The audio sample format
"""
self.input_col = input_col
self.output_col = output_col
self.model_url = model_url
self.engine = engine
self.batch_size = batch_size
self.translator_factory = translator_factory
self.batchifier = batchifier
self.channels = channels
self.sample_rate = sample_rate
self.sample_format = sample_format

def recognize(self, dataset):
"""
Performs speech recognition on the provided dataset.
:param dataset: input dataset
:return: output dataset
"""
sc = SparkContext._active_spark_context
recognizer = (
sc._jvm.ai.djl.spark.task.audio.SpeechRecognizer()
.setInputCol(self.input_col)
.setOutputCol(self.output_col)
.setModelUrl(self.model_url)
)
if self.engine is not None:
recognizer = recognizer.setEngine(self.engine)
if self.batch_size is not None:
recognizer = recognizer.setBatchSize(self.batch_size)
if self.translator_factory is not None:
recognizer = recognizer.setTranslatorFactory(self.translator_factory)
if self.batchifier is not None:
recognizer = recognizer.setBatchifier(self.batchifier)
return DataFrame(recognizer.recognize(dataset._jdf), dataset.sparkSession)
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import io
import librosa
import pandas as pd
from typing import Iterator
from typing import Iterator, Optional
from transformers import pipeline
from ...util import files_util, dependency_util

Expand All @@ -28,7 +28,13 @@

class WhisperSpeechRecognizer:

def __init__(self, input_col, output_col, model_url=None, hf_model_id=None, engine="PyTorch"):
def __init__(self,
input_col: str,
output_col: str,
model_url: Optional[str] = None,
hf_model_id: Optional[str] = None,
engine: Optional[str] = "PyTorch",
batch_size: Optional[int] = 10):
"""
Initializes the WhisperSpeechRecognizer.
Expand All @@ -37,12 +43,14 @@ def __init__(self, input_col, output_col, model_url=None, hf_model_id=None, engi
:param model_url: The model URL
:param hf_model_id: The Huggingface model ID
:param engine: The engine. Currently only PyTorch is supported.
:param batch_size: The batch size
"""
self.input_col = input_col
self.output_col = output_col
self.model_url = model_url
self.hf_model_id = hf_model_id
self.engine = engine
self.batch_size = batch_size

def recognize(self, dataset, generate_kwargs=None, **kwargs):
"""
Expand All @@ -68,13 +76,13 @@ def recognize(self, dataset, generate_kwargs=None, **kwargs):

@pandas_udf(StringType())
def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
pipe = pipeline(TASK, generate_kwargs=generate_kwargs,
model=model_id_or_path, chunk_length_s=30, **kwargs)
pipe = pipeline(TASK, generate_kwargs=generate_kwargs, model=model_id_or_path,
batch_size=self.batch_size, chunk_length_s=30, **kwargs)
for s in iterator:
# Model expects single channel, 16000 sample rate audio
batch = [librosa.load(io.BytesIO(d), mono=True, sr=16000)[0] for d in s]
output = pipe(batch)
text = map(lambda x: x["text"], output)
text = [o["text"] for o in output]
yield pd.Series(text)

return dataset.withColumn(self.output_col, predict_udf(self.input_col))
return dataset.withColumn(self.output_col, predict_udf(self.input_col))
5 changes: 4 additions & 1 deletion extensions/spark/setup/djl_spark/task/binary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

"""DJL Spark Tasks Binary API."""

from . import binary_predictor

BinaryPredictor = binary_predictor.BinaryPredictor

# Remove unnecessary modules to avoid duplication in API.
del binary_predictor
del binary_predictor
61 changes: 38 additions & 23 deletions extensions/spark/setup/djl_spark/task/binary/binary_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,47 @@
# the specific language governing permissions and limitations under the License.

from pyspark import SparkContext
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import DataFrame
from typing import Optional


class BinaryPredictor:
"""BinaryPredictor performs prediction on binary input.
"""

def __init__(self, input_col, output_col, model_url, engine=None,
input_class=None, output_class=None, translator=None,
batchifier="none"):
def __init__(self,
input_col: str,
output_col: str,
model_url: str,
engine: Optional[str] = None,
batch_size: Optional[int] = None,
input_class=None,
output_class=None,
translator_factory=None,
batchifier: Optional[str] = None):
"""
Initializes the BinaryPredictor.
:param input_col: The input column.
:param output_col: The output column.
:param model_url: The model URL.
:param engine (optional): The engine.
:param input_col: The input column
:param output_col: The output column
:param model_url: The model URL
:param engine (optional): The engine
:param batch_size (optional): The batch size
:param input_class (optional): The input class. Default is byte array.
:param output_class (optional): The output class. Default is byte array.
:param translator (optional): The translator. Default is NpBinaryTranslator.
:param batchifier (optional): The batchifier. Valid values include none (default),
stack, and padding.
:param translator_factory (optional): The translator factory.
Default is NpBinaryTranslatorFactory.
:param batchifier (optional): The batchifier. Valid values include "none" (default),
"stack", and "padding".
"""
self.input_col = input_col
self.output_col = output_col
self.model_url = model_url
self.engine = engine
self.batch_size = batch_size
self.input_class = input_class
self.output_class = output_class
self.translator = translator
self.translator_factory = translator_factory
self.batchifier = batchifier

def predict(self, dataset):
Expand All @@ -52,18 +63,22 @@ def predict(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context

predictor = sc._jvm.ai.djl.spark.task.binary.BinaryPredictor()
predictor = (
sc._jvm.ai.djl.spark.task.binary.BinaryPredictor()
.setInputCol(self.input_col)
.setOutputCol(self.output_col)
.setModelUrl(self.model_url)
)
if self.engine is not None:
predictor = predictor.setEngine(self.engine)
if self.batch_size is not None:
predictor = predictor.setBatchSize(self.batch_size)
if self.input_class is not None:
predictor = predictor.setinputClass(self.input_class)
if self.output_class is not None:
predictor = predictor.setOutputClass(self.output_class)
if self.translator is not None:
self.translator = predictor.setTranslator(self.translator)
predictor = predictor.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setModelUrl(self.model_url) \
.setEngine(self.engine) \
.setBatchifier(self.batchifier)
return DataFrame(predictor.predict(dataset._jdf),
dataset.sparkSession)
if self.translator_factory is not None:
predictor = predictor.setTranslatorFactory(self.translator_factory)
if self.batchifier is not None:
predictor = predictor.setBatchifier(self.batchifier)
return DataFrame(predictor.predict(dataset._jdf), dataset.sparkSession)
35 changes: 25 additions & 10 deletions extensions/spark/setup/djl_spark/task/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,34 @@

"""DJL Spark Tasks Text API."""

from . import text_decoder, text_encoder, text_tokenizer, text_embedder, text2text_generator, text_generator
from . import (
question_answerer,
text2text_generator,
text_classifier,
text_decoder,
text_embedder,
text_encoder,
text_generator,
text_tokenizer,
)

QuestionAnswerer = question_answerer.QuestionAnswerer
Text2TextGenerator = text2text_generator.Text2TextGenerator
TextClassifier = text_classifier.TextClassifier
TextDecoder = text_decoder.TextDecoder
TextEncoder = text_encoder.TextEncoder
TextTokenizer = text_tokenizer.TextTokenizer
TextEmbedder = text_embedder.TextEmbedder
Text2TextGenerator = text2text_generator.Text2TextGenerator
TextEncoder = text_encoder.TextEncoder
TextGenerator = text_generator.TextGenerator
TextTokenizer = text_tokenizer.TextTokenizer

# Remove unnecessary modules to avoid duplication in API.
del text_decoder
del text_encoder
del text_tokenizer
del text_embedder
del text2text_generator
del text_generator
del (
question_answerer,
text2text_generator,
text_classifier,
text_decoder,
text_embedder,
text_encoder,
text_generator,
text_tokenizer,
)
Loading

0 comments on commit b53db96

Please sign in to comment.