Skip to content

Commit

Permalink
add smj support
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Dec 14, 2022
1 parent 897a3d1 commit 74a1cc6
Show file tree
Hide file tree
Showing 4 changed files with 455 additions and 273 deletions.
4 changes: 2 additions & 2 deletions ep/build-velox/src/get_velox.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

set -exu

VELOX_REPO=https://github.com/oap-project/velox.git
VELOX_BRANCH=main
VELOX_REPO=https://github.com/JkSelf/velox.git
VELOX_BRANCH=add-smj-operator

for arg in "$@"
do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,15 @@ trait HashJoinLikeExecTransformer
return false
}
val relNode = try {
createJoinRel(null, null, streamedPlan.output,
JoinUtils.createJoinRel(
streamedKeyExprs,
buildKeyExprs,
condition,
substraitJoinType,
exchangeTable,
joinType,
genJoinParametersBuilder(),
null, null, streamedPlan.output,
buildPlan.output, substraitContext, substraitContext.nextOperatorId, validation = true)
} catch {
case e: Throwable =>
Expand Down Expand Up @@ -731,14 +739,21 @@ trait HashJoinLikeExecTransformer
substraitContext.registerRelToOperator(operatorId)
}

if (preProjectionNeeded(streamedKeyExprs)) {
if (JoinUtils.preProjectionNeeded(streamedKeyExprs)) {
joinParams.streamPreProjectionNeeded = true
}
if (preProjectionNeeded(buildKeyExprs)) {
if (JoinUtils.preProjectionNeeded(buildKeyExprs)) {
joinParams.buildPreProjectionNeeded = true
}

val joinRel = createJoinRel(
val joinRel = JoinUtils.createJoinRel(
streamedKeyExprs,
buildKeyExprs,
condition,
substraitJoinType,
exchangeTable,
joinType,
genJoinParametersBuilder(),
inputStreamedRelNode,
inputBuildRelNode,
inputStreamedOutput,
Expand All @@ -748,205 +763,9 @@ trait HashJoinLikeExecTransformer

substraitContext.registerJoinParam(operatorId, joinParams)

createTransformContext(joinRel, inputStreamedOutput, inputBuildOutput)
}

private def preProjectionNeeded(keyExprs: Seq[Expression]): Boolean = {
!keyExprs.forall(_.isInstanceOf[AttributeReference])
}

protected def createPreProjectionIfNeeded(keyExprs: Seq[Expression],
inputNode: RelNode,
inputNodeOutput: Seq[Attribute],
joinOutput: Seq[Attribute],
substraitContext: SubstraitContext,
operatorId: java.lang.Long,
validation: Boolean)
: (Seq[(ExpressionNode, DataType)], RelNode, Seq[Attribute]) = {
if (!preProjectionNeeded(keyExprs)) {
// Skip pre-projection if all keys are [AttributeReference]s,
// which can be directly converted into SelectionNode.
val keys = keyExprs.map { expr =>
(ExpressionConverter.replaceWithExpressionTransformer(expr, joinOutput)
.asInstanceOf[AttributeReferenceTransformer]
.doTransform(substraitContext.registeredFunction), expr.dataType)
}
(keys, inputNode, inputNodeOutput)
} else {
// Pre-projection is constructed from original columns followed by join-key expressions.
val selectOrigins = inputNodeOutput.indices.map(ExpressionBuilder.makeSelection(_))
val appendedKeys = keyExprs.flatMap {
case _: AttributeReference => None
case expr =>
Some(
(ExpressionConverter
.replaceWithExpressionTransformer(expr, inputNodeOutput)
.doTransform(substraitContext.registeredFunction), expr.dataType))
}
val preProjectNode = RelBuilder.makeProjectRel(
inputNode,
new java.util.ArrayList[ExpressionNode](
(selectOrigins ++ appendedKeys.map(_._1)).asJava),
createExtensionNode(inputNodeOutput, validation),
substraitContext,
operatorId)

// Compute index for join keys in join outputs.
val offset = joinOutput.size - inputNodeOutput.size + selectOrigins.size
val appendedKeysAndIndices = appendedKeys.zipWithIndex.iterator
val keys = keyExprs.map {
case a: AttributeReference =>
// The selection index for original AttributeReference is unchanged.
(ExpressionConverter.replaceWithExpressionTransformer(a, joinOutput)
.asInstanceOf[AttributeReferenceTransformer]
.doTransform(substraitContext.registeredFunction), a.dataType)
case _ =>
val (key, idx) = appendedKeysAndIndices.next()
(ExpressionBuilder.makeSelection(idx + offset), key._2)
}
(keys, preProjectNode, inputNodeOutput ++
appendedKeys.zipWithIndex.map { case (key, idx) =>
// Create output attributes for appended keys.
// This is used as place holder for finding the right column indexes in post-join filters.
AttributeReference(s"col_${idx + offset}", key._2)()
}
)
}
}


protected def createJoinRel(inputStreamedRelNode: RelNode,
inputBuildRelNode: RelNode,
inputStreamedOutput: Seq[Attribute],
inputBuildOutput: Seq[Attribute],
substraitContext: SubstraitContext,
operatorId: java.lang.Long,
validation: Boolean = false): RelNode = {
// Create pre-projection for build/streamed plan. Append projected keys to each side.
val (streamedKeys, streamedRelNode, streamedOutput) = createPreProjectionIfNeeded(
streamedKeyExprs,
inputStreamedRelNode,
inputStreamedOutput,
inputStreamedOutput,
substraitContext,
operatorId,
validation)

val (buildKeys, buildRelNode, buildOutput) = createPreProjectionIfNeeded(
buildKeyExprs,
inputBuildRelNode,
inputBuildOutput,
streamedOutput ++ inputBuildOutput,
substraitContext,
operatorId,
validation)

// Combine join keys to make a single expression.
val joinExpressionNode = (streamedKeys zip buildKeys).map {
case ((leftKey, leftType), (rightKey, rightType)) =>
HashJoinLikeExecTransformer.makeEqualToExpression(
leftKey, leftType, rightKey, rightType, substraitContext.registeredFunction)
}.reduce((l, r) =>
HashJoinLikeExecTransformer.makeAndExpression(l, r, substraitContext.registeredFunction))

// Create post-join filter, which will be computed in hash join.
val postJoinFilter = condition.map {
expr =>
ExpressionConverter
.replaceWithExpressionTransformer(expr, streamedOutput ++ buildOutput)
.doTransform(substraitContext.registeredFunction)
}

// Create JoinRel.
val joinRel = RelBuilder.makeJoinRel(
streamedRelNode,
buildRelNode,
substraitJoinType,
joinExpressionNode,
postJoinFilter.orNull,
createJoinExtensionNode(streamedOutput ++ buildOutput),
substraitContext,
operatorId)

// Result projection will drop the appended keys, and exchange columns order if BuildLeft.
val resultProjection = if (exchangeTable) {
val (leftOutput, rightOutput) =
getDirectJoinOutput(inputBuildOutput, inputStreamedOutput)
joinType match {
case _: ExistenceJoin =>
inputBuildOutput.indices.map(ExpressionBuilder.makeSelection(_)) ++
Seq(ExpressionBuilder.makeSelection(buildOutput.size))
case LeftExistence(_) =>
leftOutput.indices.map(ExpressionBuilder.makeSelection(_))
case _ =>
// Exchange the order of build and streamed.
leftOutput.indices.map(idx =>
ExpressionBuilder.makeSelection(idx + streamedOutput.size)) ++
rightOutput.indices
.map(ExpressionBuilder.makeSelection(_))
}
} else {
val (leftOutput, rightOutput) =
getDirectJoinOutput(inputStreamedOutput, inputBuildOutput)
if (joinType.isInstanceOf[ExistenceJoin]) {
inputStreamedOutput.indices.map(ExpressionBuilder.makeSelection(_)) ++
Seq(ExpressionBuilder.makeSelection(streamedOutput.size))
} else {
leftOutput.indices.map(ExpressionBuilder.makeSelection(_)) ++
rightOutput.indices.map(idx => ExpressionBuilder.makeSelection(idx + streamedOutput.size))
}
}

RelBuilder.makeProjectRel(
joinRel,
new java.util.ArrayList[ExpressionNode](resultProjection.asJava),
createExtensionNode(
if (exchangeTable) getDirectJoinOutputSeq(buildOutput, streamedOutput)
else getDirectJoinOutputSeq(streamedOutput, buildOutput),
validation),
substraitContext,
operatorId)
}

private def createTransformContext(rel: RelNode,
inputStreamedOutput: Seq[Attribute],
inputBuildOutput: Seq[Attribute]): TransformContext = {
val inputAttributes = if (exchangeTable) {
inputBuildOutput ++ inputStreamedOutput
} else {
inputStreamedOutput ++ inputBuildOutput
}
TransformContext(inputAttributes, output, rel)
}

private def createEnhancement(output: Seq[Attribute]): com.google.protobuf.Any = {
val inputTypeNodes = output.map { attr =>
ConverterUtils.getTypeNode(attr.dataType, attr.nullable)
}
// Normally the enhancement node is only used for plan validation. But here the enhancement
// is also used in execution phase. In this case an empty typeUrlPrefix need to be passed,
// so that it can be correctly parsed into json string on the cpp side.
Any.pack(TypeBuilder.makeStruct(false,
new util.ArrayList[TypeNode](inputTypeNodes.asJava)).toProtobuf,
/* typeUrlPrefix */"")
}

protected def createExtensionNode(output: Seq[Attribute],
validation: Boolean): AdvancedExtensionNode = {
// Use field [enhancement] in a extension node for input type validation.
if (validation) {
ExtensionBuilder.makeAdvancedExtension(createEnhancement(output))
} else {
null
}
}

protected def createJoinExtensionNode(output: Seq[Attribute]): AdvancedExtensionNode = {
// Use field [optimization] in a extension node
// to send some join parameters through Substrait plan.
val joinParameters = genJoinParametersBuilder()
val enhancement = createEnhancement(output)
ExtensionBuilder.makeAdvancedExtension(joinParameters.build(), enhancement)
JoinUtils.createTransformContext(
exchangeTable, output, joinRel,
inputStreamedOutput, inputBuildOutput)
}

def genJoinParametersBuilder(): Any.Builder = {
Expand Down Expand Up @@ -974,38 +793,6 @@ trait HashJoinLikeExecTransformer
(0, 0, "")
}

protected def getDirectJoinOutputSeq(leftOutput: Seq[Attribute],
rightOutput: Seq[Attribute]): Seq[Attribute] = {
val (left, right) = getDirectJoinOutput(leftOutput, rightOutput)
left ++ right
}

// Return the direct join output.
protected def getDirectJoinOutput(leftOutput: Seq[Attribute],
rightOutput: Seq[Attribute])
: (Seq[Attribute], Seq[Attribute]) = {
joinType match {
case _: InnerLike =>
(leftOutput, rightOutput)
case LeftOuter =>
(leftOutput, rightOutput.map(_.withNullability(true)))
case RightOuter =>
(leftOutput.map(_.withNullability(true)), rightOutput)
case FullOuter =>
(leftOutput.map(_.withNullability(true)), rightOutput.map(_.withNullability(true)))
case j: ExistenceJoin =>
(leftOutput :+ j.exists, Nil)
case LeftExistence(_) =>
// LeftSemi | LeftAnti | ExistenceJoin.
(leftOutput, Nil)
case x =>
throw new IllegalArgumentException(
s"${
getClass.getSimpleName
} not take $x as the JoinType")
}
}

override protected def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException(
s"${
Expand Down
Loading

0 comments on commit 74a1cc6

Please sign in to comment.