diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java index 247d2a5e31a8c..0fbee6e433608 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java @@ -33,7 +33,7 @@ import org.apache.spark.ml.tuning.CrossValidator; import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; @@ -71,7 +71,7 @@ public static void main(String[] args) { new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -112,11 +112,11 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). - cvModel.transform(test).registerAsTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + cvModel.transform(test).registerTempTable("prediction"); + DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); for (Row r: predictions.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + ", prediction=" + r.get(3)); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 5b92655e2e838..cbf697be80c24 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -28,7 +28,7 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; @@ -54,7 +54,7 @@ public static void main(String[] args) { new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -94,14 +94,14 @@ public static void main(String[] args) { new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' // column since we renamed the lr.scoreCol parameter previously. - model2.transform(test).registerAsTable("results"); - SchemaRDD results = + model2.transform(test).registerTempTable("results"); + DataFrame results = jsql.sql("SELECT features, label, probability, prediction FROM results"); for (Row r: results.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index 74db449fada7d..82d665a3e1386 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -29,7 +29,7 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; @@ -54,7 +54,7 @@ public static void main(String[] args) { new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -79,11 +79,11 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. - model.transform(test).registerAsTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + model.transform(test).registerTempTable("prediction"); + DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); for (Row r: predictions.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + ", prediction=" + r.get(3)); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index b70804635d5c9..2a589c3a4c01a 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -26,9 +26,9 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; public class JavaSparkSQL { public static class Person implements Serializable { @@ -74,11 +74,11 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - SchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); + DataFrame schemaPeople = sqlCtx.applySchema(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - SchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are SchemaRDDs and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -99,11 +99,11 @@ public String call(Row row) { // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a JavaSchemaRDD. - SchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); + DataFrame parquetFile = sqlCtx.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); - SchemaRDD teenagers2 = + DataFrame teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.toJavaRDD().map(new Function() { @Override @@ -120,7 +120,7 @@ public String call(Row row) { // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; // Create a JavaSchemaRDD from the file(s) pointed by path - SchemaRDD peopleFromJsonFile = sqlCtx.jsonFile(path); + DataFrame peopleFromJsonFile = sqlCtx.jsonFile(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -134,7 +134,7 @@ public String call(Row row) { peopleFromJsonFile.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlCtx. - SchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are JavaSchemaRDDs and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -151,7 +151,7 @@ public String call(Row row) { List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - SchemaRDD peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); // Take a look at the schema of this new JavaSchemaRDD. peopleFromJsonRDD.printSchema(); @@ -164,7 +164,7 @@ public String call(Row row) { peopleFromJsonRDD.registerTempTable("people2"); - SchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + DataFrame peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { @Override public String call(Row row) { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index d8c7ef38ee46d..283bb80f1c788 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -18,7 +18,6 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator @@ -101,7 +100,7 @@ object CrossValidatorExample { // Make predictions on test documents. cvModel uses the best model found (lrModel). cvModel.transform(test) - .select('id, 'text, 'score, 'prediction) + .select("id", "text", "score", "prediction") .collect() .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index e8a2adff929cb..4359ef161911c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -18,7 +18,6 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.param.ParamMap import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -92,7 +91,7 @@ object SimpleParamsExample { // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' // column since we renamed the lr.scoreCol parameter previously. model2.transform(test) - .select('features, 'label, 'probability, 'prediction) + .select("features", "label", "probability", "prediction") .collect() .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index b9a6ef0229def..065db62b0f5ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -20,7 +20,6 @@ package org.apache.spark.examples.ml import scala.beans.BeanInfo import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} @@ -80,7 +79,7 @@ object SimpleTextClassificationPipeline { // Make predictions on test documents. model.transform(test) - .select('id, 'text, 'score, 'prediction) + .select("id", "text", "score", "prediction") .collect() .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index f8d83f4ec7327..dce93ddfb5e3a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -28,10 +28,10 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} +import org.apache.spark.sql.{Row, SQLContext, DataFrame} /** - * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with + * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with * {{{ * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] * }}} @@ -81,18 +81,18 @@ object DatasetExample { println(s"Loaded ${origData.count()} instances from file: ${params.input}") // Convert input data to SchemaRDD explicitly. - val schemaRDD: SchemaRDD = origData + val schemaRDD: DataFrame = origData println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") // Select columns, using implicit conversion to SchemaRDD. - val labelsSchemaRDD: SchemaRDD = origData.select('label) + val labelsSchemaRDD: DataFrame = origData.select("label") val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } val numLabels = labels.count() val meanLabel = labels.fold(0.0)(_ + _) / numLabels println(s"Selected label column with average value $meanLabel") - val featuresSchemaRDD: SchemaRDD = origData.select('features) + val featuresSchemaRDD: DataFrame = origData.select("features") val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), @@ -109,7 +109,7 @@ object DatasetExample { val newDataset = sqlContext.parquetFile(outputDir) println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } + val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 2e98b2dc30b80..a5d7f262581f5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -19,6 +19,8 @@ package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.dsl._ +import org.apache.spark.sql.dsl.literals._ // One method for defining the schema of an RDD is to make a case class with the desired column // names and types. @@ -54,7 +56,7 @@ object RDDRelation { rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) // Queries can also be written using a LINQ-like Scala DSL. - rdd.where('key === 1).orderBy('value.asc).select('key).collect().foreach(println) + rdd.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) // Write out an RDD as a parquet file. rdd.saveAsParquetFile("pair.parquet") @@ -63,7 +65,7 @@ object RDDRelation { val parquetFile = sqlContext.parquetFile("pair.parquet") // Queries can be run using the DSL on parequet files just like the original RDD. - parquetFile.where('key === 1).select('value as 'a).collect().foreach(println) + parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println) // These files can also be registered as tables. parquetFile.registerTempTable("parquetFile") diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 77d230eb4a122..bc3defe968afd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -21,7 +21,7 @@ import scala.annotation.varargs import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame /** * :: AlphaComponent :: @@ -38,7 +38,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * @return fitted model */ @varargs - def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = { + def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = { val map = new ParamMap().put(paramPairs: _*) fit(dataset, map) } @@ -50,7 +50,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * @param paramMap parameter map * @return fitted model */ - def fit(dataset: SchemaRDD, paramMap: ParamMap): M + def fit(dataset: DataFrame, paramMap: ParamMap): M /** * Fits multiple models to the input data with multiple sets of parameters. @@ -61,7 +61,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * @param paramMaps an array of parameter maps * @return fitted models, matching the input parameter maps */ - def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = { + def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala index db563dd550e56..d2ca2e6871e6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame /** * :: AlphaComponent :: @@ -35,5 +35,5 @@ abstract class Evaluator extends Identifiable { * @param paramMap parameter map that specifies the input columns and output metrics * @return metric */ - def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double + def evaluate(dataset: DataFrame, paramMap: ParamMap): Double } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index ad6fed178fae9..fe39cd1bc0bd2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** @@ -88,7 +88,7 @@ class Pipeline extends Estimator[PipelineModel] { * @param paramMap parameter map * @return fitted pipeline */ - override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = { transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap val theStages = map(stages) @@ -162,7 +162,7 @@ class PipelineModel private[ml] ( } } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap val map = (fittingParamMap ++ this.paramMap) ++ paramMap transformSchema(dataset.schema, map, logging = true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index af56f9c435351..b233bff08305c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -22,9 +22,9 @@ import scala.annotation.varargs import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ -import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql._ +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.types._ /** @@ -41,7 +41,7 @@ abstract class Transformer extends PipelineStage with Params { * @return transformed dataset */ @varargs - def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = { + def transform(dataset: DataFrame, paramPairs: ParamPair[_]*): DataFrame = { val map = new ParamMap() paramPairs.foreach(map.put(_)) transform(dataset, map) @@ -53,7 +53,7 @@ abstract class Transformer extends PipelineStage with Params { * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ - def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD + def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame } /** @@ -95,11 +95,10 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O StructType(outputFields) } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr)) - dataset.select(Star(None), udf as map(outputCol)) + dataset.select($"*", callUDF( + this.createTransformFunc(map), outputDataType, Column(map(inputCol))).as(map(outputCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8c570812f8316..eeb6301c3f64a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -24,7 +24,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.storage.StorageLevel @@ -87,11 +87,10 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti def setScoreCol(value: String): this.type = set(scoreCol, value) def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr) + val instances = dataset.select(map(labelCol), map(featuresCol)) .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) }.persist(StorageLevel.MEMORY_AND_DISK) @@ -131,9 +130,8 @@ class LogisticRegressionModel private[ml] ( validateAndTransformSchema(schema, paramMap, fitting = false) } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap val score: Vector => Double = (v) => { val margin = BLAS.dot(v, weights) @@ -143,7 +141,7 @@ class LogisticRegressionModel private[ml] ( val predict: Double => Double = (score) => { if (score > t) 1.0 else 0.0 } - dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol)) - .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol)) + dataset.select($"*", callUDF(score, Column(map(featuresCol))).as(map(scoreCol))) + .select($"*", callUDF(predict, Column(map(scoreCol))).as(map(predictionCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 12473cb2b5719..1979ab9eb6516 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.sql.{Row, SchemaRDD} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** @@ -41,7 +41,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params def setScoreCol(value: String): this.type = set(scoreCol, value) def setLabelCol(value: String): this.type = set(labelCol, value) - override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = { + override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { val map = this.paramMap ++ paramMap val schema = dataset.schema @@ -52,8 +52,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params require(labelType == DoubleType, s"Label column ${map(labelCol)} must be double type but found $labelType") - import dataset.sqlContext._ - val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr) + val scoreAndLabels = dataset.select(map(scoreCol), map(labelCol)) .map { case Row(score: Double, label: Double) => (score, label) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 72825f6e02182..e7bdb070c8193 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.types.{StructField, StructType} @@ -43,14 +43,10 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP def setInputCol(value: String): this.type = set(inputCol, value) def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val input = dataset.select(map(inputCol).attr) - .map { case Row(v: Vector) => - v - } + val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler().fit(input) val model = new StandardScalerModel(this, map, scaler) Params.inheritValues(map, this, model) @@ -83,14 +79,13 @@ class StandardScalerModel private[ml] ( def setInputCol(value: String): this.type = set(inputCol, value) def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap val scale: (Vector) => Vector = (v) => { scaler.transform(v) } - dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol)) + dataset.select($"*", callUDF(scale, Column(map(inputCol))).as(map(outputCol))) } private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 08fe99176424a..5d51c51346665 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** @@ -64,7 +64,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP def setEvaluator(value: Evaluator): this.type = set(evaluator, value) def setNumFolds(value: Int): this.type = set(numFolds, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { val map = this.paramMap ++ paramMap val schema = dataset.schema transformSchema(dataset.schema, paramMap, logging = true) @@ -74,7 +74,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP val epm = map(estimatorParamMaps) val numModels = epm.size val metrics = new Array[Double](epm.size) - val splits = MLUtils.kFold(dataset, map(numFolds), 0) + val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.applySchema(training, schema).cache() val validationDataset = sqlCtx.applySchema(validation, schema).cache() @@ -117,7 +117,7 @@ class CrossValidatorModel private[ml] ( val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { bestModel.transform(dataset, paramMap) } diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 47f1f46c6c260..56a9dbdd58b64 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -37,7 +37,7 @@ public class JavaPipelineSuite { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { @@ -65,7 +65,7 @@ public void pipeline() { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 2eba83335bb58..f4ba23c44563e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -34,7 +34,7 @@ public class JavaLogisticRegressionSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { @@ -55,7 +55,7 @@ public void logisticRegression() { LogisticRegression lr = new LogisticRegression(); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); predictions.collectAsList(); } @@ -67,7 +67,7 @@ public void logisticRegressionWithSetters() { LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold .registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); predictions.collectAsList(); } diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index a9f1c4a2c3ca7..074b58c07df7a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -30,7 +30,7 @@ import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -38,7 +38,7 @@ public class JavaCrossValidatorSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 4515084bc7ae9..2f175fb117941 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -23,7 +23,7 @@ import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame class PipelineSuite extends FunSuite { @@ -36,11 +36,11 @@ class PipelineSuite extends FunSuite { val estimator2 = mock[Estimator[MyModel]] val model2 = mock[MyModel] val transformer3 = mock[Transformer] - val dataset0 = mock[SchemaRDD] - val dataset1 = mock[SchemaRDD] - val dataset2 = mock[SchemaRDD] - val dataset3 = mock[SchemaRDD] - val dataset4 = mock[SchemaRDD] + val dataset0 = mock[DataFrame] + val dataset1 = mock[DataFrame] + val dataset2 = mock[DataFrame] + val dataset3 = mock[DataFrame] + val dataset4 = mock[DataFrame] when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0) when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1) @@ -74,7 +74,7 @@ class PipelineSuite extends FunSuite { val estimator = mock[Estimator[MyModel]] val pipeline = new Pipeline() .setStages(Array(estimator, estimator)) - val dataset = mock[SchemaRDD] + val dataset = mock[DataFrame] intercept[IllegalArgumentException] { pipeline.fit(dataset) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index e8030fef55b1d..1912afce93b18 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -21,12 +21,12 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, DataFrame} class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ - @transient var dataset: SchemaRDD = _ + @transient var dataset: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() @@ -36,34 +36,28 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { } test("logistic regression") { - val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression val model = lr.fit(dataset) model.transform(dataset) - .select('label, 'prediction) + .select("label", "prediction") .collect() } test("logistic regression with setters") { - val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression() .setMaxIter(10) .setRegParam(1.0) val model = lr.fit(dataset) model.transform(dataset, model.threshold -> 0.8) // overwrite threshold - .select('label, 'score, 'prediction) + .select("label", "score", "prediction") .collect() } test("logistic regression fit and transform with varargs") { - val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability") - .select('label, 'probability, 'prediction) + .select("label", "probability", "prediction") .collect() } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 41cc13da4d5b1..74104fa7a681a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -23,11 +23,11 @@ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, DataFrame} class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { - @transient var dataset: SchemaRDD = _ + @transient var dataset: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 8319ef2bc7266..913138f82a600 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -48,6 +48,8 @@ object Literal { object Column { def unapply(col: Column): Option[Expression] = Some(col.expr) + + def apply(colName: String): Column = new Column(colName) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 616bd156cda26..ece1a4531bc41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -84,6 +84,8 @@ class DataFrame( override def columns: Array[String] = schema.fields.map(_.name) + override def printSchema(): Unit = println(schema.treeString) + override def show(): Unit = { ??? } @@ -147,6 +149,11 @@ class DataFrame( Project(exprs.toSeq, logicalPlan) } + @scala.annotation.varargs + override def select(col: String, cols: String*): DataFrame = { + select((col +: cols).map(new Column(_)) :_*) + } + /** Filtering */ override def filter(condition: Column): DataFrame = { Filter(condition.expr, logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala index ff97e0c493910..411848ee786ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.api.java.JavaRDD + import scala.reflect.ClassTag import org.apache.spark.annotation.Experimental @@ -62,6 +64,8 @@ trait DataFrameSpecificApi { */ def schema: StructType + def printSchema(): Unit + ///////////////////////////////////////////////////////////////////////////// // Metadata ///////////////////////////////////////////////////////////////////////////// @@ -89,6 +93,9 @@ trait DataFrameSpecificApi { @scala.annotation.varargs def select(cols: Column*): DataFrame + @scala.annotation.varargs + def select(col: String, cols: String*): DataFrame + /** Filtering */ def apply(condition: Column): DataFrame = filter(condition) @@ -152,6 +159,8 @@ trait DataFrameSpecificApi { def rdd: RDD[Row] + def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD() + def toJSON: RDD[String] def registerTempTable(tableName: String): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala index ceb1d025103ef..043ed21bc156a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala @@ -20,8 +20,11 @@ package org.apache.spark.sql import java.sql.{Timestamp, Date} import scala.language.implicitConversions +import scala.reflect.runtime.universe.{TypeTag, typeTag} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.DataType package object dsl { @@ -56,6 +59,7 @@ package object dsl { def abs(e: Column): Column = Abs(e.expr) + object literals { implicit def booleanToLiteral(b: Boolean): Column = Literal(b) @@ -84,4 +88,393 @@ package object dsl { implicit def binaryToLiteral(a: Array[Byte]): Column = Literal(a) } + + // scalastyle:off + + /* Use the following code to generate: + (0 to 22).map { x => + val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") + val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + println(s""" + /** + * Register a Scala closure of ${x} arguments as user-defined function (UDF). + */ + def callUDF[$typeTags](f: Function$x[$types]${if (args.length > 0) ", " + args else ""}): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq($argsInUdf)) + }""") + } + */ + /** + * Call a Scala function of 0 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag](f: Function0[RT]): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq()) + } + + /** + * Call a Scala function of 1 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT], arg1: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr)) + } + + /** + * Call a Scala function of 2 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT], arg1: Column, arg2: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr)) + } + + /** + * Call a Scala function of 3 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT], arg1: Column, arg2: Column, arg3: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr)) + } + + /** + * Call a Scala function of 4 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + } + + /** + * Call a Scala function of 5 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + } + + /** + * Call a Scala function of 6 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + } + + /** + * Call a Scala function of 7 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + } + + /** + * Call a Scala function of 8 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + } + + /** + * Call a Scala function of 9 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + } + + /** + * Call a Scala function of 10 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + } + + /** + * Call a Scala function of 11 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](f: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr)) + } + + /** + * Call a Scala function of 12 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](f: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr)) + } + + /** + * Call a Scala function of 13 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](f: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr)) + } + + /** + * Call a Scala function of 14 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](f: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr)) + } + + /** + * Call a Scala function of 15 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](f: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr)) + } + + /** + * Call a Scala function of 16 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](f: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr)) + } + + /** + * Call a Scala function of 17 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](f: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr)) + } + + /** + * Call a Scala function of 18 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](f: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr)) + } + + /** + * Call a Scala function of 19 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](f: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr)) + } + + /** + * Call a Scala function of 20 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](f: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr)) + } + + /** + * Call a Scala function of 21 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](f: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr)) + } + + /** + * Call a Scala function of 22 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](f: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr)) + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Call a Scala function of 0 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function0[_], returnType: DataType): Column = { + ScalaUdf(f, returnType, Seq()) + } + + /** + * Call a Scala function of 1 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr)) + } + + /** + * Call a Scala function of 2 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) + } + + /** + * Call a Scala function of 3 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + } + + /** + * Call a Scala function of 4 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + } + + /** + * Call a Scala function of 5 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + } + + /** + * Call a Scala function of 6 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + } + + /** + * Call a Scala function of 7 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + } + + /** + * Call a Scala function of 8 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + } + + /** + * Call a Scala function of 9 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + } + + /** + * Call a Scala function of 10 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + } + + /** + * Call a Scala function of 11 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr)) + } + + /** + * Call a Scala function of 12 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr)) + } + + /** + * Call a Scala function of 13 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr)) + } + + /** + * Call a Scala function of 14 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr)) + } + + /** + * Call a Scala function of 15 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr)) + } + + /** + * Call a Scala function of 16 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr)) + } + + /** + * Call a Scala function of 17 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr)) + } + + /** + * Call a Scala function of 18 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr)) + } + + /** + * Call a Scala function of 19 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr)) + } + + /** + * Call a Scala function of 20 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr)) + } + + /** + * Call a Scala function of 21 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr)) + } + + /** + * Call a Scala function of 22 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr)) + } + + // scalastyle:on } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index ad4223fd2de00..f03b3a32e34e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -102,7 +102,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { sparkContext .parallelize(0 to 1000) .map(i => Tuple1(i / 100.0)) - .select($"_1" cast decimal) + .select($"_1" cast decimal as "abcd") for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { withTempPath { dir =>