Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dotted syntax (via lambda) for TypedColumns #449

Merged
merged 9 commits into from
Sep 6, 2021
Merged
37 changes: 15 additions & 22 deletions dataset/src/main/scala/frameless/TypedColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import shapeless.ops.record.Selector
import scala.annotation.implicitNotFound
import scala.reflect.ClassTag

import scala.language.experimental.macros

sealed trait UntypedExpression[T] {
def expr: Expression
def uencoder: TypedEncoder[_]
Expand Down Expand Up @@ -864,8 +866,7 @@ object SortedTypedColumn {
implicit def caseTypedColumn[T, U : CatalystOrdered] = at[TypedColumn[T, U]](c => defaultAscending(c))
implicit def caseTypeSortedColumn[T, U] = at[SortedTypedColumn[T, U]](identity)
}
}

}

object TypedColumn {
/** Evidence that type `T` has column `K` with type `V`. */
Expand Down Expand Up @@ -896,26 +897,18 @@ object TypedColumn {
i1: Selector.Aux[H, K, V]
): Exists[T, K, V] = new Exists[T, K, V] {}
}
}

/** Compute the intersection of two types:
*
* - With[A, A] = A
* - With[A, B] = A with B (when A != B)
*
* This type function is needed to prevent IDEs from infering large types
* with shape `A with A with ... with A`. These types could be confusing for
* both end users and IDE's type checkers.
*/
trait With[A, B] { type Out }

trait LowPrioWith {
type Aux[A, B, W] = With[A, B] { type Out = W }
protected[this] val theInstance = new With[Any, Any] {}
protected[this] def of[A, B, W]: With[A, B] { type Out = W } = theInstance.asInstanceOf[Aux[A, B, W]]
implicit def identity[T]: Aux[T, T, T] = of[T, T, T]
}
/**
* {{{
* import frameless.TypedColumn
*
* case class Foo(id: Int, bar: String)
*
* val colbar: TypedColumn[Foo, String] = TypedColumn { foo: Foo => foo.bar }
* val colid = TypedColumn[Foo, Int](_.id)
* }}}
*/
def apply[T, U](x: T => U): TypedColumn[T, U] =
macro TypedColumnMacroImpl.applyImpl[T, U]

object With extends LowPrioWith {
implicit def combine[A, B]: Aux[A, B, A with B] = of[A, B, A with B]
}
84 changes: 84 additions & 0 deletions dataset/src/main/scala/frameless/TypedColumnMacroImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package frameless

import scala.reflect.macros.whitebox

private[frameless] object TypedColumnMacroImpl {

def applyImpl[T: c.WeakTypeTag, U: c.WeakTypeTag](c: whitebox.Context)(x: c.Tree): c.Expr[TypedColumn[T, U]] = {
import c.universe._

val t = c.weakTypeOf[T]
val u = c.weakTypeOf[U]

def buildExpression(path: List[String]): c.Expr[TypedColumn[T, U]] = {
val columnName = path.mkString(".")

c.Expr[TypedColumn[T, U]](q"new _root_.frameless.TypedColumn[$t, $u]((org.apache.spark.sql.functions.col($columnName)).expr)")
}

def abort(msg: String) = c.abort(c.enclosingPosition, msg)

@annotation.tailrec
def path(in: Select, out: List[TermName]): List[TermName] =
in.qualifier match {
case sub: Select =>
path(sub, in.name.toTermName :: out)

case id: Ident =>
id.name.toTermName :: in.name.toTermName :: out

case u =>
abort(s"Unsupported selection: $u")
}

@annotation.tailrec
def check(current: Type, in: List[TermName]): Boolean = in match {
case next :: tail => {
val sym = current.decl(next).asTerm

if (!sym.isStable) {
abort(s"Stable term expected: ${current}.${next}")
}

check(sym.info, tail)
}

case _ =>
true
}

x match {
case fn: Function => fn.body match {
case select: Select if select.name.isTermName =>
val expectedRoot: Option[String] = fn.vparams match {
case List(rt) if rt.rhs == EmptyTree =>
Option.empty[String]

case List(rt) =>
Some(rt.toString)

case u =>
abort(s"Select expression must have a single parameter: ${u mkString ", "}")
}

path(select, List.empty) match {
case root :: tail if (
expectedRoot.forall(_ == root) && check(t, tail)) => {
val colPath = tail.mkString(".")

c.Expr[TypedColumn[T, U]](q"new _root_.frameless.TypedColumn[$t, $u]((org.apache.spark.sql.functions.col($colPath)).expr)")
}

case _ =>
abort(s"Invalid select expression: $select")
}

case t =>
abort(s"Select expression expected: $t")
}

case _ =>
abort(s"Function expected: $x")
}
}
}
13 changes: 13 additions & 0 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import shapeless.labelled.FieldType
import shapeless.ops.hlist.{Diff, IsHCons, Mapper, Prepend, ToTraversable, Tupler}
import shapeless.ops.record.{Keys, Modifier, Remover, Values}

import scala.language.experimental.macros

/** [[TypedDataset]] is a safer interface for working with `Dataset`.
*
* NOTE: Prefer `TypedDataset.create` over `new TypedDataset` unless you
Expand Down Expand Up @@ -238,6 +240,17 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
): TypedColumn[T, A] =
new TypedColumn[T, A](dataset(column.value.name).as[A](TypedExpressionEncoder[A]))

/** Returns `TypedColumn` of type `A` given a lambda indicating the field.
*
* {{{
* td.col(_.id)
* }}}
*
* It is statically checked that column with such name exists and has type `A`.
*/
def col[A](x: Function1[T, A]): TypedColumn[T, A] =
macro TypedColumnMacroImpl.applyImpl[T, A]

