Skip to content

Commit

Permalink
[spark] Support requirements.txt in model tar file (#2528)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Apr 28, 2023
1 parent 86ddc66 commit 2ecf3fb
Show file tree
Hide file tree
Showing 17 changed files with 273 additions and 155 deletions.
6 changes: 6 additions & 0 deletions docker/spark/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,16 @@ ADD --chmod=644 https://repo1.maven.org/maven2/com/google/protobuf/protobuf-java
# Set environment
ENV PYTORCH_PRECXX11 true
ENV OMP_NUM_THREADS 1
ENV DJL_CACHE_DIR /tmp/.djl.ai
ENV HUGGINGFACE_HUB_CACHE /tmp
ENV TRANSFORMERS_CACHE /tmp

RUN echo 'export SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS -Dai.djl.pytorch.graph_optimizer=false"' >> /opt/hadoop-config/spark-env.sh
RUN echo "export PYTORCH_PRECXX11=true" >> /opt/hadoop-config/spark-env.sh
RUN echo "export OMP_NUM_THREADS=1" >> /opt/hadoop-config/spark-env.sh
RUN echo "export DJL_CACHE_DIR=/tmp/.djl.ai" >> /opt/hadoop-config/spark-env.sh
RUN echo "export HUGGINGFACE_HUB_CACHE=/tmp" >> /opt/hadoop-config/spark-env.sh
RUN echo "export TRANSFORMERS_CACHE=/tmp" >> /opt/hadoop-config/spark-env.sh
RUN echo "spark.yarn.appMasterEnv.PYTORCH_PRECXX11 true" >> /opt/hadoop-config/spark-defaults.conf
RUN echo "spark.executorEnv.PYTORCH_PRECXX11 true" >> /opt/hadoop-config/spark-defaults.conf
RUN echo "spark.hadoop.fs.s3a.connection.maximum 1000" >> /opt/hadoop-config/spark-defaults.conf
Original file line number Diff line number Diff line change
Expand Up @@ -11,59 +11,70 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

from pyspark import SparkContext
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import StringType
import io
import librosa
import pandas as pd
from typing import Iterator
from transformers import pipeline
from ...util import files_util, dependency_util


TASK = "automatic-speech-recognition"
APPLICATION = "audio/automatic_speech_recognition"
GROUP_ID = "ai/djl/huggingface/pytorch"


class WhisperSpeechRecognizer:

def __init__(self, input_col, output_col, engine, model_url=None, model_name=None):
def __init__(self, input_col, output_col, model_url=None, hf_model_id=None, engine="PyTorch"):
"""
Initializes the WhisperSpeechRecognizer.
:param input_col: The input column
:param output_col: The output column
:param engine: The engine. Currently only PyTorch is supported.
:param model_url: The model URL
:param model_name: The model name
:param hf_model_id: The Huggingface model ID
:param engine: The engine. Currently only PyTorch is supported.
"""
self.input_col = input_col
self.output_col = output_col
self.engine = engine
self.model_url = model_url
self.model_name = model_name
self.hf_model_id = hf_model_id
self.engine = engine

def recognize(self, dataset, generate_kwargs=None, **kwargs):
"""
Performs speech recognition on the provided dataset.
:param dataset: input dataset
:param generate_kwargs: The dictionary of ad-hoc parametrization of generate_config
to be used for the generation call.
:return: output dataset
"""
sc = SparkContext._active_spark_context
if not self.model_url and not self.model_name:
raise ValueError("Either model_url or model_name must be provided.")
model_name_or_url = self.model_url if self.model_url else self.model_name
pipe = pipeline("automatic-speech-recognition", generate_kwargs=generate_kwargs,
model=model_name_or_url, chunk_length_s=30, **kwargs)
bc_pipe = sc.broadcast(pipe)
if self.engine is None or self.engine.lower() != "pytorch":
raise ValueError("Only PyTorch engine is supported.")

if self.model_url:
cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID, self.model_url)
files_util.download_and_extract(self.model_url, cache_dir)
dependency_util.install(cache_dir)
model_id_or_path = cache_dir
elif self.hf_model_id:
model_id_or_path = self.hf_model_id
else:
raise ValueError("Either model_url or hf_model_id must be provided.")

