From 002cd5ad6473cbb000291e32ebe5a459d2870b61 Mon Sep 17 00:00:00 2001 From: Almog Gavra Date: Mon, 20 Apr 2020 15:54:51 -0700 Subject: [PATCH] feat: add multi-join expression support (#5081) --- .../ksql/analyzer/RewrittenAnalysis.java | 15 +-- .../io/confluent/ksql/planner/JoinTree.java | 108 ++++++++++++------ .../ksql/planner/LogicalPlanner.java | 41 ++++++- .../confluent/ksql/planner/JoinTreeTest.java | 51 ++++++++- .../query-validation-tests/multi-joins.json | 75 ++++++------ 5 files changed, 197 insertions(+), 93 deletions(-) diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/RewrittenAnalysis.java b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/RewrittenAnalysis.java index 8864b00a8d59..b5d770eb18f4 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/RewrittenAnalysis.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/RewrittenAnalysis.java @@ -59,10 +59,6 @@ public RewrittenAnalysis( this.rewriter = Objects.requireNonNull(rewriter, "rewriter"); } - public ImmutableAnalysis getOriginal() { - return original; - } - @Override public List getTableFunctions() { return rewriteList(original.getTableFunctions()); @@ -148,16 +144,7 @@ public OptionalInt getLimitClause() { @Override public List getJoin() { - return original.getJoin().stream().map( - j -> new JoinInfo( - j.getLeftSource(), - rewrite(j.getLeftJoinExpression()), - j.getRightSource(), - rewrite(j.getRightJoinExpression()), - j.getType(), - j.getWithinExpression() - ) - ).collect(Collectors.toList()); + return original.getJoin(); } @Override diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/JoinTree.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/JoinTree.java index f2daf9aa870a..adaac2ae02f8 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/JoinTree.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/JoinTree.java @@ -15,6 +15,8 @@ package io.confluent.ksql.planner; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import io.confluent.ksql.analyzer.Analysis.AliasedDataSource; import io.confluent.ksql.analyzer.Analysis.JoinInfo; import io.confluent.ksql.execution.expression.tree.Expression; @@ -22,6 +24,7 @@ import io.confluent.ksql.util.KsqlException; import java.util.List; import java.util.Objects; +import java.util.Set; import org.apache.commons.lang3.StringUtils; /** @@ -111,6 +114,18 @@ interface Node { * @return a debug string that pretty prints the tree */ String debugString(int indent); + + /** + * An {@code Expression} that is part of this Node's equivalence set evaluates + * to the same value as the key for this join. Consider the following JOIN: + *
{@code
+     *  SELECT * FROM A JOIN B on A.id = B.id + 1
+     *                  JOIN C on A.id = C.id - 1;
+     * }
+ * The equivalence set for the above join would be {@code {A.id, B.id + 1, c.id - 1}} + * since all of those expressions evaluate to the same value. + */ + Set joinEquivalenceSet(); } static class Join implements Node { @@ -123,23 +138,6 @@ static class Join implements Node { this.left = left; this.right = right; this.info = info; - - checkJoinConditions(info.getLeftSource(), info.getLeftJoinExpression()); - checkJoinConditions(info.getRightSource(), info.getRightJoinExpression()); - } - - /** - * @see #5062 - */ - private void checkJoinConditions(final AliasedDataSource source, final Expression expression) { - final Expression existing = joinForSource(source); - if (!Objects.equals(existing, expression)) { - throw new KsqlException( - "KSQL does not yet support multi-joins with different expressions. " - + "A join for " + source.getAlias() - + " already exists with the condition `" + existing - + "` - cannot additionally join on `" + expression + "`"); - } } public JoinInfo getInfo() { @@ -159,22 +157,6 @@ public boolean containsSource(final AliasedDataSource dataSource) { return left.containsSource(dataSource) || right.containsSource(dataSource); } - private Expression joinForSource(final AliasedDataSource dataSource) { - if (left.containsSource(dataSource)) { - return left instanceof Leaf - ? info.getLeftJoinExpression() - : ((Join) left).joinForSource(dataSource); - } - - if (right.containsSource(dataSource)) { - return right instanceof Leaf - ? info.getRightJoinExpression() - : ((Join) right).joinForSource(dataSource); - } - - throw new IllegalStateException("Expected to find a leaf containing " + dataSource); - } - @Override public String debugString(final int indent) { return "⋈\n" @@ -184,6 +166,61 @@ public String debugString(final int indent) { + right.debugString(indent + 3); } + @Override + public Set joinEquivalenceSet() { + // the algorithm to compute the keys on a tree recursively + // checks to see if the keys from subtrees are equivalent to + // the existing join criteria. take the following tree and + // join conditions as example: + // + // ⋈ + // / \ + // ⋈ ⋈ + // / \ / \ + // A B C D + // + // A JOIN B on A.id = B.id as AB + // C JOIN D on C.id = D.id as BC + // AB JOIN CD on A.id = C.id + 1 as ABCD + // + // The final topic would be partitioned on A.id, which is equivalent + // to any one of [A.id, B.id, C.id+1]. + // + // We can compute this set by checking if either the left or right side + // of the join expression is contained within the equivalence key set + // of the child node. If it is contained, then we know that the child + // equivalence set can be included in the parent equivalence set + // + // In the example above, the left side of the join, A.id, is in the AB + // child set. This means we can add all of AB's keys to the output. + // `C.id + 1` on the other hand, is not in either AB or BC's child set + // so we do not include them in the output set. + // + // We always include both sides of the current join in the output set, + // since we know the key will be the equivalence of those + + final Set keys = ImmutableSet.of( + info.getLeftJoinExpression(), + info.getRightJoinExpression() + ); + + final Set lefts = left.joinEquivalenceSet(); + final Set rights = right.joinEquivalenceSet(); + + final boolean includeLeft = !Sets.intersection(lefts, keys).isEmpty(); + final boolean includeRight = !Sets.intersection(rights, keys).isEmpty(); + + if (includeLeft && includeRight) { + return Sets.union(keys, Sets.union(lefts, rights)); + } else if (includeLeft) { + return Sets.union(keys, lefts); + } else if (includeRight) { + return Sets.union(keys, rights); + } else { + return keys; + } + } + @Override public String toString() { return "Join{" + "left=" + left @@ -234,6 +271,11 @@ public String debugString(final int indent) { return source.getAlias().toString(FormatOptions.noEscape()); } + @Override + public Set joinEquivalenceSet() { + return ImmutableSet.of(); // Leaf nodes don't have join equivalence sets + } + @Override public String toString() { return "Leaf{" + "source=" + source + '}'; diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/LogicalPlanner.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/LogicalPlanner.java index 34426aaffe05..c9aca947dc26 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/LogicalPlanner.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/LogicalPlanner.java @@ -449,12 +449,37 @@ public Optional visitQualifiedColumnReference( ); } + private PlanNode buildSourceForJoin( + final Join join, + final PlanNode joinedSource, + final String side, + final Expression joinExpression + ) { + // we do not need to repartition if the joinExpression + // is already part of the join equivalence set + if (join.joinEquivalenceSet().contains(joinExpression)) { + return joinedSource; + } + + return buildRepartitionNode( + side + "SourceKeyed", + joinedSource, + new PartitionBy( + Optional.empty(), + // We need to repartition on the original join expression, and we need to drop + // all qualifiers. + ExpressionTreeRewriter.rewriteWith(refRewriter::process, joinExpression), + Optional.empty() + ) + ); + } + private PlanNode buildSourceNode() { if (!analysis.isJoin()) { return buildNonJoinNode(analysis.getFrom()); } - final List joinInfo = analysis.getOriginal().getJoin(); + final List joinInfo = analysis.getJoin(); final JoinTree.Node tree = JoinTree.build(joinInfo); if (tree instanceof JoinTree.Leaf) { throw new IllegalStateException("Expected more than one source:" @@ -472,7 +497,12 @@ private PlanNode buildSourceNode() { private PlanNode buildJoin(final Join root, final String prefix) { final PlanNode left; if (root.getLeft() instanceof JoinTree.Join) { - left = buildJoin((Join) root.getLeft(), prefix + "L_"); + left = buildSourceForJoin( + (JoinTree.Join) root.getLeft(), + buildJoin((Join) root.getLeft(), prefix + "L_"), + prefix + "Left", + root.getInfo().getLeftJoinExpression() + ); } else { final JoinTree.Leaf leaf = (Leaf) root.getLeft(); left = buildSourceForJoin( @@ -481,7 +511,12 @@ private PlanNode buildJoin(final Join root, final String prefix) { final PlanNode right; if (root.getRight() instanceof JoinTree.Join) { - right = buildJoin((Join) root.getRight(), prefix + "R_"); + right = buildSourceForJoin( + (JoinTree.Join) root.getRight(), + buildJoin((Join) root.getRight(), prefix + "R_"), + prefix + "Right", + root.getInfo().getRightJoinExpression() + ); } else { final JoinTree.Leaf leaf = (Leaf) root.getRight(); right = buildSourceForJoin( diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/planner/JoinTreeTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/planner/JoinTreeTest.java index a85ba6fa505b..ed3c2913fb2c 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/planner/JoinTreeTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/planner/JoinTreeTest.java @@ -16,16 +16,18 @@ package io.confluent.ksql.planner; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertThrows; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; import io.confluent.ksql.analyzer.Analysis.AliasedDataSource; import io.confluent.ksql.analyzer.Analysis.JoinInfo; +import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.name.SourceName; import io.confluent.ksql.planner.JoinTree.Join; import io.confluent.ksql.planner.JoinTree.Leaf; @@ -49,6 +51,11 @@ public class JoinTreeTest { @Mock private JoinInfo j1; @Mock private JoinInfo j2; + @Mock private Expression e1; + @Mock private Expression e2; + @Mock private Expression e3; + @Mock private Expression e4; + @Before public void setUp() { when(a.getAlias()).thenReturn(SourceName.of("a")); @@ -124,6 +131,48 @@ public void handlesRightThreeWayJoin() { )); } + @Test + public void shouldComputeEquivalenceSetWithOverlap() { + when(j1.getLeftSource()).thenReturn(a); + when(j1.getRightSource()).thenReturn(b); + when(j2.getLeftSource()).thenReturn(a); + when(j2.getRightSource()).thenReturn(c); + + when(j1.getLeftJoinExpression()).thenReturn(e1); + when(j1.getRightJoinExpression()).thenReturn(e2); + when(j2.getLeftJoinExpression()).thenReturn(e1); + when(j2.getRightJoinExpression()).thenReturn(e3); + + final List joins = ImmutableList.of(j1, j2); + + // When: + final Node root = JoinTree.build(joins); + + // Then: + assertThat(root.joinEquivalenceSet(), containsInAnyOrder(e1, e2, e3)); + } + + @Test + public void shouldComputeEquivalenceSetWithoutOverlap() { + when(j1.getLeftSource()).thenReturn(a); + when(j1.getRightSource()).thenReturn(b); + when(j2.getLeftSource()).thenReturn(a); + when(j2.getRightSource()).thenReturn(c); + + when(j1.getLeftJoinExpression()).thenReturn(e1); + when(j1.getRightJoinExpression()).thenReturn(e2); + when(j2.getLeftJoinExpression()).thenReturn(e3); + when(j2.getRightJoinExpression()).thenReturn(e4); + + final List joins = ImmutableList.of(j1, j2); + + // When: + final Node root = JoinTree.build(joins); + + // Then: + assertThat(root.joinEquivalenceSet(), contains(e3, e4)); + } + @Test public void outputsCorrectJoinTreeString() { // Given: diff --git a/ksqldb-functional-tests/src/test/resources/query-validation-tests/multi-joins.json b/ksqldb-functional-tests/src/test/resources/query-validation-tests/multi-joins.json index 3b671bd67c88..f3dbc3e9541b 100644 --- a/ksqldb-functional-tests/src/test/resources/query-validation-tests/multi-joins.json +++ b/ksqldb-functional-tests/src/test/resources/query-validation-tests/multi-joins.json @@ -145,28 +145,53 @@ } }, { - "name": "stream-table-table - inner-inner - rekey with expression", + "name": "stream-table-table - inner-inner - rekey with different expression", "statements": [ "CREATE STREAM S1 (ROWKEY INT KEY, K INT, ID bigint) WITH (kafka_topic='left', value_format='JSON');", "CREATE TABLE T2 (ROWKEY INT PRIMARY KEY, ID bigint) WITH (kafka_topic='right', value_format='JSON');", "CREATE TABLE T3 (ROWKEY INT PRIMARY KEY, ID bigint) WITH (kafka_topic='right2', value_format='JSON');", - "CREATE STREAM OUTPUT as SELECT s1.ID, t2.ID, t3.ID FROM S1 JOIN T2 ON S1.K + 1 = T2.ROWKEY JOIN T3 ON S1.K + 1 = T3.ROWKEY;" + "CREATE STREAM OUTPUT as SELECT s1.ID, t2.ID, t3.ID FROM S1 JOIN T2 ON S1.K - 1 = T2.ROWKEY JOIN T3 ON S1.K + 1 = T3.ROWKEY;" ], "properties": { "ksql.any.key.name.enabled": true }, "inputs": [ - {"topic": "right2", "key": 1, "value": {"id": 3}, "timestamp": 10}, + {"topic": "right2", "key": 3, "value": {"id": 3}, "timestamp": 10}, {"topic": "right", "key": 1, "value": {"id": 2}, "timestamp": 11}, - {"topic": "left", "key": 0, "value": {"k": 0, "id": 1}, "timestamp": 12} + {"topic": "left", "key": 0, "value": {"k": 2, "id": 1}, "timestamp": 12} ], "outputs": [ - {"topic": "_confluent-ksql-some.ksql.service.idquery_CSAS_OUTPUT_0-L_Join-repartition", "key": 1, "value": {"S1_K": 0, "S1_ID": 1, "S1_ROWTIME": 12, "S1_ROWKEY": 0, "S1_KSQL_COL_0": 1}, "timestamp": 12}, - {"topic": "OUTPUT", "key": 1, "value": {"S1_ID": 1, "T2_ID": 2, "T3_ID": 3}, "timestamp": 12} + {"topic": "OUTPUT", "key": 3, "value": {"S1_ID": 1, "T2_ID": 2, "T3_ID": 3}, "timestamp": 12} + ], + "post": { + "sources": [ + {"name": "OUTPUT", "type": "stream", "schema": "KSQL_COL_1 INT KEY, S1_ID BIGINT, T2_ID BIGINT, T3_ID BIGINT"} + ] + } + }, + { + "name": "stream-table-table - inner-inner - rekey on different field", + "comments": ["https://github.com/confluentinc/ksql/issues/5062"], + "statements": [ + "CREATE STREAM S1 (ROWKEY INT KEY, K INT, ID int) WITH (kafka_topic='left', value_format='JSON');", + "CREATE TABLE T2 (ROWKEY INT PRIMARY KEY, ID int) WITH (kafka_topic='right', value_format='JSON');", + "CREATE TABLE T3 (ROWKEY INT PRIMARY KEY, ID int) WITH (kafka_topic='right2', value_format='JSON');", + "CREATE STREAM OUTPUT as SELECT s1.ID, t2.ID, t3.ID FROM S1 JOIN T2 ON S1.ID = T2.ROWKEY JOIN T3 ON S1.K = T3.ROWKEY;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "right2", "key": 2, "value": {"id": 3}, "timestamp": 10}, + {"topic": "right", "key": 1, "value": {"id": 2}, "timestamp": 11}, + {"topic": "left", "key": 0, "value": {"k": 2, "id": 1}, "timestamp": 12} + ], + "outputs": [ + {"topic": "OUTPUT", "key": 2, "value": {"S1_ID": 1, "T2_ID": 2, "T3_ID": 3}, "timestamp": 12} ], "post": { "sources": [ - {"name": "OUTPUT", "type": "stream", "schema": "KSQL_COL_0 INT KEY, S1_ID BIGINT, T2_ID BIGINT, T3_ID BIGINT"} + {"name": "OUTPUT", "type": "stream", "schema": "S1_K INT KEY, S1_ID INT, T2_ID INT, T3_ID INT"} ] } }, @@ -728,7 +753,7 @@ } }, { - "name": "TTS - Invalid", + "name": "table-table-stream - Invalid", "statements": [ "CREATE STREAM S1 (ROWKEY INT KEY, ID bigint) WITH (kafka_topic='left_topic', value_format='JSON');", "CREATE TABLE T2 (ROWKEY INT PRIMARY KEY, ID bigint) WITH (kafka_topic='right_topic', value_format='JSON');", @@ -739,40 +764,6 @@ "type": "io.confluent.ksql.util.KsqlException", "message": "Join between invalid operands requested: left type: KTABLE, right type: KSTREAM" } - }, - { - "name": "stream-table-table - inner-inner - rekey on different field", - "comments": ["https://github.com/confluentinc/ksql/issues/5062"], - "statements": [ - "CREATE STREAM S1 (ROWKEY INT KEY, K INT, ID bigint) WITH (kafka_topic='left', value_format='JSON');", - "CREATE TABLE T2 (ROWKEY INT PRIMARY KEY, ID bigint) WITH (kafka_topic='right', value_format='JSON');", - "CREATE TABLE T3 (ROWKEY INT PRIMARY KEY, ID bigint) WITH (kafka_topic='right2', value_format='JSON');", - "CREATE STREAM OUTPUT as SELECT s1.ID, t2.ID, t3.ID FROM S1 JOIN T2 ON S1.ID = T2.ROWKEY JOIN T3 ON S1.K = T3.ROWKEY;" - ], - "properties": { - "ksql.any.key.name.enabled": true - }, - "expectedException": { - "type": "io.confluent.ksql.util.KsqlException", - "message": "KSQL does not yet support multi-joins with different expressions. A join for `S1` already exists with the condition `S1.ID` - cannot additionally join on `S1.K`" - } - }, - { - "name": "stream-table-table - inner-inner - rekey on different expression", - "comments": ["https://github.com/confluentinc/ksql/issues/5062"], - "statements": [ - "CREATE STREAM S1 (ROWKEY INT KEY, K INT, ID bigint) WITH (kafka_topic='left', value_format='JSON');", - "CREATE TABLE T2 (ROWKEY INT PRIMARY KEY, ID bigint) WITH (kafka_topic='right', value_format='JSON');", - "CREATE TABLE T3 (ROWKEY INT PRIMARY KEY, ID bigint) WITH (kafka_topic='right2', value_format='JSON');", - "CREATE STREAM OUTPUT as SELECT s1.ID, t2.ID, t3.ID FROM S1 JOIN T2 ON S1.K - 1 = T2.ROWKEY JOIN T3 ON S1.K = T3.ROWKEY;" - ], - "properties": { - "ksql.any.key.name.enabled": true - }, - "expectedException": { - "type": "io.confluent.ksql.util.KsqlException", - "message": "KSQL does not yet support multi-joins with different expressions. A join for `S1` already exists with the condition `(S1.K - 1)` - cannot additionally join on `S1.K`" - } } ] } \ No newline at end of file