From aa61c10c85cd42a7feff35b7be93d0d2a889df7c Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Fri, 12 May 2023 16:15:41 -0700 Subject: [PATCH] [spark] Update README (#2596) --- extensions/spark/README.md | 34 +++++++++++++++------------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/extensions/spark/README.md b/extensions/spark/README.md index 717a0dc8abc..91661004848 100644 --- a/extensions/spark/README.md +++ b/extensions/spark/README.md @@ -45,32 +45,28 @@ Using the DJL Spark Extension is simple and straightforward. Here is an example ### Scala ```scala -import ai.djl.spark.SparkTransformer -import ai.djl.spark.translator.SparkImageClassificationTranslator +import ai.djl.spark.task.vision.ImageClassifier -val transformer = new SparkTransformer[Classifications]() - .setInputCols(Array("input_col1", "input_col2")) - .setOutputCols(Array("value")) +val classifier = new ImageClassifier() + .setInputCols(Array("origin", "height", "width", "nChannels", "mode", "data")) + .setOutputCol("prediction") .setEngine("PyTorch") - .setModelUrl("model_url") - .setOutputClass(classOf[Classifications]) - .setTranslator(new SparkImageClassificationTranslator()) -val outputDf = transformer.transform(df) + .setModelUrl("djl://ai.djl.pytorch/resnet") + .setTopK(2) +var outputDf = classifier.classify(df) ``` ### Python ```python -from djl_spark.transformer import SparkTransformer -from djl_spark.translator import SparkImageClassificationTranslator - -transformer = SparkTransformer(input_cols=["input_col1", "input_col2"], - output_cols=["value"], - engine="PyTorch", - model_url="model_url", - output_class="ai.djl.modality.Classifications", - translator=SparkImageClassificationTranslator()) -outputDf = transformer.transform(df) +from djl_spark.task.vision import ImageClassifier + +classifier = ImageClassifier(input_cols=["origin", "height", "width", "nChannels", "mode", "data"], + output_col="prediction", + engine="PyTorch", + model_url="djl://ai.djl.pytorch/resnet", + top_k=2) +outputDf = classifier.classify(df) ``` See [examples](https://github.com/deepjavalibrary/djl-demo/tree/master/apache-spark/spark3.0) for more details.