Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spark] Use batch predict in spark #2545

Merged
merged 2 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docker/spark/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,5 @@ RUN echo "export HUGGINGFACE_HUB_CACHE=/tmp" >> /opt/hadoop-config/spark-env.sh
RUN echo "export TRANSFORMERS_CACHE=/tmp" >> /opt/hadoop-config/spark-env.sh
RUN echo "spark.yarn.appMasterEnv.PYTORCH_PRECXX11 true" >> /opt/hadoop-config/spark-defaults.conf
RUN echo "spark.executorEnv.PYTORCH_PRECXX11 true" >> /opt/hadoop-config/spark-defaults.conf
RUN echo "spark.sql.execution.arrow.maxRecordsPerBatch 500" >> /opt/hadoop-config/spark-defaults.conf
RUN echo "spark.hadoop.fs.s3a.connection.maximum 1000" >> /opt/hadoop-config/spark-defaults.conf
14 changes: 10 additions & 4 deletions extensions/spark/setup/djl_spark/task/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
"""DJL Spark Tasks Audio API."""

"""DJL Spark Tasks Text API."""

from . import whisper_speech_recognizer
from . import (
speech_recognizer,
whisper_speech_recognizer,
)

SpeechRecognizer = speech_recognizer.SpeechRecognizer
WhisperSpeechRecognizer = whisper_speech_recognizer.WhisperSpeechRecognizer

