From a052ed42501fee3641348337505b6176426653c4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 8 Feb 2015 18:56:51 -0800 Subject: [PATCH 01/24] [SPARK-5643][SQL] Add a show method to print the content of a DataFrame in tabular format. An example: ``` year month AVG('Adj Close) MAX('Adj Close) 1980 12 0.503218 0.595103 1981 01 0.523289 0.570307 1982 02 0.436504 0.475256 1983 03 0.410516 0.442194 1984 04 0.450090 0.483521 ``` Author: Reynold Xin Closes #4416 from rxin/SPARK-5643 and squashes the following commits: d0e0d6e [Reynold Xin] [SQL] Minor update to data source and statistics documentation. 269da83 [Reynold Xin] Updated isLocal comment. 2cf3c27 [Reynold Xin] Moved logic into optimizer. 1a04d8b [Reynold Xin] [SPARK-5643][SQL] Add a show method to print the content of a DataFrame in columnar format. --- .../sql/catalyst/optimizer/Optimizer.scala | 18 +++++- .../catalyst/plans/logical/LogicalPlan.scala | 7 ++- .../ConvertToLocalRelationSuite.scala | 57 +++++++++++++++++++ .../org/apache/spark/sql/DataFrame.scala | 21 ++++++- .../org/apache/spark/sql/DataFrameImpl.scala | 41 +++++++++++-- .../apache/spark/sql/IncomputableColumn.scala | 6 +- .../spark/sql/execution/basicOperators.scala | 7 +-- .../apache/spark/sql/sources/interfaces.scala | 15 +++-- 8 files changed, 151 insertions(+), 21 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8c8f2896eb99b..3bc48c95c5653 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -50,7 +50,9 @@ object DefaultOptimizer extends Optimizer { CombineFilters, PushPredicateThroughProject, PushPredicateThroughJoin, - ColumnPruning) :: Nil + ColumnPruning) :: + Batch("LocalRelation", FixedPoint(100), + ConvertToLocalRelation) :: Nil } /** @@ -610,3 +612,17 @@ object DecimalAggregates extends Rule[LogicalPlan] { DecimalType(prec + 4, scale + 4)) } } + +/** + * Converts local operations (i.e. ones that don't require data exchange) on LocalRelation to + * another LocalRelation. + * + * This is relatively simple as it currently handles only a single case: Project. + */ +object ConvertToLocalRelation extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Project(projectList, LocalRelation(output, data)) => + val projection = new InterpretedProjection(projectList, output) + LocalRelation(projectList.map(_.toAttribute), data.map(projection)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8d30528328946..7cf4b81274906 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -29,12 +29,15 @@ import org.apache.spark.sql.catalyst.trees /** * Estimates of various statistics. The default estimation logic simply lazily multiplies the * corresponding statistic produced by the children. To override this behavior, override - * `statistics` and assign it an overriden version of `Statistics`. + * `statistics` and assign it an overridden version of `Statistics`. * - * '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the + * '''NOTE''': concrete and/or overridden versions of statistics fields should pay attention to the * performance of the implementations. The reason is that estimations might get triggered in * performance-critical processes, such as query plan planning. * + * Note that we are using a BigInt here since it is easy to overflow a 64-bit integer in + * cardinality estimation (e.g. cartesian joins). + * * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it * defaults to the product of children's `sizeInBytes`. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala new file mode 100644 index 0000000000000..cf42d43823399 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class ConvertToLocalRelationSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("LocalRelation", FixedPoint(100), + ConvertToLocalRelation) :: Nil + } + + test("Project on LocalRelation should be turned into a single LocalRelation") { + val testRelation = LocalRelation( + LocalRelation('a.int, 'b.int).output, + Row(1, 2) :: + Row(4, 5) :: Nil) + + val correctAnswer = LocalRelation( + LocalRelation('a1.int, 'b1.int).output, + Row(1, 3) :: + Row(4, 6) :: Nil) + + val projectOnLocal = testRelation.select( + UnresolvedAttribute("a").as("a1"), + (UnresolvedAttribute("b") + 1).as("b1")) + + val optimized = Optimize(projectOnLocal.analyze) + + comparePlans(optimized, correctAnswer) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 8ad6526f872e5..17ea3cde8e50e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -102,7 +102,7 @@ trait DataFrame extends RDDApi[Row] { * }}} */ @scala.annotation.varargs - def toDataFrame(colName: String, colNames: String*): DataFrame + def toDataFrame(colNames: String*): DataFrame /** Returns the schema of this [[DataFrame]]. */ def schema: StructType @@ -116,6 +116,25 @@ trait DataFrame extends RDDApi[Row] { /** Prints the schema to the console in a nice tree format. */ def printSchema(): Unit + /** + * Returns true if the `collect` and `take` methods can be run locally + * (without any Spark executors). + */ + def isLocal: Boolean + + /** + * Displays the [[DataFrame]] in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + */ + def show(): Unit + /** * Cartesian join with another [[DataFrame]]. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index 789bcf6184b3e..fa05a5dcac6bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -90,14 +90,13 @@ private[sql] class DataFrameImpl protected[sql]( } } - override def toDataFrame(colName: String, colNames: String*): DataFrame = { - val newNames = colName +: colNames - require(schema.size == newNames.size, + override def toDataFrame(colNames: String*): DataFrame = { + require(schema.size == colNames.size, "The number of columns doesn't match.\n" + "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" + - "New column names: " + newNames.mkString(", ")) + "New column names: " + colNames.mkString(", ")) - val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) => + val newCols = schema.fieldNames.zip(colNames).map { case (oldName, newName) => apply(oldName).as(newName) } select(newCols :_*) @@ -113,6 +112,38 @@ private[sql] class DataFrameImpl protected[sql]( override def printSchema(): Unit = println(schema.treeString) + override def isLocal: Boolean = { + logicalPlan.isInstanceOf[LocalRelation] + } + + override def show(): Unit = { + val data = take(20) + val numCols = schema.fieldNames.length + + // For cells that are beyond 20 characters, replace it with the first 17 and "..." + val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row => + row.toSeq.map { cell => + val str = if (cell == null) "null" else cell.toString + if (str.length > 20) str.substring(0, 17) + "..." else str + } : Seq[String] + } + + // Compute the width of each column + val colWidths = Array.fill(numCols)(0) + for (row <- rows) { + for ((cell, i) <- row.zipWithIndex) { + colWidths(i) = math.max(colWidths(i), cell.length) + } + } + + // Pad the cells and print them + println(rows.map { row => + row.zipWithIndex.map { case (cell, i) => + String.format(s"%-${colWidths(i)}s", cell) + }.mkString(" ") + }.mkString("\n")) + } + override def join(right: DataFrame): DataFrame = { Join(logicalPlan, right.logicalPlan, joinType = Inner, None) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala index 6043fb4dee01d..782f6e28eebb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala @@ -48,7 +48,7 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten protected[sql] override def logicalPlan: LogicalPlan = err() - override def toDataFrame(colName: String, colNames: String*): DataFrame = err() + override def toDataFrame(colNames: String*): DataFrame = err() override def schema: StructType = err() @@ -58,6 +58,10 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten override def printSchema(): Unit = err() + override def show(): Unit = err() + + override def isLocal: Boolean = false + override def join(right: DataFrame): DataFrame = err() override def join(right: DataFrame, joinExprs: Column): DataFrame = err() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 66aed5d5113d1..4dc506c21ab9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.execution -import scala.collection.mutable.ArrayBuffer -import scala.reflect.runtime.universe.TypeTag - import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} @@ -40,7 +37,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends @transient lazy val buildProjection = newMutableProjection(projectList, child.output) - def execute() = child.execute().mapPartitions { iter => + override def execute() = child.execute().mapPartitions { iter => val resuableProjection = buildProjection() iter.map(resuableProjection) } @@ -55,7 +52,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { @transient lazy val conditionEvaluator = newPredicate(condition, child.output) - def execute() = child.execute().mapPartitions { iter => + override def execute() = child.execute().mapPartitions { iter => iter.filter(conditionEvaluator) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index a640ba57e0885..5eecc303ef72b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -87,13 +87,13 @@ trait CreatableRelationProvider { /** * ::DeveloperApi:: - * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must - * be able to produce the schema of their data in the form of a [[StructType]] Concrete + * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must + * be able to produce the schema of their data in the form of a [[StructType]]. Concrete * implementation should inherit from one of the descendant `Scan` classes, which define various * abstract methods for execution. * * BaseRelations must also define a equality function that only returns true when the two - * instances will return the same data. This equality function is used when determining when + * instances will return the same data. This equality function is used when determining when * it is safe to substitute cached results for a given relation. */ @DeveloperApi @@ -102,13 +102,16 @@ abstract class BaseRelation { def schema: StructType /** - * Returns an estimated size of this relation in bytes. This information is used by the planner + * Returns an estimated size of this relation in bytes. This information is used by the planner * to decided when it is safe to broadcast a relation and can be overridden by sources that * know the size ahead of time. By default, the system will assume that tables are too - * large to broadcast. This method will be called multiple times during query planning + * large to broadcast. This method will be called multiple times during query planning * and thus should not perform expensive operations for each invocation. + * + * Note that it is always better to overestimate size than underestimate, because underestimation + * could lead to execution plans that are suboptimal (i.e. broadcasting a very large table). */ - def sizeInBytes = sqlContext.conf.defaultSizeInBytes + def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes } /** From c17161189d57f2e3a8d3550ea59a68edf487c8b7 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 8 Feb 2015 21:07:36 -0800 Subject: [PATCH 02/24] [SPARK-5660][MLLIB] Make Matrix apply public This is #4447 with `override`. Closes #4447 Author: Joseph K. Bradley Author: Xiangrui Meng Closes #4462 from mengxr/SPARK-5660 and squashes the following commits: f82c8d6 [Xiangrui Meng] add override to matrix.apply 91cedde [Joseph K. Bradley] made matrix apply public --- .../main/scala/org/apache/spark/mllib/linalg/Matrices.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 84f8ac2e0d9cd..c8a97b8c53d9b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -50,7 +50,7 @@ sealed trait Matrix extends Serializable { private[mllib] def toBreeze: BM[Double] /** Gets the (i, j)-th element. */ - private[mllib] def apply(i: Int, j: Int): Double + def apply(i: Int, j: Int): Double /** Return the index for the (i, j)-th element in the backing array. */ private[mllib] def index(i: Int, j: Int): Int @@ -163,7 +163,7 @@ class DenseMatrix( private[mllib] def apply(i: Int): Double = values(i) - private[mllib] def apply(i: Int, j: Int): Double = values(index(i, j)) + override def apply(i: Int, j: Int): Double = values(index(i, j)) private[mllib] def index(i: Int, j: Int): Int = { if (!isTransposed) i + numRows * j else j + numCols * i @@ -398,7 +398,7 @@ class SparseMatrix( } } - private[mllib] def apply(i: Int, j: Int): Double = { + override def apply(i: Int, j: Int): Double = { val ind = index(i, j) if (ind < 0) 0.0 else values(ind) } From 4396dfb37f433ef186e3e0a09db9906986ec940b Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 8 Feb 2015 21:08:50 -0800 Subject: [PATCH 03/24] SPARK-4405 [MLLIB] Matrices.* construction methods should check for rows x cols overflow Check that size of dense matrix array is not beyond Int.MaxValue in Matrices.* methods. jkbradley this should be an easy one. Review and/or merge as you see fit. Author: Sean Owen Closes #4461 from srowen/SPARK-4405 and squashes the following commits: c67574e [Sean Owen] Check that size of dense matrix array is not beyond Int.MaxValue in Matrices.* methods --- .../org/apache/spark/mllib/linalg/Matrices.scala | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index c8a97b8c53d9b..89b38679b7494 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -256,8 +256,11 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros */ - def zeros(numRows: Int, numCols: Int): DenseMatrix = + def zeros(numRows: Int, numCols: Int): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, new Array[Double](numRows * numCols)) + } /** * Generate a `DenseMatrix` consisting of ones. @@ -265,8 +268,11 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones */ - def ones(numRows: Int, numCols: Int): DenseMatrix = + def ones(numRows: Int, numCols: Int): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(1.0)) + } /** * Generate an Identity Matrix in `DenseMatrix` format. @@ -291,6 +297,8 @@ object DenseMatrix { * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble())) } @@ -302,6 +310,8 @@ object DenseMatrix { * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian())) } From 4575c5643a82818bf64f9648314bdc2fdc12febb Mon Sep 17 00:00:00 2001 From: Hung Lin Date: Sun, 8 Feb 2015 22:36:42 -0800 Subject: [PATCH 04/24] [SPARK-5472][SQL] Fix Scala code style Fix Scala code style. Author: Hung Lin Closes #4464 from hunglin/SPARK-5472 and squashes the following commits: ef7a3b3 [Hung Lin] SPARK-5472: fix scala style --- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 42 +++++++++---------- .../apache/spark/sql/jdbc/JDBCRelation.scala | 35 +++++++++------- 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index a2f94675fb5a3..0bec32cca1325 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -17,13 +17,10 @@ package org.apache.spark.sql.jdbc -import java.sql.{Connection, DatabaseMetaData, DriverManager, ResultSet, ResultSetMetaData, SQLException} -import scala.collection.mutable.ArrayBuffer +import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.util.NextIterator -import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow} import org.apache.spark.sql.types._ import org.apache.spark.sql.sources._ @@ -100,7 +97,7 @@ private[sql] object JDBCRDD extends Logging { try { val rsmd = rs.getMetaData val ncols = rsmd.getColumnCount - var fields = new Array[StructField](ncols); + val fields = new Array[StructField](ncols) var i = 0 while (i < ncols) { val columnName = rsmd.getColumnName(i + 1) @@ -176,23 +173,27 @@ private[sql] object JDBCRDD extends Logging { * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ - def scanTable(sc: SparkContext, - schema: StructType, - driver: String, - url: String, - fqTable: String, - requiredColumns: Array[String], - filters: Array[Filter], - parts: Array[Partition]): RDD[Row] = { + def scanTable( + sc: SparkContext, + schema: StructType, + driver: String, + url: String, + fqTable: String, + requiredColumns: Array[String], + filters: Array[Filter], + parts: Array[Partition]): RDD[Row] = { + val prunedSchema = pruneSchema(schema, requiredColumns) - return new JDBCRDD(sc, - getConnector(driver, url), - prunedSchema, - fqTable, - requiredColumns, - filters, - parts) + return new + JDBCRDD( + sc, + getConnector(driver, url), + prunedSchema, + fqTable, + requiredColumns, + filters, + parts) } } @@ -412,6 +413,5 @@ private[sql] class JDBCRDD( gotNext = false nextValue } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index e09125e406ba2..66ad38eb7c45b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -96,7 +96,8 @@ private[sql] class DefaultSource extends RelationProvider { if (driver != null) Class.forName(driver) - if ( partitionColumn != null + if ( + partitionColumn != null && (lowerBound == null || upperBound == null || numPartitions == null)) { sys.error("Partitioning incompletely specified") } @@ -104,30 +105,34 @@ private[sql] class DefaultSource extends RelationProvider { val partitionInfo = if (partitionColumn == null) { null } else { - JDBCPartitioningInfo(partitionColumn, - lowerBound.toLong, upperBound.toLong, - numPartitions.toInt) + JDBCPartitioningInfo( + partitionColumn, + lowerBound.toLong, + upperBound.toLong, + numPartitions.toInt) } val parts = JDBCRelation.columnPartition(partitionInfo) JDBCRelation(url, table, parts)(sqlContext) } } -private[sql] case class JDBCRelation(url: String, - table: String, - parts: Array[Partition])( - @transient val sqlContext: SQLContext) - extends PrunedFilteredScan { +private[sql] case class JDBCRelation( + url: String, + table: String, + parts: Array[Partition])(@transient val sqlContext: SQLContext) extends PrunedFilteredScan { override val schema = JDBCRDD.resolveTable(url, table) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName - JDBCRDD.scanTable(sqlContext.sparkContext, - schema, - driver, url, - table, - requiredColumns, filters, - parts) + JDBCRDD.scanTable( + sqlContext.sparkContext, + schema, + driver, + url, + table, + requiredColumns, + filters, + parts) } } From 855d12ac0a9cdade4cd2cc64c4e7209478be6690 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 8 Feb 2015 23:40:36 -0800 Subject: [PATCH 05/24] [SPARK-5539][MLLIB] LDA guide This is the LDA user guide from jkbradley with Java and Scala code example. Author: Xiangrui Meng Author: Joseph K. Bradley Closes #4465 from mengxr/lda-guide and squashes the following commits: 6dcb7d1 [Xiangrui Meng] update java example in the user guide 76169ff [Xiangrui Meng] update java example 36c3ae2 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into lda-guide c2a1efe [Joseph K. Bradley] Added LDA programming guide, plus Java example (which is in the guide and probably should be removed). --- data/mllib/sample_lda_data.txt | 12 ++ docs/mllib-clustering.md | 129 +++++++++++++++++- .../spark/examples/mllib/JavaLDAExample.java | 75 ++++++++++ 3 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 data/mllib/sample_lda_data.txt create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java diff --git a/data/mllib/sample_lda_data.txt b/data/mllib/sample_lda_data.txt new file mode 100644 index 0000000000000..2e76702ca9d67 --- /dev/null +++ b/data/mllib/sample_lda_data.txt @@ -0,0 +1,12 @@ +1 2 6 0 2 3 1 1 0 0 3 +1 3 0 1 3 0 0 2 0 0 1 +1 4 1 0 0 4 9 0 1 2 0 +2 1 0 3 0 0 5 0 2 3 9 +3 1 1 9 3 0 2 0 0 1 3 +4 2 0 3 4 5 1 1 1 4 0 +2 1 0 3 0 0 5 0 2 2 9 +1 1 1 9 2 1 2 0 0 1 3 +4 4 0 3 4 2 1 3 0 0 0 +2 8 2 0 3 0 2 0 2 7 2 +1 1 1 9 0 2 2 0 0 3 3 +4 1 0 0 4 5 1 3 0 1 0 diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 1e9ef345b7435..99ed6b60e3f00 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -55,7 +55,7 @@ has the following parameters: Power iteration clustering is a scalable and efficient algorithm for clustering points given pointwise mutual affinity values. Internally the algorithm: -* accepts a [Graph](https://spark.apache.org/docs/0.9.2/api/graphx/index.html#org.apache.spark.graphx.Graph) that represents a normalized pairwise affinity between all input points. +* accepts a [Graph](api/graphx/index.html#org.apache.spark.graphx.Graph) that represents a normalized pairwise affinity between all input points. * calculates the principal eigenvalue and eigenvector * Clusters each of the input points according to their principal eigenvector component value @@ -71,6 +71,35 @@ Example outputs for a dataset inspired by the paper - but with five clusters ins

