diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BaseJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BaseJoinExec.scala new file mode 100644 index 0000000000000..86b31eb0d0c7e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BaseJoinExec.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.execution.{BinaryExecNode, ExplainUtils} + +/** + * Holds common logic for join operators + */ +trait BaseJoinExec extends BinaryExecNode { + def joinType: JoinType + def condition: Option[Expression] + def leftKeys: Seq[Expression] + def rightKeys: Seq[Expression] + + override def simpleStringWithNodeId(): String = { + val opId = ExplainUtils.getOpId(this) + s"$nodeName $joinType ($opId)".trim + } + + override def verboseStringWithOperatorId(): String = { + val joinCondStr = if (condition.isDefined) { + s"${condition.get}" + } else "None" + if (leftKeys.nonEmpty || rightKeys.nonEmpty) { + s""" + |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} + |${ExplainUtils.generateFieldString("Left keys", leftKeys)} + |${ExplainUtils.generateFieldString("Right keys", rightKeys)} + |${ExplainUtils.generateFieldString("Join condition", joinCondStr)} + """.stripMargin + } else { + s""" + |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} + |${ExplainUtils.generateFieldString("Join condition", joinCondStr)} + """.stripMargin + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index fd4a7897c7ad1..08128d8f69dab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} +import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{BooleanType, LongType} @@ -44,7 +44,7 @@ case class BroadcastHashJoinExec( condition: Option[Expression], left: SparkPlan, right: SparkPlan) - extends BinaryExecNode with HashJoin with CodegenSupport { + extends HashJoin with CodegenSupport { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 5517c0dcdb188..888e7af7c07ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} +import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.{BitSet, CompactBuffer} @@ -32,7 +32,10 @@ case class BroadcastNestedLoopJoinExec( right: SparkPlan, buildSide: BuildSide, joinType: JoinType, - condition: Option[Expression]) extends BinaryExecNode { + condition: Option[Expression]) extends BaseJoinExec { + + override def leftKeys: Seq[Expression] = Nil + override def rightKeys: Seq[Expression] = Nil override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -43,6 +46,11 @@ case class BroadcastNestedLoopJoinExec( case BuildLeft => (right, left) } + override def simpleStringWithNodeId(): String = { + val opId = ExplainUtils.getOpId(this) + s"$nodeName $joinType ${buildSide} ($opId)".trim + } + override def requiredChildDistribution: Seq[Distribution] = buildSide match { case BuildLeft => BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 7e2f487fdcc5d..a71bf94c45034 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -22,7 +22,8 @@ import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Predicate, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner -import org.apache.spark.sql.execution.{BinaryExecNode, ExplainUtils, ExternalAppendOnlyUnsafeRowArray, SparkPlan} +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.CompletionIterator @@ -60,23 +61,17 @@ class UnsafeCartesianRDD( case class CartesianProductExec( left: SparkPlan, right: SparkPlan, - condition: Option[Expression]) extends BinaryExecNode { + condition: Option[Expression]) extends BaseJoinExec { + + override def joinType: JoinType = Inner + override def leftKeys: Seq[Expression] = Nil + override def rightKeys: Seq[Expression] = Nil + override def output: Seq[Attribute] = left.output ++ right.output override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - override def verboseStringWithOperatorId(): String = { - val joinCondStr = if (condition.isDefined) { - s"${condition.get}" - } else "None" - - s""" - |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} - |${ExplainUtils.generateFieldString("Join condition", joinCondStr)} - """.stripMargin - } - protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index f4796c194cb4f..7f90a51c1f234 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -22,39 +22,18 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{ExplainUtils, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.{ExplainUtils, RowIterator} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.{IntegralType, LongType} -trait HashJoin { - self: SparkPlan => - - def leftKeys: Seq[Expression] - def rightKeys: Seq[Expression] - def joinType: JoinType +trait HashJoin extends BaseJoinExec { def buildSide: BuildSide - def condition: Option[Expression] - def left: SparkPlan - def right: SparkPlan override def simpleStringWithNodeId(): String = { val opId = ExplainUtils.getOpId(this) s"$nodeName $joinType ${buildSide} ($opId)".trim } - override def verboseStringWithOperatorId(): String = { - val joinCondStr = if (condition.isDefined) { - s"${condition.get}" - } else "None" - - s""" - |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} - |${ExplainUtils.generateFieldString("Left keys", leftKeys)} - |${ExplainUtils.generateFieldString("Right keys", rightKeys)} - |${ExplainUtils.generateFieldString("Join condition", joinCondStr)} - """.stripMargin - } - override def output: Seq[Attribute] = { joinType match { case _: InnerLike => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index a8361fd7dd559..755a63e545ef1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -39,7 +39,7 @@ case class ShuffledHashJoinExec( condition: Option[Expression], left: SparkPlan, right: SparkPlan) - extends BinaryExecNode with HashJoin { + extends HashJoin { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 7a08dd1afd3a6..2c57956de5bca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -41,7 +41,7 @@ case class SortMergeJoinExec( condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isSkewJoin: Boolean = false) extends BinaryExecNode with CodegenSupport { + isSkewJoin: Boolean = false) extends BaseJoinExec with CodegenSupport { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -52,23 +52,6 @@ case class SortMergeJoinExec( override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator - override def simpleStringWithNodeId(): String = { - val opId = ExplainUtils.getOpId(this) - s"$nodeName $joinType ($opId)".trim - } - - override def verboseStringWithOperatorId(): String = { - val joinCondStr = if (condition.isDefined) { - s"${condition.get}" - } else "None" - s""" - |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} - |${ExplainUtils.generateFieldString("Left keys", leftKeys)} - |${ExplainUtils.generateFieldString("Right keys", rightKeys)} - |${ExplainUtils.generateFieldString("Join condition", joinCondStr)} - """.stripMargin - } - override def output: Seq[Attribute] = { joinType match { case _: InnerLike =>