Skip to content

Commit

Permalink
feat: allow expressions on left table columns in FK-joins (#7904)
Browse files Browse the repository at this point in the history
Currently, FK-join require to use a plain column reference for the left join expression. This PR lifts this restriction and allows for actual expression in the left join expression.
  • Loading branch information
mjsax authored Jul 29, 2021
1 parent ccd5a72 commit a9668de
Show file tree
Hide file tree
Showing 91 changed files with 16,777 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.confluent.ksql.analyzer.RewrittenAnalysis;
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter;
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context;
import io.confluent.ksql.execution.codegen.CodeGenRunner;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.Expression;
Expand All @@ -39,6 +40,7 @@
import io.confluent.ksql.execution.streams.PartitionByParamsFactory;
import io.confluent.ksql.execution.streams.timestamp.TimestampExtractionPolicyFactory;
import io.confluent.ksql.execution.timestamp.TimestampColumn;
import io.confluent.ksql.execution.transform.ExpressionEvaluator;
import io.confluent.ksql.execution.util.ExpressionTypeManager;
import io.confluent.ksql.execution.windows.KsqlWindowExpression;
import io.confluent.ksql.function.udf.AsValue;
Expand Down Expand Up @@ -597,26 +599,26 @@ private JoinNode buildJoin(final Join root, final String prefix, final boolean i
);
}

final Optional<ColumnReferenceExp> fkColumnName =
final Optional<Expression> fkExpression =
verifyJoin(root.getInfo(), preRepartitionLeft, preRepartitionRight);

final JoinKey joinKey = fkColumnName
.map(columnReferenceExp -> buildForeignJoinKey(root, fkColumnName.get()))
final JoinKey joinKey = fkExpression
.map(columnReferenceExp -> buildForeignJoinKey(root, fkExpression.get()))
.orElseGet(() -> buildJoinKey(root));

final PlanNode left = prepareSourceForJoin(
root.getLeft(),
preRepartitionLeft,
prefix + "Left",
root.getInfo().getLeftJoinExpression(),
fkColumnName.isPresent()
fkExpression.isPresent()
);
final PlanNode right = prepareSourceForJoin(
root.getRight(),
preRepartitionRight,
prefix + "Right",
root.getInfo().getRightJoinExpression(),
fkColumnName.isPresent()
fkExpression.isPresent()
);