+### Latent Dirichlet Allocation (LDA) + +[Latent Dirichlet Allocation (LDA)](http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation) +is a topic model which infers topics from a collection of text documents. +LDA can be thought of as a clustering algorithm as follows: + +* Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset. +* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts. +* Rather than estimating a clustering using a traditional distance, LDA uses a function based + on a statistical model of how text documents are generated. + +LDA takes in a collection of documents as vectors of word counts. +It learns clustering using [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) +on the likelihood function. After fitting on the documents, LDA provides: + +* Topics: Inferred topics, each of which is a probability distribution over terms (words). +* Topic distributions for documents: For each document in the training set, LDA gives a probability distribution over topics. + +LDA takes the following parameters: + +* `k`: Number of topics (i.e., cluster centers) +* `maxIterations`: Limit on the number of iterations of EM used for learning +* `docConcentration`: Hyperparameter for prior over documents' distributions over topics. Currently must be > 1, where larger values encourage smoother inferred distributions. +* `topicConcentration`: Hyperparameter for prior over topics' distributions over terms (words). Currently must be > 1, where larger values encourage smoother inferred distributions. +* `checkpointInterval`: If using checkpointing (set in the Spark configuration), this parameter specifies the frequency with which checkpoints will be created. If `maxIterations` is large, using checkpointing can help reduce shuffle file sizes on disk and help with failure recovery. + +*Note*: LDA is a new feature with some missing functionality. In particular, it does not yet +support prediction on new documents, and it does not have a Python API. These will be added in the future. + ### Examples #### k-means @@ -293,6 +322,104 @@ for i in range(2): +#### Latent Dirichlet Allocation (LDA) Example + +In the following example, we load word count vectors representing a corpus of documents. +We then use [LDA](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) +to infer three topics from the documents. The number of desired clusters is passed +to the algorithm. We then output the topics, represented as probability distributions over words. + +
+
+ +{% highlight scala %} +import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.linalg.Vectors + +// Load and parse the data +val data = sc.textFile("data/mllib/sample_lda_data.txt") +val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))) +// Index documents with unique IDs +val corpus = parsedData.zipWithIndex.map(_.swap).cache() + +// Cluster the documents into three topics using LDA +val ldaModel = new LDA().setK(3).run(corpus) + +// Output topics. Each is a distribution over words (matching word count vectors) +println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize + " words):") +val topics = ldaModel.topicsMatrix +for (topic <- Range(0, 3)) { + print("Topic " + topic + ":") + for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } + println() +} +{% endhighlight %} +
+ +
+{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.DistributedLDAModel; +import org.apache.spark.mllib.clustering.LDA; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SparkConf; + +public class JavaLDAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("LDA Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_lda_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) + values[i] = Double.parseDouble(sarray[i]); + return Vectors.dense(values); + } + } + ); + // Index documents with unique IDs + JavaPairRDD corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 doc_id) { + return doc_id.swap(); + } + } + )); + corpus.cache(); + + // Cluster the documents into three topics using LDA + DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus); + + // Output topics. Each is a distribution over words (matching word count vectors) + System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() + + " words):"); + Matrix topics = ldaModel.topicsMatrix(); + for (int topic = 0; topic < 3; topic++) { + System.out.print("Topic " + topic + ":"); + for (int word = 0; word < ldaModel.vocabSize(); word++) { + System.out.print(" " + topics.apply(word, topic)); + } + System.out.println(); + } + } +} +{% endhighlight %} +
+ +
+ + In order to run the above application, follow the instructions provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) section of the Spark diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java new file mode 100644 index 0000000000000..f394ff2084463 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.DistributedLDAModel; +import org.apache.spark.mllib.clustering.LDA; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SparkConf; + +public class JavaLDAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("LDA Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_lda_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) + values[i] = Double.parseDouble(sarray[i]); + return Vectors.dense(values); + } + } + ); + // Index documents with unique IDs + JavaPairRDD corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 doc_id) { + return doc_id.swap(); + } + } + )); + corpus.cache(); + + // Cluster the documents into three topics using LDA + DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus); + + // Output topics. Each is a distribution over words (matching word count vectors) + System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() + + " words):"); + Matrix topics = ldaModel.topicsMatrix(); + for (int topic = 0; topic < 3; topic++) { + System.out.print("Topic " + topic + ":"); + for (int word = 0; word < ldaModel.vocabSize(); word++) { + System.out.print(" " + topics.apply(word, topic)); + } + System.out.println(); + } + } +} From 4dfe180fc893bee1146161f8b2a6efd4d6d2bb8c Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Mon, 9 Feb 2015 09:44:53 +0000 Subject: [PATCH 06/24] [SPARK-5473] [EC2] Expose SSH failures after status checks pass If there is some fatal problem with launching a cluster, `spark-ec2` just hangs without giving the user useful feedback on what the problem is. This PR exposes the output of the SSH calls to the user if the SSH test fails during cluster launch for any reason but the instance status checks are all green. It also removes the growing trail of dots while waiting in favor of a fixed 3 dots. For example: ``` $ ./ec2/spark-ec2 -k key -i /incorrect/path/identity.pem --instance-type m3.medium --slaves 1 --zone us-east-1c launch "spark-test" Setting up security groups... Searching for existing cluster spark-test... Spark AMI: ami-35b1885c Launching instances... Launched 1 slaves in us-east-1c, regid = r-7dadd096 Launched master in us-east-1c, regid = r-fcadd017 Waiting for cluster to enter 'ssh-ready' state... Warning: SSH connection error. (This could be temporary.) Host: 127.0.0.1 SSH return code: 255 SSH output: Warning: Identity file /incorrect/path/identity.pem not accessible: No such file or directory. Warning: Permanently added '127.0.0.1' (RSA) to the list of known hosts. Permission denied (publickey). ``` This should give users enough information when some unrecoverable error occurs during launch so they can know to abort the launch. This will help avoid situations like the ones reported [here on Stack Overflow](http://stackoverflow.com/q/28002443/) and [here on the user list](http://mail-archives.apache.org/mod_mbox/spark-user/201501.mbox/%3C1422323829398-21381.postn3.nabble.com%3E), where the users couldn't tell what the problem was because it was being hidden by `spark-ec2`. This is a usability improvement that should be backported to 1.2. Resolves [SPARK-5473](https://issues.apache.org/jira/browse/SPARK-5473). Author: Nicholas Chammas Closes #4262 from nchammas/expose-ssh-failure and squashes the following commits: 8bda6ed [Nicholas Chammas] default to print SSH output 2b92534 [Nicholas Chammas] show SSH output after status check pass --- ec2/spark_ec2.py | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 725b1e47e0cea..87b2112fe4628 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -34,6 +34,7 @@ import sys import tarfile import tempfile +import textwrap import time import urllib2 import warnings @@ -681,21 +682,32 @@ def setup_spark_cluster(master, opts): print "Ganglia started at http://%s:5080/ganglia" % master -def is_ssh_available(host, opts): +def is_ssh_available(host, opts, print_ssh_output=True): """ Check if SSH is available on a host. """ - try: - with open(os.devnull, 'w') as devnull: - ret = subprocess.check_call( - ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', - '%s@%s' % (opts.user, host), stringify_command('true')], - stdout=devnull, - stderr=devnull - ) - return ret == 0 - except subprocess.CalledProcessError as e: - return False + s = subprocess.Popen( + ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', + '%s@%s' % (opts.user, host), stringify_command('true')], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT # we pipe stderr through stdout to preserve output order + ) + cmd_output = s.communicate()[0] # [1] is stderr, which we redirected to stdout + + if s.returncode != 0 and print_ssh_output: + # extra leading newline is for spacing in wait_for_cluster_state() + print textwrap.dedent("""\n + Warning: SSH connection error. (This could be temporary.) + Host: {h} + SSH return code: {r} + SSH output: {o} + """).format( + h=host, + r=s.returncode, + o=cmd_output.strip() + ) + + return s.returncode == 0 def is_cluster_ssh_available(cluster_instances, opts): From 0793ee1b4dea1f4b0df749e8ad7c1ab70b512faf Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Mon, 9 Feb 2015 10:12:12 +0000 Subject: [PATCH 07/24] SPARK-2149. [MLLIB] Univariate kernel density estimation Author: Sandy Ryza Closes #1093 from sryza/sandy-spark-2149 and squashes the following commits: 5f06b33 [Sandy Ryza] More review comments 0f73060 [Sandy Ryza] Respond to Sean's review comments 0dfa005 [Sandy Ryza] SPARK-2149. Univariate kernel density estimation --- .../spark/mllib/stat/KernelDensity.scala | 71 +++++++++++++++++++ .../apache/spark/mllib/stat/Statistics.scala | 14 ++++ .../spark/mllib/stat/KernelDensitySuite.scala | 47 ++++++++++++ 3 files changed, 132 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala new file mode 100644 index 0000000000000..0deef11b4511a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.spark.rdd.RDD + +private[stat] object KernelDensity { + /** + * Given a set of samples from a distribution, estimates its density at the set of given points. + * Uses a Gaussian kernel with the given standard deviation. + */ + def estimate(samples: RDD[Double], standardDeviation: Double, + evaluationPoints: Array[Double]): Array[Double] = { + if (standardDeviation <= 0.0) { + throw new IllegalArgumentException("Standard deviation must be positive") + } + + // This gets used in each Gaussian PDF computation, so compute it up front + val logStandardDeviationPlusHalfLog2Pi = + Math.log(standardDeviation) + 0.5 * Math.log(2 * Math.PI) + + val (points, count) = samples.aggregate((new Array[Double](evaluationPoints.length), 0))( + (x, y) => { + var i = 0 + while (i < evaluationPoints.length) { + x._1(i) += normPdf(y, standardDeviation, logStandardDeviationPlusHalfLog2Pi, + evaluationPoints(i)) + i += 1 + } + (x._1, i) + }, + (x, y) => { + var i = 0 + while (i < evaluationPoints.length) { + x._1(i) += y._1(i) + i += 1 + } + (x._1, x._2 + y._2) + }) + + var i = 0 + while (i < points.length) { + points(i) /= count + i += 1 + } + points + } + + private def normPdf(mean: Double, standardDeviation: Double, + logStandardDeviationPlusHalfLog2Pi: Double, x: Double): Double = { + val x0 = x - mean + val x1 = x0 / standardDeviation + val logDensity = -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi + Math.exp(logDensity) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index b3fad0c52d655..32561620ac914 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -149,4 +149,18 @@ object Statistics { def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { ChiSqTest.chiSquaredFeatures(data) } + + /** + * Given an empirical distribution defined by the input RDD of samples, estimate its density at + * each of the given evaluation points using a Gaussian kernel. + * + * @param samples The samples RDD used to define the empirical distribution. + * @param standardDeviation The standard deviation of the kernel Gaussians. + * @param evaluationPoints The points at which to estimate densities. + * @return An array the same size as evaluationPoints with the density at each point. + */ + def kernelDensity(samples: RDD[Double], standardDeviation: Double, + evaluationPoints: Iterable[Double]): Array[Double] = { + KernelDensity.estimate(samples, standardDeviation, evaluationPoints.toArray) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala new file mode 100644 index 0000000000000..f6a1e19f50296 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.scalatest.FunSuite + +import org.apache.commons.math3.distribution.NormalDistribution + +import org.apache.spark.mllib.util.LocalClusterSparkContext + +class KernelDensitySuite extends FunSuite with LocalClusterSparkContext { + test("kernel density single sample") { + val rdd = sc.parallelize(Array(5.0)) + val evaluationPoints = Array(5.0, 6.0) + val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints) + val normal = new NormalDistribution(5.0, 3.0) + val acceptableErr = 1e-6 + assert(densities(0) - normal.density(5.0) < acceptableErr) + assert(densities(0) - normal.density(6.0) < acceptableErr) + } + + test("kernel density multiple samples") { + val rdd = sc.parallelize(Array(5.0, 10.0)) + val evaluationPoints = Array(5.0, 6.0) + val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints) + val normal1 = new NormalDistribution(5.0, 3.0) + val normal2 = new NormalDistribution(10.0, 3.0) + val acceptableErr = 1e-6 + assert(densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2 < acceptableErr) + assert(densities(0) - (normal1.density(6.0) + normal2.density(6.0)) / 2 < acceptableErr) + } +} From de7806048ac49a8bfdf44d8f87bc11cea1dfb242 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 9 Feb 2015 10:33:57 -0800 Subject: [PATCH 08/24] SPARK-4267 [YARN] Failing to launch jobs on Spark on YARN with Hadoop 2.5.0 or later Before passing to YARN, escape arguments in "extraJavaOptions" args, in order to correctly handle cases like -Dfoo="one two three". Also standardize how these args are handled and ensure that individual args are treated as stand-alone args, not one string. vanzin andrewor14 Author: Sean Owen Closes #4452 from srowen/SPARK-4267.2 and squashes the following commits: c8297d2 [Sean Owen] Before passing to YARN, escape arguments in "extraJavaOptions" args, in order to correctly handle cases like -Dfoo="one two three". Also standardize how these args are handled and ensure that individual args are treated as stand-alone args, not one string. --- .../org/apache/spark/deploy/yarn/Client.scala | 9 +++++---- .../spark/deploy/yarn/ExecutorRunnable.scala | 17 +++++++++-------- .../spark/deploy/yarn/YarnClusterSuite.scala | 6 ++++-- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index e7005094b5f3c..8afc1ccdad732 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -435,10 +435,11 @@ private[spark] class Client( // Include driver-specific java options if we are launching a driver if (isClusterMode) { - sparkConf.getOption("spark.driver.extraJavaOptions") + val driverOpts = sparkConf.getOption("spark.driver.extraJavaOptions") .orElse(sys.env.get("SPARK_JAVA_OPTS")) - .map(Utils.splitCommandString).getOrElse(Seq.empty) - .foreach(opts => javaOpts += opts) + driverOpts.foreach { opts => + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + } val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { @@ -460,7 +461,7 @@ private[spark] class Client( val msg = s"$amOptsKey is not allowed to alter memory settings (was '$opts')." throw new SparkException(msg) } - javaOpts ++= Utils.splitCommandString(opts) + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 408cf09b9bdfa..7cd8c5f0f9204 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -128,14 +128,15 @@ class ExecutorRunnable( // Set the JVM memory val executorMemoryString = executorMemory + "m" - javaOpts += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " " + javaOpts += "-Xms" + executorMemoryString + javaOpts += "-Xmx" + executorMemoryString // Set extra Java options for the executor, if defined sys.props.get("spark.executor.extraJavaOptions").foreach { opts => - javaOpts += opts + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } sys.env.get("SPARK_JAVA_OPTS").foreach { opts => - javaOpts += opts + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } sys.props.get("spark.executor.extraLibraryPath").foreach { p => prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p))) @@ -173,11 +174,11 @@ class ExecutorRunnable( // The options are based on // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use // %20the%20Concurrent%20Low%20Pause%20Collector|outline - javaOpts += " -XX:+UseConcMarkSweepGC " - javaOpts += " -XX:+CMSIncrementalMode " - javaOpts += " -XX:+CMSIncrementalPacing " - javaOpts += " -XX:CMSIncrementalDutyCycleMin=0 " - javaOpts += " -XX:CMSIncrementalDutyCycle=10 " + javaOpts += "-XX:+UseConcMarkSweepGC" + javaOpts += "-XX:+CMSIncrementalMode" + javaOpts += "-XX:+CMSIncrementalPacing" + javaOpts += "-XX:CMSIncrementalDutyCycleMin=0" + javaOpts += "-XX:CMSIncrementalDutyCycle=10" } */ diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index eda40efc4c77f..e39de82740b1d 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -75,6 +75,8 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit private var oldConf: Map[String, String] = _ override def beforeAll() { + super.beforeAll() + tempDir = Utils.createTempDir() val logConfDir = new File(tempDir, "log4j") @@ -129,8 +131,8 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit sys.props += ("spark.executor.instances" -> "1") sys.props += ("spark.driver.extraClassPath" -> childClasspath) sys.props += ("spark.executor.extraClassPath" -> childClasspath) - - super.beforeAll() + sys.props += ("spark.executor.extraJavaOptions" -> "-Dfoo=\"one two three\"") + sys.props += ("spark.driver.extraJavaOptions" -> "-Dfoo=\"one two three\"") } override def afterAll() { From afb131637d96e1e5e07eb8abf24e32e7f3b2304d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 9 Feb 2015 11:42:52 -0800 Subject: [PATCH 09/24] [SPARK-5678] Convert DataFrame to pandas.DataFrame and Series ``` pyspark.sql.DataFrame.to_pandas = to_pandas(self) unbound pyspark.sql.DataFrame method Collect all the rows and return a `pandas.DataFrame`. >>> df.to_pandas() # doctest: +SKIP age name 0 2 Alice 1 5 Bob pyspark.sql.Column.to_pandas = to_pandas(self) unbound pyspark.sql.Column method Return a pandas.Series from the column >>> df.age.to_pandas() # doctest: +SKIP 0 2 1 5 dtype: int64 ``` Not tests by jenkins (they depends on pandas) Author: Davies Liu Closes #4476 from davies/to_pandas and squashes the following commits: 6276fb6 [Davies Liu] Convert DataFrame to pandas.DataFrame and Series --- python/pyspark/sql.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index e55f285a778c4..6a6dfbc5851b8 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -2284,6 +2284,18 @@ def addColumn(self, colName, col): """ return self.select('*', col.alias(colName)) + def to_pandas(self): + """ + Collect all the rows and return a `pandas.DataFrame`. + + >>> df.to_pandas() # doctest: +SKIP + age name + 0 2 Alice + 1 5 Bob + """ + import pandas as pd + return pd.DataFrame.from_records(self.collect(), columns=self.columns) + # Having SchemaRDD for backward compatibility (for docs) class SchemaRDD(DataFrame): @@ -2551,6 +2563,19 @@ def cast(self, dataType): jc = self._jc.cast(jdt) return Column(jc, self.sql_ctx) + def to_pandas(self): + """ + Return a pandas.Series from the column + + >>> df.age.to_pandas() # doctest: +SKIP + 0 2 + 1 5 + dtype: int64 + """ + import pandas as pd + data = [c for c, in self.collect()] + return pd.Series(data) + def _aggregate_func(name, doc=""): """ Create a function for aggregator by name""" From dae216147f2247fd722fb0909da74fe71cf2fa8b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 9 Feb 2015 11:45:12 -0800 Subject: [PATCH 10/24] [SPARK-5664][BUILD] Restore stty settings when exiting from SBT's spark-shell For launching spark-shell from SBT. Author: Liang-Chi Hsieh Closes #4451 from viirya/restore_stty and squashes the following commits: fdfc480 [Liang-Chi Hsieh] Restore stty settings when exit (for launching spark-shell from SBT). --- build/sbt | 28 ++++++++++++++++++++++++++++ build/sbt-launch-lib.bash | 2 +- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/build/sbt b/build/sbt index 28ebb64f7197c..cc3203d79bccd 100755 --- a/build/sbt +++ b/build/sbt @@ -125,4 +125,32 @@ loadConfigFile() { [[ -f "$etc_sbt_opts_file" ]] && set -- $(loadConfigFile "$etc_sbt_opts_file") "$@" [[ -f "$sbt_opts_file" ]] && set -- $(loadConfigFile "$sbt_opts_file") "$@" +exit_status=127 +saved_stty="" + +restoreSttySettings() { + stty $saved_stty + saved_stty="" +} + +onExit() { + if [[ "$saved_stty" != "" ]]; then + restoreSttySettings + fi + exit $exit_status +} + +saveSttySettings() { + saved_stty=$(stty -g 2>/dev/null) + if [[ ! $? ]]; then + saved_stty="" + fi +} + +saveSttySettings +trap onExit INT + run "$@" + +exit_status=$? +onExit diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index 5e0c640fa5919..504be48b358fa 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -81,7 +81,7 @@ execRunner () { echo "" } - exec "$@" + "$@" } addJava () { From 6fe70d8432314f0b7290a66f114306f61e0a87cc Mon Sep 17 00:00:00 2001 From: mcheah Date: Mon, 9 Feb 2015 13:20:14 -0800 Subject: [PATCH 11/24] [SPARK-5691] Fixing wrong data structure lookup for dupe app registratio... In Master's registerApplication method, it checks if the application had already registered by examining the addressToWorker hash map. In reality, it should refer to the addressToApp data structure, as this is what really tracks which apps have been registered. Author: mcheah Closes #4477 from mccheah/spark-5691 and squashes the following commits: efdc573 [mcheah] [SPARK-5691] Fixing wrong data structure lookup for dupe app registration --- core/src/main/scala/org/apache/spark/deploy/master/Master.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index b8b1a25abff2e..53e453990f8c7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -671,7 +671,7 @@ private[spark] class Master( def registerApplication(app: ApplicationInfo): Unit = { val appAddress = app.driver.path.address - if (addressToWorker.contains(appAddress)) { + if (addressToApp.contains(appAddress)) { logInfo("Attempted to re-register application at same address: " + appAddress) return } From 0765af9b21e9204c410c7a849c7201bc3eda8cc3 Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Mon, 9 Feb 2015 14:17:14 -0800 Subject: [PATCH 12/24] [SPARK-4905][STREAMING] FlumeStreamSuite fix. Using String constructor instead of CharsetDecoder to see if it fixes the issue of empty strings in Flume test output. Author: Hari Shreedharan Closes #4371 from harishreedharan/Flume-stream-attempted-fix and squashes the following commits: 550d363 [Hari Shreedharan] Fix imports. 8695950 [Hari Shreedharan] Use Charsets.UTF_8 instead of "UTF-8" in String constructors. af3ba14 [Hari Shreedharan] [SPARK-4905][STREAMING] FlumeStreamSuite fix. --- .../apache/spark/streaming/flume/FlumeStreamSuite.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index f333e3891b5f0..322de7bf2fed8 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.streaming.flume import java.net.{InetSocketAddress, ServerSocket} import java.nio.ByteBuffer -import java.nio.charset.Charset import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps +import com.google.common.base.Charsets import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor import org.apache.flume.source.avro @@ -108,7 +108,7 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L val inputEvents = input.map { item => val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(item.getBytes("UTF-8"))) + event.setBody(ByteBuffer.wrap(item.getBytes(Charsets.UTF_8))) event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) event } @@ -138,14 +138,13 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L status should be (avro.Status.OK) } - val decoder = Charset.forName("UTF-8").newDecoder() eventually(timeout(10 seconds), interval(100 milliseconds)) { val outputEvents = outputBuffer.flatten.map { _.event } outputEvents.foreach { event => event.getHeaders.get("test") should be("header") } - val output = outputEvents.map(event => decoder.decode(event.getBody()).toString) + val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) output should be (input) } } From f48199eb354d6ec8675c2c1f96c3005064058d66 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 9 Feb 2015 14:51:46 -0800 Subject: [PATCH 13/24] [SPARK-5675][SQL] XyzType companion object should subclass XyzType Otherwise, the following will always return false in Java. ```scala dataType instanceof StringType ``` Author: Reynold Xin Closes #4463 from rxin/type-companion-object and squashes the following commits: 04d5d8d [Reynold Xin] Comment. 976e11e [Reynold Xin] [SPARK-5675][SQL]StringType case object should be subclass of StringType class --- .../apache/spark/sql/types/dataTypes.scala | 85 ++++++++++++++++--- 1 file changed, 73 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index 91efe320546a7..2abb1caee9cd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -240,10 +240,16 @@ abstract class DataType { * @group dataType */ @DeveloperApi -case object NullType extends DataType { +class NullType private() extends DataType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "NullType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. override def defaultSize: Int = 1 } +case object NullType extends NullType + + protected[sql] object NativeType { val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) @@ -292,7 +298,10 @@ protected[sql] abstract class NativeType extends DataType { * @group dataType */ @DeveloperApi -case object StringType extends NativeType with PrimitiveType { +class StringType private() extends NativeType with PrimitiveType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "StringType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] @@ -303,6 +312,8 @@ case object StringType extends NativeType with PrimitiveType { override def defaultSize: Int = 4096 } +case object StringType extends StringType + /** * :: DeveloperApi :: @@ -313,7 +324,10 @@ case object StringType extends NativeType with PrimitiveType { * @group dataType */ @DeveloperApi -case object BinaryType extends NativeType with PrimitiveType { +class BinaryType private() extends NativeType with PrimitiveType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Array[Byte] @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = new Ordering[JvmType] { @@ -332,6 +346,8 @@ case object BinaryType extends NativeType with PrimitiveType { override def defaultSize: Int = 4096 } +case object BinaryType extends BinaryType + /** * :: DeveloperApi :: @@ -341,7 +357,10 @@ case object BinaryType extends NativeType with PrimitiveType { *@group dataType */ @DeveloperApi -case object BooleanType extends NativeType with PrimitiveType { +class BooleanType private() extends NativeType with PrimitiveType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Boolean @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] @@ -352,6 +371,8 @@ case object BooleanType extends NativeType with PrimitiveType { override def defaultSize: Int = 1 } +case object BooleanType extends BooleanType + /** * :: DeveloperApi :: @@ -362,7 +383,10 @@ case object BooleanType extends NativeType with PrimitiveType { * @group dataType */ @DeveloperApi -case object TimestampType extends NativeType { +class TimestampType private() extends NativeType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Timestamp @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } @@ -377,6 +401,8 @@ case object TimestampType extends NativeType { override def defaultSize: Int = 12 } +case object TimestampType extends TimestampType + /** * :: DeveloperApi :: @@ -387,7 +413,10 @@ case object TimestampType extends NativeType { * @group dataType */ @DeveloperApi -case object DateType extends NativeType { +class DateType private() extends NativeType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "DateType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Int @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } @@ -400,6 +429,8 @@ case object DateType extends NativeType { override def defaultSize: Int = 4 } +case object DateType extends DateType + abstract class NumericType extends NativeType with PrimitiveType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for @@ -438,7 +469,10 @@ protected[sql] sealed abstract class IntegralType extends NumericType { * @group dataType */ @DeveloperApi -case object LongType extends IntegralType { +class LongType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "LongType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Long @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Long]] @@ -453,6 +487,8 @@ case object LongType extends IntegralType { override def simpleString = "bigint" } +case object LongType extends LongType + /** * :: DeveloperApi :: @@ -462,7 +498,10 @@ case object LongType extends IntegralType { * @group dataType */ @DeveloperApi -case object IntegerType extends IntegralType { +class IntegerType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Int @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Int]] @@ -477,6 +516,8 @@ case object IntegerType extends IntegralType { override def simpleString = "int" } +case object IntegerType extends IntegerType + /** * :: DeveloperApi :: @@ -486,7 +527,10 @@ case object IntegerType extends IntegralType { * @group dataType */ @DeveloperApi -case object ShortType extends IntegralType { +class ShortType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "ShortType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Short @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Short]] @@ -501,6 +545,8 @@ case object ShortType extends IntegralType { override def simpleString = "smallint" } +case object ShortType extends ShortType + /** * :: DeveloperApi :: @@ -510,7 +556,10 @@ case object ShortType extends IntegralType { * @group dataType */ @DeveloperApi -case object ByteType extends IntegralType { +class ByteType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "ByteType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Byte @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Byte]] @@ -525,6 +574,8 @@ case object ByteType extends IntegralType { override def simpleString = "tinyint" } +case object ByteType extends ByteType + /** Matcher for any expressions that evaluate to [[FractionalType]]s */ protected[sql] object FractionalType { @@ -630,7 +681,10 @@ object DecimalType { * @group dataType */ @DeveloperApi -case object DoubleType extends FractionalType { +class DoubleType private() extends FractionalType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Double @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Double]] @@ -644,6 +698,8 @@ case object DoubleType extends FractionalType { override def defaultSize: Int = 8 } +case object DoubleType extends DoubleType + /** * :: DeveloperApi :: @@ -653,7 +709,10 @@ case object DoubleType extends FractionalType { * @group dataType */ @DeveloperApi -case object FloatType extends FractionalType { +class FloatType private() extends FractionalType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "FloatType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type JvmType = Float @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Float]] @@ -667,6 +726,8 @@ case object FloatType extends FractionalType { override def defaultSize: Int = 4 } +case object FloatType extends FloatType + object ArrayType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ From b884daa58084d4f42e2318894067565b94e07f9d Mon Sep 17 00:00:00 2001 From: Florian Verhein Date: Mon, 9 Feb 2015 23:47:07 +0000 Subject: [PATCH 14/24] [SPARK-5611] [EC2] Allow spark-ec2 repo and branch to be set on CLI of spark_ec2.py and by extension, the ami-list Useful for using alternate spark-ec2 repos or branches. Author: Florian Verhein Closes #4385 from florianverhein/master and squashes the following commits: 7e2b4be [Florian Verhein] [SPARK-5611] [EC2] typo 8b653dc [Florian Verhein] [SPARK-5611] [EC2] Enforce only supporting spark-ec2 forks from github, log improvement bc4b0ed [Florian Verhein] [SPARK-5611] allow spark-ec2 repos with different names 8b5c551 [Florian Verhein] improve option naming, fix logging, fix lint failing, add guard to enforce spark-ec2 7724308 [Florian Verhein] [SPARK-5611] [EC2] fixes b42b68c [Florian Verhein] [SPARK-5611] [EC2] Allow spark-ec2 repo and branch to be set on CLI of spark_ec2.py --- ec2/spark_ec2.py | 37 ++++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 87b2112fe4628..3e4c49c0e1db6 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -62,10 +62,10 @@ DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark" -MESOS_SPARK_EC2_BRANCH = "branch-1.3" -# A URL prefix from which to fetch AMI information -AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH) +# Default location to get the spark-ec2 scripts (and ami-list) from +DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/mesos/spark-ec2" +DEFAULT_SPARK_EC2_BRANCH = "branch-1.3" def setup_boto(): @@ -147,6 +147,14 @@ def parse_args(): "--spark-git-repo", default=DEFAULT_SPARK_GITHUB_REPO, help="Github repo from which to checkout supplied commit hash (default: %default)") + parser.add_option( + "--spark-ec2-git-repo", + default=DEFAULT_SPARK_EC2_GITHUB_REPO, + help="Github repo from which to checkout spark-ec2 (default: %default)") + parser.add_option( + "--spark-ec2-git-branch", + default=DEFAULT_SPARK_EC2_BRANCH, + help="Github repo branch of spark-ec2 to use (default: %default)") parser.add_option( "--hadoop-major-version", default="1", help="Major version of Hadoop (default: %default)") @@ -333,7 +341,12 @@ def get_spark_ami(opts): print >> stderr,\ "Don't recognize %s, assuming type is pvm" % opts.instance_type - ami_path = "%s/%s/%s" % (AMI_PREFIX, opts.region, instance_type) + # URL prefix from which to fetch AMI information + ami_prefix = "{r}/{b}/ami-list".format( + r=opts.spark_ec2_git_repo.replace("https://github.com", "https://raw.github.com", 1), + b=opts.spark_ec2_git_branch) + + ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type) try: ami = urllib2.urlopen(ami_path).read().strip() print "Spark AMI: " + ami @@ -650,12 +663,15 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten + print "Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( + r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch) ssh( host=master, opts=opts, command="rm -rf spark-ec2" + " && " - + "git clone https://github.com/mesos/spark-ec2.git -b {b}".format(b=MESOS_SPARK_EC2_BRANCH) + + "git clone {r} -b {b} spark-ec2".format(r=opts.spark_ec2_git_repo, + b=opts.spark_ec2_git_branch) ) print "Deploying files to master..." @@ -1038,6 +1054,17 @@ def real_main(): print >> stderr, "ebs-vol-num cannot be greater than 8" sys.exit(1) + # Prevent breaking ami_prefix (/, .git and startswith checks) + # Prevent forks with non spark-ec2 names for now. + if opts.spark_ec2_git_repo.endswith("/") or \ + opts.spark_ec2_git_repo.endswith(".git") or \ + not opts.spark_ec2_git_repo.startswith("https://github.com") or \ + not opts.spark_ec2_git_repo.endswith("spark-ec2"): + print >> stderr, "spark-ec2-git-repo must be a github repo and it must not have a " \ + "trailing / or .git. " \ + "Furthermore, we currently only support forks named spark-ec2." + sys.exit(1) + try: conn = ec2.connect_to_region(opts.region) except Exception as e: From 68b25cf695e0fce9e465288d5a053e540a3fccb4 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 9 Feb 2015 16:02:56 -0800 Subject: [PATCH 15/24] [SQL] Add some missing DataFrame functions. - as with a `Symbol` - distinct - sqlContext.emptyDataFrame - move add/remove col out of RDDApi section Author: Michael Armbrust Closes #4437 from marmbrus/dfMissingFuncs and squashes the following commits: 2004023 [Michael Armbrust] Add missing functions --- .../scala/org/apache/spark/sql/Column.scala | 9 +++++ .../org/apache/spark/sql/DataFrame.scala | 12 +++++-- .../org/apache/spark/sql/DataFrameImpl.scala | 34 +++++++++++-------- .../apache/spark/sql/IncomputableColumn.scala | 10 +++--- .../scala/org/apache/spark/sql/RDDApi.scala | 2 ++ .../org/apache/spark/sql/SQLContext.scala | 5 ++- 6 files changed, 51 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 878b2b0556de7..1011bf0bb5ef4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -549,6 +549,15 @@ trait Column extends DataFrame { */ override def as(alias: String): Column = exprToColumn(Alias(expr, alias)()) + /** + * Gives the column an alias. + * {{{ + * // Renames colA to colB in select output. + * df.select($"colA".as('colB)) + * }}} + */ + override def as(alias: Symbol): Column = exprToColumn(Alias(expr, alias.name)()) + /** * Casts the column to a different data type. * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 17ea3cde8e50e..6abfb7853cf1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -156,7 +156,7 @@ trait DataFrame extends RDDApi[Row] { def join(right: DataFrame, joinExprs: Column): DataFrame /** - * Join with another [[DataFrame]], usin g the given join expression. The following performs + * Join with another [[DataFrame]], using the given join expression. The following performs * a full outer join between `df1` and `df2`. * * {{{ @@ -233,7 +233,12 @@ trait DataFrame extends RDDApi[Row] { /** * Returns a new [[DataFrame]] with an alias set. */ - def as(name: String): DataFrame + def as(alias: String): DataFrame + + /** + * (Scala-specific) Returns a new [[DataFrame]] with an alias set. + */ + def as(alias: Symbol): DataFrame /** * Selects a set of expressions. @@ -516,6 +521,9 @@ trait DataFrame extends RDDApi[Row] { */ override def repartition(numPartitions: Int): DataFrame + /** Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. */ + override def distinct: DataFrame + override def persist(): this.type override def persist(newLevel: StorageLevel): this.type diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index fa05a5dcac6bf..73393295ab0a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -196,7 +196,9 @@ private[sql] class DataFrameImpl protected[sql]( }.toSeq :_*) } - override def as(name: String): DataFrame = Subquery(name, logicalPlan) + override def as(alias: String): DataFrame = Subquery(alias, logicalPlan) + + override def as(alias: Symbol): DataFrame = Subquery(alias.name, logicalPlan) override def select(cols: Column*): DataFrame = { val exprs = cols.zipWithIndex.map { @@ -215,7 +217,19 @@ private[sql] class DataFrameImpl protected[sql]( override def selectExpr(exprs: String*): DataFrame = { select(exprs.map { expr => Column(new SqlParser().parseExpression(expr)) - } :_*) + }: _*) + } + + override def addColumn(colName: String, col: Column): DataFrame = { + select(Column("*"), col.as(colName)) + } + + override def renameColumn(existingName: String, newName: String): DataFrame = { + val colNames = schema.map { field => + val name = field.name + if (name == existingName) Column(name).as(newName) else Column(name) + } + select(colNames :_*) } override def filter(condition: Column): DataFrame = { @@ -264,18 +278,8 @@ private[sql] class DataFrameImpl protected[sql]( } ///////////////////////////////////////////////////////////////////////////// - - override def addColumn(colName: String, col: Column): DataFrame = { - select(Column("*"), col.as(colName)) - } - - override def renameColumn(existingName: String, newName: String): DataFrame = { - val colNames = schema.map { field => - val name = field.name - if (name == existingName) Column(name).as(newName) else Column(name) - } - select(colNames :_*) - } + // RDD API + ///////////////////////////////////////////////////////////////////////////// override def head(n: Int): Array[Row] = limit(n).collect() @@ -307,6 +311,8 @@ private[sql] class DataFrameImpl protected[sql]( sqlContext.applySchema(rdd.repartition(numPartitions), schema) } + override def distinct: DataFrame = Distinct(logicalPlan) + override def persist(): this.type = { sqlContext.cacheManager.cacheQuery(this) this diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala index 782f6e28eebb0..0600dcc226b4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala @@ -86,6 +86,10 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten override def selectExpr(exprs: String*): DataFrame = err() + override def addColumn(colName: String, col: Column): DataFrame = err() + + override def renameColumn(existingName: String, newName: String): DataFrame = err() + override def filter(condition: Column): DataFrame = err() override def filter(conditionExpr: String): DataFrame = err() @@ -110,10 +114,6 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten ///////////////////////////////////////////////////////////////////////////// - override def addColumn(colName: String, col: Column): DataFrame = err() - - override def renameColumn(existingName: String, newName: String): DataFrame = err() - override def head(n: Int): Array[Row] = err() override def head(): Row = err() @@ -140,6 +140,8 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten override def repartition(numPartitions: Int): DataFrame = err() + override def distinct: DataFrame = err() + override def persist(): this.type = err() override def persist(newLevel: StorageLevel): this.type = err() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala index 38e6382f171d5..df866fd1ad8ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala @@ -60,4 +60,6 @@ private[sql] trait RDDApi[T] { def first(): T def repartition(numPartitions: Int): DataFrame + + def distinct: DataFrame } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index bf3990671029e..97e3777f933e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, NoRelation} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} @@ -130,6 +130,9 @@ class SQLContext(@transient val sparkContext: SparkContext) */ val experimental: ExperimentalMethods = new ExperimentalMethods(this) + /** Returns a [[DataFrame]] with no rows or columns. */ + lazy val emptyDataFrame = DataFrame(this, NoRelation) + /** * A collection of methods for registering user-defined functions (UDF). * From 5f0b30e59cc6a3017168189d3aaf09402699dc3b Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 9 Feb 2015 16:20:42 -0800 Subject: [PATCH 16/24] [SQL] Code cleanup. I added an unnecessary line of code in https://github.com/apache/spark/commit/13531dd97c08563e53dacdaeaf1102bdd13ef825. My bad. Let's delete it. Author: Yin Huai Closes #4482 from yhuai/unnecessaryCode and squashes the following commits: 3645af0 [Yin Huai] Code cleanup. --- .../org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index c23575fe96898..036efa84d7c85 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -351,9 +351,6 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |) """.stripMargin) - new Path("/Users/yhuai/Desktop/whatever") - - val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") val filesystemPath = new Path(expectedPath) val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) From b8080aa86d55e0467fd4328f10a2f0d6605e6cc6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 9 Feb 2015 16:23:12 -0800 Subject: [PATCH 17/24] [SPARK-5696] [SQL] [HOTFIX] Asks HiveThriftServer2 to re-initialize log4j using Hive configurations In this way, log4j configurations overriden by jets3t-0.9.2.jar can be again overriden by Hive default log4j configurations. This might not be the best solution for this issue since it requires users to use `hive-log4j.properties` rather than `log4j.properties` to initialize `HiveThriftServer2` logging configurations, which can be confusing. The main purpose of this PR is to fix Jenkins PR build. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/4484) Author: Cheng Lian Closes #4484 from liancheng/spark-5696 and squashes the following commits: df83956 [Cheng Lian] Hot fix: asks HiveThriftServer2 to re-initialize log4j using Hive configurations --- .../apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 6e07df18b0e15..525777aa454c4 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.thriftserver import org.apache.commons.logging.LogFactory +import org.apache.hadoop.hive.common.LogUtils import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} @@ -54,6 +55,8 @@ object HiveThriftServer2 extends Logging { System.exit(-1) } + LogUtils.initHiveLog4j() + logInfo("Starting SparkContext") SparkSQLEnv.init() From 2a36292534a1e9f7a501e88f69bfc3a09fb62cb3 Mon Sep 17 00:00:00 2001 From: Lu Yan Date: Mon, 9 Feb 2015 16:25:38 -0800 Subject: [PATCH 18/24] [SPARK-5614][SQL] Predicate pushdown through Generate. Now in Catalyst's rules, predicates can not be pushed through "Generate" nodes. Further more, partition pruning in HiveTableScan can not be applied on those queries involves "Generate". This makes such queries very inefficient. In practice, it finds patterns like ```scala Filter(predicate, Generate(generator, _, _, _, grandChild)) ``` and splits the predicate into 2 parts by referencing the generated column from Generate node or not. And a new Filter will be created for those conjuncts can be pushed beneath Generate node. If nothing left for the original Filter, it will be removed. For example, physical plan for query ```sql select len, bk from s_server lateral view explode(len_arr) len_table as len where len > 5 and day = '20150102'; ``` where 'day' is a partition column in metastore is like this in current version of Spark SQL: > Project [len, bk] > > Filter ((len > "5") && "(day = "20150102")") > > Generate explode(len_arr), true, false > > HiveTableScan [bk, len_arr, day], (MetastoreRelation default, s_server, None), None But theoretically the plan should be like this > Project [len, bk] > > Filter (len > "5") > > Generate explode(len_arr), true, false > > HiveTableScan [bk, len_arr, day], (MetastoreRelation default, s_server, None), Some(day = "20150102") Where partition pruning predicates can be pushed to HiveTableScan nodes. Author: Lu Yan Closes #4394 from ianluyan/ppd and squashes the following commits: a67dce9 [Lu Yan] Fix English grammar. 7cea911 [Lu Yan] Revised based on @marmbrus's opinions ffc59fc [Lu Yan] [SPARK-5614][SQL] Predicate pushdown through Generate. --- .../sql/catalyst/optimizer/Optimizer.scala | 25 ++++++++ .../optimizer/FilterPushdownSuite.scala | 63 ++++++++++++++++++- 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3bc48c95c5653..fd58b9681ea24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -50,6 +50,7 @@ object DefaultOptimizer extends Optimizer { CombineFilters, PushPredicateThroughProject, PushPredicateThroughJoin, + PushPredicateThroughGenerate, ColumnPruning) :: Batch("LocalRelation", FixedPoint(100), ConvertToLocalRelation) :: Nil @@ -455,6 +456,30 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { } } +/** + * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference + * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath. + */ +object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case filter @ Filter(condition, + generate @ Generate(generator, join, outer, alias, grandChild)) => + // Predicates that reference attributes produced by the `Generate` operator cannot + // be pushed below the operator. + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { + conjunct => conjunct.references subsetOf grandChild.outputSet + } + if (pushDown.nonEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild)) + stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) + } else { + filter + } + } +} + /** * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ebb123c1f909e..1158b5dfc6147 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.expressions.Explode import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.IntegerType class FilterPushdownSuite extends PlanTest { @@ -34,7 +36,8 @@ class FilterPushdownSuite extends PlanTest { Batch("Filter Pushdown", Once, CombineFilters, PushPredicateThroughProject, - PushPredicateThroughJoin) :: Nil + PushPredicateThroughJoin, + PushPredicateThroughGenerate) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -411,4 +414,62 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer)) } + + val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) + + test("generate: predicate referenced no generated column") { + val originalQuery = { + testRelationWithArrayType + .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")) + .where(('b >= 5) && ('a > 6)) + } + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = { + testRelationWithArrayType + .where(('b >= 5) && ('a > 6)) + .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze + } + + comparePlans(optimized, correctAnswer) + } + + test("generate: part of conjuncts referenced generated column") { + val generator = Explode(Seq("c"), 'c_arr) + val originalQuery = { + testRelationWithArrayType + .generate(generator, true, false, Some("arr")) + .where(('b >= 5) && ('c > 6)) + } + val optimized = Optimize(originalQuery.analyze) + val referenceResult = { + testRelationWithArrayType + .where('b >= 5) + .generate(generator, true, false, Some("arr")) + .where('c > 6).analyze + } + + // Since newly generated columns get different ids every time being analyzed + // e.g. comparePlans(originalQuery.analyze, originalQuery.analyze) fails. + // So we check operators manually here. + // Filter("c" > 6) + assertResult(classOf[Filter])(optimized.getClass) + assertResult(1)(optimized.asInstanceOf[Filter].condition.references.size) + assertResult("c"){ + optimized.asInstanceOf[Filter].condition.references.toSeq(0).name + } + + // the rest part + comparePlans(optimized.children(0), referenceResult.children(0)) + } + + test("generate: all conjuncts referenced generated column") { + val originalQuery = { + testRelationWithArrayType + .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")) + .where(('c > 6) || ('b > 5)).analyze + } + val optimized = Optimize(originalQuery) + + comparePlans(optimized, originalQuery) + } } From 0ee53ebce9944722e76b2b28fae79d9956be9f17 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 9 Feb 2015 16:39:34 -0800 Subject: [PATCH 19/24] [SPARK-2096][SQL] support dot notation on array of struct ~~The rule is simple: If you want `a.b` work, then `a` must be some level of nested array of struct(level 0 means just a StructType). And the result of `a.b` is same level of nested array of b-type. An optimization is: the resolve chain looks like `Attribute -> GetItem -> GetField -> GetField ...`, so we could transmit the nested array information between `GetItem` and `GetField` to avoid repeated computation of `innerDataType` and `containsNullList` of that nested array.~~ marmbrus Could you take a look? to evaluate `a.b`, if `a` is array of struct, then `a.b` means get field `b` on each element of `a`, and return a result of array. Author: Wenchen Fan Closes #2405 from cloud-fan/nested-array-dot and squashes the following commits: 08a228a [Wenchen Fan] support dot notation on array of struct --- .../sql/catalyst/analysis/Analyzer.scala | 30 +++++++++------- .../catalyst/expressions/complexTypes.scala | 34 ++++++++++++++++--- .../sql/catalyst/optimizer/Optimizer.scala | 3 +- .../ExpressionEvaluationSuite.scala | 2 +- .../org/apache/spark/sql/json/JsonSuite.scala | 6 ++-- 5 files changed, 53 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0b59ed1739566..fb2ff014cef07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{ArrayType, StructField, StructType, IntegerType} /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing @@ -311,18 +310,25 @@ class Analyzer(catalog: Catalog, * desired fields are found. */ protected def resolveGetField(expr: Expression, fieldName: String): Expression = { + def findField(fields: Array[StructField]): Int = { + val checkField = (f: StructField) => resolver(f.name, fieldName) + val ordinal = fields.indexWhere(checkField) + if (ordinal == -1) { + sys.error( + s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") + } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { + sys.error(s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") + } else { + ordinal + } + } expr.dataType match { case StructType(fields) => - val actualField = fields.filter(f => resolver(f.name, fieldName)) - if (actualField.length == 0) { - sys.error( - s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") - } else if (actualField.length == 1) { - val field = actualField(0) - GetField(expr, field, fields.indexOf(field)) - } else { - sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}") - } + val ordinal = findField(fields) + StructGetField(expr, fields(ordinal), ordinal) + case ArrayType(StructType(fields), containsNull) => + val ordinal = findField(fields) + ArrayGetField(expr, fields(ordinal), ordinal, containsNull) case otherType => sys.error(s"GetField is not valid on fields of type $otherType") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 66e2e5c4bafce..68051a2a2007e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -70,22 +70,48 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { } } + +trait GetField extends UnaryExpression { + self: Product => + + type EvaluatedType = Any + override def foldable = child.foldable + override def toString = s"$child.${field.name}" + + def field: StructField +} + /** * Returns the value of fields in the Struct `child`. */ -case class GetField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression { - type EvaluatedType = Any +case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField { def dataType = field.dataType override def nullable = child.nullable || field.nullable - override def foldable = child.foldable override def eval(input: Row): Any = { val baseValue = child.eval(input).asInstanceOf[Row] if (baseValue == null) null else baseValue(ordinal) } +} - override def toString = s"$child.${field.name}" +/** + * Returns the array of value of fields in the Array of Struct `child`. + */ +case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean) + extends GetField { + + def dataType = ArrayType(field.dataType, containsNull) + override def nullable = child.nullable + + override def eval(input: Row): Any = { + val baseValue = child.eval(input).asInstanceOf[Seq[Row]] + if (baseValue == null) null else { + baseValue.map { row => + if (row == null) null else row(ordinal) + } + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index fd58b9681ea24..0da081ed1a6e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -209,7 +209,8 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType) case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) - case e @ GetField(Literal(null, _), _, _) => Literal(null, e.dataType) + case e @ StructGetField(Literal(null, _), _, _) => Literal(null, e.dataType) + case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case e @ Count(expr) if !expr.nullable => Count(Literal(1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 7cf6c80194f6c..dcfd8b28cb02a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -851,7 +851,7 @@ class ExpressionEvaluationSuite extends FunSuite { expr.dataType match { case StructType(fields) => val field = fields.find(_.name == fieldName).get - GetField(expr, field, fields.indexOf(field)) + StructGetField(expr, field, fields.indexOf(field)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 926ba68828ee8..7870cf9b0a868 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -342,21 +342,19 @@ class JsonSuite extends QueryTest { ) } - ignore("Complex field and type inferring (Ignored)") { + test("GetField operation on complex data type") { val jsonDF = jsonRDD(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") - // Right now, "field1" and "field2" are treated as aliases. We should fix it. checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), Row(true, "str1") ) - // Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2. // Getting all values of a specific field from an array of structs. checkAnswer( sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), - Row(Seq(true, false), Seq("str1", null)) + Row(Seq(true, false, null), Seq("str1", null, null)) ) } From d08e7c2b498584609cb3c7922eaaa2a0d115603f Mon Sep 17 00:00:00 2001 From: DoingDone9 <799203320@qq.com> Date: Mon, 9 Feb 2015 16:40:26 -0800 Subject: [PATCH 20/24] [SPARK-5648][SQL] support "alter ... unset tblproperties("key")" make hivecontext support "alter ... unset tblproperties("key")" like : alter view viewName unset tblproperties("k") alter table tableName unset tblproperties("k") Author: DoingDone9 <799203320@qq.com> Closes #4424 from DoingDone9/unset and squashes the following commits: 6dd8bee [DoingDone9] support "alter ... unset tblproperties("key")" --- sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 2a4b88092179f..f51af62d3340b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -103,6 +103,7 @@ private[hive] object HiveQl { "TOK_CREATEINDEX", "TOK_DROPDATABASE", "TOK_DROPINDEX", + "TOK_DROPTABLE_PROPERTIES", "TOK_MSCK", "TOK_ALTERVIEW_ADDPARTS", @@ -111,6 +112,7 @@ private[hive] object HiveQl { "TOK_ALTERVIEW_PROPERTIES", "TOK_ALTERVIEW_RENAME", "TOK_CREATEVIEW", + "TOK_DROPVIEW_PROPERTIES", "TOK_DROPVIEW", "TOK_EXPORT", From 3ec3ad295ddd1435da68251b7479ffb60aec7037 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 9 Feb 2015 16:52:05 -0800 Subject: [PATCH 21/24] [SPARK-5699] [SQL] [Tests] Runs hive-thriftserver tests whenever SQL code is modified [Review on Reviewable](https://reviewable.io/reviews/apache/spark/4486) Author: Cheng Lian Closes #4486 from liancheng/spark-5699 and squashes the following commits: 538001d [Cheng Lian] Runs hive-thriftserver tests whenever SQL code is modified --- dev/run-tests | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/run-tests b/dev/run-tests index 2257a566bb1bb..483958757a2dd 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -36,7 +36,7 @@ function handle_error () { } -# Build against the right verison of Hadoop. +# Build against the right version of Hadoop. { if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then if [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop1.0" ]; then @@ -77,7 +77,7 @@ export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl" fi } -# Only run Hive tests if there are sql changes. +# Only run Hive tests if there are SQL changes. # Partial solution for SPARK-1455. if [ -n "$AMPLAB_JENKINS" ]; then git fetch origin master:master @@ -183,7 +183,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS if [ -n "$_SQL_TESTS_ONLY" ]; then # This must be an array of individual arguments. Otherwise, having one long string # will be interpreted as a single test, which doesn't work. - SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test") + SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "hive-thriftserver/test" "mllib/test") else SBT_MAVEN_TEST_ARGS=("test") fi From d302c4800bf2f74eceb731169ddf1766136b7398 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 9 Feb 2015 17:33:29 -0800 Subject: [PATCH 22/24] [SPARK-5698] Do not let user request negative # of executors Otherwise we might crash the ApplicationMaster. Why? Please see https://issues.apache.org/jira/browse/SPARK-5698. sryza I believe this is also relevant in your patch #4168. Author: Andrew Or Closes #4483 from andrewor14/da-negative and squashes the following commits: 53ed955 [Andrew Or] Throw IllegalArgumentException instead 0e89fd5 [Andrew Or] Check against negative requests --- .../scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 9d2fb4f3b4729..f9ca93432bf41 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -314,6 +314,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste * Return whether the request is acknowledged. */ final override def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized { + if (numAdditionalExecutors < 0) { + throw new IllegalArgumentException( + "Attempted to request a negative number of additional executor(s) " + + s"$numAdditionalExecutors from the cluster manager. Please specify a positive number!") + } logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") logDebug(s"Number of pending executors is now $numPendingExecutors") numPendingExecutors += numAdditionalExecutors From 08488c175f2e8532cb6aab84da2abd9ad57179cc Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 9 Feb 2015 20:49:22 -0800 Subject: [PATCH 23/24] [SPARK-5469] restructure pyspark.sql into multiple files All the DataTypes moved into pyspark.sql.types The changes can be tracked by `--find-copies-harder -M25` ``` davieslocalhost:~/work/spark/python$ git diff --find-copies-harder -M25 --numstat master.. 2 5 python/docs/pyspark.ml.rst 0 3 python/docs/pyspark.mllib.rst 10 2 python/docs/pyspark.sql.rst 1 1 python/pyspark/mllib/linalg.py 21 14 python/pyspark/{mllib => sql}/__init__.py 14 2108 python/pyspark/{sql.py => sql/context.py} 10 1772 python/pyspark/{sql.py => sql/dataframe.py} 7 6 python/pyspark/{sql_tests.py => sql/tests.py} 8 1465 python/pyspark/{sql.py => sql/types.py} 4 2 python/run-tests 1 1 sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala ``` Also `git blame -C -C python/pyspark/sql/context.py` to track the history. Author: Davies Liu Closes #4479 from davies/sql and squashes the following commits: 1b5f0a5 [Davies Liu] Merge branch 'master' of github.com:apache/spark into sql 2b2b983 [Davies Liu] restructure pyspark.sql --- python/docs/pyspark.ml.rst | 7 +- python/docs/pyspark.mllib.rst | 3 - python/docs/pyspark.sql.rst | 12 +- python/pyspark/mllib/linalg.py | 2 +- python/pyspark/sql.py | 2736 ----------------- python/pyspark/sql/__init__.py | 42 + python/pyspark/sql/context.py | 642 ++++ python/pyspark/sql/dataframe.py | 974 ++++++ python/pyspark/{sql_tests.py => sql/tests.py} | 13 +- python/pyspark/sql/types.py | 1279 ++++++++ python/run-tests | 6 +- .../spark/sql/test/ExamplePointUDT.scala | 2 +- 12 files changed, 2962 insertions(+), 2756 deletions(-) delete mode 100644 python/pyspark/sql.py create mode 100644 python/pyspark/sql/__init__.py create mode 100644 python/pyspark/sql/context.py create mode 100644 python/pyspark/sql/dataframe.py rename python/pyspark/{sql_tests.py => sql/tests.py} (96%) create mode 100644 python/pyspark/sql/types.py diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index f10d1339a9a8f..4da6d4a74a299 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -1,11 +1,8 @@ pyspark.ml package ===================== -Submodules ----------- - -pyspark.ml module ------------------ +Module Context +-------------- .. automodule:: pyspark.ml :members: diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst index 4548b8739ed91..21f66ca344a3c 100644 --- a/python/docs/pyspark.mllib.rst +++ b/python/docs/pyspark.mllib.rst @@ -1,9 +1,6 @@ pyspark.mllib package ===================== -Submodules ----------- - pyspark.mllib.classification module ----------------------------------- diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst index 65b3650ae10ab..80c6f02a9df41 100644 --- a/python/docs/pyspark.sql.rst +++ b/python/docs/pyspark.sql.rst @@ -1,10 +1,18 @@ pyspark.sql module ================== -Module contents ---------------- +Module Context +-------------- .. automodule:: pyspark.sql :members: :undoc-members: :show-inheritance: + + +pyspark.sql.types module +------------------------ +.. automodule:: pyspark.sql.types + :members: + :undoc-members: + :show-inheritance: diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 7f21190ed8c25..597012b1c967c 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -29,7 +29,7 @@ import numpy as np -from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ +from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ IntegerType, ByteType diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py deleted file mode 100644 index 6a6dfbc5851b8..0000000000000 --- a/python/pyspark/sql.py +++ /dev/null @@ -1,2736 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -public classes of Spark SQL: - - - L{SQLContext} - Main entry point for SQL functionality. - - L{DataFrame} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, DataFrames also support SQL. - - L{GroupedData} - - L{Column} - Column is a DataFrame with a single column. - - L{Row} - A Row of data returned by a Spark SQL query. - - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. -""" - -import sys -import itertools -import decimal -import datetime -import keyword -import warnings -import json -import re -import random -import os -from tempfile import NamedTemporaryFile -from array import array -from operator import itemgetter -from itertools import imap - -from py4j.protocol import Py4JError -from py4j.java_collections import ListConverter, MapConverter - -from pyspark.context import SparkContext -from pyspark.rdd import RDD, _prepare_for_python_RDD -from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ - CloudPickleSerializer, UTF8Deserializer -from pyspark.storagelevel import StorageLevel -from pyspark.traceback_utils import SCCallSiteSync - - -__all__ = [ - "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", - "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", - "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "DataFrame", "GroupedData", "Column", "Row", "Dsl", - "SchemaRDD"] - - -class DataType(object): - - """Spark SQL DataType""" - - def __repr__(self): - return self.__class__.__name__ - - def __hash__(self): - return hash(str(self)) - - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.__dict__ == other.__dict__) - - def __ne__(self, other): - return not self.__eq__(other) - - @classmethod - def typeName(cls): - return cls.__name__[:-4].lower() - - def jsonValue(self): - return self.typeName() - - def json(self): - return json.dumps(self.jsonValue(), - separators=(',', ':'), - sort_keys=True) - - -class PrimitiveTypeSingleton(type): - - """Metaclass for PrimitiveType""" - - _instances = {} - - def __call__(cls): - if cls not in cls._instances: - cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() - return cls._instances[cls] - - -class PrimitiveType(DataType): - - """Spark SQL PrimitiveType""" - - __metaclass__ = PrimitiveTypeSingleton - - def __eq__(self, other): - # because they should be the same object - return self is other - - -class NullType(PrimitiveType): - - """Spark SQL NullType - - The data type representing None, used for the types which has not - been inferred. - """ - - -class StringType(PrimitiveType): - - """Spark SQL StringType - - The data type representing string values. - """ - - -class BinaryType(PrimitiveType): - - """Spark SQL BinaryType - - The data type representing bytearray values. - """ - - -class BooleanType(PrimitiveType): - - """Spark SQL BooleanType - - The data type representing bool values. - """ - - -class DateType(PrimitiveType): - - """Spark SQL DateType - - The data type representing datetime.date values. - """ - - -class TimestampType(PrimitiveType): - - """Spark SQL TimestampType - - The data type representing datetime.datetime values. - """ - - -class DecimalType(DataType): - - """Spark SQL DecimalType - - The data type representing decimal.Decimal values. - """ - - def __init__(self, precision=None, scale=None): - self.precision = precision - self.scale = scale - self.hasPrecisionInfo = precision is not None - - def jsonValue(self): - if self.hasPrecisionInfo: - return "decimal(%d,%d)" % (self.precision, self.scale) - else: - return "decimal" - - def __repr__(self): - if self.hasPrecisionInfo: - return "DecimalType(%d,%d)" % (self.precision, self.scale) - else: - return "DecimalType()" - - -class DoubleType(PrimitiveType): - - """Spark SQL DoubleType - - The data type representing float values. - """ - - -class FloatType(PrimitiveType): - - """Spark SQL FloatType - - The data type representing single precision floating-point values. - """ - - -class ByteType(PrimitiveType): - - """Spark SQL ByteType - - The data type representing int values with 1 singed byte. - """ - - -class IntegerType(PrimitiveType): - - """Spark SQL IntegerType - - The data type representing int values. - """ - - -class LongType(PrimitiveType): - - """Spark SQL LongType - - The data type representing long values. If the any value is - beyond the range of [-9223372036854775808, 9223372036854775807], - please use DecimalType. - """ - - -class ShortType(PrimitiveType): - - """Spark SQL ShortType - - The data type representing int values with 2 signed bytes. - """ - - -class ArrayType(DataType): - - """Spark SQL ArrayType - - The data type representing list values. An ArrayType object - comprises two fields, elementType (a DataType) and containsNull (a bool). - The field of elementType is used to specify the type of array elements. - The field of containsNull is used to specify if the array has None values. - - """ - - def __init__(self, elementType, containsNull=True): - """Creates an ArrayType - - :param elementType: the data type of elements. - :param containsNull: indicates whether the list contains None values. - - >>> ArrayType(StringType) == ArrayType(StringType, True) - True - >>> ArrayType(StringType, False) == ArrayType(StringType) - False - """ - self.elementType = elementType - self.containsNull = containsNull - - def __repr__(self): - return "ArrayType(%s,%s)" % (self.elementType, - str(self.containsNull).lower()) - - def jsonValue(self): - return {"type": self.typeName(), - "elementType": self.elementType.jsonValue(), - "containsNull": self.containsNull} - - @classmethod - def fromJson(cls, json): - return ArrayType(_parse_datatype_json_value(json["elementType"]), - json["containsNull"]) - - -class MapType(DataType): - - """Spark SQL MapType - - The data type representing dict values. A MapType object comprises - three fields, keyType (a DataType), valueType (a DataType) and - valueContainsNull (a bool). - - The field of keyType is used to specify the type of keys in the map. - The field of valueType is used to specify the type of values in the map. - The field of valueContainsNull is used to specify if values of this - map has None values. - - For values of a MapType column, keys are not allowed to have None values. - - """ - - def __init__(self, keyType, valueType, valueContainsNull=True): - """Creates a MapType - :param keyType: the data type of keys. - :param valueType: the data type of values. - :param valueContainsNull: indicates whether values contains - null values. - - >>> (MapType(StringType, IntegerType) - ... == MapType(StringType, IntegerType, True)) - True - >>> (MapType(StringType, IntegerType, False) - ... == MapType(StringType, FloatType)) - False - """ - self.keyType = keyType - self.valueType = valueType - self.valueContainsNull = valueContainsNull - - def __repr__(self): - return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, - str(self.valueContainsNull).lower()) - - def jsonValue(self): - return {"type": self.typeName(), - "keyType": self.keyType.jsonValue(), - "valueType": self.valueType.jsonValue(), - "valueContainsNull": self.valueContainsNull} - - @classmethod - def fromJson(cls, json): - return MapType(_parse_datatype_json_value(json["keyType"]), - _parse_datatype_json_value(json["valueType"]), - json["valueContainsNull"]) - - -class StructField(DataType): - - """Spark SQL StructField - - Represents a field in a StructType. - A StructField object comprises three fields, name (a string), - dataType (a DataType) and nullable (a bool). The field of name - is the name of a StructField. The field of dataType specifies - the data type of a StructField. - - The field of nullable specifies if values of a StructField can - contain None values. - - """ - - def __init__(self, name, dataType, nullable=True, metadata=None): - """Creates a StructField - :param name: the name of this field. - :param dataType: the data type of this field. - :param nullable: indicates whether values of this field - can be null. - :param metadata: metadata of this field, which is a map from string - to simple type that can be serialized to JSON - automatically - - >>> (StructField("f1", StringType, True) - ... == StructField("f1", StringType, True)) - True - >>> (StructField("f1", StringType, True) - ... == StructField("f2", StringType, True)) - False - """ - self.name = name - self.dataType = dataType - self.nullable = nullable - self.metadata = metadata or {} - - def __repr__(self): - return "StructField(%s,%s,%s)" % (self.name, self.dataType, - str(self.nullable).lower()) - - def jsonValue(self): - return {"name": self.name, - "type": self.dataType.jsonValue(), - "nullable": self.nullable, - "metadata": self.metadata} - - @classmethod - def fromJson(cls, json): - return StructField(json["name"], - _parse_datatype_json_value(json["type"]), - json["nullable"], - json["metadata"]) - - -class StructType(DataType): - - """Spark SQL StructType - - The data type representing rows. - A StructType object comprises a list of L{StructField}. - - """ - - def __init__(self, fields): - """Creates a StructType - - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True)]) - >>> struct1 == struct2 - True - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True), - ... [StructField("f2", IntegerType, False)]]) - >>> struct1 == struct2 - False - """ - self.fields = fields - - def __repr__(self): - return ("StructType(List(%s))" % - ",".join(str(field) for field in self.fields)) - - def jsonValue(self): - return {"type": self.typeName(), - "fields": [f.jsonValue() for f in self.fields]} - - @classmethod - def fromJson(cls, json): - return StructType([StructField.fromJson(f) for f in json["fields"]]) - - -class UserDefinedType(DataType): - """ - .. note:: WARN: Spark Internal Use Only - SQL User-Defined Type (UDT). - """ - - @classmethod - def typeName(cls): - return cls.__name__.lower() - - @classmethod - def sqlType(cls): - """ - Underlying SQL storage type for this UDT. - """ - raise NotImplementedError("UDT must implement sqlType().") - - @classmethod - def module(cls): - """ - The Python module of the UDT. - """ - raise NotImplementedError("UDT must implement module().") - - @classmethod - def scalaUDT(cls): - """ - The class name of the paired Scala UDT. - """ - raise NotImplementedError("UDT must have a paired Scala UDT.") - - def serialize(self, obj): - """ - Converts the a user-type object into a SQL datum. - """ - raise NotImplementedError("UDT must implement serialize().") - - def deserialize(self, datum): - """ - Converts a SQL datum into a user-type object. - """ - raise NotImplementedError("UDT must implement deserialize().") - - def json(self): - return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) - - def jsonValue(self): - schema = { - "type": "udt", - "class": self.scalaUDT(), - "pyClass": "%s.%s" % (self.module(), type(self).__name__), - "sqlType": self.sqlType().jsonValue() - } - return schema - - @classmethod - def fromJson(cls, json): - pyUDT = json["pyClass"] - split = pyUDT.rfind(".") - pyModule = pyUDT[:split] - pyClass = pyUDT[split+1:] - m = __import__(pyModule, globals(), locals(), [pyClass], -1) - UDT = getattr(m, pyClass) - return UDT() - - def __eq__(self, other): - return type(self) == type(other) - - -_all_primitive_types = dict((v.typeName(), v) - for v in globals().itervalues() - if type(v) is PrimitiveTypeSingleton and - v.__base__ == PrimitiveType) - - -_all_complex_types = dict((v.typeName(), v) - for v in [ArrayType, MapType, StructType]) - - -def _parse_datatype_json_string(json_string): - """Parses the given data type JSON string. - >>> def check_datatype(datatype): - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) - ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) - ... return datatype == python_datatype - >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) - True - >>> # Simple ArrayType. - >>> simple_arraytype = ArrayType(StringType(), True) - >>> check_datatype(simple_arraytype) - True - >>> # Simple MapType. - >>> simple_maptype = MapType(StringType(), LongType()) - >>> check_datatype(simple_maptype) - True - >>> # Simple StructType. - >>> simple_structtype = StructType([ - ... StructField("a", DecimalType(), False), - ... StructField("b", BooleanType(), True), - ... StructField("c", LongType(), True), - ... StructField("d", BinaryType(), False)]) - >>> check_datatype(simple_structtype) - True - >>> # Complex StructType. - >>> complex_structtype = StructType([ - ... StructField("simpleArray", simple_arraytype, True), - ... StructField("simpleMap", simple_maptype, True), - ... StructField("simpleStruct", simple_structtype, True), - ... StructField("boolean", BooleanType(), False), - ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) - >>> check_datatype(complex_structtype) - True - >>> # Complex ArrayType. - >>> complex_arraytype = ArrayType(complex_structtype, True) - >>> check_datatype(complex_arraytype) - True - >>> # Complex MapType. - >>> complex_maptype = MapType(complex_structtype, - ... complex_arraytype, False) - >>> check_datatype(complex_maptype) - True - >>> check_datatype(ExamplePointUDT()) - True - >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> check_datatype(structtype_with_udt) - True - """ - return _parse_datatype_json_value(json.loads(json_string)) - - -_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") - - -def _parse_datatype_json_value(json_value): - if type(json_value) is unicode: - if json_value in _all_primitive_types.keys(): - return _all_primitive_types[json_value]() - elif json_value == u'decimal': - return DecimalType() - elif _FIXED_DECIMAL.match(json_value): - m = _FIXED_DECIMAL.match(json_value) - return DecimalType(int(m.group(1)), int(m.group(2))) - else: - raise ValueError("Could not parse datatype: %s" % json_value) - else: - tpe = json_value["type"] - if tpe in _all_complex_types: - return _all_complex_types[tpe].fromJson(json_value) - elif tpe == 'udt': - return UserDefinedType.fromJson(json_value) - else: - raise ValueError("not supported type: %s" % tpe) - - -# Mapping Python types to Spark SQL DataType -_type_mappings = { - type(None): NullType, - bool: BooleanType, - int: IntegerType, - long: LongType, - float: DoubleType, - str: StringType, - unicode: StringType, - bytearray: BinaryType, - decimal.Decimal: DecimalType, - datetime.date: DateType, - datetime.datetime: TimestampType, - datetime.time: TimestampType, -} - - -def _infer_type(obj): - """Infer the DataType from obj - - >>> p = ExamplePoint(1.0, 2.0) - >>> _infer_type(p) - ExamplePointUDT - """ - if obj is None: - raise ValueError("Can not infer type for None") - - if hasattr(obj, '__UDT__'): - return obj.__UDT__ - - dataType = _type_mappings.get(type(obj)) - if dataType is not None: - return dataType() - - if isinstance(obj, dict): - for key, value in obj.iteritems(): - if key is not None and value is not None: - return MapType(_infer_type(key), _infer_type(value), True) - else: - return MapType(NullType(), NullType(), True) - elif isinstance(obj, (list, array)): - for v in obj: - if v is not None: - return ArrayType(_infer_type(obj[0]), True) - else: - return ArrayType(NullType(), True) - else: - try: - return _infer_schema(obj) - except ValueError: - raise ValueError("not supported type: %s" % type(obj)) - - -def _infer_schema(row): - """Infer the schema from dict/namedtuple/object""" - if isinstance(row, dict): - items = sorted(row.items()) - - elif isinstance(row, tuple): - if hasattr(row, "_fields"): # namedtuple - items = zip(row._fields, tuple(row)) - elif hasattr(row, "__FIELDS__"): # Row - items = zip(row.__FIELDS__, tuple(row)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in row): - items = row - else: - raise ValueError("Can't infer schema from tuple") - - elif hasattr(row, "__dict__"): # object - items = sorted(row.__dict__.items()) - - else: - raise ValueError("Can not infer schema for type: %s" % type(row)) - - fields = [StructField(k, _infer_type(v), True) for k, v in items] - return StructType(fields) - - -def _need_python_to_sql_conversion(dataType): - """ - Checks whether we need python to sql conversion for the given type. - For now, only UDTs need this conversion. - - >>> _need_python_to_sql_conversion(DoubleType()) - False - >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), - ... StructField("values", ArrayType(DoubleType(), False), False)]) - >>> _need_python_to_sql_conversion(schema0) - False - >>> _need_python_to_sql_conversion(ExamplePointUDT()) - True - >>> schema1 = ArrayType(ExamplePointUDT(), False) - >>> _need_python_to_sql_conversion(schema1) - True - >>> schema2 = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> _need_python_to_sql_conversion(schema2) - True - """ - if isinstance(dataType, StructType): - return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) - elif isinstance(dataType, ArrayType): - return _need_python_to_sql_conversion(dataType.elementType) - elif isinstance(dataType, MapType): - return _need_python_to_sql_conversion(dataType.keyType) or \ - _need_python_to_sql_conversion(dataType.valueType) - elif isinstance(dataType, UserDefinedType): - return True - else: - return False - - -def _python_to_sql_converter(dataType): - """ - Returns a converter that converts a Python object into a SQL datum for the given type. - - >>> conv = _python_to_sql_converter(DoubleType()) - >>> conv(1.0) - 1.0 - >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) - >>> conv([1.0, 2.0]) - [1.0, 2.0] - >>> conv = _python_to_sql_converter(ExamplePointUDT()) - >>> conv(ExamplePoint(1.0, 2.0)) - [1.0, 2.0] - >>> schema = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> conv = _python_to_sql_converter(schema) - >>> conv((1.0, ExamplePoint(1.0, 2.0))) - (1.0, [1.0, 2.0]) - """ - if not _need_python_to_sql_conversion(dataType): - return lambda x: x - - if isinstance(dataType, StructType): - names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - converters = map(_python_to_sql_converter, types) - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): - return tuple(c(v) for c, v in zip(converters, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs - d = dict(obj) - return tuple(c(d.get(n)) for n, c in zip(names, converters)) - else: - return tuple(c(v) for c, v in zip(converters, obj)) - else: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) - return converter - elif isinstance(dataType, ArrayType): - element_converter = _python_to_sql_converter(dataType.elementType) - return lambda a: [element_converter(v) for v in a] - elif isinstance(dataType, MapType): - key_converter = _python_to_sql_converter(dataType.keyType) - value_converter = _python_to_sql_converter(dataType.valueType) - return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) - elif isinstance(dataType, UserDefinedType): - return lambda obj: dataType.serialize(obj) - else: - raise ValueError("Unexpected type %r" % dataType) - - -def _has_nulltype(dt): - """ Return whether there is NullType in `dt` or not """ - if isinstance(dt, StructType): - return any(_has_nulltype(f.dataType) for f in dt.fields) - elif isinstance(dt, ArrayType): - return _has_nulltype((dt.elementType)) - elif isinstance(dt, MapType): - return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) - else: - return isinstance(dt, NullType) - - -def _merge_type(a, b): - if isinstance(a, NullType): - return b - elif isinstance(b, NullType): - return a - elif type(a) is not type(b): - # TODO: type cast (such as int -> long) - raise TypeError("Can not merge type %s and %s" % (a, b)) - - # same type - if isinstance(a, StructType): - nfs = dict((f.name, f.dataType) for f in b.fields) - fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) - for f in a.fields] - names = set([f.name for f in fields]) - for n in nfs: - if n not in names: - fields.append(StructField(n, nfs[n])) - return StructType(fields) - - elif isinstance(a, ArrayType): - return ArrayType(_merge_type(a.elementType, b.elementType), True) - - elif isinstance(a, MapType): - return MapType(_merge_type(a.keyType, b.keyType), - _merge_type(a.valueType, b.valueType), - True) - else: - return a - - -def _create_converter(dataType): - """Create an converter to drop the names of fields in obj """ - if isinstance(dataType, ArrayType): - conv = _create_converter(dataType.elementType) - return lambda row: map(conv, row) - - elif isinstance(dataType, MapType): - kconv = _create_converter(dataType.keyType) - vconv = _create_converter(dataType.valueType) - return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems()) - - elif isinstance(dataType, NullType): - return lambda x: None - - elif not isinstance(dataType, StructType): - return lambda x: x - - # dataType must be StructType - names = [f.name for f in dataType.fields] - converters = [_create_converter(f.dataType) for f in dataType.fields] - - def convert_struct(obj): - if obj is None: - return - - if isinstance(obj, tuple): - if hasattr(obj, "_fields"): - d = dict(zip(obj._fields, obj)) - elif hasattr(obj, "__FIELDS__"): - d = dict(zip(obj.__FIELDS__, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): - d = dict(obj) - else: - raise ValueError("unexpected tuple: %s" % str(obj)) - - elif isinstance(obj, dict): - d = obj - elif hasattr(obj, "__dict__"): # object - d = obj.__dict__ - else: - raise ValueError("Unexpected obj: %s" % obj) - - return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) - - return convert_struct - - -_BRACKETS = {'(': ')', '[': ']', '{': '}'} - - -def _split_schema_abstract(s): - """ - split the schema abstract into fields - - >>> _split_schema_abstract("a b c") - ['a', 'b', 'c'] - >>> _split_schema_abstract("a(a b)") - ['a(a b)'] - >>> _split_schema_abstract("a b[] c{a b}") - ['a', 'b[]', 'c{a b}'] - >>> _split_schema_abstract(" ") - [] - """ - - r = [] - w = '' - brackets = [] - for c in s: - if c == ' ' and not brackets: - if w: - r.append(w) - w = '' - else: - w += c - if c in _BRACKETS: - brackets.append(c) - elif c in _BRACKETS.values(): - if not brackets or c != _BRACKETS[brackets.pop()]: - raise ValueError("unexpected " + c) - - if brackets: - raise ValueError("brackets not closed: %s" % brackets) - if w: - r.append(w) - return r - - -def _parse_field_abstract(s): - """ - Parse a field in schema abstract - - >>> _parse_field_abstract("a") - StructField(a,None,true) - >>> _parse_field_abstract("b(c d)") - StructField(b,StructType(...c,None,true),StructField(d... - >>> _parse_field_abstract("a[]") - StructField(a,ArrayType(None,true),true) - >>> _parse_field_abstract("a{[]}") - StructField(a,MapType(None,ArrayType(None,true),true),true) - """ - if set(_BRACKETS.keys()) & set(s): - idx = min((s.index(c) for c in _BRACKETS if c in s)) - name = s[:idx] - return StructField(name, _parse_schema_abstract(s[idx:]), True) - else: - return StructField(s, None, True) - - -def _parse_schema_abstract(s): - """ - parse abstract into schema - - >>> _parse_schema_abstract("a b c") - StructType...a...b...c... - >>> _parse_schema_abstract("a[b c] b{}") - StructType...a,ArrayType...b...c...b,MapType... - >>> _parse_schema_abstract("c{} d{a b}") - StructType...c,MapType...d,MapType...a...b... - >>> _parse_schema_abstract("a b(t)").fields[1] - StructField(b,StructType(List(StructField(t,None,true))),true) - """ - s = s.strip() - if not s: - return - - elif s.startswith('('): - return _parse_schema_abstract(s[1:-1]) - - elif s.startswith('['): - return ArrayType(_parse_schema_abstract(s[1:-1]), True) - - elif s.startswith('{'): - return MapType(None, _parse_schema_abstract(s[1:-1])) - - parts = _split_schema_abstract(s) - fields = [_parse_field_abstract(p) for p in parts] - return StructType(fields) - - -def _infer_schema_type(obj, dataType): - """ - Fill the dataType with types inferred from obj - - >>> schema = _parse_schema_abstract("a b c d") - >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) - >>> _infer_schema_type(row, schema) - StructType...IntegerType...DoubleType...StringType...DateType... - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> _infer_schema_type(row, schema) - StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType... - """ - if dataType is None: - return _infer_type(obj) - - if not obj: - return NullType() - - if isinstance(dataType, ArrayType): - eType = _infer_schema_type(obj[0], dataType.elementType) - return ArrayType(eType, True) - - elif isinstance(dataType, MapType): - k, v = obj.iteritems().next() - return MapType(_infer_schema_type(k, dataType.keyType), - _infer_schema_type(v, dataType.valueType)) - - elif isinstance(dataType, StructType): - fs = dataType.fields - assert len(fs) == len(obj), \ - "Obj(%s) have different length with fields(%s)" % (obj, fs) - fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) - for o, f in zip(obj, fs)] - return StructType(fields) - - else: - raise ValueError("Unexpected dataType: %s" % dataType) - - -_acceptable_types = { - BooleanType: (bool,), - ByteType: (int, long), - ShortType: (int, long), - IntegerType: (int, long), - LongType: (int, long), - FloatType: (float,), - DoubleType: (float,), - DecimalType: (decimal.Decimal,), - StringType: (str, unicode), - BinaryType: (bytearray,), - DateType: (datetime.date,), - TimestampType: (datetime.datetime,), - ArrayType: (list, tuple, array), - MapType: (dict,), - StructType: (tuple, list), -} - - -def _verify_type(obj, dataType): - """ - Verify the type of obj against dataType, raise an exception if - they do not match. - - >>> _verify_type(None, StructType([])) - >>> _verify_type("", StringType()) - >>> _verify_type(0, IntegerType()) - >>> _verify_type(range(3), ArrayType(ShortType())) - >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - TypeError:... - >>> _verify_type({}, MapType(StringType(), IntegerType())) - >>> _verify_type((), StructType([])) - >>> _verify_type([], StructType([])) - >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - """ - # all objects are nullable - if obj is None: - return - - if isinstance(dataType, UserDefinedType): - if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.serialize(obj), dataType.sqlType()) - return - - _type = type(dataType) - assert _type in _acceptable_types, "unkown datatype: %s" % dataType - - # subclass of them can not be deserialized in JVM - if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object in type %s" - % (dataType, type(obj))) - - if isinstance(dataType, ArrayType): - for i in obj: - _verify_type(i, dataType.elementType) - - elif isinstance(dataType, MapType): - for k, v in obj.iteritems(): - _verify_type(k, dataType.keyType) - _verify_type(v, dataType.valueType) - - elif isinstance(dataType, StructType): - if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with" - "length of fields (%d)" % (len(obj), len(dataType.fields))) - for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType) - - -_cached_cls = {} - - -def _restore_object(dataType, obj): - """ Restore object during unpickling. """ - # use id(dataType) as key to speed up lookup in dict - # Because of batched pickling, dataType will be the - # same object in most cases. - k = id(dataType) - cls = _cached_cls.get(k) - if cls is None: - # use dataType as key to avoid create multiple class - cls = _cached_cls.get(dataType) - if cls is None: - cls = _create_cls(dataType) - _cached_cls[dataType] = cls - _cached_cls[k] = cls - return cls(obj) - - -def _create_object(cls, v): - """ Create an customized object with class `cls`. """ - # datetime.date would be deserialized as datetime.datetime - # from java type, so we need to set it back. - if cls is datetime.date and isinstance(v, datetime.datetime): - return v.date() - return cls(v) if v is not None else v - - -def _create_getter(dt, i): - """ Create a getter for item `i` with schema """ - cls = _create_cls(dt) - - def getter(self): - return _create_object(cls, self[i]) - - return getter - - -def _has_struct_or_date(dt): - """Return whether `dt` is or has StructType/DateType in it""" - if isinstance(dt, StructType): - return True - elif isinstance(dt, ArrayType): - return _has_struct_or_date(dt.elementType) - elif isinstance(dt, MapType): - return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) - elif isinstance(dt, DateType): - return True - elif isinstance(dt, UserDefinedType): - return True - return False - - -def _create_properties(fields): - """Create properties according to fields""" - ps = {} - for i, f in enumerate(fields): - name = f.name - if (name.startswith("__") and name.endswith("__") - or keyword.iskeyword(name)): - warnings.warn("field name %s can not be accessed in Python," - "use position to access it instead" % name) - if _has_struct_or_date(f.dataType): - # delay creating object until accessing it - getter = _create_getter(f.dataType, i) - else: - getter = itemgetter(i) - ps[name] = property(getter) - return ps - - -def _create_cls(dataType): - """ - Create an class by dataType - - The created class is similar to namedtuple, but can have nested schema. - - >>> schema = _parse_schema_abstract("a b c") - >>> row = (1, 1.0, "str") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> import pickle - >>> pickle.loads(pickle.dumps(obj)) - Row(a=1, b=1.0, c='str') - - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> pickle.loads(pickle.dumps(obj)) - Row(a=[1], b={'key': Row(c=1, d=2.0)}) - >>> pickle.loads(pickle.dumps(obj.a)) - [1] - >>> pickle.loads(pickle.dumps(obj.b)) - {'key': Row(c=1, d=2.0)} - """ - - if isinstance(dataType, ArrayType): - cls = _create_cls(dataType.elementType) - - def List(l): - if l is None: - return - return [_create_object(cls, v) for v in l] - - return List - - elif isinstance(dataType, MapType): - kcls = _create_cls(dataType.keyType) - vcls = _create_cls(dataType.valueType) - - def Dict(d): - if d is None: - return - return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) - - return Dict - - elif isinstance(dataType, DateType): - return datetime.date - - elif isinstance(dataType, UserDefinedType): - return lambda datum: dataType.deserialize(datum) - - elif not isinstance(dataType, StructType): - # no wrapper for primitive types - return lambda x: x - - class Row(tuple): - - """ Row in DataFrame """ - __DATATYPE__ = dataType - __FIELDS__ = tuple(f.name for f in dataType.fields) - __slots__ = () - - # create property for fast access - locals().update(_create_properties(dataType.fields)) - - def asDict(self): - """ Return as a dict """ - return dict((n, getattr(self, n)) for n in self.__FIELDS__) - - def __repr__(self): - # call collect __repr__ for nested objects - return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__FIELDS__)) - - def __reduce__(self): - return (_restore_object, (self.__DATATYPE__, tuple(self))) - - return Row - - -class SQLContext(object): - - """Main entry point for Spark SQL functionality. - - A SQLContext can be used create L{DataFrame}, register L{DataFrame} as - tables, execute SQL over tables, cache tables, and read parquet files. - """ - - def __init__(self, sparkContext, sqlContext=None): - """Create a new SQLContext. - - :param sparkContext: The SparkContext to wrap. - :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new - SQLContext in the JVM, instead we make all calls to this object. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - TypeError:... - - >>> bad_rdd = sc.parallelize([1,2,3]) - >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - - >>> from datetime import datetime - >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, - ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), - ... time=datetime(2014, 8, 1, 14, 1, 5))]) - >>> df = sqlCtx.inferSchema(allTypes) - >>> df.registerTempTable("allTypes") - >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' - ... 'from allTypes where b and i > 0').collect() - [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] - >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, - ... x.row.a, x.list)).collect() - [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] - """ - self._sc = sparkContext - self._jsc = self._sc._jsc - self._jvm = self._sc._jvm - self._scala_SQLContext = sqlContext - - @property - def _ssql_ctx(self): - """Accessor for the JVM Spark SQL context. - - Subclasses can override this property to provide their own - JVM Contexts. - """ - if self._scala_SQLContext is None: - self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) - return self._scala_SQLContext - - def registerFunction(self, name, f, returnType=StringType()): - """Registers a lambda function as a UDF so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. - - >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() - [Row(c0=u'4')] - >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() - [Row(c0=4)] - """ - func = lambda _, it: imap(lambda x: f(*x), it) - ser = AutoBatchedSerializer(PickleSerializer()) - command = (func, None, ser, ser) - pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self) - self._ssql_ctx.udf().registerPython(name, - bytearray(pickled_cmd), - env, - includes, - self._sc.pythonExec, - bvars, - self._sc._javaAccumulator, - returnType.json()) - - def inferSchema(self, rdd, samplingRatio=None): - """Infer and apply a schema to an RDD of L{Row}. - - When samplingRatio is specified, the schema is inferred by looking - at the types of each row in the sampled dataset. Otherwise, the - first 100 rows of the RDD are inspected. Nested collections are - supported, which can include array, dict, list, Row, tuple, - namedtuple, or object. - - Each row could be L{pyspark.sql.Row} object or namedtuple or objects. - Using top level dicts is deprecated, as dict is used to represent Maps. - - If a single column has multiple distinct inferred types, it may cause - runtime exceptions. - - >>> rdd = sc.parallelize( - ... [Row(field1=1, field2="row1"), - ... Row(field1=2, field2="row2"), - ... Row(field1=3, field2="row3")]) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.collect()[0] - Row(field1=1, field2=u'row1') - - >>> NestedRow = Row("f1", "f2") - >>> nestedRdd1 = sc.parallelize([ - ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), - ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) - >>> df = sqlCtx.inferSchema(nestedRdd1) - >>> df.collect() - [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] - - >>> nestedRdd2 = sc.parallelize([ - ... NestedRow([[1, 2], [2, 3]], [1, 2]), - ... NestedRow([[2, 3], [3, 4]], [2, 3])]) - >>> df = sqlCtx.inferSchema(nestedRdd2) - >>> df.collect() - [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] - - >>> from collections import namedtuple - >>> CustomRow = namedtuple('CustomRow', 'field1 field2') - >>> rdd = sc.parallelize( - ... [CustomRow(field1=1, field2="row1"), - ... CustomRow(field1=2, field2="row2"), - ... CustomRow(field1=3, field2="row3")]) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.collect()[0] - Row(field1=1, field2=u'row1') - """ - - if isinstance(rdd, DataFrame): - raise TypeError("Cannot apply schema to DataFrame") - - first = rdd.first() - if not first: - raise ValueError("The first row in RDD is empty, " - "can not infer schema") - if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated," - "please use pyspark.sql.Row instead") - - if samplingRatio is None: - schema = _infer_schema(first) - if _has_nulltype(schema): - for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row)) - if not _has_nulltype(schema): - break - else: - warnings.warn("Some of types cannot be determined by the " - "first 100 rows, please try again with sampling") - else: - if samplingRatio > 0.99: - rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(_infer_schema).reduce(_merge_type) - - converter = _create_converter(schema) - rdd = rdd.map(converter) - return self.applySchema(rdd, schema) - - def applySchema(self, rdd, schema): - """ - Applies the given schema to the given RDD of L{tuple} or L{list}. - - These tuples or lists can contain complex nested structures like - lists, maps or nested rows. - - The schema should be a StructType. - - It is important that the schema matches the types of the objects - in each row or exceptions could be thrown at runtime. - - >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) - >>> schema = StructType([StructField("field1", IntegerType(), False), - ... StructField("field2", StringType(), False)]) - >>> df = sqlCtx.applySchema(rdd2, schema) - >>> sqlCtx.registerRDDAsTable(df, "table1") - >>> df2 = sqlCtx.sql("SELECT * from table1") - >>> df2.collect() - [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] - - >>> from datetime import date, datetime - >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, - ... date(2010, 1, 1), - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3], None)]) - >>> schema = StructType([ - ... StructField("byte1", ByteType(), False), - ... StructField("byte2", ByteType(), False), - ... StructField("short1", ShortType(), False), - ... StructField("short2", ShortType(), False), - ... StructField("int", IntegerType(), False), - ... StructField("float", FloatType(), False), - ... StructField("date", DateType(), False), - ... StructField("time", TimestampType(), False), - ... StructField("map", - ... MapType(StringType(), IntegerType(), False), False), - ... StructField("struct", - ... StructType([StructField("b", ShortType(), False)]), False), - ... StructField("list", ArrayType(ByteType(), False), False), - ... StructField("null", DoubleType(), True)]) - >>> df = sqlCtx.applySchema(rdd, schema) - >>> results = df.map( - ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, - ... x.time, x.map["a"], x.struct.b, x.list, x.null)) - >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE - (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), - datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - - >>> df.registerTempTable("table2") - >>> sqlCtx.sql( - ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + - ... "float + 1.5 as float FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] - - >>> rdd = sc.parallelize([(127, -32768, 1.0, - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3])]) - >>> abstract = "byte short float time map{} struct(b) list[]" - >>> schema = _parse_schema_abstract(abstract) - >>> typedSchema = _infer_schema_type(rdd.first(), schema) - >>> df = sqlCtx.applySchema(rdd, typedSchema) - >>> df.collect() - [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] - """ - - if isinstance(rdd, DataFrame): - raise TypeError("Cannot apply schema to DataFrame") - - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) - rows = rdd.take(10) - - for row in rows: - _verify_type(row, schema) - - # convert python objects to sql data - converter = _python_to_sql_converter(schema) - rdd = rdd.map(converter) - - jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return DataFrame(df, self) - - def registerRDDAsTable(self, rdd, tableName): - """Registers the given RDD as a temporary table in the catalog. - - Temporary tables exist only during the lifetime of this instance of - SQLContext. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(df, "table1") - """ - if (rdd.__class__ is DataFrame): - df = rdd._jdf - self._ssql_ctx.registerRDDAsTable(df, tableName) - else: - raise ValueError("Can only register DataFrame as table") - - def parquetFile(self, *paths): - """Loads a Parquet file, returning the result as a L{DataFrame}. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(df.collect()) == sorted(df2.collect()) - True - """ - gateway = self._sc._gateway - jpath = paths[0] - jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths) - 1) - for i in range(1, len(paths)): - jpaths[i] = paths[i] - jdf = self._ssql_ctx.parquetFile(jpath, jpaths) - return DataFrame(jdf, self) - - def jsonFile(self, path, schema=None, samplingRatio=1.0): - """ - Loads a text file storing one JSON object per line as a - L{DataFrame}. - - If the schema is provided, applies the given schema to this - JSON dataset. - - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. - - >>> import tempfile, shutil - >>> jsonFile = tempfile.mkdtemp() - >>> shutil.rmtree(jsonFile) - >>> ofn = open(jsonFile, 'w') - >>> for json in jsonStrings: - ... print>>ofn, json - >>> ofn.close() - >>> df1 = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in df2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema()) - >>> sqlCtx.registerRDDAsTable(df3, "table2") - >>> df4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in df4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> schema = StructType([ - ... StructField("field2", StringType(), True), - ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> df5 = sqlCtx.jsonFile(jsonFile, schema) - >>> sqlCtx.registerRDDAsTable(df5, "table3") - >>> df6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> df6.collect() - [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] - """ - if schema is None: - df = self._ssql_ctx.jsonFile(path, samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.jsonFile(path, scala_datatype) - return DataFrame(df, self) - - def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a L{DataFrame}. - - If the schema is provided, applies the given schema to this - JSON dataset. - - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. - - >>> df1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in df2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> df3 = sqlCtx.jsonRDD(json, df1.schema()) - >>> sqlCtx.registerRDDAsTable(df3, "table2") - >>> df4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in df4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> schema = StructType([ - ... StructField("field2", StringType(), True), - ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> df5 = sqlCtx.jsonRDD(json, schema) - >>> sqlCtx.registerRDDAsTable(df5, "table3") - >>> df6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> df6.collect() - [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] - - >>> sqlCtx.jsonRDD(sc.parallelize(['{}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] - >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] - """ - - def func(iterator): - for x in iterator: - if not isinstance(x, basestring): - x = unicode(x) - if isinstance(x, unicode): - x = x.encode("utf-8") - yield x - keyed = rdd.mapPartitions(func) - keyed._bypass_serializer = True - jrdd = keyed._jrdd.map(self._jvm.BytesToString()) - if schema is None: - df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return DataFrame(df, self) - - def sql(self, sqlQuery): - """Return a L{DataFrame} representing the result of the given query. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(df, "table1") - >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") - >>> df2.collect() - [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] - """ - return DataFrame(self._ssql_ctx.sql(sqlQuery), self) - - def table(self, tableName): - """Returns the specified table as a L{DataFrame}. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(df, "table1") - >>> df2 = sqlCtx.table("table1") - >>> sorted(df.collect()) == sorted(df2.collect()) - True - """ - return DataFrame(self._ssql_ctx.table(tableName), self) - - def cacheTable(self, tableName): - """Caches the specified table in-memory.""" - self._ssql_ctx.cacheTable(tableName) - - def uncacheTable(self, tableName): - """Removes the specified table from the in-memory cache.""" - self._ssql_ctx.uncacheTable(tableName) - - -class HiveContext(SQLContext): - - """A variant of Spark SQL that integrates with data stored in Hive. - - Configuration for Hive is read from hive-site.xml on the classpath. - It supports running both SQL and HiveQL commands. - """ - - def __init__(self, sparkContext, hiveContext=None): - """Create a new HiveContext. - - :param sparkContext: The SparkContext to wrap. - :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new - HiveContext in the JVM, instead we make all calls to this object. - """ - SQLContext.__init__(self, sparkContext) - - if hiveContext: - self._scala_HiveContext = hiveContext - - @property - def _ssql_ctx(self): - try: - if not hasattr(self, '_scala_HiveContext'): - self._scala_HiveContext = self._get_hive_ctx() - return self._scala_HiveContext - except Py4JError as e: - raise Exception("You must build Spark with Hive. " - "Export 'SPARK_HIVE=true' and run " - "build/sbt assembly", e) - - def _get_hive_ctx(self): - return self._jvm.HiveContext(self._jsc.sc()) - - -def _create_row(fields, values): - row = Row(*values) - row.__FIELDS__ = fields - return row - - -class Row(tuple): - - """ - A row in L{DataFrame}. The fields in it can be accessed like attributes. - - Row can be used to create a row object by using named arguments, - the fields will be sorted by names. - - >>> row = Row(name="Alice", age=11) - >>> row - Row(age=11, name='Alice') - >>> row.name, row.age - ('Alice', 11) - - Row also can be used to create another Row like class, then it - could be used to create Row objects, such as - - >>> Person = Row("name", "age") - >>> Person - - >>> Person("Alice", 11) - Row(name='Alice', age=11) - """ - - def __new__(self, *args, **kwargs): - if args and kwargs: - raise ValueError("Can not use both args " - "and kwargs to create Row") - if args: - # create row class or objects - return tuple.__new__(self, args) - - elif kwargs: - # create row objects - names = sorted(kwargs.keys()) - values = tuple(kwargs[n] for n in names) - row = tuple.__new__(self, values) - row.__FIELDS__ = names - return row - - else: - raise ValueError("No args or kwargs") - - def asDict(self): - """ - Return as an dict - """ - if not hasattr(self, "__FIELDS__"): - raise TypeError("Cannot convert a Row class into dict") - return dict(zip(self.__FIELDS__, self)) - - # let obect acs like class - def __call__(self, *args): - """create new Row object""" - return _create_row(self, args) - - def __getattr__(self, item): - if item.startswith("__"): - raise AttributeError(item) - try: - # it will be slow when it has many fields, - # but this will not be used in normal cases - idx = self.__FIELDS__.index(item) - return self[idx] - except IndexError: - raise AttributeError(item) - - def __reduce__(self): - if hasattr(self, "__FIELDS__"): - return (_create_row, (self.__FIELDS__, tuple(self))) - else: - return tuple.__reduce__(self) - - def __repr__(self): - if hasattr(self, "__FIELDS__"): - return "Row(%s)" % ", ".join("%s=%r" % (k, v) - for k, v in zip(self.__FIELDS__, self)) - else: - return "" % ", ".join(self) - - -class DataFrame(object): - - """A collection of rows that have the same columns. - - A :class:`DataFrame` is equivalent to a relational table in Spark SQL, - and can be created using various functions in :class:`SQLContext`:: - - people = sqlContext.parquetFile("...") - - Once created, it can be manipulated using the various domain-specific-language - (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. - - To select a column from the data frame, use the apply method:: - - ageCol = people.age - - Note that the :class:`Column` type can also be manipulated - through its various functions:: - - # The following creates a new column that increases everybody's age by 10. - people.age + 10 - - - A more concrete example:: - - # To create DataFrame using SQLContext - people = sqlContext.parquetFile("...") - department = sqlContext.parquetFile("...") - - people.filter(people.age > 30).join(department, people.deptId == department.id)) \ - .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) - """ - - def __init__(self, jdf, sql_ctx): - self._jdf = jdf - self.sql_ctx = sql_ctx - self._sc = sql_ctx and sql_ctx._sc - self.is_cached = False - - @property - def rdd(self): - """ - Return the content of the :class:`DataFrame` as an :class:`RDD` - of :class:`Row` s. - """ - if not hasattr(self, '_lazy_rdd'): - jrdd = self._jdf.javaToPython() - rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) - schema = self.schema() - - def applySchema(it): - cls = _create_cls(schema) - return itertools.imap(cls, it) - - self._lazy_rdd = rdd.mapPartitions(applySchema) - - return self._lazy_rdd - - def toJSON(self, use_unicode=False): - """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. - - >>> df1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( "SELECT * from table1") - >>> df2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' - True - >>> df3 = sqlCtx.sql( "SELECT field3.field4 from table1") - >>> df3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] - True - """ - rdd = self._jdf.toJSON() - return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) - - def saveAsParquetFile(self, path): - """Save the contents as a Parquet file, preserving the schema. - - Files that are written out using this method can be read back in as - a DataFrame using the L{SQLContext.parquetFile} method. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(df2.collect()) == sorted(df.collect()) - True - """ - self._jdf.saveAsParquetFile(path) - - def registerTempTable(self, name): - """Registers this RDD as a temporary table using the given name. - - The lifetime of this temporary table is tied to the L{SQLContext} - that was used to create this DataFrame. - - >>> df.registerTempTable("people") - >>> df2 = sqlCtx.sql("select * from people") - >>> sorted(df.collect()) == sorted(df2.collect()) - True - """ - self._jdf.registerTempTable(name) - - def registerAsTable(self, name): - """DEPRECATED: use registerTempTable() instead""" - warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) - self.registerTempTable(name) - - def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this DataFrame into the specified table. - - Optionally overwriting any existing data. - """ - self._jdf.insertInto(tableName, overwrite) - - def saveAsTable(self, tableName): - """Creates a new table with the contents of this DataFrame.""" - self._jdf.saveAsTable(tableName) - - def schema(self): - """Returns the schema of this DataFrame (represented by - a L{StructType}). - - >>> df.schema() - StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) - """ - return _parse_datatype_json_string(self._jdf.schema().json()) - - def printSchema(self): - """Prints out the schema in the tree format. - - >>> df.printSchema() - root - |-- age: integer (nullable = true) - |-- name: string (nullable = true) - - """ - print (self._jdf.schema().treeString()) - - def count(self): - """Return the number of elements in this RDD. - - Unlike the base RDD implementation of count, this implementation - leverages the query optimizer to compute the count on the DataFrame, - which supports features such as filter pushdown. - - >>> df.count() - 2L - """ - return self._jdf.count() - - def collect(self): - """Return a list that contains all of the rows. - - Each object in the list is a Row, the fields can be accessed as - attributes. - - >>> df.collect() - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] - """ - with SCCallSiteSync(self._sc) as css: - bytesInJava = self._jdf.javaToPython().collect().iterator() - tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) - tempFile.close() - self._sc._writeToFile(bytesInJava, tempFile.name) - # Read the data into Python and deserialize it: - with open(tempFile.name, 'rb') as tempFile: - rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) - os.unlink(tempFile.name) - cls = _create_cls(self.schema()) - return [cls(r) for r in rs] - - def limit(self, num): - """Limit the result count to the number specified. - - >>> df.limit(1).collect() - [Row(age=2, name=u'Alice')] - >>> df.limit(0).collect() - [] - """ - jdf = self._jdf.limit(num) - return DataFrame(jdf, self.sql_ctx) - - def take(self, num): - """Take the first num rows of the RDD. - - Each object in the list is a Row, the fields can be accessed as - attributes. - - >>> df.take(2) - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] - """ - return self.limit(num).collect() - - def map(self, f): - """ Return a new RDD by applying a function to each Row, it's a - shorthand for df.rdd.map() - - >>> df.map(lambda p: p.name).collect() - [u'Alice', u'Bob'] - """ - return self.rdd.map(f) - - def mapPartitions(self, f, preservesPartitioning=False): - """ - Return a new RDD by applying a function to each partition. - - >>> rdd = sc.parallelize([1, 2, 3, 4], 4) - >>> def f(iterator): yield 1 - >>> rdd.mapPartitions(f).sum() - 4 - """ - return self.rdd.mapPartitions(f, preservesPartitioning) - - def cache(self): - """ Persist with the default storage level (C{MEMORY_ONLY_SER}). - """ - self.is_cached = True - self._jdf.cache() - return self - - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): - """ Set the storage level to persist its values across operations - after the first time it is computed. This can only be used to assign - a new storage level if the RDD does not have a storage level set yet. - If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). - """ - self.is_cached = True - javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) - self._jdf.persist(javaStorageLevel) - return self - - def unpersist(self, blocking=True): - """ Mark it as non-persistent, and remove all blocks for it from - memory and disk. - """ - self.is_cached = False - self._jdf.unpersist(blocking) - return self - - # def coalesce(self, numPartitions, shuffle=False): - # rdd = self._jdf.coalesce(numPartitions, shuffle, None) - # return DataFrame(rdd, self.sql_ctx) - - def repartition(self, numPartitions): - """ Return a new :class:`DataFrame` that has exactly `numPartitions` - partitions. - """ - rdd = self._jdf.repartition(numPartitions, None) - return DataFrame(rdd, self.sql_ctx) - - def sample(self, withReplacement, fraction, seed=None): - """ - Return a sampled subset of this DataFrame. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.sample(False, 0.5, 97).count() - 2L - """ - assert fraction >= 0.0, "Negative fraction value: %s" % fraction - seed = seed if seed is not None else random.randint(0, sys.maxint) - rdd = self._jdf.sample(withReplacement, fraction, long(seed)) - return DataFrame(rdd, self.sql_ctx) - - # def takeSample(self, withReplacement, num, seed=None): - # """Return a fixed-size sampled subset of this DataFrame. - # - # >>> df = sqlCtx.inferSchema(rdd) - # >>> df.takeSample(False, 2, 97) - # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] - # """ - # seed = seed if seed is not None else random.randint(0, sys.maxint) - # with SCCallSiteSync(self.context) as css: - # bytesInJava = self._jdf \ - # .takeSampleToPython(withReplacement, num, long(seed)) \ - # .iterator() - # cls = _create_cls(self.schema()) - # return map(cls, self._collect_iterator_through_file(bytesInJava)) - - @property - def dtypes(self): - """Return all column names and their data types as a list. - - >>> df.dtypes - [('age', 'integer'), ('name', 'string')] - """ - return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields] - - @property - def columns(self): - """ Return all column names as a list. - - >>> df.columns - [u'age', u'name'] - """ - return [f.name for f in self.schema().fields] - - def join(self, other, joinExprs=None, joinType=None): - """ - Join with another DataFrame, using the given join expression. - The following performs a full outer join between `df1` and `df2`:: - - :param other: Right side of the join - :param joinExprs: Join expression - :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. - - >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() - [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] - """ - - if joinExprs is None: - jdf = self._jdf.join(other._jdf) - else: - assert isinstance(joinExprs, Column), "joinExprs should be Column" - if joinType is None: - jdf = self._jdf.join(other._jdf, joinExprs._jc) - else: - assert isinstance(joinType, basestring), "joinType should be basestring" - jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) - return DataFrame(jdf, self.sql_ctx) - - def sort(self, *cols): - """ Return a new :class:`DataFrame` sorted by the specified column. - - :param cols: The columns or expressions used for sorting - - >>> df.sort(df.age.desc()).collect() - [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] - >>> df.sortBy(df.age.desc()).collect() - [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] - """ - if not cols: - raise ValueError("should sort by at least one column") - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - self._sc._gateway._gateway_client) - jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) - return DataFrame(jdf, self.sql_ctx) - - sortBy = sort - - def head(self, n=None): - """ Return the first `n` rows or the first row if n is None. - - >>> df.head() - Row(age=2, name=u'Alice') - >>> df.head(1) - [Row(age=2, name=u'Alice')] - """ - if n is None: - rs = self.head(1) - return rs[0] if rs else None - return self.take(n) - - def first(self): - """ Return the first row. - - >>> df.first() - Row(age=2, name=u'Alice') - """ - return self.head() - - def __getitem__(self, item): - """ Return the column by given name - - >>> df['age'].collect() - [Row(age=2), Row(age=5)] - >>> df[ ["name", "age"]].collect() - [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] - >>> df[ df.age > 3 ].collect() - [Row(age=5, name=u'Bob')] - """ - if isinstance(item, basestring): - jc = self._jdf.apply(item) - return Column(jc, self.sql_ctx) - elif isinstance(item, Column): - return self.filter(item) - elif isinstance(item, list): - return self.select(*item) - else: - raise IndexError("unexpected index: %s" % item) - - def __getattr__(self, name): - """ Return the column by given name - - >>> df.age.collect() - [Row(age=2), Row(age=5)] - """ - if name.startswith("__"): - raise AttributeError(name) - jc = self._jdf.apply(name) - return Column(jc, self.sql_ctx) - - def select(self, *cols): - """ Selecting a set of expressions. - - >>> df.select().collect() - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] - >>> df.select('*').collect() - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] - >>> df.select('name', 'age').collect() - [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] - >>> df.select(df.name, (df.age + 10).alias('age')).collect() - [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] - """ - if not cols: - cols = ["*"] - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - self._sc._gateway._gateway_client) - jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) - return DataFrame(jdf, self.sql_ctx) - - def selectExpr(self, *expr): - """ - Selects a set of SQL expressions. This is a variant of - `select` that accepts SQL expressions. - - >>> df.selectExpr("age * 2", "abs(age)").collect() - [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)] - """ - jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client) - jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) - return DataFrame(jdf, self.sql_ctx) - - def filter(self, condition): - """ Filtering rows using the given condition, which could be - Column expression or string of SQL expression. - - where() is an alias for filter(). - - >>> df.filter(df.age > 3).collect() - [Row(age=5, name=u'Bob')] - >>> df.where(df.age == 2).collect() - [Row(age=2, name=u'Alice')] - - >>> df.filter("age > 3").collect() - [Row(age=5, name=u'Bob')] - >>> df.where("age = 2").collect() - [Row(age=2, name=u'Alice')] - """ - if isinstance(condition, basestring): - jdf = self._jdf.filter(condition) - elif isinstance(condition, Column): - jdf = self._jdf.filter(condition._jc) - else: - raise TypeError("condition should be string or Column") - return DataFrame(jdf, self.sql_ctx) - - where = filter - - def groupBy(self, *cols): - """ Group the :class:`DataFrame` using the specified columns, - so we can run aggregation on them. See :class:`GroupedData` - for all the available aggregate functions. - - >>> df.groupBy().avg().collect() - [Row(AVG(age#0)=3.5)] - >>> df.groupBy('name').agg({'age': 'mean'}).collect() - [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] - >>> df.groupBy(df.name).avg().collect() - [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] - """ - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - self._sc._gateway._gateway_client) - jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) - return GroupedData(jdf, self.sql_ctx) - - def agg(self, *exprs): - """ Aggregate on the entire :class:`DataFrame` without groups - (shorthand for df.groupBy.agg()). - - >>> df.agg({"age": "max"}).collect() - [Row(MAX(age#0)=5)] - >>> from pyspark.sql import Dsl - >>> df.agg(Dsl.min(df.age)).collect() - [Row(MIN(age#0)=2)] - """ - return self.groupBy().agg(*exprs) - - def unionAll(self, other): - """ Return a new DataFrame containing union of rows in this - frame and another frame. - - This is equivalent to `UNION ALL` in SQL. - """ - return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) - - def intersect(self, other): - """ Return a new :class:`DataFrame` containing rows only in - both this frame and another frame. - - This is equivalent to `INTERSECT` in SQL. - """ - return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) - - def subtract(self, other): - """ Return a new :class:`DataFrame` containing rows in this frame - but not in another frame. - - This is equivalent to `EXCEPT` in SQL. - """ - return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) - - def addColumn(self, colName, col): - """ Return a new :class:`DataFrame` by adding a column. - - >>> df.addColumn('age2', df.age + 2).collect() - [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] - """ - return self.select('*', col.alias(colName)) - - def to_pandas(self): - """ - Collect all the rows and return a `pandas.DataFrame`. - - >>> df.to_pandas() # doctest: +SKIP - age name - 0 2 Alice - 1 5 Bob - """ - import pandas as pd - return pd.DataFrame.from_records(self.collect(), columns=self.columns) - - -# Having SchemaRDD for backward compatibility (for docs) -class SchemaRDD(DataFrame): - """ - SchemaRDD is deprecated, please use DataFrame - """ - - -def dfapi(f): - def _api(self): - name = f.__name__ - jdf = getattr(self._jdf, name)() - return DataFrame(jdf, self.sql_ctx) - _api.__name__ = f.__name__ - _api.__doc__ = f.__doc__ - return _api - - -class GroupedData(object): - - """ - A set of methods for aggregations on a :class:`DataFrame`, - created by DataFrame.groupBy(). - """ - - def __init__(self, jdf, sql_ctx): - self._jdf = jdf - self.sql_ctx = sql_ctx - - def agg(self, *exprs): - """ Compute aggregates by specifying a map from column name - to aggregate methods. - - The available aggregate methods are `avg`, `max`, `min`, - `sum`, `count`. - - :param exprs: list or aggregate columns or a map from column - name to aggregate methods. - - >>> gdf = df.groupBy(df.name) - >>> gdf.agg({"age": "max"}).collect() - [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)] - >>> from pyspark.sql import Dsl - >>> gdf.agg(Dsl.min(df.age)).collect() - [Row(MIN(age#0)=5), Row(MIN(age#0)=2)] - """ - assert exprs, "exprs should not be empty" - if len(exprs) == 1 and isinstance(exprs[0], dict): - jmap = MapConverter().convert(exprs[0], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(jmap) - else: - # Columns - assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jcols = ListConverter().convert([c._jc for c in exprs[1:]], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) - return DataFrame(jdf, self.sql_ctx) - - @dfapi - def count(self): - """ Count the number of rows for each group. - - >>> df.groupBy(df.age).count().collect() - [Row(age=2, count=1), Row(age=5, count=1)] - """ - - @dfapi - def mean(self): - """Compute the average value for each numeric columns - for each group. This is an alias for `avg`.""" - - @dfapi - def avg(self): - """Compute the average value for each numeric columns - for each group.""" - - @dfapi - def max(self): - """Compute the max value for each numeric columns for - each group. """ - - @dfapi - def min(self): - """Compute the min value for each numeric column for - each group.""" - - @dfapi - def sum(self): - """Compute the sum for each numeric columns for each - group.""" - - -def _create_column_from_literal(literal): - sc = SparkContext._active_spark_context - return sc._jvm.Dsl.lit(literal) - - -def _create_column_from_name(name): - sc = SparkContext._active_spark_context - return sc._jvm.Dsl.col(name) - - -def _to_java_column(col): - if isinstance(col, Column): - jcol = col._jc - else: - jcol = _create_column_from_name(col) - return jcol - - -def _unary_op(name, doc="unary operator"): - """ Create a method for given unary operator """ - def _(self): - jc = getattr(self._jc, name)() - return Column(jc, self.sql_ctx) - _.__doc__ = doc - return _ - - -def _dsl_op(name, doc=''): - def _(self): - jc = getattr(self._sc._jvm.Dsl, name)(self._jc) - return Column(jc, self.sql_ctx) - _.__doc__ = doc - return _ - - -def _bin_op(name, doc="binary operator"): - """ Create a method for given binary operator - """ - def _(self, other): - jc = other._jc if isinstance(other, Column) else other - njc = getattr(self._jc, name)(jc) - return Column(njc, self.sql_ctx) - _.__doc__ = doc - return _ - - -def _reverse_op(name, doc="binary operator"): - """ Create a method for binary operator (this object is on right side) - """ - def _(self, other): - jother = _create_column_from_literal(other) - jc = getattr(jother, name)(self._jc) - return Column(jc, self.sql_ctx) - _.__doc__ = doc - return _ - - -class Column(DataFrame): - - """ - A column in a DataFrame. - - `Column` instances can be created by:: - - # 1. Select a column out of a DataFrame - df.colName - df["colName"] - - # 2. Create from an expression - df.colName + 1 - 1 / df.colName - """ - - def __init__(self, jc, sql_ctx=None): - self._jc = jc - super(Column, self).__init__(jc, sql_ctx) - - # arithmetic operators - __neg__ = _dsl_op("negate") - __add__ = _bin_op("plus") - __sub__ = _bin_op("minus") - __mul__ = _bin_op("multiply") - __div__ = _bin_op("divide") - __mod__ = _bin_op("mod") - __radd__ = _bin_op("plus") - __rsub__ = _reverse_op("minus") - __rmul__ = _bin_op("multiply") - __rdiv__ = _reverse_op("divide") - __rmod__ = _reverse_op("mod") - - # logistic operators - __eq__ = _bin_op("equalTo") - __ne__ = _bin_op("notEqual") - __lt__ = _bin_op("lt") - __le__ = _bin_op("leq") - __ge__ = _bin_op("geq") - __gt__ = _bin_op("gt") - - # `and`, `or`, `not` cannot be overloaded in Python, - # so use bitwise operators as boolean operators - __and__ = _bin_op('and') - __or__ = _bin_op('or') - __invert__ = _dsl_op('not') - __rand__ = _bin_op("and") - __ror__ = _bin_op("or") - - # container operators - __contains__ = _bin_op("contains") - __getitem__ = _bin_op("getItem") - getField = _bin_op("getField", "An expression that gets a field by name in a StructField.") - - # string methods - rlike = _bin_op("rlike") - like = _bin_op("like") - startswith = _bin_op("startsWith") - endswith = _bin_op("endsWith") - - def substr(self, startPos, length): - """ - Return a Column which is a substring of the column - - :param startPos: start position (int or Column) - :param length: length of the substring (int or Column) - - >>> df.name.substr(1, 3).collect() - [Row(col=u'Ali'), Row(col=u'Bob')] - """ - if type(startPos) != type(length): - raise TypeError("Can not mix the type") - if isinstance(startPos, (int, long)): - jc = self._jc.substr(startPos, length) - elif isinstance(startPos, Column): - jc = self._jc.substr(startPos._jc, length._jc) - else: - raise TypeError("Unexpected type: %s" % type(startPos)) - return Column(jc, self.sql_ctx) - - __getslice__ = substr - - # order - asc = _unary_op("asc") - desc = _unary_op("desc") - - isNull = _unary_op("isNull", "True if the current expression is null.") - isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") - - def alias(self, alias): - """Return a alias for this column - - >>> df.age.alias("age2").collect() - [Row(age2=2), Row(age2=5)] - """ - return Column(getattr(self._jc, "as")(alias), self.sql_ctx) - - def cast(self, dataType): - """ Convert the column into type `dataType` - - >>> df.select(df.age.cast("string").alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] - >>> df.select(df.age.cast(StringType()).alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] - """ - if self.sql_ctx is None: - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - else: - ssql_ctx = self.sql_ctx._ssql_ctx - if isinstance(dataType, basestring): - jc = self._jc.cast(dataType) - elif isinstance(dataType, DataType): - jdt = ssql_ctx.parseDataType(dataType.json()) - jc = self._jc.cast(jdt) - return Column(jc, self.sql_ctx) - - def to_pandas(self): - """ - Return a pandas.Series from the column - - >>> df.age.to_pandas() # doctest: +SKIP - 0 2 - 1 5 - dtype: int64 - """ - import pandas as pd - data = [c for c, in self.collect()] - return pd.Series(data) - - -def _aggregate_func(name, doc=""): - """ Create a function for aggregator by name""" - def _(col): - sc = SparkContext._active_spark_context - jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col)) - return Column(jc) - _.__name__ = name - _.__doc__ = doc - return staticmethod(_) - - -class UserDefinedFunction(object): - def __init__(self, func, returnType): - self.func = func - self.returnType = returnType - self._broadcast = None - self._judf = self._create_judf() - - def _create_judf(self): - f = self.func # put it in closure `func` - func = lambda _, it: imap(lambda x: f(*x), it) - ser = AutoBatchedSerializer(PickleSerializer()) - command = (func, None, ser, ser) - sc = SparkContext._active_spark_context - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(self.returnType.json()) - judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, - includes, sc.pythonExec, broadcast_vars, - sc._javaAccumulator, jdt) - return judf - - def __del__(self): - if self._broadcast is not None: - self._broadcast.unpersist() - self._broadcast = None - - def __call__(self, *cols): - sc = SparkContext._active_spark_context - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - sc._gateway._gateway_client) - jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols)) - return Column(jc) - - -class Dsl(object): - """ - A collections of builtin aggregators - """ - DSLS = { - 'lit': 'Creates a :class:`Column` of literal value.', - 'col': 'Returns a :class:`Column` based on the given column name.', - 'column': 'Returns a :class:`Column` based on the given column name.', - 'upper': 'Converts a string expression to upper case.', - 'lower': 'Converts a string expression to upper case.', - 'sqrt': 'Computes the square root of the specified float value.', - 'abs': 'Computes the absolutle value.', - - 'max': 'Aggregate function: returns the maximum value of the expression in a group.', - 'min': 'Aggregate function: returns the minimum value of the expression in a group.', - 'first': 'Aggregate function: returns the first value in a group.', - 'last': 'Aggregate function: returns the last value in a group.', - 'count': 'Aggregate function: returns the number of items in a group.', - 'sum': 'Aggregate function: returns the sum of all values in the expression.', - 'avg': 'Aggregate function: returns the average of the values in a group.', - 'mean': 'Aggregate function: returns the average of the values in a group.', - 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', - } - - for _name, _doc in DSLS.items(): - locals()[_name] = _aggregate_func(_name, _doc) - del _name, _doc - - @staticmethod - def countDistinct(col, *cols): - """ Return a new Column for distinct count of (col, *cols) - - >>> from pyspark.sql import Dsl - >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect() - [Row(c=2)] - - >>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect() - [Row(c=2)] - """ - sc = SparkContext._active_spark_context - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - sc._gateway._gateway_client) - jc = sc._jvm.Dsl.countDistinct(_to_java_column(col), - sc._jvm.PythonUtils.toSeq(jcols)) - return Column(jc) - - @staticmethod - def approxCountDistinct(col, rsd=None): - """ Return a new Column for approxiate distinct count of (col, *cols) - - >>> from pyspark.sql import Dsl - >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect() - [Row(c=2)] - """ - sc = SparkContext._active_spark_context - if rsd is None: - jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col)) - else: - jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd) - return Column(jc) - - @staticmethod - def udf(f, returnType=StringType()): - """Create a user defined function (UDF) - - >>> slen = Dsl.udf(lambda s: len(s), IntegerType()) - >>> df.select(slen(df.name).alias('slen')).collect() - [Row(slen=5), Row(slen=3)] - """ - return UserDefinedFunction(f, returnType) - - -def _test(): - import doctest - from pyspark.context import SparkContext - # let doctest run in pyspark.sql, so DataTypes can be picklable - import pyspark.sql - from pyspark.sql import Row, SQLContext - from pyspark.sql_tests import ExamplePoint, ExamplePointUDT - globs = pyspark.sql.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') - globs['sc'] = sc - globs['sqlCtx'] = sqlCtx = SQLContext(sc) - globs['rdd'] = sc.parallelize( - [Row(field1=1, field2="row1"), - Row(field1=2, field2="row2"), - Row(field1=3, field2="row3")] - ) - rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]) - rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]) - globs['df'] = sqlCtx.inferSchema(rdd2) - globs['df2'] = sqlCtx.inferSchema(rdd3) - globs['ExamplePoint'] = ExamplePoint - globs['ExamplePointUDT'] = ExamplePointUDT - jsonStrings = [ - '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' - '"field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", ' - '"field3":{"field4":33, "field5": []}}' - ] - globs['jsonStrings'] = jsonStrings - globs['json'] = sc.parallelize(jsonStrings) - (failure_count, test_count) = doctest.testmod( - pyspark.sql, globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() - if failure_count: - exit(-1) - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py new file mode 100644 index 0000000000000..0a5ba00393aab --- /dev/null +++ b/python/pyspark/sql/__init__.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +public classes of Spark SQL: + + - L{SQLContext} + Main entry point for SQL functionality. + - L{DataFrame} + A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In + addition to normal RDD operations, DataFrames also support SQL. + - L{GroupedData} + - L{Column} + Column is a DataFrame with a single column. + - L{Row} + A Row of data returned by a Spark SQL query. + - L{HiveContext} + Main entry point for accessing data stored in Apache Hive.. +""" + +from pyspark.sql.context import SQLContext, HiveContext +from pyspark.sql.types import Row +from pyspark.sql.dataframe import DataFrame, GroupedData, Column, Dsl, SchemaRDD + +__all__ = [ + 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', + 'Dsl', +] diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py new file mode 100644 index 0000000000000..49f016a9cf2e9 --- /dev/null +++ b/python/pyspark/sql/context.py @@ -0,0 +1,642 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +import json +from array import array +from itertools import imap + +from py4j.protocol import Py4JError + +from pyspark.rdd import _prepare_for_python_RDD +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.sql.types import StringType, StructType, _verify_type, \ + _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter +from pyspark.sql.dataframe import DataFrame + +__all__ = ["SQLContext", "HiveContext"] + + +class SQLContext(object): + + """Main entry point for Spark SQL functionality. + + A SQLContext can be used create L{DataFrame}, register L{DataFrame} as + tables, execute SQL over tables, cache tables, and read parquet files. + """ + + def __init__(self, sparkContext, sqlContext=None): + """Create a new SQLContext. + + :param sparkContext: The SparkContext to wrap. + :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + SQLContext in the JVM, instead we make all calls to this object. + + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError:... + + >>> bad_rdd = sc.parallelize([1,2,3]) + >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + + >>> from datetime import datetime + >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, + ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), + ... time=datetime(2014, 8, 1, 14, 1, 5))]) + >>> df = sqlCtx.inferSchema(allTypes) + >>> df.registerTempTable("allTypes") + >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + ... 'from allTypes where b and i > 0').collect() + [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] + >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, + ... x.row.a, x.list)).collect() + [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] + """ + self._sc = sparkContext + self._jsc = self._sc._jsc + self._jvm = self._sc._jvm + self._scala_SQLContext = sqlContext + + @property + def _ssql_ctx(self): + """Accessor for the JVM Spark SQL context. + + Subclasses can override this property to provide their own + JVM Contexts. + """ + if self._scala_SQLContext is None: + self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) + return self._scala_SQLContext + + def registerFunction(self, name, f, returnType=StringType()): + """Registers a lambda function as a UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() + [Row(c0=u'4')] + >>> from pyspark.sql.types import IntegerType + >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + """ + func = lambda _, it: imap(lambda x: f(*x), it) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self) + self._ssql_ctx.udf().registerPython(name, + bytearray(pickled_cmd), + env, + includes, + self._sc.pythonExec, + bvars, + self._sc._javaAccumulator, + returnType.json()) + + def inferSchema(self, rdd, samplingRatio=None): + """Infer and apply a schema to an RDD of L{Row}. + + When samplingRatio is specified, the schema is inferred by looking + at the types of each row in the sampled dataset. Otherwise, the + first 100 rows of the RDD are inspected. Nested collections are + supported, which can include array, dict, list, Row, tuple, + namedtuple, or object. + + Each row could be L{pyspark.sql.Row} object or namedtuple or objects. + Using top level dicts is deprecated, as dict is used to represent Maps. + + If a single column has multiple distinct inferred types, it may cause + runtime exceptions. + + >>> rdd = sc.parallelize( + ... [Row(field1=1, field2="row1"), + ... Row(field1=2, field2="row2"), + ... Row(field1=3, field2="row3")]) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] + Row(field1=1, field2=u'row1') + + >>> NestedRow = Row("f1", "f2") + >>> nestedRdd1 = sc.parallelize([ + ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), + ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) + >>> df = sqlCtx.inferSchema(nestedRdd1) + >>> df.collect() + [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] + + >>> nestedRdd2 = sc.parallelize([ + ... NestedRow([[1, 2], [2, 3]], [1, 2]), + ... NestedRow([[2, 3], [3, 4]], [2, 3])]) + >>> df = sqlCtx.inferSchema(nestedRdd2) + >>> df.collect() + [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] + + >>> from collections import namedtuple + >>> CustomRow = namedtuple('CustomRow', 'field1 field2') + >>> rdd = sc.parallelize( + ... [CustomRow(field1=1, field2="row1"), + ... CustomRow(field1=2, field2="row2"), + ... CustomRow(field1=3, field2="row3")]) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] + Row(field1=1, field2=u'row1') + """ + + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") + + first = rdd.first() + if not first: + raise ValueError("The first row in RDD is empty, " + "can not infer schema") + if type(first) is dict: + warnings.warn("Using RDD of dict to inferSchema is deprecated," + "please use pyspark.sql.Row instead") + + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + warnings.warn("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio > 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + + converter = _create_converter(schema) + rdd = rdd.map(converter) + return self.applySchema(rdd, schema) + + def applySchema(self, rdd, schema): + """ + Applies the given schema to the given RDD of L{tuple} or L{list}. + + These tuples or lists can contain complex nested structures like + lists, maps or nested rows. + + The schema should be a StructType. + + It is important that the schema matches the types of the objects + in each row or exceptions could be thrown at runtime. + + >>> from pyspark.sql.types import * + >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) + >>> schema = StructType([StructField("field1", IntegerType(), False), + ... StructField("field2", StringType(), False)]) + >>> df = sqlCtx.applySchema(rdd2, schema) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT * from table1") + >>> df2.collect() + [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] + + >>> from datetime import date, datetime + >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, + ... date(2010, 1, 1), + ... datetime(2010, 1, 1, 1, 1, 1), + ... {"a": 1}, (2,), [1, 2, 3], None)]) + >>> schema = StructType([ + ... StructField("byte1", ByteType(), False), + ... StructField("byte2", ByteType(), False), + ... StructField("short1", ShortType(), False), + ... StructField("short2", ShortType(), False), + ... StructField("int", IntegerType(), False), + ... StructField("float", FloatType(), False), + ... StructField("date", DateType(), False), + ... StructField("time", TimestampType(), False), + ... StructField("map", + ... MapType(StringType(), IntegerType(), False), False), + ... StructField("struct", + ... StructType([StructField("b", ShortType(), False)]), False), + ... StructField("list", ArrayType(ByteType(), False), False), + ... StructField("null", DoubleType(), True)]) + >>> df = sqlCtx.applySchema(rdd, schema) + >>> results = df.map( + ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, + ... x.time, x.map["a"], x.struct.b, x.list, x.null)) + >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE + (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), + datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + + >>> df.registerTempTable("table2") + >>> sqlCtx.sql( + ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + + ... "float + 1.5 as float FROM table2").collect() + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] + + >>> from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type + >>> rdd = sc.parallelize([(127, -32768, 1.0, + ... datetime(2010, 1, 1, 1, 1, 1), + ... {"a": 1}, (2,), [1, 2, 3])]) + >>> abstract = "byte short float time map{} struct(b) list[]" + >>> schema = _parse_schema_abstract(abstract) + >>> typedSchema = _infer_schema_type(rdd.first(), schema) + >>> df = sqlCtx.applySchema(rdd, typedSchema) + >>> df.collect() + [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] + """ + + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") + + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + + # take the first few rows to verify schema + rows = rdd.take(10) + # Row() cannot been deserialized by Pyrolite + if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': + rdd = rdd.map(tuple) + rows = rdd.take(10) + + for row in rows: + _verify_type(row, schema) + + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) + df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + return DataFrame(df, self) + + def registerRDDAsTable(self, rdd, tableName): + """Registers the given RDD as a temporary table in the catalog. + + Temporary tables exist only during the lifetime of this instance of + SQLContext. + + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + """ + if (rdd.__class__ is DataFrame): + df = rdd._jdf + self._ssql_ctx.registerRDDAsTable(df, tableName) + else: + raise ValueError("Can only register DataFrame as table") + + def parquetFile(self, *paths): + """Loads a Parquet file, returning the result as a L{DataFrame}. + + >>> import tempfile, shutil + >>> parquetFile = tempfile.mkdtemp() + >>> shutil.rmtree(parquetFile) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + gateway = self._sc._gateway + jpath = paths[0] + jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths) - 1) + for i in range(1, len(paths)): + jpaths[i] = paths[i] + jdf = self._ssql_ctx.parquetFile(jpath, jpaths) + return DataFrame(jdf, self) + + def jsonFile(self, path, schema=None, samplingRatio=1.0): + """ + Loads a text file storing one JSON object per line as a + L{DataFrame}. + + If the schema is provided, applies the given schema to this + JSON dataset. + + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. + + >>> import tempfile, shutil + >>> jsonFile = tempfile.mkdtemp() + >>> shutil.rmtree(jsonFile) + >>> ofn = open(jsonFile, 'w') + >>> for json in jsonStrings: + ... print>>ofn, json + >>> ofn.close() + >>> df1 = sqlCtx.jsonFile(jsonFile) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table1") + >>> for r in df2.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + + >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table2") + >>> for r in df4.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", + ... ArrayType(IntegerType(), False), True)]), False)]) + >>> df5 = sqlCtx.jsonFile(jsonFile, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, " + ... "field3.field5[0] as f3 from table3") + >>> df6.collect() + [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] + """ + if schema is None: + df = self._ssql_ctx.jsonFile(path, samplingRatio) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.jsonFile(path, scala_datatype) + return DataFrame(df, self) + + def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): + """Loads an RDD storing one JSON object per string as a L{DataFrame}. + + If the schema is provided, applies the given schema to this + JSON dataset. + + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. + + >>> df1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table1") + >>> for r in df2.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + + >>> df3 = sqlCtx.jsonRDD(json, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table2") + >>> for r in df4.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", + ... ArrayType(IntegerType(), False), True)]), False)]) + >>> df5 = sqlCtx.jsonRDD(json, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, " + ... "field3.field5[0] as f3 from table3") + >>> df6.collect() + [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] + + >>> sqlCtx.jsonRDD(sc.parallelize(['{}', + ... '{"key0": {"key1": "value1"}}'])).collect() + [Row(key0=None), Row(key0=Row(key1=u'value1'))] + >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', + ... '{"key0": {"key1": "value1"}}'])).collect() + [Row(key0=None), Row(key0=Row(key1=u'value1'))] + """ + + def func(iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x + keyed = rdd.mapPartitions(func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._jvm.BytesToString()) + if schema is None: + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) + return DataFrame(df, self) + + def sql(self, sqlQuery): + """Return a L{DataFrame} representing the result of the given query. + + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> df2.collect() + [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] + """ + return DataFrame(self._ssql_ctx.sql(sqlQuery), self) + + def table(self, tableName): + """Returns the specified table as a L{DataFrame}. + + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.table("table1") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + return DataFrame(self._ssql_ctx.table(tableName), self) + + def cacheTable(self, tableName): + """Caches the specified table in-memory.""" + self._ssql_ctx.cacheTable(tableName) + + def uncacheTable(self, tableName): + """Removes the specified table from the in-memory cache.""" + self._ssql_ctx.uncacheTable(tableName) + + +class HiveContext(SQLContext): + + """A variant of Spark SQL that integrates with data stored in Hive. + + Configuration for Hive is read from hive-site.xml on the classpath. + It supports running both SQL and HiveQL commands. + """ + + def __init__(self, sparkContext, hiveContext=None): + """Create a new HiveContext. + + :param sparkContext: The SparkContext to wrap. + :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new + HiveContext in the JVM, instead we make all calls to this object. + """ + SQLContext.__init__(self, sparkContext) + + if hiveContext: + self._scala_HiveContext = hiveContext + + @property + def _ssql_ctx(self): + try: + if not hasattr(self, '_scala_HiveContext'): + self._scala_HiveContext = self._get_hive_ctx() + return self._scala_HiveContext + except Py4JError as e: + raise Exception("You must build Spark with Hive. " + "Export 'SPARK_HIVE=true' and run " + "build/sbt assembly", e) + + def _get_hive_ctx(self): + return self._jvm.HiveContext(self._jsc.sc()) + + +def _create_row(fields, values): + row = Row(*values) + row.__FIELDS__ = fields + return row + + +class Row(tuple): + + """ + A row in L{DataFrame}. The fields in it can be accessed like attributes. + + Row can be used to create a row object by using named arguments, + the fields will be sorted by names. + + >>> row = Row(name="Alice", age=11) + >>> row + Row(age=11, name='Alice') + >>> row.name, row.age + ('Alice', 11) + + Row also can be used to create another Row like class, then it + could be used to create Row objects, such as + + >>> Person = Row("name", "age") + >>> Person + + >>> Person("Alice", 11) + Row(name='Alice', age=11) + """ + + def __new__(self, *args, **kwargs): + if args and kwargs: + raise ValueError("Can not use both args " + "and kwargs to create Row") + if args: + # create row class or objects + return tuple.__new__(self, args) + + elif kwargs: + # create row objects + names = sorted(kwargs.keys()) + values = tuple(kwargs[n] for n in names) + row = tuple.__new__(self, values) + row.__FIELDS__ = names + return row + + else: + raise ValueError("No args or kwargs") + + def asDict(self): + """ + Return as an dict + """ + if not hasattr(self, "__FIELDS__"): + raise TypeError("Cannot convert a Row class into dict") + return dict(zip(self.__FIELDS__, self)) + + # let obect acs like class + def __call__(self, *args): + """create new Row object""" + return _create_row(self, args) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__FIELDS__.index(item) + return self[idx] + except IndexError: + raise AttributeError(item) + + def __reduce__(self): + if hasattr(self, "__FIELDS__"): + return (_create_row, (self.__FIELDS__, tuple(self))) + else: + return tuple.__reduce__(self) + + def __repr__(self): + if hasattr(self, "__FIELDS__"): + return "Row(%s)" % ", ".join("%s=%r" % (k, v) + for k, v in zip(self.__FIELDS__, self)) + else: + return "" % ", ".join(self) + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.context + globs = pyspark.sql.context.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx = SQLContext(sc) + globs['rdd'] = sc.parallelize( + [Row(field1=1, field2="row1"), + Row(field1=2, field2="row2"), + Row(field1=3, field2="row3")] + ) + jsonStrings = [ + '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', + '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' + '"field6":[{"field7": "row2"}]}', + '{"field1" : null, "field2": "row3", ' + '"field3":{"field4":33, "field5": []}}' + ] + globs['jsonStrings'] = jsonStrings + globs['json'] = sc.parallelize(jsonStrings) + (failure_count, test_count) = doctest.testmod( + pyspark.sql.context, globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py new file mode 100644 index 0000000000000..cda704eea75f5 --- /dev/null +++ b/python/pyspark/sql/dataframe.py @@ -0,0 +1,974 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import itertools +import warnings +import random +import os +from tempfile import NamedTemporaryFile +from itertools import imap + +from py4j.java_collections import ListConverter, MapConverter + +from pyspark.context import SparkContext +from pyspark.rdd import RDD, _prepare_for_python_RDD +from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ + UTF8Deserializer +from pyspark.storagelevel import StorageLevel +from pyspark.traceback_utils import SCCallSiteSync +from pyspark.sql.types import * +from pyspark.sql.types import _create_cls, _parse_datatype_json_string + + +__all__ = ["DataFrame", "GroupedData", "Column", "Dsl", "SchemaRDD"] + + +class DataFrame(object): + + """A collection of rows that have the same columns. + + A :class:`DataFrame` is equivalent to a relational table in Spark SQL, + and can be created using various functions in :class:`SQLContext`:: + + people = sqlContext.parquetFile("...") + + Once created, it can be manipulated using the various domain-specific-language + (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. + + To select a column from the data frame, use the apply method:: + + ageCol = people.age + + Note that the :class:`Column` type can also be manipulated + through its various functions:: + + # The following creates a new column that increases everybody's age by 10. + people.age + 10 + + + A more concrete example:: + + # To create DataFrame using SQLContext + people = sqlContext.parquetFile("...") + department = sqlContext.parquetFile("...") + + people.filter(people.age > 30).join(department, people.deptId == department.id)) \ + .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + self._sc = sql_ctx and sql_ctx._sc + self.is_cached = False + + @property + def rdd(self): + """ + Return the content of the :class:`DataFrame` as an :class:`RDD` + of :class:`Row` s. + """ + if not hasattr(self, '_lazy_rdd'): + jrdd = self._jdf.javaToPython() + rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + schema = self.schema() + + def applySchema(it): + cls = _create_cls(schema) + return itertools.imap(cls, it) + + self._lazy_rdd = rdd.mapPartitions(applySchema) + + return self._lazy_rdd + + def toJSON(self, use_unicode=False): + """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. + + >>> df.toJSON().first() + '{"age":2,"name":"Alice"}' + """ + rdd = self._jdf.toJSON() + return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) + + def saveAsParquetFile(self, path): + """Save the contents as a Parquet file, preserving the schema. + + Files that are written out using this method can be read back in as + a DataFrame using the L{SQLContext.parquetFile} method. + + >>> import tempfile, shutil + >>> parquetFile = tempfile.mkdtemp() + >>> shutil.rmtree(parquetFile) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df2.collect()) == sorted(df.collect()) + True + """ + self._jdf.saveAsParquetFile(path) + + def registerTempTable(self, name): + """Registers this RDD as a temporary table using the given name. + + The lifetime of this temporary table is tied to the L{SQLContext} + that was used to create this DataFrame. + + >>> df.registerTempTable("people") + >>> df2 = sqlCtx.sql("select * from people") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + self._jdf.registerTempTable(name) + + def registerAsTable(self, name): + """DEPRECATED: use registerTempTable() instead""" + warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) + self.registerTempTable(name) + + def insertInto(self, tableName, overwrite=False): + """Inserts the contents of this DataFrame into the specified table. + + Optionally overwriting any existing data. + """ + self._jdf.insertInto(tableName, overwrite) + + def saveAsTable(self, tableName): + """Creates a new table with the contents of this DataFrame.""" + self._jdf.saveAsTable(tableName) + + def schema(self): + """Returns the schema of this DataFrame (represented by + a L{StructType}). + + >>> df.schema() + StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) + """ + return _parse_datatype_json_string(self._jdf.schema().json()) + + def printSchema(self): + """Prints out the schema in the tree format. + + >>> df.printSchema() + root + |-- age: integer (nullable = true) + |-- name: string (nullable = true) + + """ + print (self._jdf.schema().treeString()) + + def count(self): + """Return the number of elements in this RDD. + + Unlike the base RDD implementation of count, this implementation + leverages the query optimizer to compute the count on the DataFrame, + which supports features such as filter pushdown. + + >>> df.count() + 2L + """ + return self._jdf.count() + + def collect(self): + """Return a list that contains all of the rows. + + Each object in the list is a Row, the fields can be accessed as + attributes. + + >>> df.collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + with SCCallSiteSync(self._sc) as css: + bytesInJava = self._jdf.javaToPython().collect().iterator() + tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) + tempFile.close() + self._sc._writeToFile(bytesInJava, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) + os.unlink(tempFile.name) + cls = _create_cls(self.schema()) + return [cls(r) for r in rs] + + def limit(self, num): + """Limit the result count to the number specified. + + >>> df.limit(1).collect() + [Row(age=2, name=u'Alice')] + >>> df.limit(0).collect() + [] + """ + jdf = self._jdf.limit(num) + return DataFrame(jdf, self.sql_ctx) + + def take(self, num): + """Take the first num rows of the RDD. + + Each object in the list is a Row, the fields can be accessed as + attributes. + + >>> df.take(2) + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + return self.limit(num).collect() + + def map(self, f): + """ Return a new RDD by applying a function to each Row, it's a + shorthand for df.rdd.map() + + >>> df.map(lambda p: p.name).collect() + [u'Alice', u'Bob'] + """ + return self.rdd.map(f) + + def mapPartitions(self, f, preservesPartitioning=False): + """ + Return a new RDD by applying a function to each partition. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(iterator): yield 1 + >>> rdd.mapPartitions(f).sum() + 4 + """ + return self.rdd.mapPartitions(f, preservesPartitioning) + + def cache(self): + """ Persist with the default storage level (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + self._jdf.cache() + return self + + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + """ Set the storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the RDD does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdf.persist(javaStorageLevel) + return self + + def unpersist(self, blocking=True): + """ Mark it as non-persistent, and remove all blocks for it from + memory and disk. + """ + self.is_cached = False + self._jdf.unpersist(blocking) + return self + + # def coalesce(self, numPartitions, shuffle=False): + # rdd = self._jdf.coalesce(numPartitions, shuffle, None) + # return DataFrame(rdd, self.sql_ctx) + + def repartition(self, numPartitions): + """ Return a new :class:`DataFrame` that has exactly `numPartitions` + partitions. + """ + rdd = self._jdf.repartition(numPartitions, None) + return DataFrame(rdd, self.sql_ctx) + + def sample(self, withReplacement, fraction, seed=None): + """ + Return a sampled subset of this DataFrame. + + >>> df.sample(False, 0.5, 97).count() + 1L + """ + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + seed = seed if seed is not None else random.randint(0, sys.maxint) + rdd = self._jdf.sample(withReplacement, fraction, long(seed)) + return DataFrame(rdd, self.sql_ctx) + + # def takeSample(self, withReplacement, num, seed=None): + # """Return a fixed-size sampled subset of this DataFrame. + # + # >>> df = sqlCtx.inferSchema(rdd) + # >>> df.takeSample(False, 2, 97) + # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] + # """ + # seed = seed if seed is not None else random.randint(0, sys.maxint) + # with SCCallSiteSync(self.context) as css: + # bytesInJava = self._jdf \ + # .takeSampleToPython(withReplacement, num, long(seed)) \ + # .iterator() + # cls = _create_cls(self.schema()) + # return map(cls, self._collect_iterator_through_file(bytesInJava)) + + @property + def dtypes(self): + """Return all column names and their data types as a list. + + >>> df.dtypes + [('age', 'integer'), ('name', 'string')] + """ + return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields] + + @property + def columns(self): + """ Return all column names as a list. + + >>> df.columns + [u'age', u'name'] + """ + return [f.name for f in self.schema().fields] + + def join(self, other, joinExprs=None, joinType=None): + """ + Join with another DataFrame, using the given join expression. + The following performs a full outer join between `df1` and `df2`:: + + :param other: Right side of the join + :param joinExprs: Join expression + :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + + >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() + [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] + """ + + if joinExprs is None: + jdf = self._jdf.join(other._jdf) + else: + assert isinstance(joinExprs, Column), "joinExprs should be Column" + if joinType is None: + jdf = self._jdf.join(other._jdf, joinExprs._jc) + else: + assert isinstance(joinType, basestring), "joinType should be basestring" + jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) + return DataFrame(jdf, self.sql_ctx) + + def sort(self, *cols): + """ Return a new :class:`DataFrame` sorted by the specified column. + + :param cols: The columns or expressions used for sorting + + >>> df.sort(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.sortBy(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + """ + if not cols: + raise ValueError("should sort by at least one column") + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) + jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + sortBy = sort + + def head(self, n=None): + """ Return the first `n` rows or the first row if n is None. + + >>> df.head() + Row(age=2, name=u'Alice') + >>> df.head(1) + [Row(age=2, name=u'Alice')] + """ + if n is None: + rs = self.head(1) + return rs[0] if rs else None + return self.take(n) + + def first(self): + """ Return the first row. + + >>> df.first() + Row(age=2, name=u'Alice') + """ + return self.head() + + def __getitem__(self, item): + """ Return the column by given name + + >>> df['age'].collect() + [Row(age=2), Row(age=5)] + >>> df[ ["name", "age"]].collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df[ df.age > 3 ].collect() + [Row(age=5, name=u'Bob')] + """ + if isinstance(item, basestring): + jc = self._jdf.apply(item) + return Column(jc, self.sql_ctx) + elif isinstance(item, Column): + return self.filter(item) + elif isinstance(item, list): + return self.select(*item) + else: + raise IndexError("unexpected index: %s" % item) + + def __getattr__(self, name): + """ Return the column by given name + + >>> df.age.collect() + [Row(age=2), Row(age=5)] + """ + if name.startswith("__"): + raise AttributeError(name) + jc = self._jdf.apply(name) + return Column(jc, self.sql_ctx) + + def select(self, *cols): + """ Selecting a set of expressions. + + >>> df.select().collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.select('*').collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.select('name', 'age').collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df.select(df.name, (df.age + 10).alias('age')).collect() + [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] + """ + if not cols: + cols = ["*"] + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) + jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + def selectExpr(self, *expr): + """ + Selects a set of SQL expressions. This is a variant of + `select` that accepts SQL expressions. + + >>> df.selectExpr("age * 2", "abs(age)").collect() + [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)] + """ + jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client) + jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) + return DataFrame(jdf, self.sql_ctx) + + def filter(self, condition): + """ Filtering rows using the given condition, which could be + Column expression or string of SQL expression. + + where() is an alias for filter(). + + >>> df.filter(df.age > 3).collect() + [Row(age=5, name=u'Bob')] + >>> df.where(df.age == 2).collect() + [Row(age=2, name=u'Alice')] + + >>> df.filter("age > 3").collect() + [Row(age=5, name=u'Bob')] + >>> df.where("age = 2").collect() + [Row(age=2, name=u'Alice')] + """ + if isinstance(condition, basestring): + jdf = self._jdf.filter(condition) + elif isinstance(condition, Column): + jdf = self._jdf.filter(condition._jc) + else: + raise TypeError("condition should be string or Column") + return DataFrame(jdf, self.sql_ctx) + + where = filter + + def groupBy(self, *cols): + """ Group the :class:`DataFrame` using the specified columns, + so we can run aggregation on them. See :class:`GroupedData` + for all the available aggregate functions. + + >>> df.groupBy().avg().collect() + [Row(AVG(age#0)=3.5)] + >>> df.groupBy('name').agg({'age': 'mean'}).collect() + [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] + >>> df.groupBy(df.name).avg().collect() + [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] + """ + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) + jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return GroupedData(jdf, self.sql_ctx) + + def agg(self, *exprs): + """ Aggregate on the entire :class:`DataFrame` without groups + (shorthand for df.groupBy.agg()). + + >>> df.agg({"age": "max"}).collect() + [Row(MAX(age#0)=5)] + >>> from pyspark.sql import Dsl + >>> df.agg(Dsl.min(df.age)).collect() + [Row(MIN(age#0)=2)] + """ + return self.groupBy().agg(*exprs) + + def unionAll(self, other): + """ Return a new DataFrame containing union of rows in this + frame and another frame. + + This is equivalent to `UNION ALL` in SQL. + """ + return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) + + def intersect(self, other): + """ Return a new :class:`DataFrame` containing rows only in + both this frame and another frame. + + This is equivalent to `INTERSECT` in SQL. + """ + return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + + def subtract(self, other): + """ Return a new :class:`DataFrame` containing rows in this frame + but not in another frame. + + This is equivalent to `EXCEPT` in SQL. + """ + return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + + def addColumn(self, colName, col): + """ Return a new :class:`DataFrame` by adding a column. + + >>> df.addColumn('age2', df.age + 2).collect() + [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] + """ + return self.select('*', col.alias(colName)) + + def to_pandas(self): + """ + Collect all the rows and return a `pandas.DataFrame`. + + >>> df.to_pandas() # doctest: +SKIP + age name + 0 2 Alice + 1 5 Bob + """ + import pandas as pd + return pd.DataFrame.from_records(self.collect(), columns=self.columns) + + +# Having SchemaRDD for backward compatibility (for docs) +class SchemaRDD(DataFrame): + """ + SchemaRDD is deprecated, please use DataFrame + """ + + +def dfapi(f): + def _api(self): + name = f.__name__ + jdf = getattr(self._jdf, name)() + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedData(object): + + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by DataFrame.groupBy(). + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + def agg(self, *exprs): + """ Compute aggregates by specifying a map from column name + to aggregate methods. + + The available aggregate methods are `avg`, `max`, `min`, + `sum`, `count`. + + :param exprs: list or aggregate columns or a map from column + name to aggregate methods. + + >>> gdf = df.groupBy(df.name) + >>> gdf.agg({"age": "max"}).collect() + [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)] + >>> from pyspark.sql import Dsl + >>> gdf.agg(Dsl.min(df.age)).collect() + [Row(MIN(age#0)=5), Row(MIN(age#0)=2)] + """ + assert exprs, "exprs should not be empty" + if len(exprs) == 1 and isinstance(exprs[0], dict): + jmap = MapConverter().convert(exprs[0], + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(jmap) + else: + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" + jcols = ListConverter().convert([c._jc for c in exprs[1:]], + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + @dfapi + def count(self): + """ Count the number of rows for each group. + + >>> df.groupBy(df.age).count().collect() + [Row(age=2, count=1), Row(age=5, count=1)] + """ + + @dfapi + def mean(self): + """Compute the average value for each numeric columns + for each group. This is an alias for `avg`.""" + + @dfapi + def avg(self): + """Compute the average value for each numeric columns + for each group.""" + + @dfapi + def max(self): + """Compute the max value for each numeric columns for + each group. """ + + @dfapi + def min(self): + """Compute the min value for each numeric column for + each group.""" + + @dfapi + def sum(self): + """Compute the sum for each numeric columns for each + group.""" + + +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.Dsl.lit(literal) + + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.Dsl.col(name) + + +def _to_java_column(col): + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + return jcol + + +def _unary_op(name, doc="unary operator"): + """ Create a method for given unary operator """ + def _(self): + jc = getattr(self._jc, name)() + return Column(jc, self.sql_ctx) + _.__doc__ = doc + return _ + + +def _dsl_op(name, doc=''): + def _(self): + jc = getattr(self._sc._jvm.Dsl, name)(self._jc) + return Column(jc, self.sql_ctx) + _.__doc__ = doc + return _ + + +def _bin_op(name, doc="binary operator"): + """ Create a method for given binary operator + """ + def _(self, other): + jc = other._jc if isinstance(other, Column) else other + njc = getattr(self._jc, name)(jc) + return Column(njc, self.sql_ctx) + _.__doc__ = doc + return _ + + +def _reverse_op(name, doc="binary operator"): + """ Create a method for binary operator (this object is on right side) + """ + def _(self, other): + jother = _create_column_from_literal(other) + jc = getattr(jother, name)(self._jc) + return Column(jc, self.sql_ctx) + _.__doc__ = doc + return _ + + +class Column(DataFrame): + + """ + A column in a DataFrame. + + `Column` instances can be created by:: + + # 1. Select a column out of a DataFrame + df.colName + df["colName"] + + # 2. Create from an expression + df.colName + 1 + 1 / df.colName + """ + + def __init__(self, jc, sql_ctx=None): + self._jc = jc + super(Column, self).__init__(jc, sql_ctx) + + # arithmetic operators + __neg__ = _dsl_op("negate") + __add__ = _bin_op("plus") + __sub__ = _bin_op("minus") + __mul__ = _bin_op("multiply") + __div__ = _bin_op("divide") + __mod__ = _bin_op("mod") + __radd__ = _bin_op("plus") + __rsub__ = _reverse_op("minus") + __rmul__ = _bin_op("multiply") + __rdiv__ = _reverse_op("divide") + __rmod__ = _reverse_op("mod") + + # logistic operators + __eq__ = _bin_op("equalTo") + __ne__ = _bin_op("notEqual") + __lt__ = _bin_op("lt") + __le__ = _bin_op("leq") + __ge__ = _bin_op("geq") + __gt__ = _bin_op("gt") + + # `and`, `or`, `not` cannot be overloaded in Python, + # so use bitwise operators as boolean operators + __and__ = _bin_op('and') + __or__ = _bin_op('or') + __invert__ = _dsl_op('not') + __rand__ = _bin_op("and") + __ror__ = _bin_op("or") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("getItem") + getField = _bin_op("getField", "An expression that gets a field by name in a StructField.") + + # string methods + rlike = _bin_op("rlike") + like = _bin_op("like") + startswith = _bin_op("startsWith") + endswith = _bin_op("endsWith") + + def substr(self, startPos, length): + """ + Return a Column which is a substring of the column + + :param startPos: start position (int or Column) + :param length: length of the substring (int or Column) + + >>> df.name.substr(1, 3).collect() + [Row(col=u'Ali'), Row(col=u'Bob')] + """ + if type(startPos) != type(length): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + jc = self._jc.substr(startPos, length) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, length._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc, self.sql_ctx) + + __getslice__ = substr + + # order + asc = _unary_op("asc") + desc = _unary_op("desc") + + isNull = _unary_op("isNull", "True if the current expression is null.") + isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") + + def alias(self, alias): + """Return a alias for this column + + >>> df.age.alias("age2").collect() + [Row(age2=2), Row(age2=5)] + """ + return Column(getattr(self._jc, "as")(alias), self.sql_ctx) + + def cast(self, dataType): + """ Convert the column into type `dataType` + + >>> df.select(df.age.cast("string").alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + >>> df.select(df.age.cast(StringType()).alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + """ + if self.sql_ctx is None: + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + else: + ssql_ctx = self.sql_ctx._ssql_ctx + if isinstance(dataType, basestring): + jc = self._jc.cast(dataType) + elif isinstance(dataType, DataType): + jdt = ssql_ctx.parseDataType(dataType.json()) + jc = self._jc.cast(jdt) + return Column(jc, self.sql_ctx) + + def to_pandas(self): + """ + Return a pandas.Series from the column + + >>> df.age.to_pandas() # doctest: +SKIP + 0 2 + 1 5 + dtype: int64 + """ + import pandas as pd + data = [c for c, in self.collect()] + return pd.Series(data) + + +def _aggregate_func(name, doc=""): + """ Create a function for aggregator by name""" + def _(col): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col)) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return staticmethod(_) + + +class UserDefinedFunction(object): + def __init__(self, func, returnType): + self.func = func + self.returnType = returnType + self._broadcast = None + self._judf = self._create_judf() + + def _create_judf(self): + f = self.func # put it in closure `func` + func = lambda _, it: imap(lambda x: f(*x), it) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + sc = SparkContext._active_spark_context + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(self.returnType.json()) + judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, + includes, sc.pythonExec, broadcast_vars, + sc._javaAccumulator, jdt) + return judf + + def __del__(self): + if self._broadcast is not None: + self._broadcast.unpersist() + self._broadcast = None + + def __call__(self, *cols): + sc = SparkContext._active_spark_context + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + sc._gateway._gateway_client) + jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + + +class Dsl(object): + """ + A collections of builtin aggregators + """ + DSLS = { + 'lit': 'Creates a :class:`Column` of literal value.', + 'col': 'Returns a :class:`Column` based on the given column name.', + 'column': 'Returns a :class:`Column` based on the given column name.', + 'upper': 'Converts a string expression to upper case.', + 'lower': 'Converts a string expression to upper case.', + 'sqrt': 'Computes the square root of the specified float value.', + 'abs': 'Computes the absolutle value.', + + 'max': 'Aggregate function: returns the maximum value of the expression in a group.', + 'min': 'Aggregate function: returns the minimum value of the expression in a group.', + 'first': 'Aggregate function: returns the first value in a group.', + 'last': 'Aggregate function: returns the last value in a group.', + 'count': 'Aggregate function: returns the number of items in a group.', + 'sum': 'Aggregate function: returns the sum of all values in the expression.', + 'avg': 'Aggregate function: returns the average of the values in a group.', + 'mean': 'Aggregate function: returns the average of the values in a group.', + 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', + } + + for _name, _doc in DSLS.items(): + locals()[_name] = _aggregate_func(_name, _doc) + del _name, _doc + + @staticmethod + def countDistinct(col, *cols): + """ Return a new Column for distinct count of (col, *cols) + + >>> from pyspark.sql import Dsl + >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect() + [Row(c=2)] + + >>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + sc._gateway._gateway_client) + jc = sc._jvm.Dsl.countDistinct(_to_java_column(col), + sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + + @staticmethod + def approxCountDistinct(col, rsd=None): + """ Return a new Column for approxiate distinct count of (col, *cols) + + >>> from pyspark.sql import Dsl + >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + if rsd is None: + jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col)) + else: + jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd) + return Column(jc) + + @staticmethod + def udf(f, returnType=StringType()): + """Create a user defined function (UDF) + + >>> slen = Dsl.udf(lambda s: len(s), IntegerType()) + >>> df.select(slen(df.name).alias('slen')).collect() + [Row(slen=5), Row(slen=3)] + """ + return UserDefinedFunction(f, returnType) + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.dataframe + globs = pyspark.sql.dataframe.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx = SQLContext(sc) + rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]) + rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]) + globs['df'] = sqlCtx.inferSchema(rdd2) + globs['df2'] = sqlCtx.inferSchema(rdd3) + (failure_count, test_count) = doctest.testmod( + pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql_tests.py b/python/pyspark/sql/tests.py similarity index 96% rename from python/pyspark/sql_tests.py rename to python/pyspark/sql/tests.py index d314f46e8d2d5..d25c6365ed067 100644 --- a/python/pyspark/sql_tests.py +++ b/python/pyspark/sql/tests.py @@ -34,8 +34,10 @@ else: import unittest -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ - UserDefinedType, DoubleType + +from pyspark.sql import SQLContext, Column +from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \ + UserDefinedType, DoubleType, LongType from pyspark.tests import ReusedPySparkTestCase @@ -220,7 +222,7 @@ def test_convert_row_to_dict(self): self.assertEqual(1.0, row.asDict()['d']['key'].c) def test_infer_schema_with_udt(self): - from pyspark.sql_tests import ExamplePoint, ExamplePointUDT + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) rdd = self.sc.parallelize([row]) df = self.sqlCtx.inferSchema(rdd) @@ -232,7 +234,7 @@ def test_infer_schema_with_udt(self): self.assertEqual(point, ExamplePoint(1.0, 2.0)) def test_apply_schema_with_udt(self): - from pyspark.sql_tests import ExamplePoint, ExamplePointUDT + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = (1.0, ExamplePoint(1.0, 2.0)) rdd = self.sc.parallelize([row]) schema = StructType([StructField("label", DoubleType(), False), @@ -242,7 +244,7 @@ def test_apply_schema_with_udt(self): self.assertEquals(point, ExamplePoint(1.0, 2.0)) def test_parquet_with_udt(self): - from pyspark.sql_tests import ExamplePoint + from pyspark.sql.tests import ExamplePoint row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) rdd = self.sc.parallelize([row]) df0 = self.sqlCtx.inferSchema(rdd) @@ -253,7 +255,6 @@ def test_parquet_with_udt(self): self.assertEquals(point, ExamplePoint(1.0, 2.0)) def test_column_operators(self): - from pyspark.sql import Column, LongType ci = self.df.key cs = self.df.value c = ci == cs diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py new file mode 100644 index 0000000000000..41afefe48ee5e --- /dev/null +++ b/python/pyspark/sql/types.py @@ -0,0 +1,1279 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import decimal +import datetime +import keyword +import warnings +import json +import re +from array import array +from operator import itemgetter + + +__all__ = [ + "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", + "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", + "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", ] + + +class DataType(object): + + """Spark SQL DataType""" + + def __repr__(self): + return self.__class__.__name__ + + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.__dict__ == other.__dict__) + + def __ne__(self, other): + return not self.__eq__(other) + + @classmethod + def typeName(cls): + return cls.__name__[:-4].lower() + + def jsonValue(self): + return self.typeName() + + def json(self): + return json.dumps(self.jsonValue(), + separators=(',', ':'), + sort_keys=True) + + +class PrimitiveTypeSingleton(type): + + """Metaclass for PrimitiveType""" + + _instances = {} + + def __call__(cls): + if cls not in cls._instances: + cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() + return cls._instances[cls] + + +class PrimitiveType(DataType): + + """Spark SQL PrimitiveType""" + + __metaclass__ = PrimitiveTypeSingleton + + def __eq__(self, other): + # because they should be the same object + return self is other + + +class NullType(PrimitiveType): + + """Spark SQL NullType + + The data type representing None, used for the types which has not + been inferred. + """ + + +class StringType(PrimitiveType): + + """Spark SQL StringType + + The data type representing string values. + """ + + +class BinaryType(PrimitiveType): + + """Spark SQL BinaryType + + The data type representing bytearray values. + """ + + +class BooleanType(PrimitiveType): + + """Spark SQL BooleanType + + The data type representing bool values. + """ + + +class DateType(PrimitiveType): + + """Spark SQL DateType + + The data type representing datetime.date values. + """ + + +class TimestampType(PrimitiveType): + + """Spark SQL TimestampType + + The data type representing datetime.datetime values. + """ + + +class DecimalType(DataType): + + """Spark SQL DecimalType + + The data type representing decimal.Decimal values. + """ + + def __init__(self, precision=None, scale=None): + self.precision = precision + self.scale = scale + self.hasPrecisionInfo = precision is not None + + def jsonValue(self): + if self.hasPrecisionInfo: + return "decimal(%d,%d)" % (self.precision, self.scale) + else: + return "decimal" + + def __repr__(self): + if self.hasPrecisionInfo: + return "DecimalType(%d,%d)" % (self.precision, self.scale) + else: + return "DecimalType()" + + +class DoubleType(PrimitiveType): + + """Spark SQL DoubleType + + The data type representing float values. + """ + + +class FloatType(PrimitiveType): + + """Spark SQL FloatType + + The data type representing single precision floating-point values. + """ + + +class ByteType(PrimitiveType): + + """Spark SQL ByteType + + The data type representing int values with 1 singed byte. + """ + + +class IntegerType(PrimitiveType): + + """Spark SQL IntegerType + + The data type representing int values. + """ + + +class LongType(PrimitiveType): + + """Spark SQL LongType + + The data type representing long values. If the any value is + beyond the range of [-9223372036854775808, 9223372036854775807], + please use DecimalType. + """ + + +class ShortType(PrimitiveType): + + """Spark SQL ShortType + + The data type representing int values with 2 signed bytes. + """ + + +class ArrayType(DataType): + + """Spark SQL ArrayType + + The data type representing list values. An ArrayType object + comprises two fields, elementType (a DataType) and containsNull (a bool). + The field of elementType is used to specify the type of array elements. + The field of containsNull is used to specify if the array has None values. + + """ + + def __init__(self, elementType, containsNull=True): + """Creates an ArrayType + + :param elementType: the data type of elements. + :param containsNull: indicates whether the list contains None values. + + >>> ArrayType(StringType) == ArrayType(StringType, True) + True + >>> ArrayType(StringType, False) == ArrayType(StringType) + False + """ + self.elementType = elementType + self.containsNull = containsNull + + def __repr__(self): + return "ArrayType(%s,%s)" % (self.elementType, + str(self.containsNull).lower()) + + def jsonValue(self): + return {"type": self.typeName(), + "elementType": self.elementType.jsonValue(), + "containsNull": self.containsNull} + + @classmethod + def fromJson(cls, json): + return ArrayType(_parse_datatype_json_value(json["elementType"]), + json["containsNull"]) + + +class MapType(DataType): + + """Spark SQL MapType + + The data type representing dict values. A MapType object comprises + three fields, keyType (a DataType), valueType (a DataType) and + valueContainsNull (a bool). + + The field of keyType is used to specify the type of keys in the map. + The field of valueType is used to specify the type of values in the map. + The field of valueContainsNull is used to specify if values of this + map has None values. + + For values of a MapType column, keys are not allowed to have None values. + + """ + + def __init__(self, keyType, valueType, valueContainsNull=True): + """Creates a MapType + :param keyType: the data type of keys. + :param valueType: the data type of values. + :param valueContainsNull: indicates whether values contains + null values. + + >>> (MapType(StringType, IntegerType) + ... == MapType(StringType, IntegerType, True)) + True + >>> (MapType(StringType, IntegerType, False) + ... == MapType(StringType, FloatType)) + False + """ + self.keyType = keyType + self.valueType = valueType + self.valueContainsNull = valueContainsNull + + def __repr__(self): + return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, + str(self.valueContainsNull).lower()) + + def jsonValue(self): + return {"type": self.typeName(), + "keyType": self.keyType.jsonValue(), + "valueType": self.valueType.jsonValue(), + "valueContainsNull": self.valueContainsNull} + + @classmethod + def fromJson(cls, json): + return MapType(_parse_datatype_json_value(json["keyType"]), + _parse_datatype_json_value(json["valueType"]), + json["valueContainsNull"]) + + +class StructField(DataType): + + """Spark SQL StructField + + Represents a field in a StructType. + A StructField object comprises three fields, name (a string), + dataType (a DataType) and nullable (a bool). The field of name + is the name of a StructField. The field of dataType specifies + the data type of a StructField. + + The field of nullable specifies if values of a StructField can + contain None values. + + """ + + def __init__(self, name, dataType, nullable=True, metadata=None): + """Creates a StructField + :param name: the name of this field. + :param dataType: the data type of this field. + :param nullable: indicates whether values of this field + can be null. + :param metadata: metadata of this field, which is a map from string + to simple type that can be serialized to JSON + automatically + + >>> (StructField("f1", StringType, True) + ... == StructField("f1", StringType, True)) + True + >>> (StructField("f1", StringType, True) + ... == StructField("f2", StringType, True)) + False + """ + self.name = name + self.dataType = dataType + self.nullable = nullable + self.metadata = metadata or {} + + def __repr__(self): + return "StructField(%s,%s,%s)" % (self.name, self.dataType, + str(self.nullable).lower()) + + def jsonValue(self): + return {"name": self.name, + "type": self.dataType.jsonValue(), + "nullable": self.nullable, + "metadata": self.metadata} + + @classmethod + def fromJson(cls, json): + return StructField(json["name"], + _parse_datatype_json_value(json["type"]), + json["nullable"], + json["metadata"]) + + +class StructType(DataType): + + """Spark SQL StructType + + The data type representing rows. + A StructType object comprises a list of L{StructField}. + + """ + + def __init__(self, fields): + """Creates a StructType + + >>> struct1 = StructType([StructField("f1", StringType, True)]) + >>> struct2 = StructType([StructField("f1", StringType, True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType([StructField("f1", StringType, True)]) + >>> struct2 = StructType([StructField("f1", StringType, True), + ... [StructField("f2", IntegerType, False)]]) + >>> struct1 == struct2 + False + """ + self.fields = fields + + def __repr__(self): + return ("StructType(List(%s))" % + ",".join(str(field) for field in self.fields)) + + def jsonValue(self): + return {"type": self.typeName(), + "fields": [f.jsonValue() for f in self.fields]} + + @classmethod + def fromJson(cls, json): + return StructType([StructField.fromJson(f) for f in json["fields"]]) + + +class UserDefinedType(DataType): + """ + .. note:: WARN: Spark Internal Use Only + SQL User-Defined Type (UDT). + """ + + @classmethod + def typeName(cls): + return cls.__name__.lower() + + @classmethod + def sqlType(cls): + """ + Underlying SQL storage type for this UDT. + """ + raise NotImplementedError("UDT must implement sqlType().") + + @classmethod + def module(cls): + """ + The Python module of the UDT. + """ + raise NotImplementedError("UDT must implement module().") + + @classmethod + def scalaUDT(cls): + """ + The class name of the paired Scala UDT. + """ + raise NotImplementedError("UDT must have a paired Scala UDT.") + + def serialize(self, obj): + """ + Converts the a user-type object into a SQL datum. + """ + raise NotImplementedError("UDT must implement serialize().") + + def deserialize(self, datum): + """ + Converts a SQL datum into a user-type object. + """ + raise NotImplementedError("UDT must implement deserialize().") + + def json(self): + return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) + + def jsonValue(self): + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + return schema + + @classmethod + def fromJson(cls, json): + pyUDT = json["pyClass"] + split = pyUDT.rfind(".") + pyModule = pyUDT[:split] + pyClass = pyUDT[split+1:] + m = __import__(pyModule, globals(), locals(), [pyClass], -1) + UDT = getattr(m, pyClass) + return UDT() + + def __eq__(self, other): + return type(self) == type(other) + + +_all_primitive_types = dict((v.typeName(), v) + for v in globals().itervalues() + if type(v) is PrimitiveTypeSingleton and + v.__base__ == PrimitiveType) + + +_all_complex_types = dict((v.typeName(), v) + for v in [ArrayType, MapType, StructType]) + + +def _parse_datatype_json_string(json_string): + """Parses the given data type JSON string. + >>> def check_datatype(datatype): + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) + ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) + ... return datatype == python_datatype + >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) + True + >>> # Simple ArrayType. + >>> simple_arraytype = ArrayType(StringType(), True) + >>> check_datatype(simple_arraytype) + True + >>> # Simple MapType. + >>> simple_maptype = MapType(StringType(), LongType()) + >>> check_datatype(simple_maptype) + True + >>> # Simple StructType. + >>> simple_structtype = StructType([ + ... StructField("a", DecimalType(), False), + ... StructField("b", BooleanType(), True), + ... StructField("c", LongType(), True), + ... StructField("d", BinaryType(), False)]) + >>> check_datatype(simple_structtype) + True + >>> # Complex StructType. + >>> complex_structtype = StructType([ + ... StructField("simpleArray", simple_arraytype, True), + ... StructField("simpleMap", simple_maptype, True), + ... StructField("simpleStruct", simple_structtype, True), + ... StructField("boolean", BooleanType(), False), + ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) + >>> check_datatype(complex_structtype) + True + >>> # Complex ArrayType. + >>> complex_arraytype = ArrayType(complex_structtype, True) + >>> check_datatype(complex_arraytype) + True + >>> # Complex MapType. + >>> complex_maptype = MapType(complex_structtype, + ... complex_arraytype, False) + >>> check_datatype(complex_maptype) + True + >>> check_datatype(ExamplePointUDT()) + True + >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> check_datatype(structtype_with_udt) + True + """ + return _parse_datatype_json_value(json.loads(json_string)) + + +_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") + + +def _parse_datatype_json_value(json_value): + if type(json_value) is unicode: + if json_value in _all_primitive_types.keys(): + return _all_primitive_types[json_value]() + elif json_value == u'decimal': + return DecimalType() + elif _FIXED_DECIMAL.match(json_value): + m = _FIXED_DECIMAL.match(json_value) + return DecimalType(int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Could not parse datatype: %s" % json_value) + else: + tpe = json_value["type"] + if tpe in _all_complex_types: + return _all_complex_types[tpe].fromJson(json_value) + elif tpe == 'udt': + return UserDefinedType.fromJson(json_value) + else: + raise ValueError("not supported type: %s" % tpe) + + +# Mapping Python types to Spark SQL DataType +_type_mappings = { + type(None): NullType, + bool: BooleanType, + int: IntegerType, + long: LongType, + float: DoubleType, + str: StringType, + unicode: StringType, + bytearray: BinaryType, + decimal.Decimal: DecimalType, + datetime.date: DateType, + datetime.datetime: TimestampType, + datetime.time: TimestampType, +} + + +def _infer_type(obj): + """Infer the DataType from obj + + >>> p = ExamplePoint(1.0, 2.0) + >>> _infer_type(p) + ExamplePointUDT + """ + if obj is None: + raise ValueError("Can not infer type for None") + + if hasattr(obj, '__UDT__'): + return obj.__UDT__ + + dataType = _type_mappings.get(type(obj)) + if dataType is not None: + return dataType() + + if isinstance(obj, dict): + for key, value in obj.iteritems(): + if key is not None and value is not None: + return MapType(_infer_type(key), _infer_type(value), True) + else: + return MapType(NullType(), NullType(), True) + elif isinstance(obj, (list, array)): + for v in obj: + if v is not None: + return ArrayType(_infer_type(obj[0]), True) + else: + return ArrayType(NullType(), True) + else: + try: + return _infer_schema(obj) + except ValueError: + raise ValueError("not supported type: %s" % type(obj)) + + +def _infer_schema(row): + """Infer the schema from dict/namedtuple/object""" + if isinstance(row, dict): + items = sorted(row.items()) + + elif isinstance(row, tuple): + if hasattr(row, "_fields"): # namedtuple + items = zip(row._fields, tuple(row)) + elif hasattr(row, "__FIELDS__"): # Row + items = zip(row.__FIELDS__, tuple(row)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in row): + items = row + else: + raise ValueError("Can't infer schema from tuple") + + elif hasattr(row, "__dict__"): # object + items = sorted(row.__dict__.items()) + + else: + raise ValueError("Can not infer schema for type: %s" % type(row)) + + fields = [StructField(k, _infer_type(v), True) for k, v in items] + return StructType(fields) + + +def _need_python_to_sql_conversion(dataType): + """ + Checks whether we need python to sql conversion for the given type. + For now, only UDTs need this conversion. + + >>> _need_python_to_sql_conversion(DoubleType()) + False + >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), + ... StructField("values", ArrayType(DoubleType(), False), False)]) + >>> _need_python_to_sql_conversion(schema0) + False + >>> _need_python_to_sql_conversion(ExamplePointUDT()) + True + >>> schema1 = ArrayType(ExamplePointUDT(), False) + >>> _need_python_to_sql_conversion(schema1) + True + >>> schema2 = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> _need_python_to_sql_conversion(schema2) + True + """ + if isinstance(dataType, StructType): + return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + elif isinstance(dataType, ArrayType): + return _need_python_to_sql_conversion(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_python_to_sql_conversion(dataType.keyType) or \ + _need_python_to_sql_conversion(dataType.valueType) + elif isinstance(dataType, UserDefinedType): + return True + else: + return False + + +def _python_to_sql_converter(dataType): + """ + Returns a converter that converts a Python object into a SQL datum for the given type. + + >>> conv = _python_to_sql_converter(DoubleType()) + >>> conv(1.0) + 1.0 + >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) + >>> conv([1.0, 2.0]) + [1.0, 2.0] + >>> conv = _python_to_sql_converter(ExamplePointUDT()) + >>> conv(ExamplePoint(1.0, 2.0)) + [1.0, 2.0] + >>> schema = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> conv = _python_to_sql_converter(schema) + >>> conv((1.0, ExamplePoint(1.0, 2.0))) + (1.0, [1.0, 2.0]) + """ + if not _need_python_to_sql_conversion(dataType): + return lambda x: x + + if isinstance(dataType, StructType): + names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) + converters = map(_python_to_sql_converter, types) + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): + return tuple(c(v) for c, v in zip(converters, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs + d = dict(obj) + return tuple(c(d.get(n)) for n, c in zip(names, converters)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + else: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return converter + elif isinstance(dataType, ArrayType): + element_converter = _python_to_sql_converter(dataType.elementType) + return lambda a: [element_converter(v) for v in a] + elif isinstance(dataType, MapType): + key_converter = _python_to_sql_converter(dataType.keyType) + value_converter = _python_to_sql_converter(dataType.valueType) + return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + elif isinstance(dataType, UserDefinedType): + return lambda obj: dataType.serialize(obj) + else: + raise ValueError("Unexpected type %r" % dataType) + + +def _has_nulltype(dt): + """ Return whether there is NullType in `dt` or not """ + if isinstance(dt, StructType): + return any(_has_nulltype(f.dataType) for f in dt.fields) + elif isinstance(dt, ArrayType): + return _has_nulltype((dt.elementType)) + elif isinstance(dt, MapType): + return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) + else: + return isinstance(dt, NullType) + + +def _merge_type(a, b): + if isinstance(a, NullType): + return b + elif isinstance(b, NullType): + return a + elif type(a) is not type(b): + # TODO: type cast (such as int -> long) + raise TypeError("Can not merge type %s and %s" % (a, b)) + + # same type + if isinstance(a, StructType): + nfs = dict((f.name, f.dataType) for f in b.fields) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + for f in a.fields] + names = set([f.name for f in fields]) + for n in nfs: + if n not in names: + fields.append(StructField(n, nfs[n])) + return StructType(fields) + + elif isinstance(a, ArrayType): + return ArrayType(_merge_type(a.elementType, b.elementType), True) + + elif isinstance(a, MapType): + return MapType(_merge_type(a.keyType, b.keyType), + _merge_type(a.valueType, b.valueType), + True) + else: + return a + + +def _create_converter(dataType): + """Create an converter to drop the names of fields in obj """ + if isinstance(dataType, ArrayType): + conv = _create_converter(dataType.elementType) + return lambda row: map(conv, row) + + elif isinstance(dataType, MapType): + kconv = _create_converter(dataType.keyType) + vconv = _create_converter(dataType.valueType) + return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems()) + + elif isinstance(dataType, NullType): + return lambda x: None + + elif not isinstance(dataType, StructType): + return lambda x: x + + # dataType must be StructType + names = [f.name for f in dataType.fields] + converters = [_create_converter(f.dataType) for f in dataType.fields] + + def convert_struct(obj): + if obj is None: + return + + if isinstance(obj, tuple): + if hasattr(obj, "_fields"): + d = dict(zip(obj._fields, obj)) + elif hasattr(obj, "__FIELDS__"): + d = dict(zip(obj.__FIELDS__, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): + d = dict(obj) + else: + raise ValueError("unexpected tuple: %s" % str(obj)) + + elif isinstance(obj, dict): + d = obj + elif hasattr(obj, "__dict__"): # object + d = obj.__dict__ + else: + raise ValueError("Unexpected obj: %s" % obj) + + return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) + + return convert_struct + + +_BRACKETS = {'(': ')', '[': ']', '{': '}'} + + +def _split_schema_abstract(s): + """ + split the schema abstract into fields + + >>> _split_schema_abstract("a b c") + ['a', 'b', 'c'] + >>> _split_schema_abstract("a(a b)") + ['a(a b)'] + >>> _split_schema_abstract("a b[] c{a b}") + ['a', 'b[]', 'c{a b}'] + >>> _split_schema_abstract(" ") + [] + """ + + r = [] + w = '' + brackets = [] + for c in s: + if c == ' ' and not brackets: + if w: + r.append(w) + w = '' + else: + w += c + if c in _BRACKETS: + brackets.append(c) + elif c in _BRACKETS.values(): + if not brackets or c != _BRACKETS[brackets.pop()]: + raise ValueError("unexpected " + c) + + if brackets: + raise ValueError("brackets not closed: %s" % brackets) + if w: + r.append(w) + return r + + +def _parse_field_abstract(s): + """ + Parse a field in schema abstract + + >>> _parse_field_abstract("a") + StructField(a,None,true) + >>> _parse_field_abstract("b(c d)") + StructField(b,StructType(...c,None,true),StructField(d... + >>> _parse_field_abstract("a[]") + StructField(a,ArrayType(None,true),true) + >>> _parse_field_abstract("a{[]}") + StructField(a,MapType(None,ArrayType(None,true),true),true) + """ + if set(_BRACKETS.keys()) & set(s): + idx = min((s.index(c) for c in _BRACKETS if c in s)) + name = s[:idx] + return StructField(name, _parse_schema_abstract(s[idx:]), True) + else: + return StructField(s, None, True) + + +def _parse_schema_abstract(s): + """ + parse abstract into schema + + >>> _parse_schema_abstract("a b c") + StructType...a...b...c... + >>> _parse_schema_abstract("a[b c] b{}") + StructType...a,ArrayType...b...c...b,MapType... + >>> _parse_schema_abstract("c{} d{a b}") + StructType...c,MapType...d,MapType...a...b... + >>> _parse_schema_abstract("a b(t)").fields[1] + StructField(b,StructType(List(StructField(t,None,true))),true) + """ + s = s.strip() + if not s: + return + + elif s.startswith('('): + return _parse_schema_abstract(s[1:-1]) + + elif s.startswith('['): + return ArrayType(_parse_schema_abstract(s[1:-1]), True) + + elif s.startswith('{'): + return MapType(None, _parse_schema_abstract(s[1:-1])) + + parts = _split_schema_abstract(s) + fields = [_parse_field_abstract(p) for p in parts] + return StructType(fields) + + +def _infer_schema_type(obj, dataType): + """ + Fill the dataType with types inferred from obj + + >>> schema = _parse_schema_abstract("a b c d") + >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) + >>> _infer_schema_type(row, schema) + StructType...IntegerType...DoubleType...StringType...DateType... + >>> row = [[1], {"key": (1, 2.0)}] + >>> schema = _parse_schema_abstract("a[] b{c d}") + >>> _infer_schema_type(row, schema) + StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType... + """ + if dataType is None: + return _infer_type(obj) + + if not obj: + return NullType() + + if isinstance(dataType, ArrayType): + eType = _infer_schema_type(obj[0], dataType.elementType) + return ArrayType(eType, True) + + elif isinstance(dataType, MapType): + k, v = obj.iteritems().next() + return MapType(_infer_schema_type(k, dataType.keyType), + _infer_schema_type(v, dataType.valueType)) + + elif isinstance(dataType, StructType): + fs = dataType.fields + assert len(fs) == len(obj), \ + "Obj(%s) have different length with fields(%s)" % (obj, fs) + fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) + for o, f in zip(obj, fs)] + return StructType(fields) + + else: + raise ValueError("Unexpected dataType: %s" % dataType) + + +_acceptable_types = { + BooleanType: (bool,), + ByteType: (int, long), + ShortType: (int, long), + IntegerType: (int, long), + LongType: (int, long), + FloatType: (float,), + DoubleType: (float,), + DecimalType: (decimal.Decimal,), + StringType: (str, unicode), + BinaryType: (bytearray,), + DateType: (datetime.date,), + TimestampType: (datetime.datetime,), + ArrayType: (list, tuple, array), + MapType: (dict,), + StructType: (tuple, list), +} + + +def _verify_type(obj, dataType): + """ + Verify the type of obj against dataType, raise an exception if + they do not match. + + >>> _verify_type(None, StructType([])) + >>> _verify_type("", StringType()) + >>> _verify_type(0, IntegerType()) + >>> _verify_type(range(3), ArrayType(ShortType())) + >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError:... + >>> _verify_type({}, MapType(StringType(), IntegerType())) + >>> _verify_type((), StructType([])) + >>> _verify_type([], StructType([])) + >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + """ + # all objects are nullable + if obj is None: + return + + if isinstance(dataType, UserDefinedType): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError("%r is not an instance of type %r" % (obj, dataType)) + _verify_type(dataType.serialize(obj), dataType.sqlType()) + return + + _type = type(dataType) + assert _type in _acceptable_types, "unkown datatype: %s" % dataType + + # subclass of them can not be deserialized in JVM + if type(obj) not in _acceptable_types[_type]: + raise TypeError("%s can not accept object in type %s" + % (dataType, type(obj))) + + if isinstance(dataType, ArrayType): + for i in obj: + _verify_type(i, dataType.elementType) + + elif isinstance(dataType, MapType): + for k, v in obj.iteritems(): + _verify_type(k, dataType.keyType) + _verify_type(v, dataType.valueType) + + elif isinstance(dataType, StructType): + if len(obj) != len(dataType.fields): + raise ValueError("Length of object (%d) does not match with" + "length of fields (%d)" % (len(obj), len(dataType.fields))) + for v, f in zip(obj, dataType.fields): + _verify_type(v, f.dataType) + + +_cached_cls = {} + + +def _restore_object(dataType, obj): + """ Restore object during unpickling. """ + # use id(dataType) as key to speed up lookup in dict + # Because of batched pickling, dataType will be the + # same object in most cases. + k = id(dataType) + cls = _cached_cls.get(k) + if cls is None: + # use dataType as key to avoid create multiple class + cls = _cached_cls.get(dataType) + if cls is None: + cls = _create_cls(dataType) + _cached_cls[dataType] = cls + _cached_cls[k] = cls + return cls(obj) + + +def _create_object(cls, v): + """ Create an customized object with class `cls`. """ + # datetime.date would be deserialized as datetime.datetime + # from java type, so we need to set it back. + if cls is datetime.date and isinstance(v, datetime.datetime): + return v.date() + return cls(v) if v is not None else v + + +def _create_getter(dt, i): + """ Create a getter for item `i` with schema """ + cls = _create_cls(dt) + + def getter(self): + return _create_object(cls, self[i]) + + return getter + + +def _has_struct_or_date(dt): + """Return whether `dt` is or has StructType/DateType in it""" + if isinstance(dt, StructType): + return True + elif isinstance(dt, ArrayType): + return _has_struct_or_date(dt.elementType) + elif isinstance(dt, MapType): + return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) + elif isinstance(dt, DateType): + return True + elif isinstance(dt, UserDefinedType): + return True + return False + + +def _create_properties(fields): + """Create properties according to fields""" + ps = {} + for i, f in enumerate(fields): + name = f.name + if (name.startswith("__") and name.endswith("__") + or keyword.iskeyword(name)): + warnings.warn("field name %s can not be accessed in Python," + "use position to access it instead" % name) + if _has_struct_or_date(f.dataType): + # delay creating object until accessing it + getter = _create_getter(f.dataType, i) + else: + getter = itemgetter(i) + ps[name] = property(getter) + return ps + + +def _create_cls(dataType): + """ + Create an class by dataType + + The created class is similar to namedtuple, but can have nested schema. + + >>> schema = _parse_schema_abstract("a b c") + >>> row = (1, 1.0, "str") + >>> schema = _infer_schema_type(row, schema) + >>> obj = _create_cls(schema)(row) + >>> import pickle + >>> pickle.loads(pickle.dumps(obj)) + Row(a=1, b=1.0, c='str') + + >>> row = [[1], {"key": (1, 2.0)}] + >>> schema = _parse_schema_abstract("a[] b{c d}") + >>> schema = _infer_schema_type(row, schema) + >>> obj = _create_cls(schema)(row) + >>> pickle.loads(pickle.dumps(obj)) + Row(a=[1], b={'key': Row(c=1, d=2.0)}) + >>> pickle.loads(pickle.dumps(obj.a)) + [1] + >>> pickle.loads(pickle.dumps(obj.b)) + {'key': Row(c=1, d=2.0)} + """ + + if isinstance(dataType, ArrayType): + cls = _create_cls(dataType.elementType) + + def List(l): + if l is None: + return + return [_create_object(cls, v) for v in l] + + return List + + elif isinstance(dataType, MapType): + kcls = _create_cls(dataType.keyType) + vcls = _create_cls(dataType.valueType) + + def Dict(d): + if d is None: + return + return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) + + return Dict + + elif isinstance(dataType, DateType): + return datetime.date + + elif isinstance(dataType, UserDefinedType): + return lambda datum: dataType.deserialize(datum) + + elif not isinstance(dataType, StructType): + # no wrapper for primitive types + return lambda x: x + + class Row(tuple): + + """ Row in DataFrame """ + __DATATYPE__ = dataType + __FIELDS__ = tuple(f.name for f in dataType.fields) + __slots__ = () + + # create property for fast access + locals().update(_create_properties(dataType.fields)) + + def asDict(self): + """ Return as a dict """ + return dict((n, getattr(self, n)) for n in self.__FIELDS__) + + def __repr__(self): + # call collect __repr__ for nested objects + return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) + for n in self.__FIELDS__)) + + def __reduce__(self): + return (_restore_object, (self.__DATATYPE__, tuple(self))) + + return Row + + +def _create_row(fields, values): + row = Row(*values) + row.__FIELDS__ = fields + return row + + +class Row(tuple): + + """ + A row in L{DataFrame}. The fields in it can be accessed like attributes. + + Row can be used to create a row object by using named arguments, + the fields will be sorted by names. + + >>> row = Row(name="Alice", age=11) + >>> row + Row(age=11, name='Alice') + >>> row.name, row.age + ('Alice', 11) + + Row also can be used to create another Row like class, then it + could be used to create Row objects, such as + + >>> Person = Row("name", "age") + >>> Person + + >>> Person("Alice", 11) + Row(name='Alice', age=11) + """ + + def __new__(self, *args, **kwargs): + if args and kwargs: + raise ValueError("Can not use both args " + "and kwargs to create Row") + if args: + # create row class or objects + return tuple.__new__(self, args) + + elif kwargs: + # create row objects + names = sorted(kwargs.keys()) + values = tuple(kwargs[n] for n in names) + row = tuple.__new__(self, values) + row.__FIELDS__ = names + return row + + else: + raise ValueError("No args or kwargs") + + def asDict(self): + """ + Return as an dict + """ + if not hasattr(self, "__FIELDS__"): + raise TypeError("Cannot convert a Row class into dict") + return dict(zip(self.__FIELDS__, self)) + + # let obect acs like class + def __call__(self, *args): + """create new Row object""" + return _create_row(self, args) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__FIELDS__.index(item) + return self[idx] + except IndexError: + raise AttributeError(item) + + def __reduce__(self): + if hasattr(self, "__FIELDS__"): + return (_create_row, (self.__FIELDS__, tuple(self))) + else: + return tuple.__reduce__(self) + + def __repr__(self): + if hasattr(self, "__FIELDS__"): + return "Row(%s)" % ", ".join("%s=%r" % (k, v) + for k, v in zip(self.__FIELDS__, self)) + else: + return "" % ", ".join(self) + + +def _test(): + import doctest + from pyspark.context import SparkContext + # let doctest run in pyspark.sql.types, so DataTypes can be picklable + import pyspark.sql.types + from pyspark.sql import Row, SQLContext + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + globs = pyspark.sql.types.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx = SQLContext(sc) + globs['ExamplePoint'] = ExamplePoint + globs['ExamplePointUDT'] = ExamplePointUDT + (failure_count, test_count) = doctest.testmod( + pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/run-tests b/python/run-tests index 649a2c44d187b..58a26dd8ff088 100755 --- a/python/run-tests +++ b/python/run-tests @@ -64,8 +64,10 @@ function run_core_tests() { function run_sql_tests() { echo "Run sql tests ..." - run_test "pyspark/sql.py" - run_test "pyspark/sql_tests.py" + run_test "pyspark/sql/types.py" + run_test "pyspark/sql/context.py" + run_test "pyspark/sql/dataframe.py" + run_test "pyspark/sql/tests.py" } function run_mllib_tests() { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index e6f622e87f7a4..eb045e37bf5a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -37,7 +37,7 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def sqlType: DataType = ArrayType(DoubleType, false) - override def pyUDT: String = "pyspark.sql_tests.ExamplePointUDT" + override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" override def serialize(obj: Any): Seq[Double] = { obj match { From 31d435ecfdc24a788a6e38f4e82767bc275a3283 Mon Sep 17 00:00:00 2001 From: KaiXinXiaoLei Date: Mon, 9 Feb 2015 20:58:58 -0800 Subject: [PATCH 24/24] Add a config option to print DAG. Add a config option "spark.rddDebug.enable" to check whether to print DAG info. When "spark.rddDebug.enable" is true, it will print information about DAG in the log. Author: KaiXinXiaoLei Closes #4257 from KaiXinXiaoLei/DAGprint and squashes the following commits: d9fe42e [KaiXinXiaoLei] change log info c27ee76 [KaiXinXiaoLei] change log info 83c2b32 [KaiXinXiaoLei] change config option adcb14f [KaiXinXiaoLei] change the file. f4e7b9e [KaiXinXiaoLei] add a option to print DAG --- core/src/main/scala/org/apache/spark/SparkContext.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 71bdbc9b38ddb..8d3c3d000adf3 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1420,6 +1420,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val callSite = getCallSite val cleanedFunc = clean(func) logInfo("Starting job: " + callSite.shortForm) + if (conf.getBoolean("spark.logLineage", false)) { + logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString) + } dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, resultHandler, localProperties.get) progressBar.foreach(_.finishAll())