Skip to content

Commit

Permalink
refactor: clean up DereferenceExpression vs QualifiedNameReference (#…
Browse files Browse the repository at this point in the history
…3358)

Fixes: #1695

The code now uses `DereferenceExpression` exclusively for accessing STRUCT fields and uses `QualifiedNameReference` when referencing a column.

The base of `DereferenceExpression` can be a `QualifiedNameReference`, e.g. col0->someField, but can also be other expressions, e.g. a udf that returns a struct: `myUdf(blah)->someField`.
  • Loading branch information
big-andy-coates authored Sep 17, 2019
1 parent 6ae3af5 commit 6be6d37
Show file tree
Hide file tree
Showing 37 changed files with 302 additions and 379 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

package io.confluent.ksql.analyzer;

import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.QualifiedNameReference;
import java.util.Map;
import java.util.Set;

Expand All @@ -28,21 +28,21 @@ interface AggregateAnalysis extends AggregateAnalysisResult {
*
* @return the map of select expression to the set of source schema fields.
*/
Map<Expression, Set<DereferenceExpression>> getNonAggregateSelectExpressions();
Map<Expression, Set<QualifiedNameReference>> getNonAggregateSelectExpressions();

/**
* Get the set of select fields that are involved in aggregate columns, but not as parameters
* to the aggregate functions.
*
* @return the set of fields used in aggregate columns outside of aggregate function parameters.
*/
Set<DereferenceExpression> getAggregateSelectFields();
Set<QualifiedNameReference> getAggregateSelectFields();

/**
* Get the set of columns from the source schema that are used in the HAVING clause outside
* of aggregate functions.
*
* @return the set of non-aggregate columns used in the HAVING clause.
*/
Set<DereferenceExpression> getNonAggregateHavingFields();
Set<QualifiedNameReference> getNonAggregateHavingFields();
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

package io.confluent.ksql.analyzer;

import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.QualifiedNameReference;
import java.util.List;

public interface AggregateAnalysisResult {
Expand All @@ -30,7 +30,7 @@ public interface AggregateAnalysisResult {
*
* @return the full set of columns from the source schema that are required.
*/
List<DereferenceExpression> getRequiredColumns();
List<QualifiedNameReference> getRequiredColumns();

List<FunctionCall> getAggregateFunctions();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
package io.confluent.ksql.analyzer;

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.QualifiedNameReference;
import io.confluent.ksql.execution.expression.tree.TraversalExpressionVisitor;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.util.KsqlException;
Expand All @@ -31,12 +31,12 @@
class AggregateAnalyzer {

private final MutableAggregateAnalysis aggregateAnalysis;
private final DereferenceExpression defaultArgument;
private final QualifiedNameReference defaultArgument;
private final FunctionRegistry functionRegistry;

AggregateAnalyzer(
final MutableAggregateAnalysis aggregateAnalysis,
final DereferenceExpression defaultArgument,
final QualifiedNameReference defaultArgument,
final FunctionRegistry functionRegistry
) {
this.aggregateAnalysis = Objects.requireNonNull(aggregateAnalysis, "aggregateAnalysis");
Expand All @@ -45,7 +45,7 @@ class AggregateAnalyzer {
}

void processSelect(final Expression expression) {
final Set<DereferenceExpression> nonAggParams = new HashSet<>();
final Set<QualifiedNameReference> nonAggParams = new HashSet<>();
final AggregateVisitor visitor = new AggregateVisitor((aggFuncName, node) -> {
if (!aggFuncName.isPresent()) {
nonAggParams.add(node);
Expand Down Expand Up @@ -83,12 +83,12 @@ void processHaving(final Expression expression) {

private final class AggregateVisitor extends TraversalExpressionVisitor<Void> {

private final BiConsumer<Optional<String>, DereferenceExpression> dereferenceCollector;
private final BiConsumer<Optional<String>, QualifiedNameReference> dereferenceCollector;
private Optional<String> aggFunctionName = Optional.empty();
private boolean visitedAggFunction = false;

private AggregateVisitor(
final BiConsumer<Optional<String>, DereferenceExpression> dereferenceCollector
final BiConsumer<Optional<String>, QualifiedNameReference> dereferenceCollector
) {
this.dereferenceCollector =
Objects.requireNonNull(dereferenceCollector, "dereferenceCollector");
Expand Down Expand Up @@ -126,15 +126,13 @@ public Void visitFunctionCall(final FunctionCall node, final Void context) {
}

@Override
public Void visitDereferenceExpression(
final DereferenceExpression node,
public Void visitQualifiedNameReference(
final QualifiedNameReference node,
final Void context
) {
dereferenceCollector.accept(aggFunctionName, node);
aggregateAnalysis.addRequiredColumn(node);
return null;
}


}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import com.google.common.collect.ImmutableSet;
import com.google.errorprone.annotations.Immutable;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.QualifiedName;
import io.confluent.ksql.execution.expression.tree.QualifiedNameReference;
Expand Down Expand Up @@ -168,10 +167,9 @@ void addDataSource(final String alias, final DataSource<?> dataSource) {
fromDataSources.add(new AliasedDataSource(alias, dataSource));
}

DereferenceExpression getDefaultArgument() {
final String base = fromDataSources.get(0).getAlias();
final Expression baseExpression = new QualifiedNameReference(QualifiedName.of(base));
return new DereferenceExpression(baseExpression, SchemaUtil.ROWTIME_NAME);
QualifiedNameReference getDefaultArgument() {
final String alias = fromDataSources.get(0).getAlias();
return new QualifiedNameReference(QualifiedName.of(alias, SchemaUtil.ROWTIME_NAME));
}

void setSerdeOptions(final Set<SerdeOption> serdeOptions) {
Expand Down
62 changes: 26 additions & 36 deletions ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import io.confluent.ksql.analyzer.Analysis.JoinInfo;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.QualifiedName;
import io.confluent.ksql.execution.expression.tree.QualifiedNameReference;
Expand Down Expand Up @@ -246,8 +245,7 @@ private List<String> getNoneMetaOrKeySelectAliases() {
for (int idx = selects.size() - 1; idx >= 0; --idx) {
final Expression select = selects.get(idx);

if (!(select instanceof DereferenceExpression)
&& !(select instanceof QualifiedNameReference)) {
if (!(select instanceof QualifiedNameReference)) {
continue;
}

Expand Down Expand Up @@ -379,11 +377,11 @@ private JoinNode.JoinType getJoinType(final Join node) {
return joinType;
}

private DereferenceExpression checkExpressionType(
private QualifiedNameReference checkExpressionType(
final ComparisonExpression comparisonExpression,
final Expression subExpression) {

if (!(subExpression instanceof DereferenceExpression)) {
if (!(subExpression instanceof QualifiedNameReference)) {
throw new KsqlException(
String.format(
"%s : Invalid comparison expression '%s' in join '%s'. Joins must only contain a "
Expand All @@ -394,40 +392,35 @@ private DereferenceExpression checkExpressionType(
)
);
}
return (DereferenceExpression) subExpression;
return (QualifiedNameReference) subExpression;
}

private String getJoinFieldName(
final ComparisonExpression comparisonExpression,
final String sourceAlias,
final LogicalSchema sourceSchema
) {
final QualifiedNameReference left =
checkExpressionType(comparisonExpression, comparisonExpression.getLeft());

final DereferenceExpression left = checkExpressionType(comparisonExpression,
comparisonExpression.getLeft());
Optional<String> joinFieldName = getJoinFieldNameFromExpr(
left,
sourceAlias
);
Optional<String> joinFieldName = getJoinFieldNameFromExpr(left, sourceAlias);

if (!joinFieldName.isPresent()) {
final DereferenceExpression right = checkExpressionType(comparisonExpression,
comparisonExpression.getRight());
joinFieldName = getJoinFieldNameFromExpr(
right,
sourceAlias
);
}
final QualifiedNameReference right =
checkExpressionType(comparisonExpression, comparisonExpression.getRight());

if (!joinFieldName.isPresent()) {
// Should never happen as we only allow DereferenceExpression
throw new IllegalStateException("Cannot find join field name");
joinFieldName = getJoinFieldNameFromExpr(right, sourceAlias);

if (!joinFieldName.isPresent()) {
// Should never happen as only QualifiedNameReference are allowed
throw new IllegalStateException("Cannot find join field name");
}
}

final String fieldName = joinFieldName.get();

final Optional<String> joinField =
getJoinFieldFromSource(fieldName, sourceAlias, sourceSchema);
getJoinFieldNameFromSource(fieldName, sourceAlias, sourceSchema);

return joinField
.orElseThrow(() -> new KsqlException(
Expand All @@ -437,23 +430,24 @@ private String getJoinFieldName(
comparisonExpression,
sourceAlias,
fieldName
)
)
));
}

private Optional<String> getJoinFieldNameFromExpr(
final DereferenceExpression expression,
final String sourceAlias) {
final String sourceAliasVal = expression.getBase().toString();
if (!sourceAliasVal.equalsIgnoreCase(sourceAlias)) {
final QualifiedNameReference nameRef,
final String sourceAlias
) {
if (nameRef.getName().qualifier().isPresent()
&& !nameRef.getName().qualifier().get().equalsIgnoreCase(sourceAlias)) {
return Optional.empty();
}

final String fieldName = expression.getFieldName();
final String fieldName = nameRef.getName().name();
return Optional.of(fieldName);
}

private Optional<String> getJoinFieldFromSource(
private Optional<String> getJoinFieldNameFromSource(
final String fieldName,
final String sourceAlias,
final LogicalSchema sourceSchema
Expand Down Expand Up @@ -531,18 +525,14 @@ private void visitSelectStar(final AllColumns allColumns) {
continue;
}

final QualifiedName name = QualifiedName.of(source.getAlias());

final QualifiedNameReference nameRef = new QualifiedNameReference(location, name);

final String aliasPrefix = analysis.isJoin()
? source.getAlias() + "_"
: "";

for (final Column column : source.getDataSource().getSchema().columns()) {

final DereferenceExpression selectItem =
new DereferenceExpression(location, nameRef, column.name());
final QualifiedNameReference selectItem = new QualifiedNameReference(location,
QualifiedName.of(source.getAlias(), column.name()));

final String alias = aliasPrefix + column.name();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression;
import io.confluent.ksql.execution.expression.tree.Cast;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate;
import io.confluent.ksql.execution.expression.tree.IsNullPredicate;
import io.confluent.ksql.execution.expression.tree.LikePredicate;
import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression;
import io.confluent.ksql.execution.expression.tree.NotExpression;
import io.confluent.ksql.execution.expression.tree.QualifiedName;
import io.confluent.ksql.execution.expression.tree.QualifiedNameReference;
import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor;
import io.confluent.ksql.util.KsqlException;
Expand All @@ -50,19 +50,24 @@ void analyzeExpression(final Expression expression) {
visitor.process(expression, null);
}

private void throwOnUnknownField(final String columnName) {
final Set<String> sourcesWithField = sourceSchemas.sourcesWithField(columnName);
private void throwOnUnknownField(final QualifiedName name) {
final Set<String> sourcesWithField = sourceSchemas.sourcesWithField(name.name());
if (sourcesWithField.isEmpty()) {
throw new KsqlException("Field '" + columnName + "' cannot be resolved.");
throw new KsqlException("Field '" + name + "' cannot be resolved.");
}

if (sourcesWithField.size() > 1) {
if (name.qualifier().isPresent()) {
if (!sourcesWithField.contains(name.qualifier().get())) {
throw new KsqlException("Source '" + name.qualifier() + "', "
+ "used in '" + name + "' cannot be resolved.");
}
} else if (sourcesWithField.size() > 1) {
final String possibilities = sourcesWithField.stream()
.sorted()
.map(source -> SchemaUtil.buildAliasedFieldName(source, columnName))
.map(source -> SchemaUtil.buildAliasedFieldName(source, name.name()))
.collect(Collectors.joining(","));

throw new KsqlException("Field '" + columnName + "' is ambiguous. "
throw new KsqlException("Field '" + name + "' is ambiguous. "
+ "Could be any of: " + possibilities);
}
}
Expand Down Expand Up @@ -125,26 +130,12 @@ public Object visitCast(final Cast node, final Object context) {
return null;
}

@Override
public Object visitDereferenceExpression(
final DereferenceExpression node,
final Object context
) {
final String columnName = sourceSchemas.isJoin()
? node.toString()
: node.getFieldName();

throwOnUnknownField(columnName);
return null;
}

@Override
public Object visitQualifiedNameReference(
final QualifiedNameReference node,
final Object context
) {
final String columnName = node.getName().name();
throwOnUnknownField(columnName);
throwOnUnknownField(node.getName());
return null;
}
}
Expand Down
Loading

0 comments on commit 6be6d37

Please sign in to comment.