/** Projects the entire TypedDataset[T] into a single column of type TypedColumn[T,T]
* {{{
* ts: TypedDataset[Foo] = ...
Expand Down
27 changes: 27 additions & 0 deletions dataset/src/main/scala/frameless/With.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package frameless

/** Compute the intersection of two types:
*
* - With[A, A] = A
* - With[A, B] = A with B (when A != B)
*
* This type function is needed to prevent IDEs from infering large types
* with shape `A with A with ... with A`. These types could be confusing for
* both end users and IDE's type checkers.
*/
trait With[A, B] { type Out }

object With extends LowPrioWith {
implicit def combine[A, B]: Aux[A, B, A with B] = of[A, B, A with B]
}

private[frameless] sealed trait LowPrioWith {
type Aux[A, B, W] = With[A, B] { type Out = W }

protected[this] val theInstance = new With[Any, Any] {}

protected[this] def of[A, B, W]: With[A, B] { type Out = W } =
theInstance.asInstanceOf[Aux[A, B, W]]

implicit def identity[T]: Aux[T, T, T] = of[T, T, T]
}
20 changes: 19 additions & 1 deletion dataset/src/test/scala/frameless/ColumnTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import ceedubs.irrec.regex.gen.CharRegexGen.genCharRegexAndCandidate

import scala.math.Ordering.Implicits._

class ColumnTests extends TypedDatasetSuite with Matchers {
final class ColumnTests extends TypedDatasetSuite with Matchers {

private implicit object OrderingImplicits {
implicit val sqlDateOrdering: Ordering[SQLDate] = Ordering.by(_.days)
Expand Down Expand Up @@ -438,4 +438,22 @@ class ColumnTests extends TypedDatasetSuite with Matchers {

"ds.select(ds('_2).field('_3))" shouldNot typeCheck
}

test("col through lambda") {
case class MyClass1(a: Int, b: String, c: MyClass2)
case class MyClass2(d: Long)

val ds = TypedDataset.create(Seq(MyClass1(1, "2", MyClass2(3L)), MyClass1(4, "5", MyClass2(6L))))

assert(ds.col(_.a).isInstanceOf[TypedColumn[MyClass1, Int]])
assert(ds.col(_.b).isInstanceOf[TypedColumn[MyClass1, String]])
assert(ds.col(_.c.d).isInstanceOf[TypedColumn[MyClass1, Long]])

"ds.col(_.c.toString)" shouldNot typeCheck
"ds.col(_.c.toInt)" shouldNot typeCheck
"ds.col(x => java.lang.Math.abs(x.a))" shouldNot typeCheck

// we should be able to block the following as well...
"ds.col(_.a.toInt)" shouldNot typeCheck
}
}
79 changes: 79 additions & 0 deletions dataset/src/test/scala/frameless/ColumnViaLambdaTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package frameless

import org.scalatest.matchers.should.Matchers
import shapeless.test.illTyped

case class MyClass1(a: Int, b: String, c: MyClass2, g: Option[MyClass4])
case class MyClass2(d: Long, e: MyClass3)
case class MyClass3(f: Double)
case class MyClass4(h: Boolean)

final class ColumnViaLambdaTests extends TypedDatasetSuite with Matchers {

def ds = {
TypedDataset.create(Seq(
MyClass1(1, "2", MyClass2(3L, MyClass3(7.0D)), Some(MyClass4(true))),
MyClass1(4, "5", MyClass2(6L, MyClass3(8.0D)), None)))
}

test("col(_.a)") {
val col = TypedColumn[MyClass1, Int](_.a)

ds.select(col).collect.run() shouldEqual Seq(1, 4)
}

test("col(x => x.a") {
val col = TypedColumn[MyClass1, Int](x => x.a)

ds.select(col).collect.run() shouldEqual Seq(1, 4)
}

test("col((x: MyClass1) => x.a") {
val col = TypedColumn { (x: MyClass1) => x.a }

ds.select(col).collect.run() shouldEqual Seq(1, 4)
}

test("col((x: MyClass1) => x.c.e.f") {
val col = TypedColumn { (x: MyClass1) => x.c.e.f }

ds.select(col).collect.run() shouldEqual Seq(7.0D, 8.0D)
}

test("col(_.c.d)") {
val col = TypedColumn[MyClass1, Long](_.c.d)

ds.select(col).collect.run() shouldEqual Seq(3L, 6L)
}

test("col(_.c.e.f)") {
val col = TypedColumn[MyClass1, Double](_.c.e.f)

ds.select(col).collect.run() shouldEqual Seq(7.0D, 8.0D)
}

test("col(_.c.d) as int does not compile (is long)") {
illTyped("TypedColumn[MyClass1, Int](_.c.d)")
}

test("col(_.g.h does not compile") {
cchantep marked this conversation as resolved.
Show resolved Hide resolved
val col = ds.col(_.g) // the path "ends" at .g (can't access h)
cchantep marked this conversation as resolved.
Show resolved Hide resolved
illTyped("""ds.col(_.g.h)""")
}

test("col(_.a.toString) does not compile") {
illTyped("""ds.col(_.a.toString)""")
}

test("col(_.a.toString.size) does not compile") {
illTyped("""ds.col(_.a.toString.size)""")
}

test("col((x: MyClass1) => x.toString.size) does not compile") {
illTyped("""ds.col((x: MyClass1) => x.toString.size)""")
}

test("col(x => java.lang.Math.abs(x.a)) does not compile") {
illTyped("""col(x => java.lang.Math.abs(x.a))""")
}
}