return new JoinNode(
Expand All @@ -634,7 +636,7 @@ private JoinNode buildJoin(final Join root, final String prefix, final boolean i
/**
* @return the foreign key column if this is a foreign key join
*/
private Optional<ColumnReferenceExp> verifyJoin(
private Optional<Expression> verifyJoin(
final JoinInfo joinInfo,
final PlanNode leftNode,
final PlanNode rightNode
Expand Down Expand Up @@ -721,7 +723,7 @@ private static void verifyStreamTableJoin(final JoinInfo joinInfo, final PlanNod
}
}

private Optional<ColumnReferenceExp> verifyForeignKeyJoin(
private Optional<Expression> verifyForeignKeyJoin(
final JoinInfo joinInfo,
final PlanNode leftNode,
final PlanNode rightNode
Expand Down Expand Up @@ -755,25 +757,34 @@ private Optional<ColumnReferenceExp> verifyForeignKeyJoin(
));
}

if (!(leftExpression instanceof ColumnReferenceExp)) {
throw new KsqlException(String.format(
"Invalid join condition:"
+ " foreign-key table-table joins with expressions are not supported yet."
+ " Got %s = %s.",
joinInfo.getFlippedLeftJoinExpression(),
joinInfo.getFlippedRightJoinExpression()
));
}

// we need to extend this to support expressions later on
final ColumnReferenceExp fkColumnReference = (ColumnReferenceExp) leftExpression;
final CodeGenRunner codeGenRunner = new CodeGenRunner(
leftNode.getSchema(),
ksqlConfig,
metaStore
);

final SqlType fkColumnType =
leftNode.getSchema().findColumn(fkColumnReference.getColumnName()).get().type();
final VisitParentExpressionVisitor<Optional<Expression>, Context<Void>> unqualifiedRewritter =
new VisitParentExpressionVisitor<Optional<Expression>, Context<Void>>(Optional.empty()) {
@Override
public Optional<Expression> visitQualifiedColumnReference(
final QualifiedColumnReferenceExp node,
final Context<Void> ctx
) {
return Optional.of(new UnqualifiedColumnReferenceExp(node.getColumnName()));
}
};

final Expression leftExpressionUnqualified =
ExpressionTreeRewriter.rewriteWith(unqualifiedRewritter::process, leftExpression);
final ExpressionEvaluator expressionEvaluator = codeGenRunner.buildCodeGenFromParseTree(
leftExpressionUnqualified,
"Left Join Expression"
);
final SqlType fkType = expressionEvaluator.getExpressionType();
final SqlType rightKeyType = Iterables.getOnlyElement(rightNode.getSchema().key()).type();

verifyJoinConditionTypes(
fkColumnType,
fkType,
rightKeyType,
leftExpression,
rightExpression,
Expand All @@ -786,7 +797,7 @@ private Optional<ColumnReferenceExp> verifyForeignKeyJoin(
);
}

return Optional.of(fkColumnReference);
return Optional.of(leftExpression);
}

private static boolean joinOnNonKeyAttribute(
Expand Down Expand Up @@ -872,21 +883,33 @@ private static boolean isInnerNode(final PlanNode node) {
throw new IllegalStateException("Unknown node type: " + node.getClass().getName());
}

private JoinKey buildForeignJoinKey(final Join join, final ColumnReferenceExp columnRef) {
private JoinKey buildForeignJoinKey(final Join join,
final Expression foreignKeyExpression) {
final AliasedDataSource leftSource = join.getInfo().getLeftSource();
final SourceName alias = leftSource.getAlias();
final List<QualifiedColumnReferenceExp> leftSourceKeys =
leftSource.getDataSource().getSchema().key().stream()
.map(c -> new QualifiedColumnReferenceExp(alias, c.name()))
.collect(Collectors.toList());

final ColumnName foreignKeyColumnName = columnRef.maybeQualifier().isPresent()
? ColumnNames.generatedJoinColumnAlias(
columnRef.maybeQualifier().get(),
columnRef.getColumnName())
: columnRef.getColumnName();
final VisitParentExpressionVisitor<Optional<Expression>, Context<Void>> aliasRewritter =
new VisitParentExpressionVisitor<Optional<Expression>, Context<Void>>(Optional.empty()) {
@Override
public Optional<Expression> visitQualifiedColumnReference(
final QualifiedColumnReferenceExp node,
final Context<Void> ctx
) {
return Optional.of(new UnqualifiedColumnReferenceExp(
ColumnNames.generatedJoinColumnAlias(node.getQualifier(), node.getColumnName())
));
}
};

final Expression aliasedForeignKeyExpression =
ExpressionTreeRewriter.rewriteWith(aliasRewritter::process, foreignKeyExpression);


return JoinKey.foreignKeyColumn(foreignKeyColumnName, leftSourceKeys);
return JoinKey.foreignKeyColumn(aliasedForeignKeyExpression, leftSourceKeys);
}

private static void verifyJoinConditionTypes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,8 @@ public SchemaKTable<K> join() {
rightTable,
((ForeignJoinKey) joinKey).getForeignKeyColumn(),
contextStacker,
valueFormatInfo
valueFormatInfo,
((ForeignJoinKey) joinKey).getForeignKeyExpression()
);
} else {
return leftTable.leftJoin(
Expand All @@ -550,7 +551,8 @@ public SchemaKTable<K> join() {
rightTable,
((ForeignJoinKey) joinKey).getForeignKeyColumn(),
contextStacker,
valueFormatInfo
valueFormatInfo,
((ForeignJoinKey) joinKey).getForeignKeyExpression()
);
} else {
return leftTable.innerJoin(
Expand Down Expand Up @@ -643,10 +645,10 @@ static JoinKey syntheticColumn() {
}

static JoinKey foreignKeyColumn(
final ColumnName foreignKeyColumn,
final Expression foreignKeyExpression,
final Collection<QualifiedColumnReferenceExp> viableKeyColumns
) {
return ForeignJoinKey.of(foreignKeyColumn, viableKeyColumns);
return ForeignJoinKey.of(foreignKeyExpression, viableKeyColumns);
}

/**
Expand Down Expand Up @@ -806,17 +808,34 @@ public JoinKey rewriteWith(
}

private static final class ForeignJoinKey implements JoinKey {
private final ColumnName foreignKeyColumn;
private final Optional<ColumnName> foreignKeyColumn;
private final Optional<Expression> foreignKeyExpression;
private final ImmutableList<? extends ColumnReferenceExp> leftSourceKeyColumns;

static JoinKey of(final ColumnName foreignKeyColumn,
final Collection<QualifiedColumnReferenceExp> leftSourceKeyColumns) {
return new ForeignJoinKey(foreignKeyColumn, leftSourceKeyColumns);
return new ForeignJoinKey(
Optional.of(foreignKeyColumn),
Optional.empty(),
leftSourceKeyColumns
);
}

private ForeignJoinKey(final ColumnName foreignKeyColumn,
static JoinKey of(final Expression foreignKeyExpression,
final Collection<QualifiedColumnReferenceExp> leftSourceKeyColumns) {
return new ForeignJoinKey(
Optional.empty(),
Optional.of(foreignKeyExpression),
leftSourceKeyColumns
);
}

private ForeignJoinKey(final Optional<ColumnName> foreignKeyColumn,
final Optional<Expression> foreignKeyExpression,
final Collection<? extends ColumnReferenceExp> viableKeyColumns) {
this.foreignKeyColumn = requireNonNull(foreignKeyColumn, "foreignKeyColumn");
this.foreignKeyExpression =
requireNonNull(foreignKeyExpression, "foreignKeyExpression");
this.leftSourceKeyColumns = ImmutableList
.copyOf(requireNonNull(viableKeyColumns, "viableKeyColumns"));
}
Expand Down Expand Up @@ -853,11 +872,15 @@ public JoinKey rewriteWith(
.map(e -> ExpressionTreeRewriter.rewriteWith(plugin, e))
.collect(Collectors.toList());

return new ForeignJoinKey(foreignKeyColumn, rewrittenViable);
return new ForeignJoinKey(Optional.empty(), foreignKeyExpression, rewrittenViable);
}

public ColumnName getForeignKeyColumn() {
public Optional<ColumnName> getForeignKeyColumn() {
return foreignKeyColumn;
}

public Optional<Expression> getForeignKeyExpression() {
return foreignKeyExpression;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,10 @@ public SchemaKTable<K> outerJoin(

public <KRightT> SchemaKTable<K> foreignKeyInnerJoin(
final SchemaKTable<KRightT> schemaKTable,
final ColumnName leftJoinColumnName,
final Optional<ColumnName> leftJoinColumnName,
final Stacker contextStacker,
final FormatInfo valueFormatInfo
final FormatInfo valueFormatInfo,
final Optional<Expression> leftJoinExpression
) {
final ForeignKeyTableTableJoin<K, KRightT> step =
ExecutionStepFactory.foreignKeyTableTableJoin(
Expand All @@ -332,7 +333,8 @@ public <KRightT> SchemaKTable<K> foreignKeyInnerJoin(
leftJoinColumnName,
InternalFormats.of(keyFormat, valueFormatInfo),
sourceTableStep,
schemaKTable.getSourceTableStep()
schemaKTable.getSourceTableStep(),
leftJoinExpression
);

return new SchemaKTable<>(
Expand All @@ -346,9 +348,10 @@ public <KRightT> SchemaKTable<K> foreignKeyInnerJoin(

public <KRightT> SchemaKTable<K> foreignKeyLeftJoin(
final SchemaKTable<KRightT> schemaKTable,
final ColumnName leftJoinColumnName,
final Optional<ColumnName> leftJoinColumnName,
final Stacker contextStacker,
final FormatInfo valueFormatInfo
final FormatInfo valueFormatInfo,
final Optional<Expression> leftJoinExpression
) {
final ForeignKeyTableTableJoin<K, KRightT> step =
ExecutionStepFactory.foreignKeyTableTableJoin(
Expand All @@ -357,7 +360,8 @@ public <KRightT> SchemaKTable<K> foreignKeyLeftJoin(
leftJoinColumnName,
InternalFormats.of(keyFormat, valueFormatInfo),
sourceTableStep,
schemaKTable.getSourceTableStep()
schemaKTable.getSourceTableStep(),
leftJoinExpression
);

return new SchemaKTable<>(
Expand Down
Loading

0 comments on commit a9668de

Please sign in to comment.