From 28cf3db779322a487d26fa17282889e217f2d6b5 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 14 May 2024 10:16:21 +0800 Subject: [PATCH] [SPARK-48259][CONNECT][TESTS] Add 3 missing methods in dsl ### What changes were proposed in this pull request? Add 3 missing methods in dsl ### Why are the changes needed? those methods could be used in tests ### Does this PR introduce _any_ user-facing change? no, test only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #46559 from zhengruifeng/missing_3_func. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .../spark/sql/connect/dsl/package.scala | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 6aadb6c34b779..da9a0865b8ca6 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -513,6 +513,25 @@ package object dsl { freqItems(cols.toArray, support) def freqItems(cols: Seq[String]): Relation = freqItems(cols, 0.01) + + def sampleBy(col: String, fractions: Map[Any, Double], seed: Long): Relation = { + Relation + .newBuilder() + .setSampleBy( + StatSampleBy + .newBuilder() + .setInput(logicalPlan) + .addAllFractions(fractions.toSeq.map { case (k, v) => + StatSampleBy.Fraction + .newBuilder() + .setStratum(toLiteralProto(k)) + .setFraction(v) + .build() + }.asJava) + .setSeed(seed) + .build()) + .build() + } } def select(exprs: Expression*): Relation = { @@ -587,6 +606,10 @@ package object dsl { .build() } + def filter(condition: Expression): Relation = { + where(condition) + } + def deduplicate(colNames: Seq[String]): Relation = Relation .newBuilder() @@ -641,6 +664,10 @@ package object dsl { join(otherPlan, joinType, usingColumns, None) } + def crossJoin(otherPlan: Relation): Relation = { + join(otherPlan, JoinType.JOIN_TYPE_CROSS, Seq(), None) + } + private def join( otherPlan: Relation, joinType: JoinType = JoinType.JOIN_TYPE_INNER,