@pandas_udf(StringType())
def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
pipe = pipeline(TASK, generate_kwargs=generate_kwargs,
model=model_id_or_path, chunk_length_s=30, **kwargs)
for s in iterator:
batch = []
for d in s:
# Model expects single channel, 16000 sample rate audio
data, sample_rate = librosa.load(io.BytesIO(d), mono=True, sr=16000)
batch.append(data)
output = bc_pipe.value(batch)
text = [o["text"] for o in output]
# Model expects single channel, 16000 sample rate audio
batch = [librosa.load(io.BytesIO(d), mono=True, sr=16000)[0] for d in s]
output = pipe(batch)
text = map(lambda x: x["text"], output)
yield pd.Series(text)

return dataset.withColumn(self.output_col, predict_udf(self.input_col))
23 changes: 12 additions & 11 deletions extensions/spark/setup/djl_spark/task/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,19 @@

"""DJL Spark Tasks Text API."""

from . import huggingface_text_decoder
from . import huggingface_text_encoder
from . import huggingface_text_tokenizer
from . import text_embedder
from . import text_decoder, text_encoder, text_tokenizer, text_embedder, text2text_generator, text_generator

HuggingFaceTextDecoder = huggingface_text_decoder.HuggingFaceTextDecoder
HuggingFaceTextEncoder = huggingface_text_encoder.HuggingFaceTextEncoder
HuggingFaceTextTokenizer = huggingface_text_tokenizer.HuggingFaceTextTokenizer
TextDecoder = text_decoder.TextDecoder
TextEncoder = text_encoder.TextEncoder
TextTokenizer = text_tokenizer.TextTokenizer
TextEmbedder = text_embedder.TextEmbedder
Text2TextGenerator = text2text_generator.Text2TextGenerator
TextGenerator = text_generator.TextGenerator

# Remove unnecessary modules to avoid duplication in API.
del huggingface_text_decoder
del huggingface_text_encoder
del huggingface_text_tokenizer
del text_embedder
del text_decoder
del text_encoder
del text_tokenizer
del text_embedder
del text2text_generator
del text_generator
36 changes: 25 additions & 11 deletions extensions/spark/setup/djl_spark/task/text/text2text_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,30 @@
from pyspark.sql.types import StringType
from typing import Iterator
from transformers import pipeline
from ...util import files_util, dependency_util

TASK = "text2text-generation"
APPLICATION = "nlp/text2text_generation"
GROUP_ID = "ai/djl/huggingface/pytorch"


class Text2TextGenerator:

def __init__(self, input_col, output_col, engine, model_url=None, model_name=None):
def __init__(self, input_col, output_col, model_url=None, hf_model_id=None, engine="PyTorch"):
"""
Initializes the Text2TextGenerator.
:param input_col: The input column
:param output_col: The output column
:param engine: The engine. Currently only PyTorch is supported.
:param model_url: The model URL
:param model_name: The model name
:param hf_model_id: The Huggingface model ID
:param engine: The engine. Currently only PyTorch is supported.
"""
self.input_col = input_col
self.output_col = output_col
self.engine = engine
self.model_url = model_url
self.model_name = model_name
self.hf_model_id = hf_model_id
self.engine = engine

def generate(self, dataset, **kwargs):
"""
Expand All @@ -43,16 +48,25 @@ def generate(self, dataset, **kwargs):
:param dataset: input dataset
:return: output dataset
"""
if not self.model_url and not self.model_name:
raise ValueError("Either model_url or model_name must be provided.")
model_name_or_url = self.model_url if self.model_url else self.model_name
if self.engine is None or self.engine.lower() != "pytorch":
raise ValueError("Only PyTorch engine is supported.")

if self.model_url:
cache_dir = files_util.get_cache_dir(APPLICATION, GROUP_ID, self.model_url)
files_util.download_and_extract(self.model_url, cache_dir)
dependency_util.install(cache_dir)
model_id_or_path = cache_dir
elif self.hf_model_id:
model_id_or_path = self.hf_model_id
else:
raise ValueError("Either model_url or hf_model_id must be provided.")

