Skip to content

Commit

Permalink
airframe-sql: Fix outputAttributes of join on using node (#2711)
Browse files Browse the repository at this point in the history
  • Loading branch information
xerial authored Jan 21, 2023
1 parent 01a4f89 commit dcb0c47
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
* limitations under the License.
*/
package wvlet.airframe.sql.analyzer
import wvlet.airframe.sql.SQLErrorCode
import wvlet.airframe.sql.{SQLError, SQLErrorCode}
import wvlet.airframe.sql.analyzer.RewriteRule.PlanRewriter
import wvlet.airframe.sql.model.Expression._
import wvlet.airframe.sql.model.LogicalPlan._
Expand Down Expand Up @@ -51,7 +51,13 @@ object TypeResolver extends LogSupport {
): LogicalPlan = {
val resolvedPlan = rules
.foldLeft(plan) { (targetPlan, rule) =>
rule.transform(targetPlan, analyzerContext)
try {
rule.transform(targetPlan, analyzerContext)
} catch {
case e: SQLError =>
debug(s"Failed to resolve with: ${rule.name}\n${targetPlan.pp}")
throw e
}
}
resolvedPlan
}
Expand Down Expand Up @@ -471,7 +477,10 @@ object TypeResolver extends LogSupport {
findMatchInInputAttributes(context, expr, inputAttributes) match {
case lst if lst.length > 1 =>
trace(s"${expr} is ambiguous in ${lst}")
throw SQLErrorCode.SyntaxError.newException(s"${expr.sqlExpr} is ambiguous", expr.nodeLocation)
throw SQLErrorCode.SyntaxError.newException(
s"${expr.sqlExpr} is ambiguous:\n- ${lst.mkString("\n- ")}",
expr.nodeLocation
)
case lst =>
lst.headOption.getOrElse(expr)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,15 @@ trait Attribute extends LeafExpression with LogSupport {
}

/**
* Sub Attributes used to generate this Attribute
* @return
* Return columns used for generating this attribute
*/
def inputColumns: Seq[Attribute]

/**
* Return columns generated from this attribute
*/
def outputColumns: Seq[Attribute]

/**
* Return true if this Attribute matches with a given column path
* @param columnPath
Expand Down Expand Up @@ -423,6 +427,7 @@ object Expression {
this.copy(qualifier = newQualifier)
}
override def inputColumns: Seq[Attribute] = Seq.empty
override def outputColumns: Seq[Attribute] = Seq.empty
override def sourceColumns: Seq[SourceColumn] = Seq.empty

}
Expand Down Expand Up @@ -509,6 +514,7 @@ object Expression {
case None => Nil
}
}
override def outputColumns: Seq[Attribute] = inputColumns

override def dataType: DataType = {
columns
Expand Down Expand Up @@ -544,8 +550,10 @@ object Expression {
expr: Expression,
nodeLocation: Option[NodeLocation]
) extends Attribute {
override def inputColumns: Seq[Attribute] = Seq(this)
override def children: Seq[Expression] = Seq(expr)
override def inputColumns: Seq[Attribute] = Seq(this)
override def outputColumns: Seq[Attribute] = inputColumns

override def children: Seq[Expression] = Seq(expr)

override def withQualifier(newQualifier: Option[String]): Attribute = {
this.copy(qualifier = newQualifier)
Expand Down Expand Up @@ -583,9 +591,11 @@ object Expression {
override def name: String = expr.attributeName
override def dataType: DataType = expr.dataType

override def inputColumns: Seq[Attribute] = Seq(this)
override def children: Seq[Expression] = Seq(expr)
override def toString = s"${fullName}:${dataTypeName} := ${expr}"
override def inputColumns: Seq[Attribute] = Seq(this)
override def outputColumns: Seq[Attribute] = inputColumns

override def children: Seq[Expression] = Seq(expr)
override def toString = s"${fullName}:${dataTypeName} := ${expr}"

override def sqlExpr: String = expr.sqlExpr
override def withQualifier(newQualifier: Option[String]): Attribute = {
Expand Down Expand Up @@ -622,6 +632,8 @@ object Expression {
SingleColumn(e, qualifier, e.nodeLocation)
}
}
override def outputColumns: Seq[Attribute] = Seq(this)

override def children: Seq[Expression] = {
// MultiSourceColumn is a reference to the multiple columns. Do not traverse here
Seq.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,8 @@ object LogicalPlan {
right: Relation,
cond: JoinCriteria,
nodeLocation: Option[NodeLocation]
) extends Relation {
) extends Relation
with LogSupport {
override def modelName: String = joinType.toString
override def children: Seq[LogicalPlan] = Seq(left, right)
override def sig(config: QuerySignatureConfig): String = {
Expand All @@ -736,9 +737,12 @@ object LogicalPlan {
cond match {
case ju: ResolvedJoinUsing =>
val joinKeys = ju.keys
val otherAttributes = inputAttributes.filter { x =>
!joinKeys.exists(jk => jk.name == x.name)
}
val otherAttributes = inputAttributes
// Expand AllColumns here
.flatMap(_.outputColumns)
.filter { x =>
!joinKeys.exists(jk => jk.name == x.name)
}
// report join keys (merged) and other attributes
joinKeys ++ otherAttributes
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ case class ResolvedAttribute(
override def withQualifier(newQualifier: Option[String]): Attribute = {
this.copy(qualifier = newQualifier)
}
override def inputColumns: Seq[Attribute] = Seq(this)
override def inputColumns: Seq[Attribute] = Seq(this)
override def outputColumns: Seq[Attribute] = inputColumns

override def toString = {
sourceColumn match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -945,4 +945,8 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper {
k2.fullName shouldBe "t2.id"
}
}

test("resolve a join key propagated through select *") {
val p = analyze("select id from A join (select * from B) using(id)")
}
}

0 comments on commit dcb0c47

Please sign in to comment.