Skip to content

Commit

Permalink
Merge pull request alteryx#8 from davies/col-computability
Browse files Browse the repository at this point in the history
fix python tests
  • Loading branch information
rxin committed Feb 3, 2015
2 parents fd92bc7 + f79034c commit 6527b86
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 37 deletions.
75 changes: 52 additions & 23 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2124,6 +2124,10 @@ def head(self, n=None):
return rs[0] if rs else None
return self.take(n)

def first(self):
""" Return the first row. """
return self.head()

def tail(self):
raise NotImplemented

Expand Down Expand Up @@ -2159,7 +2163,7 @@ def select(self, *cols):
else:
cols = [c._jc for c in cols]
jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
jdf = self._jdf.select(self._jdf.toColumnArray(jcols))
jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
return DataFrame(jdf, self.sql_ctx)

def filter(self, condition):
Expand Down Expand Up @@ -2189,7 +2193,7 @@ def groupBy(self, *cols):
else:
cols = [c._jc for c in cols]
jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols))
jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
return GroupedDataFrame(jdf, self.sql_ctx)

def agg(self, *exprs):
Expand Down Expand Up @@ -2278,14 +2282,17 @@ def agg(self, *exprs):
:param exprs: list or aggregate columns or a map from column
name to agregate methods.
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
jmap = MapConverter().convert(exprs[0],
self.sql_ctx._sc._gateway._gateway_client)
jdf = self._jdf.agg(jmap)
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns"
jdf = self._jdf.agg(*exprs)
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
jcols = ListConverter().convert([c._jc for c in exprs[1:]],
self.sql_ctx._sc._gateway._gateway_client)
jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
return DataFrame(jdf, self.sql_ctx)

@dfapi
Expand Down Expand Up @@ -2347,7 +2354,7 @@ def _create_column_from_literal(literal):

def _create_column_from_name(name):
sc = SparkContext._active_spark_context
return sc._jvm.Column(name)
return sc._jvm.IncomputableColumn(name)


def _scalaMethod(name):
Expand All @@ -2371,7 +2378,7 @@ def _(self):
return _


def _bin_op(name, pass_literal_through=False):
def _bin_op(name, pass_literal_through=True):
""" Create a method for given binary operator
Keyword arguments:
Expand Down Expand Up @@ -2465,18 +2472,17 @@ def __init__(self, jc, jdf=None, sql_ctx=None):
# __getattr__ = _bin_op("getField")

# string methods
rlike = _bin_op("rlike", pass_literal_through=True)
like = _bin_op("like", pass_literal_through=True)
startswith = _bin_op("startsWith", pass_literal_through=True)
endswith = _bin_op("endsWith", pass_literal_through=True)
rlike = _bin_op("rlike")
like = _bin_op("like")
startswith = _bin_op("startsWith")
endswith = _bin_op("endsWith")
upper = _unary_op("upper")
lower = _unary_op("lower")

def substr(self, startPos, pos):
if type(startPos) != type(pos):
raise TypeError("Can not mix the type")
if isinstance(startPos, (int, long)):

jc = self._jc.substr(startPos, pos)
elif isinstance(startPos, Column):
jc = self._jc.substr(startPos._jc, pos._jc)
Expand Down Expand Up @@ -2507,30 +2513,53 @@ def cast(self, dataType):
return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx)


def _to_java_column(col):
if isinstance(col, Column):
jcol = col._jc
else:
jcol = _create_column_from_name(col)
return jcol


def _aggregate_func(name):
""" Create a function for aggregator by name"""
def _(col):
sc = SparkContext._active_spark_context
if isinstance(col, Column):
jcol = col._jc
else:
jcol = _create_column_from_name(col)
jc = getattr(sc._jvm.org.apache.spark.sql.Dsl, name)(jcol)
jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
return Column(jc)

return staticmethod(_)


class Aggregator(object):
"""
A collections of builtin aggregators
"""
max = _aggregate_func("max")
min = _aggregate_func("min")
avg = mean = _aggregate_func("mean")
sum = _aggregate_func("sum")
first = _aggregate_func("first")
last = _aggregate_func("last")
count = _aggregate_func("count")
AGGS = [
'lit', 'col', 'column', 'upper', 'lower', 'sqrt', 'abs',
'min', 'max', 'first', 'last', 'count', 'avg', 'mean', 'sum', 'sumDistinct',
]
for _name in AGGS:
locals()[_name] = _aggregate_func(_name)
del _name

@staticmethod
def countDistinct(col, *cols):
sc = SparkContext._active_spark_context
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
sc._gateway._gateway_client)
jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
sc._jvm.Dsl.toColumns(jcols))
return Column(jc)

@staticmethod
def approxCountDistinct(col, rsd=None):
sc = SparkContext._active_spark_context
if rsd is None:
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
else:
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
return Column(jc)


def _test():
Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,9 +1029,11 @@ def test_aggregator(self):
g = df.groupBy()
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
# TODO(davies): fix aggregators

from pyspark.sql import Aggregator as Agg
# self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
self.assertEqual((0, u'99'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
self.assertTrue(95 < g.agg(Agg.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(Agg.countDistinct(df.value)).first()[0])

def test_help_command(self):
# Regression test for SPARK-5464
Expand Down
4 changes: 0 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -500,10 +500,6 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] {
////////////////////////////////////////////////////////////////////////////
// for Python API
////////////////////////////////////////////////////////////////////////////
/**
* A helpful function for Py4j, convert a list of Column to an array
*/
protected[sql] def toColumnArray(cols: JList[Column]): Array[Column]

/**
* Converts a JavaRDD to a PythonRDD.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,6 @@ private[sql] class DataFrameImpl protected[sql](
////////////////////////////////////////////////////////////////////////////
// for Python API
////////////////////////////////////////////////////////////////////////////
protected[sql] override def toColumnArray(cols: JList[Column]): Array[Column] = {
cols.toList.toArray
}

protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = {
val fieldTypes = schema.fields.map(_.dataType)
val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
Expand Down
13 changes: 11 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.sql

import java.util.{List => JList}

import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
import scala.collection.JavaConversions._

import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -105,8 +108,7 @@ object Dsl {
def countDistinct(expr: Column, exprs: Column*): Column =
CountDistinct((expr +: exprs).map(_.expr))

def approxCountDistinct(e: Column): Column =
ApproxCountDistinct(e.expr)
def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr)
def approxCountDistinct(e: Column, rsd: Double): Column =
ApproxCountDistinct(e.expr, rsd)

Expand All @@ -121,6 +123,13 @@ object Dsl {
def sqrt(e: Column): Column = Sqrt(e.expr)
def abs(e: Column): Column = Abs(e.expr)

/**
* This is a private API for Python
* TODO: move this to a private package
*/
def toColumns(cols: JList[Column]): Seq[Column] = {
cols.toList.toSeq
}

// scalastyle:off

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql

import java.util.{List => JList}

import scala.language.implicitConversions
import scala.collection.JavaConversions._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,5 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten

override def toJSON: RDD[String] = err()

protected[sql] override def toColumnArray(cols: java.util.List[Column]): Array[Column] = err()

protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = err()
}

0 comments on commit 6527b86

Please sign in to comment.