@pandas_udf(StringType())
def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
generator = pipeline('text2text-generation', model=model_name_or_url, **kwargs)
pipe = pipeline(TASK, model=model_id_or_path, **kwargs)
for s in iterator:
output = generator(s.tolist())
text = [o["generated_text"] for o in output]
output = pipe(s.tolist())
text = map(lambda x: x["generated_text"], output)
yield pd.Series(text)

return dataset.withColumn(self.output_col, predict_udf(self.input_col))
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,31 @@
from pyspark.sql import DataFrame


class HuggingFaceTextDecoder:
class TextDecoder:

def __init__(self, input_col, output_col, name):
def __init__(self, input_col, output_col, hf_model_id):
"""
Initializes the HuggingFaceTextDecoder.
Initializes the TextDecoder.
:param input_col: The input column
:param output_col: The output column
:param name: The name of the tokenizer
:param hf_model_id: The Huggingface model ID
"""
self.input_col = input_col
self.output_col = output_col
self.name = name
self.hf_model_id = hf_model_id

def decode(self, dataset):
"""
Performs sentence encoding on the provided dataset.
Performs sentence decoding on the provided dataset.
:param dataset: input dataset
:return: output dataset
"""
sc = SparkContext._active_spark_context
decoder = sc._jvm.ai.djl.spark.task.text.HuggingFaceTextDecoder() \
decoder = sc._jvm.ai.djl.spark.task.text.TextDecoder() \
.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setName(self.name)
.setHfModelId(self.hf_model_id)
return DataFrame(decoder.decode(dataset._jdf),
dataset.sparkSession)
14 changes: 7 additions & 7 deletions extensions/spark/setup/djl_spark/task/text/text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@

class TextEmbedder:

def __init__(self, input_col, output_col, engine, model_url,
output_class=None, translator=None):
def __init__(self, input_col, output_col, model_url, engine=None,
output_class=None, translator_factory=None):
"""
Initializes the TextEmbedder.
:param input_col: The input column
:param output_col: The output column
:param engine (optional): The engine
:param model_url: The model URL
:param engine (optional): The engine
:param output_class (optional): The output class
:param translator (optional): The translator. Default is TextEmbeddingTranslator.
:param translator_factory (optional): The translator factory. Default is TextEmbeddingTranslatorFactory.
"""
self.input_col = input_col
self.output_col = output_col
self.engine = engine
self.model_url = model_url
self.output_class = output_class
self.translator = translator
self.translator_factory = translator_factory

def embed(self, dataset):
"""
Expand All @@ -47,8 +47,8 @@ def embed(self, dataset):
embedder = sc._jvm.ai.djl.spark.task.text.TextEmbedder()
if self.output_class is not None:
embedder = embedder.setOutputClass(self.output_class)
if self.translator is not None:
embedder = embedder.setTranslator(self.translator)
if self.translator_factory is not None:
embedder = embedder.setTranslatorFactory(self.translator_factory)
embedder = embedder.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setEngine(self.engine) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@
from pyspark.sql import DataFrame


class HuggingFaceTextEncoder:
class TextEncoder:

def __init__(self, input_col, output_col, name):
def __init__(self, input_col, output_col, hf_model_id):
"""
Initializes the HuggingFaceTextEncoder.
Initializes the TextEncoder.
:param input_col: The input column
:param output_col: The output column
:param name: The name of the tokenizer
:param hf_model_id: The Huggingface model ID
"""
self.input_col = input_col
self.output_col = output_col
self.name = name
self.hf_model_id = hf_model_id

def encode(self, dataset):
"""
Expand All @@ -37,9 +37,9 @@ def encode(self, dataset):
:return: output dataset
"""
sc = SparkContext._active_spark_context
encoder = sc._jvm.ai.djl.spark.task.text.HuggingFaceTextEncoder() \
encoder = sc._jvm.ai.djl.spark.task.text.TextEncoder() \
.setInputCol(self.input_col) \
.setOutputCol(self.output_col) \
.setName(self.name)
.setHfModelId(self.hf_model_id)
return DataFrame(encoder.encode(dataset._jdf),
dataset.sparkSession)
Loading

0 comments on commit 2ecf3fb

Please sign in to comment.