Skip to content

Commit

Permalink
Fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
takezoe committed Jul 12, 2023
1 parent adfea27 commit 3a1f979
Show file tree
Hide file tree
Showing 11 changed files with 199 additions and 152 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ object SQLAnonymizer extends LogSupport {
} else {
None
}
val v = UnresolvedAttribute(qualifier, parts.last, u.nodeLocation)
val v = UnresolvedAttribute(qualifier, parts.last, None, u.nodeLocation)
m += u -> v
}
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ object TypeResolver extends LogSupport {
inputs(index) match {
case a: AllColumns =>
resolveIndex(index, a.inputColumns)
case SingleColumn(expr, _, _) =>
case SingleColumn(expr, _, _, _) =>
expr
case Alias(_, _, expr, _) =>
case Alias(_, _, expr, _, _) =>
expr
case other =>
other
Expand Down Expand Up @@ -222,7 +222,7 @@ object TypeResolver extends LogSupport {
case Some(cte) =>
CTERelationRef(
qname.fullName,
cte.outputAttributes.map(_.withQualifier(qname.fullName)),
cte.outputAttributes.map(_.withTableAlias(qname.fullName)),
plan.nodeLocation
)
case None =>
Expand Down Expand Up @@ -260,14 +260,14 @@ object TypeResolver extends LogSupport {
val mergedJoinKeys = resolvedJoinKeys
.groupBy(_.attributeName).map { case (name, keys) =>
val resolvedKeys = keys.flatMap {
case SingleColumn(r: ResolvedAttribute, qual, _) =>
case SingleColumn(r: ResolvedAttribute, qual, _, _) =>
Seq(r.withQualifier(qual))
case m: MultiSourceColumn =>
m.inputs
case other =>
Seq(other)
}
MultiSourceColumn(resolvedKeys, None, None)
MultiSourceColumn(resolvedKeys, None, None, None)
}
.toSeq
// Preserve the original USING(k1, k2, ...) order
Expand Down Expand Up @@ -348,14 +348,14 @@ object TypeResolver extends LogSupport {
): Seq[Attribute] = {
val resolvedColumns = Seq.newBuilder[Attribute]
outputColumns.map {
case a @ Alias(qualifier, name, expr, _) =>
case a @ Alias(qualifier, name, expr, _, _) =>
val resolved = resolveExpression(context, expr, inputAttributes)
if (expr eq resolved) {
resolvedColumns += a
} else {
resolvedColumns += a.copy(expr = resolved)
}
case s @ SingleColumn(expr, qualifier, nodeLocation) =>
case s @ SingleColumn(expr, qualifier, _, nodeLocation) =>
resolveExpression(context, expr, inputAttributes) match {
case a: Attribute =>
resolvedColumns += a.withQualifier(qualifier)
Expand All @@ -372,15 +372,15 @@ object TypeResolver extends LogSupport {

def resolveAttribute(attribute: Attribute): Attribute = {
attribute match {
case a @ Alias(qualifier, name, attr: Attribute, _) =>
case a @ Alias(qualifier, name, attr: Attribute, _, _) =>
val resolved = resolveAttribute(attr)
if (attr eq resolved) {
a
} else {
a.copy(expr = resolved)
}
case SingleColumn(a: Attribute, qualifier, _) if a.resolved =>
a.withQualifier(qualifier)
case SingleColumn(a: Attribute, qualifier, _, _) if a.resolved =>
a
case m: MultiSourceColumn =>
var changed = false
val resolvedInputs = m.inputs.map {
Expand All @@ -402,28 +402,25 @@ object TypeResolver extends LogSupport {
}

private def toResolvedAttribute(name: String, expr: Expression): Attribute = {

def findSourceColumn(e: Expression): Option[SourceColumn] = {
e match {
case r: ResolvedAttribute =>
r.sourceColumn
case a: Alias =>
findSourceColumn(a.expr)
case _ => None
case r: ResolvedAttribute => r.sourceColumn
case a: Alias => findSourceColumn(a.expr)
case _ => None
}
}

expr match {
case a: Alias =>
ResolvedAttribute(a.name, a.expr.dataType, a.qualifier, findSourceColumn(a.expr), a.nodeLocation)
ResolvedAttribute(a.name, a.expr.dataType, a.qualifier, findSourceColumn(a.expr), None, a.nodeLocation)
case s: SingleColumn =>
ResolvedAttribute(name, s.dataType, s.qualifier, findSourceColumn(s.expr), s.nodeLocation)
ResolvedAttribute(name, s.dataType, s.qualifier, findSourceColumn(s.expr), None, s.nodeLocation)
case a: Attribute =>
// No need to resolve Attribute expressions
a
a.withTableAlias(None)
case other =>
// Resolve expr as ResolvedAttribute so as not to pull-up too much details
ResolvedAttribute(name, other.dataType, None, findSourceColumn(expr), other.nodeLocation)
ResolvedAttribute(name, other.dataType, None, findSourceColumn(expr), None, other.nodeLocation)
}
}

Expand Down Expand Up @@ -458,20 +455,18 @@ object TypeResolver extends LogSupport {
val results = expr match {
case i: Identifier =>
lookup(i.value, context).map(toResolvedAttribute(i.value, _))
case u @ UnresolvedAttribute(qualifier, name, _) =>
case u @ UnresolvedAttribute(qualifier, name, _, _) =>
lookup(u.fullName, context).map(toResolvedAttribute(name, _).withQualifier(qualifier))
case a @ AllColumns(qualifier, None, _) =>
case a @ AllColumns(qualifier, None, _, _) =>
// Resolve the inputs of AllColumn as ResolvedAttribute
// so as not to pull up too much details
val allColumns = resolvedAttributes.map {
case a: Attribute =>
// Attribute can be used as is
a
case other =>
toResolvedAttribute(other.name, other)
// Attribute can be used as is
case a: Attribute => a
case other => toResolvedAttribute(other.name, other)
}
List(a.copy(columns = Some((qualifier match {
case Some(q) => allColumns.filter(_.qualifier.contains(q))
case Some(q) => allColumns.filter(c => c.qualifier.contains(q) || c.tableAlias.contains(q))
case None => allColumns
}).map(_.withQualifier(None)))))
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import wvlet.airframe.sql.Assertion._
import wvlet.log.LogSupport

import java.util.Locale
import scala.annotation.tailrec

/**
*/
Expand Down Expand Up @@ -294,12 +295,6 @@ trait Attribute extends LeafExpression with LogSupport {
def qualifier: Option[String]
def withQualifier(newQualifier: String): Attribute = withQualifier(Some(newQualifier))
def withQualifier(newQualifier: Option[String]): Attribute
def setQualifierIfEmpty(newQualifier: Option[String]): Attribute = {
qualifier match {
case Some(q) => this
case None => this.withQualifier(newQualifier)
}
}

import Expression.Alias
def alias: Option[String] = {
Expand All @@ -320,11 +315,15 @@ trait Attribute extends LeafExpression with LogSupport {
// No need to have alias
other
case other =>
Alias(qualifier, alias, other, None)
Alias(qualifier, alias, other, other.tableAlias, None)
}
}
}

def tableAlias: Option[String]
def withTableAlias(tableAlias: String): Attribute = withTableAlias(Some(tableAlias))
def withTableAlias(tableAlias: Option[String]): Attribute

/**
* Return columns used for generating this attribute
*/
Expand Down Expand Up @@ -355,7 +354,7 @@ trait Attribute extends LeafExpression with LogSupport {
columnPath.table match {
// TODO handle (catalog).(database).(table) names in the qualifier
case Some(tableName) =>
qualifier.exists(_ == tableName) && matchesWith(columnPath.columnName)
(qualifier.contains(tableName) || tableAlias.contains(tableName)) && matchesWith(columnPath.columnName)
case None =>
matchesWith(columnPath.columnName)
}
Expand All @@ -366,12 +365,15 @@ trait Attribute extends LeafExpression with LogSupport {
* via Join, Union), return MultiSourceAttribute.
*/
def matched(columnPath: ColumnPath, context: AnalyzerContext): Option[Attribute] = {
@tailrec
def findMatched(tableName: Option[String], columnName: String): Seq[Attribute] = {
tableName match {
case Some(tableName) =>
this match {
case r: ResolvedAttribute
if r.qualifier.orElse(r.sourceColumn.map(_.table.name)).exists(_.equalsIgnoreCase(tableName)) =>
if r.qualifier
.orElse(tableAlias)
.orElse(r.sourceColumn.map(_.table.name)).exists(_.equalsIgnoreCase(tableName)) =>
findMatched(None, columnName)
case _ =>
Nil
Expand All @@ -394,6 +396,8 @@ trait Attribute extends LeafExpression with LogSupport {
if (databaseName == context.database) {
if (qualifier.contains(tableName)) {
findMatched(None, columnName).map(_.withQualifier(qualifier))
} else if (tableAlias.contains(tableName)) {
findMatched(None, columnName)
} else {
findMatched(Some(tableName), columnName)
}
Expand All @@ -405,8 +409,10 @@ trait Attribute extends LeafExpression with LogSupport {
}
}
case ColumnPath(None, Some(tableName), columnName) =>
if (qualifier.exists(_ == tableName)) {
if (qualifier.contains(tableName)) {
findMatched(None, columnName).map(_.withQualifier(qualifier))
} else if (tableAlias.contains(tableName)) {
findMatched(None, columnName)
} else {
findMatched(Some(tableName), columnName)
}
Expand All @@ -421,7 +427,7 @@ trait Attribute extends LeafExpression with LogSupport {
} else {
qualifier
}
Some(MultiSourceColumn(result, qualifier = q, None))
Some(MultiSourceColumn(result, qualifier = q, None, None))
} else {
result.headOption
}
Expand Down Expand Up @@ -490,6 +496,7 @@ object Expression {
case class UnresolvedAttribute(
override val qualifier: Option[String],
name: String,
tableAlias: Option[String],
nodeLocation: Option[NodeLocation]
) extends Attribute {
override def toString: String = s"UnresolvedAttribute(${fullName})"
Expand All @@ -498,6 +505,9 @@ object Expression {
override def withQualifier(newQualifier: Option[String]): UnresolvedAttribute = {
this.copy(qualifier = newQualifier)
}
override def withTableAlias(tableAlias: Option[String]): Attribute = {
this.copy(tableAlias = tableAlias)
}
override def inputColumns: Seq[Attribute] = Seq.empty
override def outputColumns: Seq[Attribute] = Seq.empty
override def sourceColumns: Seq[SourceColumn] = Seq.empty
Expand Down Expand Up @@ -566,6 +576,7 @@ object Expression {
case class AllColumns(
override val qualifier: Option[String],
columns: Option[Seq[Attribute]],
tableAlias: Option[String],
nodeLocation: Option[NodeLocation]
) extends Attribute
with LogSupport {
Expand All @@ -587,7 +598,7 @@ object Expression {
}
}
override def outputColumns: Seq[Attribute] = {
inputColumns.map(_.withQualifier(qualifier))
inputColumns.map(_.withTableAlias(tableAlias).withQualifier(qualifier))
}

override def dataType: DataType = {
Expand All @@ -599,6 +610,9 @@ object Expression {
override def withQualifier(newQualifier: Option[String]): Attribute = {
this.copy(qualifier = newQualifier)
}
override def withTableAlias(tableAlias: Option[String]): Attribute = {
this.copy(tableAlias = tableAlias)
}

override def toString = {
columns match {
Expand All @@ -622,6 +636,7 @@ object Expression {
qualifier: Option[String],
name: String,
expr: Expression,
tableAlias: Option[String],
nodeLocation: Option[NodeLocation]
) extends Attribute {
override def inputColumns: Seq[Attribute] = Seq(this)
Expand All @@ -633,6 +648,10 @@ object Expression {
this.copy(qualifier = newQualifier)
}

override def withTableAlias(tableAlias: Option[String]): Attribute = {
this.copy(tableAlias = tableAlias)
}

override def toString: String = {
s"<${fullName}> := ${expr}"
}
Expand All @@ -659,7 +678,8 @@ object Expression {
*/
case class SingleColumn(
expr: Expression,
qualifier: Option[String] = None,
qualifier: Option[String],
tableAlias: Option[String],
nodeLocation: Option[NodeLocation]
) extends Attribute {
override def name: String = expr.attributeName
Expand All @@ -675,6 +695,9 @@ object Expression {
override def withQualifier(newQualifier: Option[String]): Attribute = {
this.copy(qualifier = newQualifier)
}
override def withTableAlias(tableAlias: Option[String]): Attribute = {
this.copy(tableAlias = tableAlias)
}

override def sourceColumns: Seq[SourceColumn] = {
expr match {
Expand All @@ -693,6 +716,7 @@ object Expression {
case class MultiSourceColumn(
inputs: Seq[Expression],
qualifier: Option[String],
tableAlias: Option[String],
nodeLocation: Option[NodeLocation]
) extends Attribute {
require(inputs.nonEmpty, s"The inputs of MultiSourceColumn should not be empty: ${this}", nodeLocation)
Expand All @@ -703,7 +727,7 @@ object Expression {
inputs.map {
case a: Attribute => a
case e: Expression =>
SingleColumn(e, qualifier, e.nodeLocation)
SingleColumn(e, qualifier, None, e.nodeLocation)
}
}
override def outputColumns: Seq[Attribute] = Seq(this)
Expand All @@ -725,6 +749,9 @@ object Expression {
override def withQualifier(newQualifier: Option[String]): Attribute = {
this.copy(qualifier = newQualifier)
}
override def withTableAlias(tableAlias: Option[String]): Attribute = {
this.copy(tableAlias = tableAlias)
}

override def sourceColumns: Seq[SourceColumn] = {
inputs.flatMap {
Expand Down
Loading

0 comments on commit 3a1f979

Please sign in to comment.