Skip to content

Commit

Permalink
feat: add multi-join expression support (#5081)
Browse files Browse the repository at this point in the history
  • Loading branch information
agavra authored Apr 20, 2020
1 parent 9c36d98 commit 002cd5a
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ public RewrittenAnalysis(
this.rewriter = Objects.requireNonNull(rewriter, "rewriter");
}

public ImmutableAnalysis getOriginal() {
return original;
}

@Override
public List<FunctionCall> getTableFunctions() {
return rewriteList(original.getTableFunctions());
Expand Down Expand Up @@ -148,16 +144,7 @@ public OptionalInt getLimitClause() {

@Override
public List<JoinInfo> getJoin() {
return original.getJoin().stream().map(
j -> new JoinInfo(
j.getLeftSource(),
rewrite(j.getLeftJoinExpression()),
j.getRightSource(),
rewrite(j.getRightJoinExpression()),
j.getType(),
j.getWithinExpression()
)
).collect(Collectors.toList());
return original.getJoin();
}

@Override
Expand Down
108 changes: 75 additions & 33 deletions ksqldb-engine/src/main/java/io/confluent/ksql/planner/JoinTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@

package io.confluent.ksql.planner;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.confluent.ksql.analyzer.Analysis.AliasedDataSource;
import io.confluent.ksql.analyzer.Analysis.JoinInfo;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.schema.ksql.FormatOptions;
import io.confluent.ksql.util.KsqlException;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import org.apache.commons.lang3.StringUtils;

/**
Expand Down Expand Up @@ -111,6 +114,18 @@ interface Node {
* @return a debug string that pretty prints the tree
*/
String debugString(int indent);

/**
* An {@code Expression} that is part of this Node's equivalence set evaluates
* to the same value as the key for this join. Consider the following JOIN:
* <pre>{@code
* SELECT * FROM A JOIN B on A.id = B.id + 1
* JOIN C on A.id = C.id - 1;
* }</pre>
* The equivalence set for the above join would be {@code {A.id, B.id + 1, c.id - 1}}
* since all of those expressions evaluate to the same value.
*/
Set<Expression> joinEquivalenceSet();
}

static class Join implements Node {
Expand All @@ -123,23 +138,6 @@ static class Join implements Node {
this.left = left;
this.right = right;
this.info = info;

checkJoinConditions(info.getLeftSource(), info.getLeftJoinExpression());
checkJoinConditions(info.getRightSource(), info.getRightJoinExpression());
}

/**
* @see <a href="https://github.com/confluentinc/ksql/issues/5062">#5062</a>
*/
private void checkJoinConditions(final AliasedDataSource source, final Expression expression) {
final Expression existing = joinForSource(source);
if (!Objects.equals(existing, expression)) {
throw new KsqlException(
"KSQL does not yet support multi-joins with different expressions. "
+ "A join for " + source.getAlias()
+ " already exists with the condition `" + existing
+ "` - cannot additionally join on `" + expression + "`");
}
}

public JoinInfo getInfo() {
Expand All @@ -159,22 +157,6 @@ public boolean containsSource(final AliasedDataSource dataSource) {
return left.containsSource(dataSource) || right.containsSource(dataSource);
}

private Expression joinForSource(final AliasedDataSource dataSource) {
if (left.containsSource(dataSource)) {
return left instanceof Leaf
? info.getLeftJoinExpression()
: ((Join) left).joinForSource(dataSource);
}

if (right.containsSource(dataSource)) {
return right instanceof Leaf
? info.getRightJoinExpression()
: ((Join) right).joinForSource(dataSource);
}

throw new IllegalStateException("Expected to find a leaf containing " + dataSource);
}

@Override
public String debugString(final int indent) {
return "⋈\n"
Expand All @@ -184,6 +166,61 @@ public String debugString(final int indent) {
+ right.debugString(indent + 3);
}

@Override
public Set<Expression> joinEquivalenceSet() {
// the algorithm to compute the keys on a tree recursively
// checks to see if the keys from subtrees are equivalent to
// the existing join criteria. take the following tree and
// join conditions as example:
//
// ⋈
// / \
// ⋈ ⋈
// / \ / \
// A B C D
//
// A JOIN B on A.id = B.id as AB
// C JOIN D on C.id = D.id as BC
// AB JOIN CD on A.id = C.id + 1 as ABCD
//
// The final topic would be partitioned on A.id, which is equivalent
// to any one of [A.id, B.id, C.id+1].
//
// We can compute this set by checking if either the left or right side
// of the join expression is contained within the equivalence key set
// of the child node. If it is contained, then we know that the child
// equivalence set can be included in the parent equivalence set
//
// In the example above, the left side of the join, A.id, is in the AB
// child set. This means we can add all of AB's keys to the output.
// `C.id + 1` on the other hand, is not in either AB or BC's child set
// so we do not include them in the output set.
//
// We always include both sides of the current join in the output set,
// since we know the key will be the equivalence of those

final Set<Expression> keys = ImmutableSet.of(
info.getLeftJoinExpression(),
info.getRightJoinExpression()
);

final Set<Expression> lefts = left.joinEquivalenceSet();
final Set<Expression> rights = right.joinEquivalenceSet();

final boolean includeLeft = !Sets.intersection(lefts, keys).isEmpty();
final boolean includeRight = !Sets.intersection(rights, keys).isEmpty();

if (includeLeft && includeRight) {
return Sets.union(keys, Sets.union(lefts, rights));
} else if (includeLeft) {
return Sets.union(keys, lefts);
} else if (includeRight) {
return Sets.union(keys, rights);
} else {
return keys;
}
}

@Override
public String toString() {
return "Join{" + "left=" + left
Expand Down Expand Up @@ -234,6 +271,11 @@ public String debugString(final int indent) {
return source.getAlias().toString(FormatOptions.noEscape());
}

@Override
public Set<Expression> joinEquivalenceSet() {
return ImmutableSet.of(); // Leaf nodes don't have join equivalence sets
}

@Override
public String toString() {
return "Leaf{" + "source=" + source + '}';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,37 @@ public Optional<Expression> visitQualifiedColumnReference(
);
}

private PlanNode buildSourceForJoin(
final Join join,
final PlanNode joinedSource,
final String side,
final Expression joinExpression
) {
// we do not need to repartition if the joinExpression
// is already part of the join equivalence set
if (join.joinEquivalenceSet().contains(joinExpression)) {
return joinedSource;
}

return buildRepartitionNode(
side + "SourceKeyed",
joinedSource,
new PartitionBy(
Optional.empty(),
// We need to repartition on the original join expression, and we need to drop
// all qualifiers.
ExpressionTreeRewriter.rewriteWith(refRewriter::process, joinExpression),
Optional.empty()
)
);
}

private PlanNode buildSourceNode() {
if (!analysis.isJoin()) {
return buildNonJoinNode(analysis.getFrom());
}

final List<JoinInfo> joinInfo = analysis.getOriginal().getJoin();
final List<JoinInfo> joinInfo = analysis.getJoin();
final JoinTree.Node tree = JoinTree.build(joinInfo);
if (tree instanceof JoinTree.Leaf) {
throw new IllegalStateException("Expected more than one source:"
Expand All @@ -472,7 +497,12 @@ private PlanNode buildSourceNode() {
private PlanNode buildJoin(final Join root, final String prefix) {
final PlanNode left;
if (root.getLeft() instanceof JoinTree.Join) {
left = buildJoin((Join) root.getLeft(), prefix + "L_");
left = buildSourceForJoin(
(JoinTree.Join) root.getLeft(),
buildJoin((Join) root.getLeft(), prefix + "L_"),
prefix + "Left",
root.getInfo().getLeftJoinExpression()
);
} else {
final JoinTree.Leaf leaf = (Leaf) root.getLeft();
left = buildSourceForJoin(
Expand All @@ -481,7 +511,12 @@ private PlanNode buildJoin(final Join root, final String prefix) {

final PlanNode right;
if (root.getRight() instanceof JoinTree.Join) {
right = buildJoin((Join) root.getRight(), prefix + "R_");
right = buildSourceForJoin(
(JoinTree.Join) root.getRight(),
buildJoin((Join) root.getRight(), prefix + "R_"),
prefix + "Right",
root.getInfo().getRightJoinExpression()
);
} else {
final JoinTree.Leaf leaf = (Leaf) root.getRight();
right = buildSourceForJoin(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@
package io.confluent.ksql.planner;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.analyzer.Analysis.AliasedDataSource;
import io.confluent.ksql.analyzer.Analysis.JoinInfo;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.planner.JoinTree.Join;
import io.confluent.ksql.planner.JoinTree.Leaf;
Expand All @@ -49,6 +51,11 @@ public class JoinTreeTest {
@Mock private JoinInfo j1;
@Mock private JoinInfo j2;

@Mock private Expression e1;
@Mock private Expression e2;
@Mock private Expression e3;
@Mock private Expression e4;

@Before
public void setUp() {
when(a.getAlias()).thenReturn(SourceName.of("a"));
Expand Down Expand Up @@ -124,6 +131,48 @@ public void handlesRightThreeWayJoin() {
));
}

@Test
public void shouldComputeEquivalenceSetWithOverlap() {
when(j1.getLeftSource()).thenReturn(a);
when(j1.getRightSource()).thenReturn(b);
when(j2.getLeftSource()).thenReturn(a);
when(j2.getRightSource()).thenReturn(c);

when(j1.getLeftJoinExpression()).thenReturn(e1);
when(j1.getRightJoinExpression()).thenReturn(e2);
when(j2.getLeftJoinExpression()).thenReturn(e1);
when(j2.getRightJoinExpression()).thenReturn(e3);

final List<JoinInfo> joins = ImmutableList.of(j1, j2);

// When:
final Node root = JoinTree.build(joins);

// Then:
assertThat(root.joinEquivalenceSet(), containsInAnyOrder(e1, e2, e3));
}

@Test
public void shouldComputeEquivalenceSetWithoutOverlap() {
when(j1.getLeftSource()).thenReturn(a);
when(j1.getRightSource()).thenReturn(b);
when(j2.getLeftSource()).thenReturn(a);
when(j2.getRightSource()).thenReturn(c);

when(j1.getLeftJoinExpression()).thenReturn(e1);
when(j1.getRightJoinExpression()).thenReturn(e2);
when(j2.getLeftJoinExpression()).thenReturn(e3);
when(j2.getRightJoinExpression()).thenReturn(e4);

final List<JoinInfo> joins = ImmutableList.of(j1, j2);

// When:
final Node root = JoinTree.build(joins);

// Then:
assertThat(root.joinEquivalenceSet(), contains(e3, e4));
}

@Test
public void outputsCorrectJoinTreeString() {
// Given:
Expand Down
Loading

0 comments on commit 002cd5a

Please sign in to comment.