# Remove unnecessary modules to avoid duplication in API.
del whisper_speech_recognizer
del (
speech_recognizer,
whisper_speech_recognizer,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python
#
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

from pyspark import SparkContext
from pyspark.sql import DataFrame
from typing import Optional


class SpeechRecognizer:

def __init__(self,
input_col: str,
output_col: str,
model_url: str,
engine: Optional[str] = None,
batch_size: Optional[int] = None,
translator_factory=None,
batchifier: Optional[str] = None,
channels: Optional[int] = None,
sample_rate: Optional[int] = None,
sample_format: Optional[int] = None):
"""
Initializes the SpeechRecognizer.

:param input_col: The input column
:param output_col: The output column
:param model_url: The model URL
:param engine (optional): The engine
:param batch_size (optional): The batch size
:param translator_factory (optional): The translator factory.
Default is SpeechRecognitionTranslatorFactory.
:param batchifier (optional): The batchifier. Valid values include "none" (default),
"stack", and "padding".
:param channels (optional): The number of channels
:param sample_rate (optional): The audio sample rate
:param sample_format (optional): The audio sample format
"""
self.input_col = input_col
self.output_col = output_col
self.model_url = model_url
self.engine = engine
self.batch_size = batch_size
self.translator_factory = translator_factory
self.batchifier = batchifier
self.channels = channels
self.sample_rate = sample_rate
self.sample_format = sample_format

def recognize(self, dataset):
"""
Performs speech recognition on the provided dataset.

:param dataset: input dataset
:return: output dataset
"""
sc = SparkContext._active_spark_context
recognizer = sc._jvm.ai.djl.spark.task.audio.SpeechRecognizer() \
.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setModelUrl(self.model_url)
if self.engine is not None:
recognizer = recognizer.setEngine(self.engine)
if self.batch_size is not None:
recognizer = recognizer.setBatchSize(self.batch_size)
if self.translator_factory is not None:
recognizer = recognizer.setTranslatorFactory(
self.translator_factory)
if self.batchifier is not None:
recognizer = recognizer.setBatchifier(self.batchifier)
return DataFrame(recognizer.recognize(dataset._jdf),
dataset.sparkSession)
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,24 @@
import io
import librosa
import pandas as pd
from typing import Iterator
from typing import Iterator, Optional
from transformers import pipeline
from ...util import files_util, dependency_util


TASK = "automatic-speech-recognition"
APPLICATION = "audio/automatic_speech_recognition"
GROUP_ID = "ai/djl/huggingface/pytorch"


class WhisperSpeechRecognizer:

def __init__(self, input_col, output_col, model_url=None, hf_model_id=None, engine="PyTorch"):
def __init__(self,
input_col: str,
output_col: str,
model_url: Optional[str] = None,
hf_model_id: Optional[str] = None,
engine: Optional[str] = "PyTorch",
batch_size: Optional[int] = 10):
"""
Initializes the WhisperSpeechRecognizer.

Expand All @@ -37,12 +42,14 @@ def __init__(self, input_col, output_col, model_url=None, hf_model_id=None, engi
:param model_url: The model URL
:param hf_model_id: The Huggingface model ID
:param engine: The engine. Currently only PyTorch is supported.
:param batch_size: The batch size
"""
self.input_col = input_col
self.output_col = output_col
self.model_url = model_url
self.hf_model_id = hf_model_id
self.engine = engine
self.batch_size = batch_size

def recognize(self, dataset, generate_kwargs=None, **kwargs):
"""
Expand All @@ -57,24 +64,33 @@ def recognize(self, dataset, generate_kwargs=None, **kwargs):
raise ValueError("Only PyTorch engine is supported.")

if self.model_url:
cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID, self.model_url)
cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID,
self.model_url)
files_util.download_and_extract(self.model_url, cache_dir)
dependency_util.install(cache_dir)
model_id_or_path = cache_dir
elif self.hf_model_id:
model_id_or_path = self.hf_model_id
else:
raise ValueError("Either model_url or hf_model_id must be provided.")
raise ValueError(
"Either model_url or hf_model_id must be provided.")

@pandas_udf(StringType())
def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
pipe = pipeline(TASK, generate_kwargs=generate_kwargs,
model=model_id_or_path, chunk_length_s=30, **kwargs)
pipe = pipeline(TASK,
generate_kwargs=generate_kwargs,
model=model_id_or_path,
batch_size=self.batch_size,
chunk_length_s=30,
**kwargs)
for s in iterator:
# Model expects single channel, 16000 sample rate audio
batch = [librosa.load(io.BytesIO(d), mono=True, sr=16000)[0] for d in s]
batch = [
librosa.load(io.BytesIO(d), mono=True, sr=16000)[0]
for d in s
]
output = pipe(batch)
text = map(lambda x: x["text"], output)
text = [o["text"] for o in output]
yield pd.Series(text)

return dataset.withColumn(self.output_col, predict_udf(self.input_col))
return dataset.withColumn(self.output_col, predict_udf(self.input_col))
4 changes: 3 additions & 1 deletion extensions/spark/setup/djl_spark/task/binary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
"""DJL Spark Tasks Binary API."""

from . import binary_predictor

BinaryPredictor = binary_predictor.BinaryPredictor

# Remove unnecessary modules to avoid duplication in API.
del binary_predictor
del binary_predictor
59 changes: 36 additions & 23 deletions extensions/spark/setup/djl_spark/task/binary/binary_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,47 @@
# the specific language governing permissions and limitations under the License.

from pyspark import SparkContext
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import DataFrame
from typing import Optional


class BinaryPredictor:
"""BinaryPredictor performs prediction on binary input.
"""

def __init__(self, input_col, output_col, model_url, engine=None,
input_class=None, output_class=None, translator=None,
batchifier="none"):
def __init__(self,
input_col: str,
output_col: str,
model_url: str,
engine: Optional[str] = None,
batch_size: Optional[int] = None,
input_class=None,
output_class=None,
translator_factory=None,
batchifier: Optional[str] = None):
"""
Initializes the BinaryPredictor.

:param input_col: The input column.
:param output_col: The output column.
:param model_url: The model URL.
:param engine (optional): The engine.
:param input_col: The input column
:param output_col: The output column
:param model_url: The model URL
:param engine (optional): The engine
:param batch_size (optional): The batch size
:param input_class (optional): The input class. Default is byte array.
:param output_class (optional): The output class. Default is byte array.
:param translator (optional): The translator. Default is NpBinaryTranslator.
:param batchifier (optional): The batchifier. Valid values include none (default),
stack, and padding.
:param translator_factory (optional): The translator factory.
Default is NpBinaryTranslatorFactory.
:param batchifier (optional): The batchifier. Valid values include "none" (default),
"stack", and "padding".
"""
self.input_col = input_col
self.output_col = output_col
self.model_url = model_url
self.engine = engine
self.batch_size = batch_size
self.input_class = input_class
self.output_class = output_class
self.translator = translator
self.translator_factory = translator_factory
self.batchifier = batchifier

def predict(self, dataset):
Expand All @@ -52,18 +63,20 @@ def predict(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context

predictor = sc._jvm.ai.djl.spark.task.binary.BinaryPredictor()
predictor = sc._jvm.ai.djl.spark.task.binary.BinaryPredictor() \
.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setModelUrl(self.model_url)
if self.engine is not None:
predictor = predictor.setEngine(self.engine)
if self.batch_size is not None:
predictor = predictor.setBatchSize(self.batch_size)
if self.input_class is not None:
predictor = predictor.setinputClass(self.input_class)
if self.output_class is not None:
predictor = predictor.setOutputClass(self.output_class)
if self.translator is not None:
self.translator = predictor.setTranslator(self.translator)
predictor = predictor.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setModelUrl(self.model_url) \
.setEngine(self.engine) \
.setBatchifier(self.batchifier)
return DataFrame(predictor.predict(dataset._jdf),
dataset.sparkSession)
if self.translator_factory is not None:
predictor = predictor.setTranslatorFactory(self.translator_factory)
if self.batchifier is not None:
predictor = predictor.setBatchifier(self.batchifier)
return DataFrame(predictor.predict(dataset._jdf), dataset.sparkSession)
36 changes: 25 additions & 11 deletions extensions/spark/setup/djl_spark/task/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,36 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

"""DJL Spark Tasks Text API."""

from . import text_decoder, text_encoder, text_tokenizer, text_embedder, text2text_generator, text_generator
from . import (
question_answerer,
text2text_generator,
text_classifier,
text_decoder,
text_embedder,
text_encoder,
text_generator,
text_tokenizer,
)

QuestionAnswerer = question_answerer.QuestionAnswerer
Text2TextGenerator = text2text_generator.Text2TextGenerator
TextClassifier = text_classifier.TextClassifier
TextDecoder = text_decoder.TextDecoder
TextEncoder = text_encoder.TextEncoder
TextTokenizer = text_tokenizer.TextTokenizer
TextEmbedder = text_embedder.TextEmbedder
Text2TextGenerator = text2text_generator.Text2TextGenerator
TextEncoder = text_encoder.TextEncoder
TextGenerator = text_generator.TextGenerator
TextTokenizer = text_tokenizer.TextTokenizer

# Remove unnecessary modules to avoid duplication in API.
del text_decoder
del text_encoder
del text_tokenizer
del text_embedder
del text2text_generator
del text_generator
del (
question_answerer,
text2text_generator,
text_classifier,
text_decoder,
text_embedder,
text_encoder,
text_generator,
text_tokenizer,
)
Loading