Skip to content

Commit

Permalink
feat: expression support in JOINs (#4278)
Browse files Browse the repository at this point in the history
  • Loading branch information
agavra authored Jan 15, 2020
1 parent 1f6eabf commit 2d0bfe8
Show file tree
Hide file tree
Showing 14 changed files with 385 additions and 721 deletions.
20 changes: 10 additions & 10 deletions ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java
Original file line number Diff line number Diff line change
Expand Up @@ -281,30 +281,30 @@ public DataSource<?> getDataSource() {
@Immutable
public static final class JoinInfo {

private final ColumnRef leftJoinField;
private final ColumnRef rightJoinField;
private final Expression leftJoinExpression;
private final Expression rightJoinExpression;
private final JoinNode.JoinType type;
private final Optional<WithinExpression> withinExpression;

JoinInfo(
final ColumnRef leftJoinField,
final ColumnRef rightJoinField,
final Expression leftJoinExpression,
final Expression rightJoinExpression,
final JoinType type,
final Optional<WithinExpression> withinExpression

) {
this.leftJoinField = requireNonNull(leftJoinField, "leftJoinField");
this.rightJoinField = requireNonNull(rightJoinField, "rightJoinField");
this.leftJoinExpression = requireNonNull(leftJoinExpression, "leftJoinExpression");
this.rightJoinExpression = requireNonNull(rightJoinExpression, "rightJoinExpression");
this.type = requireNonNull(type, "type");
this.withinExpression = requireNonNull(withinExpression, "withinExpression");
}

public ColumnRef getLeftJoinField() {
return leftJoinField;
public Expression getLeftJoinExpression() {
return leftJoinExpression;
}

public ColumnRef getRightJoinField() {
return rightJoinField;
public Expression getRightJoinExpression() {
return rightJoinExpression;
}

public JoinType getType() {
Expand Down
142 changes: 49 additions & 93 deletions ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import static java.util.Objects.requireNonNull;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.MoreCollectors;
import io.confluent.ksql.analyzer.Analysis.AliasedDataSource;
import io.confluent.ksql.analyzer.Analysis.Into;
import io.confluent.ksql.analyzer.Analysis.JoinInfo;
Expand Down Expand Up @@ -67,7 +69,6 @@
import io.confluent.ksql.util.SchemaUtil;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -349,28 +350,63 @@ protected AstNode visitJoin(final Join node, final Void context) {
);
}

final ColumnRef leftJoinField = getJoinFieldName(
comparisonExpression,
left.getAlias(),
left.getDataSource().getSchema()
);
final Set<ColumnRef> colsUsedInLeft = new ExpressionAnalyzer(analysis.getFromSourceSchemas())
.analyzeExpression(comparisonExpression.getLeft(), false);
final Set<ColumnRef> colsUsedInRight = new ExpressionAnalyzer(analysis.getFromSourceSchemas())
.analyzeExpression(comparisonExpression.getRight(), false);

final ColumnRef rightJoinField = getJoinFieldName(
comparisonExpression,
right.getAlias(),
right.getDataSource().getSchema()
);
final SourceName leftSourceName = getOnlySourceForJoin(
comparisonExpression.getLeft(), comparisonExpression, colsUsedInLeft);
final SourceName rightSourceName = getOnlySourceForJoin(
comparisonExpression.getRight(), comparisonExpression, colsUsedInRight);

if (!validJoin(left.getAlias(), right.getAlias(), leftSourceName, rightSourceName)) {
throw new KsqlException(
"Each side of the join must reference exactly one source and not the same source. "
+ "Left side references " + leftSourceName
+ " and right references " + rightSourceName
);
}

final boolean flipped = leftSourceName.equals(right.getAlias());
analysis.setJoin(new JoinInfo(
leftJoinField,
rightJoinField,
flipped ? comparisonExpression.getRight() : comparisonExpression.getLeft(),
flipped ? comparisonExpression.getLeft() : comparisonExpression.getRight(),
joinType,
node.getWithinExpression()
));

return null;
}

private boolean validJoin(
final SourceName leftName,
final SourceName rightName,
final SourceName leftExpressionSource,
final SourceName rightExpressionSource
) {
return ImmutableSet.of(leftExpressionSource, rightExpressionSource)
.containsAll(ImmutableList.of(leftName, rightName));
}

private SourceName getOnlySourceForJoin(
final Expression exp,
final ComparisonExpression join,
final Set<ColumnRef> columnRefs
) {
try {
return columnRefs.stream()
.map(ColumnRef::source)
.filter(Optional::isPresent)
.map(Optional::get)
.collect(MoreCollectors.onlyElement());
} catch (final Exception e) {
throw new KsqlException("Invalid comparison expression '" + exp + "' in join '" + join
+ "'. Each side of the join comparision must contain references from exactly one "
+ "source.");
}
}

private JoinNode.JoinType getJoinType(final Join node) {
final JoinNode.JoinType joinType;
switch (node.getType()) {
Expand All @@ -389,86 +425,6 @@ private JoinNode.JoinType getJoinType(final Join node) {
return joinType;
}

private ColumnReferenceExp checkExpressionType(
final ComparisonExpression comparisonExpression,
final Expression subExpression) {

if (!(subExpression instanceof ColumnReferenceExp)) {
throw new KsqlException(
String.format(
"%s : Invalid comparison expression '%s' in join '%s'. Joins must only contain a "
+ "field comparison.",
comparisonExpression.getLocation().map(Objects::toString).orElse(""),
subExpression,
comparisonExpression
)
);
}
return (ColumnReferenceExp) subExpression;
}

private ColumnRef getJoinFieldName(
final ComparisonExpression comparisonExpression,
final SourceName sourceAlias,
final LogicalSchema sourceSchema
) {
final ColumnReferenceExp left =
checkExpressionType(comparisonExpression, comparisonExpression.getLeft());

Optional<ColumnRef> joinFieldName = getJoinFieldNameFromExpr(left, sourceAlias);

if (!joinFieldName.isPresent()) {
final ColumnReferenceExp right =
checkExpressionType(comparisonExpression, comparisonExpression.getRight());

joinFieldName = getJoinFieldNameFromExpr(right, sourceAlias);

if (!joinFieldName.isPresent()) {
// Should never happen as only QualifiedNameReference are allowed
throw new IllegalStateException("Cannot find join field name");
}
}

final ColumnRef fieldName = joinFieldName.get();

final Optional<ColumnRef> joinField =
getJoinFieldNameFromSource(fieldName.withoutSource(), sourceAlias, sourceSchema);

return joinField
.orElseThrow(() -> new KsqlException(
String.format(
"%s : Invalid join criteria %s. Column %s.%s does not exist.",
comparisonExpression.getLocation().map(Objects::toString).orElse(""),
comparisonExpression,
sourceAlias.name(),
fieldName.name().toString(FormatOptions.noEscape())
)
));
}

private Optional<ColumnRef> getJoinFieldNameFromExpr(
final ColumnReferenceExp nameRef,
final SourceName sourceAlias
) {
if (nameRef.getReference().source().isPresent()
&& !nameRef.getReference().source().get().equals(sourceAlias)) {
return Optional.empty();
}

final ColumnRef fieldName = nameRef.getReference();
return Optional.of(fieldName);
}

private Optional<ColumnRef> getJoinFieldNameFromSource(
final ColumnRef fieldName,
final SourceName sourceAlias,
final LogicalSchema sourceSchema
) {
return sourceSchema.findColumn(fieldName)
.map(Column::ref)
.map(ref -> ref.withSource(sourceAlias));
}

@Override
protected AstNode visitAliasedRelation(final AliasedRelation node, final Void context) {
final SourceName structuredDataSourceName = ((Table) node.getRelation()).getName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,16 @@

package io.confluent.ksql.analyzer;

import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression;
import io.confluent.ksql.execution.expression.tree.Cast;
import com.google.common.collect.Iterables;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate;
import io.confluent.ksql.execution.expression.tree.IsNullPredicate;
import io.confluent.ksql.execution.expression.tree.LikePredicate;
import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression;
import io.confluent.ksql.execution.expression.tree.NotExpression;
import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor;
import io.confluent.ksql.execution.expression.tree.TraversalExpressionVisitor;
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.schema.ksql.FormatOptions;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.HashSet;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
Expand All @@ -48,99 +40,46 @@ class ExpressionAnalyzer {
this.sourceSchemas = Objects.requireNonNull(sourceSchemas, "sourceSchemas");
}

void analyzeExpression(final Expression expression, final boolean allowWindowMetaFields) {
final Visitor visitor = new Visitor(allowWindowMetaFields);
visitor.process(expression, null);
Set<ColumnRef> analyzeExpression(
final Expression expression,
final boolean allowWindowMetaFields
) {
final Set<ColumnRef> referencedColumns = new HashSet<>();
final ColumnExtractor extractor = new ColumnExtractor(allowWindowMetaFields, referencedColumns);
extractor.process(expression, null);
return referencedColumns;
}

private final class Visitor extends VisitParentExpressionVisitor<Object, Object> {
private final class ColumnExtractor extends TraversalExpressionVisitor<Object> {

private final Set<ColumnRef> referencedColumns;
private final boolean allowWindowMetaFields;

Visitor(final boolean allowWindowMetaFields) {
ColumnExtractor(
final boolean allowWindowMetaFields,
final Set<ColumnRef> referencedColumns
) {
this.allowWindowMetaFields = allowWindowMetaFields;
}

public Object visitLikePredicate(final LikePredicate node, final Object context) {
process(node.getValue(), null);
return null;
}

public Object visitFunctionCall(final FunctionCall node, final Object context) {
for (final Expression argExpr : node.getArguments()) {
process(argExpr, null);
}
return null;
}

public Object visitArithmeticBinary(
final ArithmeticBinaryExpression node,
final Object context) {
process(node.getLeft(), null);
process(node.getRight(), null);
return null;
}

public Object visitIsNotNullPredicate(final IsNotNullPredicate node, final Object context) {
return process(node.getValue(), context);
}

public Object visitIsNullPredicate(final IsNullPredicate node, final Object context) {
return process(node.getValue(), context);
}

public Object visitLogicalBinaryExpression(
final LogicalBinaryExpression node,
final Object context) {
process(node.getLeft(), null);
process(node.getRight(), null);
return null;
this.referencedColumns = referencedColumns;
}

@Override
public Object visitComparisonExpression(
final ComparisonExpression node,
final Object context) {
process(node.getLeft(), null);
process(node.getRight(), null);
return null;
}

@Override
public Object visitNotExpression(final NotExpression node, final Object context) {
return process(node.getValue(), null);
}

@Override
public Object visitCast(final Cast node, final Object context) {
process(node.getExpression(), context);
return null;
}

@Override
public Object visitColumnReference(
public Void visitColumnReference(
final ColumnReferenceExp node,
final Object context
) {
throwOnUnknownOrAmbiguousColumn(node.getReference());
final ColumnRef reference = node.getReference();
referencedColumns.add(getQualifiedColumnRef(reference));
return null;
}

@Override
public Object visitDereferenceExpression(
final DereferenceExpression node,
final Object context
) {
process(node.getBase(), context);
return null;
}

private void throwOnUnknownOrAmbiguousColumn(final ColumnRef name) {
private ColumnRef getQualifiedColumnRef(final ColumnRef name) {
final Set<SourceName> sourcesWithField = sourceSchemas.sourcesWithField(name);

if (sourcesWithField.isEmpty()) {
if (allowWindowMetaFields && name.name().equals(SchemaUtil.WINDOWSTART_NAME)) {
return;
// window start doesn't need a qualifier as it's a special hacky column
return name;
}

throw new KsqlException("Column '" + name.toString(FormatOptions.noEscape())
Expand All @@ -156,6 +95,8 @@ private void throwOnUnknownOrAmbiguousColumn(final ColumnRef name) {
throw new KsqlException("Column '" + name.name().name() + "' is ambiguous. "
+ "Could be any of: " + possibilities);
}

return name.withSource(Iterables.getOnlyElement(sourcesWithField));
}
}
}
Loading

0 comments on commit 2d0bfe8

Please sign in to comment.