From b8e0c991a982229afd6d42dc8752a29a35fa937e Mon Sep 17 00:00:00 2001 From: Alan Sheinberg <57688982+AlanConfluent@users.noreply.github.com> Date: Thu, 28 Jan 2021 15:00:58 -0800 Subject: [PATCH] feat: Rewrites pull query WHERE clause to be in DNF and allow more expressions (#6874) * feat: More generic where clause and key extraction --- .../ksql/physical/pull/PullPhysicalPlan.java | 36 +- .../pull/PullPhysicalPlanBuilder.java | 16 +- .../operators/KeyedTableLookupOperator.java | 5 +- .../KeyedWindowedTableLookupOperator.java | 33 +- .../ksql/planner/plan/KeyConstraint.java | 129 ++++++++ .../ksql/planner/plan/LogicRewriter.java | 308 ++++++++++++++++++ .../ksql/planner/plan/LookupConstraint.java | 25 ++ .../ksql/planner/plan/NonKeyConstraint.java | 26 ++ .../ksql/planner/plan/PullFilterNode.java | 192 +++++------ .../ksql/planner/plan/PullQueryRewriter.java | 80 +++++ .../KeyedTableLookupOperatorTest.java | 40 ++- .../KeyedWindowedTableLookupOperatorTest.java | 116 +++++-- .../ksql/planner/plan/LogicRewriterTest.java | 159 +++++++++ .../ksql/planner/plan/PullFilterNodeTest.java | 253 ++++++++++---- .../planner/plan/PullQueryRewriterTest.java | 57 ++++ ...eries-against-materialized-aggregates.json | 94 +++++- .../confluent/ksql/util/MetaStoreFixture.java | 25 ++ .../streams/materialization/Locator.java | 15 +- .../streams/materialization/ks/KsLocator.java | 24 +- .../materialization/ks/KsLocatorTest.java | 53 ++- 20 files changed, 1404 insertions(+), 282 deletions(-) create mode 100644 ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/KeyConstraint.java create mode 100644 ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/LogicRewriter.java create mode 100644 ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/LookupConstraint.java create mode 100644 ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/NonKeyConstraint.java create mode 100644 ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/PullQueryRewriter.java create mode 100644 ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/LogicRewriterTest.java create mode 100644 ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/PullQueryRewriterTest.java diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullPhysicalPlan.java b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullPhysicalPlan.java index b68b188e2f35..67ece2713272 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullPhysicalPlan.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullPhysicalPlan.java @@ -15,14 +15,19 @@ package io.confluent.ksql.physical.pull; -import io.confluent.ksql.GenericKey; +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.execution.streams.materialization.Locator.KsqlKey; import io.confluent.ksql.execution.streams.materialization.Locator.KsqlPartitionLocation; import io.confluent.ksql.execution.streams.materialization.Materialization; import io.confluent.ksql.physical.pull.operators.AbstractPhysicalOperator; import io.confluent.ksql.physical.pull.operators.DataSourceOperator; +import io.confluent.ksql.planner.plan.KeyConstraint; +import io.confluent.ksql.planner.plan.KeyConstraint.ConstraintOperator; +import io.confluent.ksql.planner.plan.LookupConstraint; import io.confluent.ksql.query.PullQueryQueue; import io.confluent.ksql.query.QueryId; import io.confluent.ksql.schema.ksql.LogicalSchema; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.function.BiFunction; @@ -41,7 +46,7 @@ public class PullPhysicalPlan { private final AbstractPhysicalOperator root; private final LogicalSchema schema; private final QueryId queryId; - private final List keys; + private final List lookupConstraints; private final Materialization mat; private final DataSourceOperator dataSourceOperator; @@ -49,14 +54,14 @@ public PullPhysicalPlan( final AbstractPhysicalOperator root, final LogicalSchema schema, final QueryId queryId, - final List keys, + final List lookupConstraints, final Materialization mat, final DataSourceOperator dataSourceOperator ) { this.root = Objects.requireNonNull(root, "root"); this.schema = Objects.requireNonNull(schema, "schema"); this.queryId = Objects.requireNonNull(queryId, "queryId"); - this.keys = Objects.requireNonNull(keys, "keys"); + this.lookupConstraints = Objects.requireNonNull(lookupConstraints, "lookupConstraints"); this.mat = Objects.requireNonNull(mat, "mat"); this.dataSourceOperator = Objects.requireNonNull( dataSourceOperator, "dataSourceOperator"); @@ -108,8 +113,27 @@ public Materialization getMaterialization() { return mat; } - public List getKeys() { - return keys; + public List getKeys() { + if (requiresRequestsToAllPartitions()) { + return Collections.emptyList(); + } + return lookupConstraints.stream() + .filter(lookupConstraint -> lookupConstraint instanceof KeyConstraint) + .map(KeyConstraint.class::cast) + .filter(keyConstraint -> keyConstraint.getConstraintOperator() == ConstraintOperator.EQUAL) + .map(KeyConstraint::getKsqlKey) + .collect(ImmutableList.toImmutableList()); + } + + private boolean requiresRequestsToAllPartitions() { + return lookupConstraints.stream() + .anyMatch(lookupConstraint -> { + if (lookupConstraint instanceof KeyConstraint) { + final KeyConstraint keyConstraint = (KeyConstraint) lookupConstraint; + return keyConstraint.getConstraintOperator() != ConstraintOperator.EQUAL; + } + return true; + }); } public LogicalSchema getOutputSchema() { diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullPhysicalPlanBuilder.java b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullPhysicalPlanBuilder.java index 17fd17567345..34eb43a39a4b 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullPhysicalPlanBuilder.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullPhysicalPlanBuilder.java @@ -15,7 +15,6 @@ package io.confluent.ksql.physical.pull; -import io.confluent.ksql.GenericKey; import io.confluent.ksql.analyzer.ImmutableAnalysis; import io.confluent.ksql.analyzer.PullQueryValidator; import io.confluent.ksql.execution.context.QueryContext.Stacker; @@ -37,10 +36,10 @@ import io.confluent.ksql.planner.LogicalPlanNode; import io.confluent.ksql.planner.plan.DataSourceNode; import io.confluent.ksql.planner.plan.KsqlBareOutputNode; +import io.confluent.ksql.planner.plan.LookupConstraint; import io.confluent.ksql.planner.plan.OutputNode; import io.confluent.ksql.planner.plan.PlanNode; import io.confluent.ksql.planner.plan.PullFilterNode; -import io.confluent.ksql.planner.plan.PullFilterNode.WindowBounds; import io.confluent.ksql.planner.plan.PullProjectNode; import io.confluent.ksql.query.QueryId; import io.confluent.ksql.util.KsqlException; @@ -48,7 +47,6 @@ import java.util.Collections; import java.util.List; import java.util.Objects; -import java.util.Optional; /** * Traverses the logical plan top-down and creates a physical plan for pull queries. @@ -66,8 +64,7 @@ public class PullPhysicalPlanBuilder { private final QueryId queryId; private final Materialization mat; - private List keys; - private Optional windowBounds; + private List lookupConstraints; private boolean seenSelectOperator = false; public PullPhysicalPlanBuilder( @@ -145,7 +142,7 @@ public PullPhysicalPlan buildPullPhysicalPlan(final LogicalPlanNode logicalPlanN rootPhysicalOp, (rootPhysicalOp).getLogicalNode().getSchema(), queryId, - keys, + lookupConstraints, mat, dataSourceOperator); } @@ -165,8 +162,7 @@ private ProjectOperator translateProjectNode(final PullProjectNode logicalNode) } private SelectOperator translateFilterNode(final PullFilterNode logicalNode) { - keys = logicalNode.getKeyValues(); - windowBounds = logicalNode.getWindowBounds(); + lookupConstraints = logicalNode.getLookupConstraints(); final ProcessingLogger logger = processingLogContext .getLoggerFactory() @@ -181,7 +177,7 @@ private AbstractPhysicalOperator translateDataSourceNode( final DataSourceNode logicalNode ) { if (!seenSelectOperator) { - keys = Collections.emptyList(); + lookupConstraints = Collections.emptyList(); if (!logicalNode.isWindowed()) { return new TableScanOperator(mat, logicalNode); } else { @@ -191,7 +187,7 @@ private AbstractPhysicalOperator translateDataSourceNode( if (!logicalNode.isWindowed()) { return new KeyedTableLookupOperator(mat, logicalNode); } else { - return new KeyedWindowedTableLookupOperator(mat, logicalNode, windowBounds.get()); + return new KeyedWindowedTableLookupOperator(mat, logicalNode); } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/operators/KeyedTableLookupOperator.java b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/operators/KeyedTableLookupOperator.java index 4356dc417ab1..75192edd901b 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/operators/KeyedTableLookupOperator.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/operators/KeyedTableLookupOperator.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.GenericKey; +import io.confluent.ksql.execution.streams.materialization.Locator.KsqlKey; import io.confluent.ksql.execution.streams.materialization.Locator.KsqlPartitionLocation; import io.confluent.ksql.execution.streams.materialization.Materialization; import io.confluent.ksql.execution.streams.materialization.Row; @@ -60,7 +61,7 @@ public void open() { if (!nextLocation.getKeys().isPresent()) { throw new IllegalStateException("Table lookup queries should be done with keys"); } - keyIterator = nextLocation.getKeys().get().iterator(); + keyIterator = nextLocation.getKeys().get().stream().map(KsqlKey::getKey).iterator(); if (keyIterator.hasNext()) { nextKey = keyIterator.next(); resultIterator = mat.nonWindowed() @@ -85,7 +86,7 @@ public Object next() { if (!nextLocation.getKeys().isPresent()) { throw new IllegalStateException("Table lookup queries should be done with keys"); } - keyIterator = nextLocation.getKeys().get().iterator(); + keyIterator = nextLocation.getKeys().get().stream().map(KsqlKey::getKey).iterator(); } nextKey = keyIterator.next(); resultIterator = mat.nonWindowed() diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/operators/KeyedWindowedTableLookupOperator.java b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/operators/KeyedWindowedTableLookupOperator.java index 63ad136062bf..8b33cf171c5a 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/operators/KeyedWindowedTableLookupOperator.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/operators/KeyedWindowedTableLookupOperator.java @@ -15,11 +15,12 @@ package io.confluent.ksql.physical.pull.operators; -import io.confluent.ksql.GenericKey; +import io.confluent.ksql.execution.streams.materialization.Locator.KsqlKey; import io.confluent.ksql.execution.streams.materialization.Locator.KsqlPartitionLocation; import io.confluent.ksql.execution.streams.materialization.Materialization; import io.confluent.ksql.execution.streams.materialization.WindowedRow; import io.confluent.ksql.planner.plan.DataSourceNode; +import io.confluent.ksql.planner.plan.KeyConstraint.KeyConstraintKey; import io.confluent.ksql.planner.plan.PlanNode; import io.confluent.ksql.planner.plan.PullFilterNode.WindowBounds; import java.util.Iterator; @@ -37,24 +38,21 @@ public class KeyedWindowedTableLookupOperator private final Materialization mat; private final DataSourceNode logicalNode; - private final WindowBounds windowBounds; private List partitionLocations; private Iterator resultIterator; - private Iterator keyIterator; + private Iterator keyIterator; private Iterator partitionLocationIterator; private KsqlPartitionLocation nextLocation; - private GenericKey nextKey; + private KsqlKey nextKey; public KeyedWindowedTableLookupOperator( final Materialization mat, - final DataSourceNode logicalNode, - final WindowBounds windowBounds + final DataSourceNode logicalNode ) { this.logicalNode = Objects.requireNonNull(logicalNode, "logicalNode"); this.mat = Objects.requireNonNull(mat, "mat"); - this.windowBounds = Objects.requireNonNull(windowBounds, "windowBounds"); } @Override @@ -65,11 +63,12 @@ public void open() { if (!nextLocation.getKeys().isPresent()) { throw new IllegalStateException("Table windowed queries should be done with keys"); } - keyIterator = nextLocation.getKeys().get().iterator(); + keyIterator = nextLocation.getKeys().get().stream().iterator(); if (keyIterator.hasNext()) { nextKey = keyIterator.next(); + final WindowBounds windowBounds = getWindowBounds(nextKey); resultIterator = mat.windowed().get( - nextKey, + nextKey.getKey(), nextLocation.getPartition(), windowBounds.getMergedStart(), windowBounds.getMergedEnd()) @@ -95,8 +94,9 @@ public Object next() { keyIterator = nextLocation.getKeys().get().iterator(); } nextKey = keyIterator.next(); + final WindowBounds windowBounds = getWindowBounds(nextKey); resultIterator = mat.windowed().get( - nextKey, + nextKey.getKey(), nextLocation.getPartition(), windowBounds.getMergedStart(), windowBounds.getMergedEnd()) @@ -106,6 +106,19 @@ public Object next() { } + private static WindowBounds getWindowBounds(final KsqlKey ksqlKey) { + if (!(ksqlKey instanceof KeyConstraintKey)) { + throw new IllegalStateException(String.format("Table windowed queries should be done with " + + "key constraints: %s", ksqlKey.toString())); + } + final KeyConstraintKey keyConstraintKey = (KeyConstraintKey) ksqlKey; + if (!keyConstraintKey.getWindowBounds().isPresent()) { + throw new IllegalStateException(String.format("Table windowed queries should be done with " + + "window bounds: %s", ksqlKey.toString())); + } + return keyConstraintKey.getWindowBounds().get(); + } + @Override public void close() { diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/KeyConstraint.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/KeyConstraint.java new file mode 100644 index 000000000000..738e9504aeaa --- /dev/null +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/KeyConstraint.java @@ -0,0 +1,129 @@ +/* + * Copyright 2020 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"; you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.planner.plan; + +import io.confluent.ksql.GenericKey; +import io.confluent.ksql.execution.streams.materialization.Locator.KsqlKey; +import io.confluent.ksql.planner.plan.PullFilterNode.WindowBounds; +import java.util.Objects; +import java.util.Optional; + +/** + * An instance of this class represents what we know about the use of keys in a given disjunct + * from an expression. The key's value, operator associated with it, and window bounds are + * available through the given methods. These are used as hints for the physical planning + * layer about how to fetch the corresponding rows. + */ +public class KeyConstraint implements LookupConstraint { + + private final ConstraintOperator operator; + private final GenericKey key; + private final Optional windowBounds; + + public KeyConstraint( + final ConstraintOperator operator, + final GenericKey key, + final Optional windowBounds + ) { + this.operator = operator; + this.key = key; + this.windowBounds = windowBounds; + } + + public static KeyConstraint equal( + final GenericKey key, + final Optional windowBounds + ) { + return new KeyConstraint(ConstraintOperator.EQUAL, key, windowBounds); + } + + // The key value. + public GenericKey getKey() { + return key; + } + + // The constraint operator associated with the value + public ConstraintOperator getConstraintOperator() { + return operator; + } + + // Window bounds, if the query is for a windowed table. + public Optional getWindowBounds() { + return windowBounds; + } + + public KeyConstraintKey getKsqlKey() { + return new KeyConstraintKey(key, windowBounds); + } + + // If the operator represents a range of keys + public boolean isRangeOperator() { + return operator != ConstraintOperator.EQUAL; + } + + public enum ConstraintOperator { + EQUAL, + LESS_THAN, + LESS_THAN_OR_EQUAL, + GREATER_THAN, + GREATER_THAN_OR_EQUAL + } + + public static class KeyConstraintKey implements KsqlKey { + + private final GenericKey key; + private final Optional windowBounds; + + public KeyConstraintKey(final GenericKey key, final Optional windowBounds) { + this.key = key; + this.windowBounds = windowBounds; + } + + @Override + public GenericKey getKey() { + return key; + } + + public Optional getWindowBounds() { + return windowBounds; + } + + @Override + public int hashCode() { + return Objects.hash(key, windowBounds); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + final KeyConstraintKey that = (KeyConstraintKey) o; + return Objects.equals(this.key, that.key) + && Objects.equals(this.windowBounds, that.windowBounds); + } + + @Override + public String toString() { + return key.toString() + "-" + windowBounds.toString(); + } + } +} diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/LogicRewriter.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/LogicRewriter.java new file mode 100644 index 000000000000..c1d218f77bc6 --- /dev/null +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/LogicRewriter.java @@ -0,0 +1,308 @@ +/* + * Copyright 2020 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"; you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.planner.plan; + +import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter; +import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context; +import io.confluent.ksql.execution.expression.tree.BooleanLiteral; +import io.confluent.ksql.execution.expression.tree.ComparisonExpression; +import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression; +import io.confluent.ksql.execution.expression.tree.NotExpression; +import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp; +import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp; +import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +public final class LogicRewriter { + + + private LogicRewriter() { + } + + public static Expression rewriteNegations(final Expression expression) { + return new ExpressionTreeRewriter<>(new NotPropagator()::process) + .rewrite(expression, new NotPropagatorContext()); + } + + public static Expression rewriteCNF(final Expression expression) { + final Expression notPropagated = new ExpressionTreeRewriter<>(new NotPropagator()::process) + .rewrite(expression, new NotPropagatorContext()); + return new ExpressionTreeRewriter<>( + new DistributiveLawApplierDisjunctionOverConjunction()::process) + .rewrite(notPropagated, null); + } + + public static Expression rewriteDNF(final Expression expression) { + final Expression notPropagated = new ExpressionTreeRewriter<>(new NotPropagator()::process) + .rewrite(expression, new NotPropagatorContext()); + return new ExpressionTreeRewriter<>( + new DistributiveLawApplierConjunctionOverDisjunction()::process) + .rewrite(notPropagated, null); + } + + public static List extractDisjuncts(final Expression expression) { + final Expression dnf = rewriteDNF(expression); + final DisjunctExtractor disjunctExtractor = new DisjunctExtractor(); + disjunctExtractor.process(dnf, null); + return disjunctExtractor.getDisjuncts(); + } + + private static final class NotPropagator extends + VisitParentExpressionVisitor, Context> { + + @Override + public Optional visitExpression( + final Expression node, + final Context context) { + return Optional.empty(); + } + + @Override + public Optional visitUnqualifiedColumnReference( + final UnqualifiedColumnReferenceExp node, + final Context context + ) { + return handlePrimitiveTerm(node, context); + } + + @Override + public Optional visitQualifiedColumnReference( + final QualifiedColumnReferenceExp node, + final Context context) { + return handlePrimitiveTerm(node, context); + } + + @Override + public Optional visitBooleanLiteral( + final BooleanLiteral node, + final Context context) { + return handlePrimitiveTerm(node, context); + } + + private Optional handlePrimitiveTerm( + final Expression node, + final Context context) { + if (!context.getContext().isNegated()) { + return Optional.empty(); + } + return Optional.of(new NotExpression(node.getLocation(), node)); + } + + @Override + public Optional visitLogicalBinaryExpression( + final LogicalBinaryExpression node, + final Context context + ) { + final boolean isNegated = context.getContext().isNegated(); + final Expression left = process(node.getLeft(), context).orElse(node.getLeft()); + context.getContext().restore(isNegated); + final Expression right = process(node.getRight(), context).orElse(node.getRight()); + context.getContext().restore(isNegated); + + LogicalBinaryExpression.Type type = node.getType(); + if (isNegated) { + type = node.getType() == LogicalBinaryExpression.Type.AND + ? LogicalBinaryExpression.Type.OR : LogicalBinaryExpression.Type.AND; + } + return Optional.of(new LogicalBinaryExpression(node.getLocation(), type, left, right)); + } + + @Override + public Optional visitComparisonExpression( + final ComparisonExpression node, + final Context context + ) { + return handlePrimitiveTerm(node, context); + } + + @Override + public Optional visitNotExpression( + final NotExpression node, final Context context) { + context.getContext().negate(); + final Expression value = process(node.getValue(), context).orElse(node.getValue()); + return Optional.of(value); + } + } + + public static final class NotPropagatorContext { + boolean isNegated = false; + + public void negate() { + isNegated = !isNegated; + } + + public boolean isNegated() { + return isNegated; + } + + public void restore(final boolean isNegated) { + this.isNegated = isNegated; + } + } + + private static final class DistributiveLawApplierDisjunctionOverConjunction extends + VisitParentExpressionVisitor, Context> { + + @Override + public Optional visitExpression( + final Expression node, + final Context context) { + return Optional.empty(); + } + + @Override + public Optional visitLogicalBinaryExpression( + final LogicalBinaryExpression node, + final Context context + ) { + final boolean isLeftLogicalExp = node.getLeft() instanceof LogicalBinaryExpression; + final boolean isRightLogicalExp = node.getRight() instanceof LogicalBinaryExpression; + if (!isLeftLogicalExp && !isRightLogicalExp) { + return Optional.empty(); + } + + final Expression left = process(node.getLeft(), context).orElse(node.getLeft()); + final Expression right = process(node.getRight(), context).orElse(node.getRight()); + + if (node.getType() == LogicalBinaryExpression.Type.OR) { + if (left instanceof LogicalBinaryExpression) { + final LogicalBinaryExpression leftLogical = (LogicalBinaryExpression) left; + if (leftLogical.getType() == LogicalBinaryExpression.Type.AND) { + Expression leftOr = new LogicalBinaryExpression(node.getLocation(), + LogicalBinaryExpression.Type.OR, leftLogical.getLeft(), right); + leftOr = process(leftOr, context).orElse(leftOr); + Expression rightOr = new LogicalBinaryExpression(node.getLocation(), + LogicalBinaryExpression.Type.OR, leftLogical.getRight(), right); + rightOr = process(rightOr, context).orElse(rightOr); + return Optional.of( + new LogicalBinaryExpression(node.getLocation(), LogicalBinaryExpression.Type.AND, + leftOr, rightOr)); + } + } + + if (right instanceof LogicalBinaryExpression) { + final LogicalBinaryExpression rightLogical = (LogicalBinaryExpression) right; + if (rightLogical.getType() == LogicalBinaryExpression.Type.AND) { + Expression leftOr = new LogicalBinaryExpression(node.getLocation(), + LogicalBinaryExpression.Type.OR, left, rightLogical.getLeft()); + leftOr = process(leftOr, context).orElse(leftOr); + Expression rightOr = new LogicalBinaryExpression(node.getLocation(), + LogicalBinaryExpression.Type.OR, left, rightLogical.getRight()); + rightOr = process(rightOr, context).orElse(rightOr); + return Optional.of( + new LogicalBinaryExpression(node.getLocation(), LogicalBinaryExpression.Type.AND, + leftOr, rightOr)); + } + } + } + return Optional.of( + new LogicalBinaryExpression(node.getLocation(), node.getType(), left, right)); + } + } + + private static final class DistributiveLawApplierConjunctionOverDisjunction extends + VisitParentExpressionVisitor, Context> { + + @Override + public Optional visitExpression( + final Expression node, + final Context context) { + return Optional.empty(); + } + + @Override + public Optional visitLogicalBinaryExpression( + final LogicalBinaryExpression node, + final Context context + ) { + final boolean isLeftLogicalExp = node.getLeft() instanceof LogicalBinaryExpression; + final boolean isRightLogicalExp = node.getRight() instanceof LogicalBinaryExpression; + if (!isLeftLogicalExp && !isRightLogicalExp) { + return Optional.empty(); + } + + final Expression left = process(node.getLeft(), context).orElse(node.getLeft()); + final Expression right = process(node.getRight(), context).orElse(node.getRight()); + + if (node.getType() == LogicalBinaryExpression.Type.AND) { + if (left instanceof LogicalBinaryExpression) { + final LogicalBinaryExpression leftLogical = (LogicalBinaryExpression) left; + if (leftLogical.getType() == LogicalBinaryExpression.Type.OR) { + Expression leftOr = new LogicalBinaryExpression(node.getLocation(), + LogicalBinaryExpression.Type.AND, leftLogical.getLeft(), right); + leftOr = process(leftOr, context).orElse(leftOr); + Expression rightOr = new LogicalBinaryExpression(node.getLocation(), + LogicalBinaryExpression.Type.AND, leftLogical.getRight(), right); + rightOr = process(rightOr, context).orElse(rightOr); + return Optional.of( + new LogicalBinaryExpression(node.getLocation(), LogicalBinaryExpression.Type.OR, + leftOr, rightOr)); + } + } + + if (right instanceof LogicalBinaryExpression) { + final LogicalBinaryExpression rightLogical = (LogicalBinaryExpression) right; + if (rightLogical.getType() == LogicalBinaryExpression.Type.OR) { + Expression leftOr = new LogicalBinaryExpression(node.getLocation(), + LogicalBinaryExpression.Type.AND, left, rightLogical.getLeft()); + leftOr = process(leftOr, context).orElse(leftOr); + Expression rightOr = new LogicalBinaryExpression(node.getLocation(), + LogicalBinaryExpression.Type.AND, left, rightLogical.getRight()); + rightOr = process(rightOr, context).orElse(rightOr); + return Optional.of( + new LogicalBinaryExpression(node.getLocation(), LogicalBinaryExpression.Type.OR, + leftOr, rightOr)); + } + } + } + return Optional.of( + new LogicalBinaryExpression(node.getLocation(), node.getType(), left, right)); + } + } + + private static final class DisjunctExtractor extends VisitParentExpressionVisitor { + private List disjuncts = new ArrayList<>(); + + @Override + public Void visitExpression( + final Expression node, + final Void context) { + disjuncts.add(node); + return null; + } + + @Override + public Void visitLogicalBinaryExpression( + final LogicalBinaryExpression node, + final Void context + ) { + if (node.getType() == LogicalBinaryExpression.Type.AND) { + disjuncts.add(node); + } else { + process(node.getLeft(), context); + process(node.getRight(), context); + } + return null; + } + + public List getDisjuncts() { + return disjuncts; + } + } +} diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/LookupConstraint.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/LookupConstraint.java new file mode 100644 index 000000000000..392c2ba84b68 --- /dev/null +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/LookupConstraint.java @@ -0,0 +1,25 @@ +/* + * Copyright 2020 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"; you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.planner.plan; + +/** + * The top level interface which represents information extracted from a given disjunct + * from the DNF representation of the where clause expression. Generally, implementing classes will + * be used as hints from the logical planning layer to the physical layer about how to fetch the + * associated data. Namely, use of keys. + */ +public interface LookupConstraint { +} diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/NonKeyConstraint.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/NonKeyConstraint.java new file mode 100644 index 000000000000..e2e73f856e0e --- /dev/null +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/NonKeyConstraint.java @@ -0,0 +1,26 @@ +/* + * Copyright 2020 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"; you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.planner.plan; + +/** + * Means nothing could be extracted about the keys. Either the expression is too complex or + * doesn't make reference to keys, so we consider there to be no bound key constraint. + * Obviously, we'll still have to evaluate the expression for correctness on some overly + * permissive set of rows (e.g. table scan), but we cannot use a key as a hint for fetching the + * data. + */ +public class NonKeyConstraint implements LookupConstraint { +} diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/PullFilterNode.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/PullFilterNode.java index 13b40191e37a..697dcc2361cd 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/PullFilterNode.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/PullFilterNode.java @@ -22,13 +22,11 @@ import io.confluent.ksql.GenericKey; import io.confluent.ksql.analyzer.PullQueryValidator; import io.confluent.ksql.engine.generic.GenericExpressionResolver; -import io.confluent.ksql.engine.rewrite.StatementRewriteForMagicPseudoTimestamp; import io.confluent.ksql.execution.codegen.CodeGenRunner; import io.confluent.ksql.execution.codegen.ExpressionMetadata; import io.confluent.ksql.execution.expression.tree.ComparisonExpression; import io.confluent.ksql.execution.expression.tree.ComparisonExpression.Type; import io.confluent.ksql.execution.expression.tree.Expression; -import io.confluent.ksql.execution.expression.tree.InPredicate; import io.confluent.ksql.execution.expression.tree.IntegerLiteral; import io.confluent.ksql.execution.expression.tree.Literal; import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression; @@ -50,7 +48,6 @@ import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.timestamp.PartialStringToTimestampParser; import java.time.Instant; -import java.util.ArrayList; import java.util.Arrays; import java.util.BitSet; import java.util.HashSet; @@ -78,11 +75,13 @@ public class PullFilterNode extends SingleSourcePlanNode { private final KsqlConfig ksqlConfig; private final LogicalSchema schema = getSource().getSchema(); - private Expression rewrittenPredicate; - private Optional windowBounds; - private List keyValues; - private Set keyColumns; - private Set systemColumns; + // The rewritten predicate in DNF, e.g. (A AND B) OR (C AND D) + private final Expression rewrittenPredicate; + // The separated disjuncts. In the above example, [(A AND B), (C AND D)] + private final List disjuncts; + private final List lookupConstraints; + private final Set keyColumns = new HashSet<>(); + private final Set systemColumns = new HashSet<>(); public PullFilterNode( final PlanNodeId id, @@ -97,20 +96,21 @@ public PullFilterNode( Objects.requireNonNull(predicate, "predicate"); this.metaStore = Objects.requireNonNull(metaStore, "metaStore"); this.ksqlConfig = Objects.requireNonNull(ksqlConfig, "ksqlConfig"); - this.rewrittenPredicate = new StatementRewriteForMagicPseudoTimestamp().rewrite(predicate); + // The predicate is rewritten as DNF. Discussion for why this format is chosen and how it helps + // to extract keys in various scenarios can be found here: + // https://github.com/confluentinc/ksql/pull/6874 + this.rewrittenPredicate = PullQueryRewriter.rewrite(predicate); + this.disjuncts = LogicRewriter.extractDisjuncts(rewrittenPredicate); this.isWindowed = isWindowed; // Basic validation of WHERE clause validateWhereClause(); - // Validation and extractions of window bounds - windowBounds = isWindowed ? Optional.of(extractWindowBounds()) : Optional.empty(); - // Extraction of key and system columns extractKeysAndSystemCols(); - // Extraction of key values - keyValues = extractKeyValues(); + // Extraction of lookup constraints + lookupConstraints = extractLookupConstraints(); // Compiling expression into byte code this.addAdditionalColumnsToIntermediateSchema = shouldAddAdditionalColumnsInSchema(); @@ -147,12 +147,8 @@ public boolean isWindowed() { return isWindowed; } - public List getKeyValues() { - return keyValues; - } - - public Optional getWindowBounds() { - return windowBounds; + public List getLookupConstraints() { + return lookupConstraints; } public boolean getAddAdditionalColumnsToIntermediateSchema() { @@ -164,83 +160,92 @@ public LogicalSchema getIntermediateSchema() { } private void validateWhereClause() { - final Validator validator = new Validator(); - validator.process(rewrittenPredicate, null); - if (!validator.isKeyedQuery) { - throw invalidWhereClauseException("WHERE clause missing key column", isWindowed); - } + for (Expression disjunct : disjuncts) { + final Validator validator = new Validator(); + validator.process(disjunct, null); + if (!validator.isKeyedQuery) { + throw invalidWhereClauseException("WHERE clause missing key column for disjunct: " + + disjunct.toString(), isWindowed); + } - if (!validator.seenKeys.isEmpty() && validator.seenKeys.cardinality() != schema.key().size()) { - final List seenKeyNames = validator.seenKeys - .stream() - .boxed() - .map(i -> schema.key().get(i)) - .map(Column::name) - .collect(Collectors.toList()); - throw invalidWhereClauseException( - "Multi-column sources must specify every key in the WHERE clause. Specified: " - + seenKeyNames + " Expected: " + schema.key(), isWindowed); + if (!validator.seenKeys.isEmpty() + && validator.seenKeys.cardinality() != schema.key().size()) { + final List seenKeyNames = validator.seenKeys + .stream() + .boxed() + .map(i -> schema.key().get(i)) + .map(Column::name) + .collect(Collectors.toList()); + throw invalidWhereClauseException( + "Multi-column sources must specify every key in the WHERE clause. Specified: " + + seenKeyNames + " Expected: " + schema.key(), isWindowed); + } } } private void extractKeysAndSystemCols() { - keyColumns = new HashSet<>(); - systemColumns = new HashSet<>(); new KeyAndSystemColsExtractor().process(rewrittenPredicate, null); } /** - * The WHERE clause is currently limited to either having a single IN predicate - * or equality conditions on the keys. - * inKeys has the key values as specified in the IN predicate. + * The WHERE clause is in DNF and this method extracts key constraints from each disjunct. + * In order to do that successfully, a given disjunct must have equality conditions on the keys. + * For example, for "KEY = 1 AND WINDOWSTART > 0 OR COUNT > 5 AND WINDOWEND < 10", the disjunct + * "KEY = 1 AND WINDOWSTART > 0" has a key equality constraint for value 1. The second + * disjunct "COUNT > 5 AND WINDOWEND < 10" does not and so has an unbound key constraint. * seenKeys is used to make sure that all columns of a multi-column * key are constrained via an equality condition. * keyContents has the key values for each columns of a key. - * @return the constrains on the key values used to to do keyed lookup. + * @return the constraints on the key values used to to do keyed lookup. */ - private List extractKeyValues() { - final KeyValueExtractor keyValueExtractor = new KeyValueExtractor(); - keyValueExtractor.process(rewrittenPredicate, null); - if (!keyValueExtractor.inKeys.isEmpty()) { - return keyValueExtractor.inKeys; - } - - return ImmutableList.of(GenericKey.fromList(Arrays.asList(keyValueExtractor.keyContents))); - } - - private WindowBounds extractWindowBounds() { - final WindowBounds windowBounds = new WindowBounds(); + private List extractLookupConstraints() { + final ImmutableList.Builder constraintPerDisjunct = ImmutableList.builder(); + for (Expression disjunct : disjuncts) { + final KeyValueExtractor keyValueExtractor = new KeyValueExtractor(); + keyValueExtractor.process(disjunct, null); + + // Validation and extractions of window bounds + final Optional optionalWindowBounds; + if (isWindowed) { + final WindowBounds windowBounds = new WindowBounds(); + new WindowBoundsExtractor().process(disjunct, windowBounds); + optionalWindowBounds = Optional.of(windowBounds); + } else { + optionalWindowBounds = Optional.empty(); + } - new WindowBoundsExtractor().process(rewrittenPredicate, windowBounds); - return windowBounds; + if (keyValueExtractor.seenKeys.isEmpty()) { + constraintPerDisjunct.add(new NonKeyConstraint()); + } else { + constraintPerDisjunct.add(KeyConstraint.equal( + GenericKey.fromList(Arrays.asList(keyValueExtractor.keyContents)), + optionalWindowBounds)); + } + } + return constraintPerDisjunct.build(); } /** - * Validate the WHERE clause for pull queries. - * 1. There must be exactly one equality condition per key - * or one IN predicate that involves a key. - * 2. An IN predicate can refer to a single key. - * 3. The IN predicate cannot be combined with other conditions. - * 4. Only AND is allowed. - * 5. If there is a multi-key, conditions on all keys must be specified. - * 6. The IN predicate cannot use multi-keys. + * Validate the WHERE clause for pull queries. Each of these validation steps are taken for each + * disjunct of a DNF expression. + * 1. There must be exactly one equality condition per key. + * 2. An IN predicate has been transformed to equality conditions and therefore isn't handled. + * 3. Only AND is allowed. + * 4. If there is a multi-key, conditions on all keys must be specified. */ private final class Validator extends TraversalExpressionVisitor { private final BitSet seenKeys; - private boolean containsINkeys; private boolean isKeyedQuery; Validator() { isKeyedQuery = false; seenKeys = new BitSet(schema.key().size()); - containsINkeys = false; } @Override public Void process(final Expression node, final Object context) { - if (!(node instanceof LogicalBinaryExpression) - && !(node instanceof ComparisonExpression) - && !(node instanceof InPredicate)) { + if (!(node instanceof LogicalBinaryExpression) + && !(node instanceof ComparisonExpression)) { throw invalidWhereClauseException("Unsupported expression in WHERE clause: " + node, false); } super.process(node, context); @@ -293,7 +298,7 @@ public Void visitComparisonExpression( + "' must currently be '='", isWindowed); } - if (containsINkeys || seenKeys.get(col.index())) { + if (seenKeys.get(col.index())) { throw invalidWhereClauseException( "An equality condition on the key column cannot be combined with other comparisons" + " such as an IN predicate", @@ -310,36 +315,6 @@ public Void visitComparisonExpression( ); } } - - @Override - public Void visitInPredicate( - final InPredicate node, - final Object context - ) { - if (schema.key().size() > 1) { - throw invalidWhereClauseException( - "Schemas with multiple KEY columns are not supported for IN predicates", false); - } - - final UnqualifiedColumnReferenceExp column - = (UnqualifiedColumnReferenceExp) node.getValue(); - final Optional col = schema.findColumn(column.getColumnName()); - if (col.isPresent() && col.get().namespace() == Namespace.KEY) { - if (!seenKeys.isEmpty()) { - throw invalidWhereClauseException( - "The IN predicate cannot be combined with other comparisons on the key column", - isWindowed); - } - containsINkeys = true; - isKeyedQuery = true; - } else { - throw invalidWhereClauseException( - "WHERE clause on unsupported column: " + column.getColumnName().text(), - false - ); - } - return null; - } } private UnqualifiedColumnReferenceExp getColumnRefSide(final ComparisonExpression comp) { @@ -377,12 +352,10 @@ public Void visitUnqualifiedColumnReference( * Necessary so that we can do key lookups when scanning the data stores. */ private final class KeyValueExtractor extends TraversalExpressionVisitor { - private final List inKeys; private final BitSet seenKeys; private final Object[] keyContents; KeyValueExtractor() { - inKeys = new ArrayList<>(); keyContents = new Object[schema.key().size()]; seenKeys = new BitSet(schema.key().size()); } @@ -403,23 +376,6 @@ public Void visitComparisonExpression( return null; } - @Override - public Void visitInPredicate( - final InPredicate node, final Object context) { - final UnqualifiedColumnReferenceExp column - = (UnqualifiedColumnReferenceExp) node.getValue(); - final Optional col = schema.findColumn(column.getColumnName()); - if (col.isPresent() && col.get().namespace() == Namespace.KEY) { - inKeys.addAll(node.getValueList() - .getValues() - .stream() - .map(expression -> resolveKey(expression, col.get(), metaStore, ksqlConfig, node)) - .map(GenericKey::genericKey) - .collect(Collectors.toList())); - } - return null; - } - private Object resolveKey( final Expression exp, final Column keyColumn, diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/PullQueryRewriter.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/PullQueryRewriter.java new file mode 100644 index 000000000000..40683c08d8a9 --- /dev/null +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/PullQueryRewriter.java @@ -0,0 +1,80 @@ +/* + * Copyright 2020 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"; you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.planner.plan; + +import com.google.common.collect.Lists; +import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter; +import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context; +import io.confluent.ksql.engine.rewrite.StatementRewriteForMagicPseudoTimestamp; +import io.confluent.ksql.execution.expression.tree.ComparisonExpression; +import io.confluent.ksql.execution.expression.tree.ComparisonExpression.Type; +import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.InPredicate; +import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression; +import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor; +import java.util.Optional; + +public final class PullQueryRewriter { + + private PullQueryRewriter() { } + + public static Expression rewrite(final Expression expression) { + final Expression pseudoTimestamp = new StatementRewriteForMagicPseudoTimestamp() + .rewrite(expression); + final Expression inPredicatesRemoved = rewriteInPredicates(pseudoTimestamp); + return LogicRewriter.rewriteDNF(inPredicatesRemoved); + } + + public static Expression rewriteInPredicates(final Expression expression) { + return new ExpressionTreeRewriter<>(new InPredicateRewriter()::process) + .rewrite(expression, null); + } + + private static final class InPredicateRewriter extends + VisitParentExpressionVisitor, Context> { + + @Override + public Optional visitExpression( + final Expression node, + final Context context) { + return Optional.empty(); + } + + @Override + public Optional visitInPredicate( + final InPredicate node, + final Context context + ) { + Expression currentExpression = null; + for (Expression inValueListExp : Lists.reverse(node.getValueList().getValues())) { + final ComparisonExpression comparisonExpression = new ComparisonExpression( + node.getLocation(), Type.EQUAL, node.getValue(), + inValueListExp); + if (currentExpression == null) { + currentExpression = comparisonExpression; + continue; + } + currentExpression = new LogicalBinaryExpression( + node.getLocation(), LogicalBinaryExpression.Type.OR, comparisonExpression, + currentExpression); + } + if (currentExpression != null) { + return Optional.of(currentExpression); + } + throw new IllegalStateException("Shouldn't have an empty in predicate"); + } + } +} diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/physical/pull/operators/KeyedTableLookupOperatorTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/physical/pull/operators/KeyedTableLookupOperatorTest.java index 1f86b381fb1a..05423813d85b 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/physical/pull/operators/KeyedTableLookupOperatorTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/physical/pull/operators/KeyedTableLookupOperatorTest.java @@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.confluent.ksql.GenericKey; +import io.confluent.ksql.execution.streams.materialization.Locator.KsqlKey; import io.confluent.ksql.execution.streams.materialization.Locator.KsqlNode; import io.confluent.ksql.execution.streams.materialization.Locator.KsqlPartitionLocation; import io.confluent.ksql.execution.streams.materialization.Materialization; @@ -33,6 +34,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; @@ -54,13 +56,21 @@ public class KeyedTableLookupOperatorTest { @Mock private DataSourceNode logicalNode; @Mock - private GenericKey KEY1; + private KsqlKey KEY1; @Mock - private GenericKey KEY2; + private KsqlKey KEY2; @Mock - private GenericKey KEY3; + private KsqlKey KEY3; @Mock - private GenericKey KEY4; + private KsqlKey KEY4; + @Mock + private GenericKey GKEY1; + @Mock + private GenericKey GKEY2; + @Mock + private GenericKey GKEY3; + @Mock + private GenericKey GKEY4; @Mock private Row ROW1; @Mock @@ -68,6 +78,14 @@ public class KeyedTableLookupOperatorTest { @Mock private Row ROW4; + @Before + public void setUp() { + when(KEY1.getKey()).thenReturn(GKEY1); + when(KEY2.getKey()).thenReturn(GKEY2); + when(KEY3.getKey()).thenReturn(GKEY3); + when(KEY4.getKey()).thenReturn(GKEY4); + } + @Test public void shouldLookupRowsForSingleKey() { //Given: @@ -83,10 +101,10 @@ public void shouldLookupRowsForSingleKey() { final KeyedTableLookupOperator lookupOperator = new KeyedTableLookupOperator(materialization, logicalNode); when(materialization.nonWindowed()).thenReturn(nonWindowedTable); - when(materialization.nonWindowed().get(KEY1, 1)).thenReturn(Optional.of(ROW1)); - when(materialization.nonWindowed().get(KEY2, 2)).thenReturn(Optional.empty()); - when(materialization.nonWindowed().get(KEY3, 3)).thenReturn(Optional.of(ROW3)); - when(materialization.nonWindowed().get(KEY4, 3)).thenReturn(Optional.of(ROW4)); + when(materialization.nonWindowed().get(GKEY1, 1)).thenReturn(Optional.of(ROW1)); + when(materialization.nonWindowed().get(GKEY2, 2)).thenReturn(Optional.empty()); + when(materialization.nonWindowed().get(GKEY3, 3)).thenReturn(Optional.of(ROW3)); + when(materialization.nonWindowed().get(GKEY4, 3)).thenReturn(Optional.of(ROW4)); lookupOperator.setPartitionLocations(singleKeyPartitionLocations); @@ -110,9 +128,9 @@ public void shouldLookupRowsForMultipleKeys() { final KeyedTableLookupOperator lookupOperator = new KeyedTableLookupOperator(materialization, logicalNode); when(materialization.nonWindowed()).thenReturn(nonWindowedTable); - when(materialization.nonWindowed().get(KEY1, 1)).thenReturn(Optional.of(ROW1)); - when(materialization.nonWindowed().get(KEY3, 3)).thenReturn(Optional.of(ROW3)); - when(materialization.nonWindowed().get(KEY4, 3)).thenReturn(Optional.of(ROW4)); + when(materialization.nonWindowed().get(GKEY1, 1)).thenReturn(Optional.of(ROW1)); + when(materialization.nonWindowed().get(GKEY3, 3)).thenReturn(Optional.of(ROW3)); + when(materialization.nonWindowed().get(GKEY4, 3)).thenReturn(Optional.of(ROW4)); lookupOperator.setPartitionLocations(multipleKeysPartitionLocations); lookupOperator.open(); diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/physical/pull/operators/KeyedWindowedTableLookupOperatorTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/physical/pull/operators/KeyedWindowedTableLookupOperatorTest.java index 2bb03331e96e..6905b866826f 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/physical/pull/operators/KeyedWindowedTableLookupOperatorTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/physical/pull/operators/KeyedWindowedTableLookupOperatorTest.java @@ -31,12 +31,14 @@ import io.confluent.ksql.execution.streams.materialization.WindowedRow; import io.confluent.ksql.execution.streams.materialization.ks.KsLocator; import io.confluent.ksql.planner.plan.DataSourceNode; +import io.confluent.ksql.planner.plan.KeyConstraint.KeyConstraintKey; import io.confluent.ksql.planner.plan.PullFilterNode.WindowBounds; import java.time.Instant; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Optional; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; @@ -56,13 +58,21 @@ public class KeyedWindowedTableLookupOperatorTest { @Mock private DataSourceNode logicalNode; @Mock - private GenericKey KEY1; + private KeyConstraintKey KEY1; @Mock - private GenericKey KEY2; + private KeyConstraintKey KEY2; @Mock - private GenericKey KEY3; + private KeyConstraintKey KEY3; @Mock - private GenericKey KEY4; + private KeyConstraintKey KEY4; + @Mock + private GenericKey GKEY1; + @Mock + private GenericKey GKEY2; + @Mock + private GenericKey GKEY3; + @Mock + private GenericKey GKEY4; @Mock private WindowedRow WINDOWED_ROW1; @Mock @@ -78,7 +88,27 @@ public class KeyedWindowedTableLookupOperatorTest { @Mock private Range WINDOW_END_BOUNDS; @Mock - private WindowBounds windowBounds; + private WindowBounds windowBounds1; + @Mock + private WindowBounds windowBounds2; + @Mock + private WindowBounds windowBounds3; + @Mock + private WindowBounds windowBounds4; + + @Before + public void setUp() { + when(KEY1.getKey()).thenReturn(GKEY1); + when(KEY2.getKey()).thenReturn(GKEY2); + when(KEY3.getKey()).thenReturn(GKEY3); + when(KEY4.getKey()).thenReturn(GKEY4); + when(KEY1.getWindowBounds()).thenReturn(Optional.of(windowBounds1)); + when(KEY2.getWindowBounds()).thenReturn(Optional.of(windowBounds1)); + when(KEY3.getWindowBounds()).thenReturn(Optional.of(windowBounds1)); + when(KEY4.getWindowBounds()).thenReturn(Optional.of(windowBounds1)); + when(windowBounds1.getMergedStart()).thenReturn(WINDOW_START_BOUNDS); + when(windowBounds1.getMergedEnd()).thenReturn(WINDOW_END_BOUNDS); + } @Test public void shouldLookupRowsForSingleKey() { @@ -93,17 +123,16 @@ public void shouldLookupRowsForSingleKey() { singleKeyPartitionLocations.add(new KsLocator.PartitionLocation( Optional.of(ImmutableSet.of(KEY4)), 3, ImmutableList.of(node3))); - final KeyedWindowedTableLookupOperator lookupOperator = new KeyedWindowedTableLookupOperator(materialization, logicalNode, windowBounds); - when(windowBounds.getMergedStart()).thenReturn(WINDOW_START_BOUNDS); - when(windowBounds.getMergedEnd()).thenReturn(WINDOW_END_BOUNDS); + final KeyedWindowedTableLookupOperator lookupOperator = new KeyedWindowedTableLookupOperator( + materialization, logicalNode); when(materialization.windowed()).thenReturn(windowedTable); - when(materialization.windowed().get(KEY1, 1, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) + when(materialization.windowed().get(GKEY1, 1, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) .thenReturn(ImmutableList.of(WINDOWED_ROW1, WINDOWED_ROW2)); - when(materialization.windowed().get(KEY2, 2, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) + when(materialization.windowed().get(GKEY2, 2, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) .thenReturn(Collections.emptyList()); - when(materialization.windowed().get(KEY3, 3, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) + when(materialization.windowed().get(GKEY3, 3, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) .thenReturn(ImmutableList.of(WINDOWED_ROW3)); - when(materialization.windowed().get(KEY4, 3, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) + when(materialization.windowed().get(GKEY4, 3, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) .thenReturn(ImmutableList.of(WINDOWED_ROW2, WINDOWED_ROW4)); lookupOperator.setPartitionLocations(singleKeyPartitionLocations); lookupOperator.open(); @@ -126,17 +155,18 @@ public void shouldLookupRowsForMultipleKey() { multipleKeysPartitionLocations.add(new KsLocator.PartitionLocation( Optional.of(ImmutableSet.of(KEY3, KEY4)), 3, ImmutableList.of(node3))); - final KeyedWindowedTableLookupOperator lookupOperator = new KeyedWindowedTableLookupOperator(materialization, logicalNode, windowBounds); - when(windowBounds.getMergedStart()).thenReturn(WINDOW_START_BOUNDS); - when(windowBounds.getMergedEnd()).thenReturn(WINDOW_END_BOUNDS); + final KeyedWindowedTableLookupOperator lookupOperator = new KeyedWindowedTableLookupOperator( + materialization, logicalNode); + when(windowBounds1.getMergedStart()).thenReturn(WINDOW_START_BOUNDS); + when(windowBounds1.getMergedEnd()).thenReturn(WINDOW_END_BOUNDS); when(materialization.windowed()).thenReturn(windowedTable); - when(materialization.windowed().get(KEY1, 1, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) + when(materialization.windowed().get(GKEY1, 1, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) .thenReturn(ImmutableList.of(WINDOWED_ROW1, WINDOWED_ROW2)); - when(materialization.windowed().get(KEY2, 1, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) + when(materialization.windowed().get(GKEY2, 1, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) .thenReturn(Collections.emptyList()); - when(materialization.windowed().get(KEY3, 3, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) + when(materialization.windowed().get(GKEY3, 3, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) .thenReturn(ImmutableList.of(WINDOWED_ROW3)); - when(materialization.windowed().get(KEY4, 3, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) + when(materialization.windowed().get(GKEY4, 3, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) .thenReturn(ImmutableList.of(WINDOWED_ROW2, WINDOWED_ROW4)); lookupOperator.setPartitionLocations(multipleKeysPartitionLocations); lookupOperator.open(); @@ -149,4 +179,52 @@ public void shouldLookupRowsForMultipleKey() { assertThat(lookupOperator.next(), is(WINDOWED_ROW4)); assertThat(lookupOperator.next(), is(nullValue())); } + + @Test + public void shouldUseDifferentWindowBoundsPerKey() { + //Given: + final List singleKeyPartitionLocations = new ArrayList<>(); + singleKeyPartitionLocations.add(new KsLocator.PartitionLocation( + Optional.of(ImmutableSet.of(KEY1)), 1, ImmutableList.of(node1))); + singleKeyPartitionLocations.add(new KsLocator.PartitionLocation( + Optional.of(ImmutableSet.of(KEY2)), 2, ImmutableList.of(node2))); + singleKeyPartitionLocations.add(new KsLocator.PartitionLocation( + Optional.of(ImmutableSet.of(KEY3)), 3, ImmutableList.of(node3))); + singleKeyPartitionLocations.add(new KsLocator.PartitionLocation( + Optional.of(ImmutableSet.of(KEY4)), 3, ImmutableList.of(node3))); + + final KeyedWindowedTableLookupOperator lookupOperator = new KeyedWindowedTableLookupOperator( + materialization, logicalNode); + when(KEY1.getWindowBounds()).thenReturn(Optional.of(windowBounds1)); + when(KEY2.getWindowBounds()).thenReturn(Optional.of(windowBounds2)); + when(KEY3.getWindowBounds()).thenReturn(Optional.of(windowBounds3)); + when(KEY4.getWindowBounds()).thenReturn(Optional.of(windowBounds4)); + when(windowBounds1.getMergedStart()).thenReturn(WINDOW_START_BOUNDS); + when(windowBounds1.getMergedEnd()).thenReturn(WINDOW_END_BOUNDS); + when(windowBounds2.getMergedStart()).thenReturn(Range.all()); + when(windowBounds2.getMergedEnd()).thenReturn(WINDOW_END_BOUNDS); + when(windowBounds3.getMergedStart()).thenReturn(WINDOW_START_BOUNDS); + when(windowBounds3.getMergedEnd()).thenReturn(Range.all()); + when(windowBounds4.getMergedStart()).thenReturn(Range.all()); + when(windowBounds4.getMergedEnd()).thenReturn(Range.all()); + when(materialization.windowed()).thenReturn(windowedTable); + when(materialization.windowed().get(GKEY1, 1, WINDOW_START_BOUNDS, WINDOW_END_BOUNDS)) + .thenReturn(ImmutableList.of(WINDOWED_ROW1, WINDOWED_ROW2)); + when(materialization.windowed().get(GKEY2, 2, Range.all(), WINDOW_END_BOUNDS)) + .thenReturn(Collections.emptyList()); + when(materialization.windowed().get(GKEY3, 3, WINDOW_START_BOUNDS, Range.all())) + .thenReturn(ImmutableList.of(WINDOWED_ROW3)); + when(materialization.windowed().get(GKEY4, 3, Range.all(), Range.all())) + .thenReturn(ImmutableList.of(WINDOWED_ROW2, WINDOWED_ROW4)); + lookupOperator.setPartitionLocations(singleKeyPartitionLocations); + lookupOperator.open(); + + //Then: + assertThat(lookupOperator.next(), is(WINDOWED_ROW1)); + assertThat(lookupOperator.next(), is(WINDOWED_ROW2)); + assertThat(lookupOperator.next(), is(WINDOWED_ROW3)); + assertThat(lookupOperator.next(), is(WINDOWED_ROW2)); + assertThat(lookupOperator.next(), is(WINDOWED_ROW4)); + assertThat(lookupOperator.next(), is(nullValue())); + } } diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/LogicRewriterTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/LogicRewriterTest.java new file mode 100644 index 000000000000..6b7518fa96b9 --- /dev/null +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/LogicRewriterTest.java @@ -0,0 +1,159 @@ +package io.confluent.ksql.planner.plan; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.mock; + +import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.metastore.MetaStore; +import io.confluent.ksql.parser.tree.Query; +import io.confluent.ksql.util.KsqlParserTestUtil; +import io.confluent.ksql.util.MetaStoreFixture; +import java.util.List; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class LogicRewriterTest { + + private MetaStore metaStore; + + @Before + public void init() { + metaStore = MetaStoreFixture.getNewMetaStore(mock(FunctionRegistry.class)); + } + + @Test + public void shouldPropagateNots() { + assertNeg("ORDERS", "ORDERUNITS > 1", "(ORDERS.ORDERUNITS > 1)"); + assertNeg("ORDERS", "NOT ORDERUNITS > 1", "(NOT (ORDERS.ORDERUNITS > 1))"); + assertNeg("TEST3", "NOT COL4", "(NOT TEST3.COL4)"); + assertNeg("TEST3", "NOT NOT COL4", "TEST3.COL4"); + assertNeg("TEST3", "NOT NOT NOT COL4", "(NOT TEST3.COL4)"); + assertNeg("TEST3", "NOT true", "(NOT true)"); + assertNeg("TEST3", "NOT true", "(NOT true)"); + assertNeg("ORDERS", "ORDERUNITS > 1 AND ITEMID = 'a'", + "((ORDERS.ORDERUNITS > 1) AND (ORDERS.ITEMID = 'a'))"); + assertNeg("ORDERS", "NOT ORDERUNITS > 1 AND ITEMID = 'a'", + "((NOT (ORDERS.ORDERUNITS > 1)) AND (ORDERS.ITEMID = 'a'))"); + assertNeg("ORDERS", "NOT (ORDERUNITS > 1 AND ITEMID = 'a')", + "((NOT (ORDERS.ORDERUNITS > 1)) OR (NOT (ORDERS.ITEMID = 'a')))"); + assertNeg("ORDERS", "NOT (ORDERUNITS + 2 > 1 AND ITEMID = 'a' + 'b')", + "((NOT ((ORDERS.ORDERUNITS + 2) > 1)) OR (NOT (ORDERS.ITEMID = ('a' + 'b'))))"); + assertNeg("ORDERS", "NOT (ORDERUNITS > 1 OR ITEMID = 'a')", + "((NOT (ORDERS.ORDERUNITS > 1)) AND (NOT (ORDERS.ITEMID = 'a')))"); + assertNeg("TEST5", "NOT (A AND B)", "((NOT TEST5.A) OR (NOT TEST5.B))"); + assertNeg("TEST5", "NOT (A OR B)", "((NOT TEST5.A) AND (NOT TEST5.B))"); + } + + @Test + public void shouldBeCNF() { + assertCNF("TEST5", "A", "TEST5.A"); + assertCNF("TEST5", "A AND B", "(TEST5.A AND TEST5.B)"); + assertCNF("TEST5", "A OR B", "(TEST5.A OR TEST5.B)"); + assertCNF("TEST5", "(A OR B) AND (C OR D)", + "((TEST5.A OR TEST5.B) AND (TEST5.C OR TEST5.D))"); + assertCNF("TEST5", "(A AND B) OR (C AND D)", + "(((TEST5.A OR TEST5.C) AND (TEST5.A OR TEST5.D)) AND " + + "((TEST5.B OR TEST5.C) AND (TEST5.B OR TEST5.D)))"); + assertCNF("TEST5", "(A AND B) OR (C AND D) OR (E AND F)", + "(((((TEST5.A OR TEST5.C) OR TEST5.E) AND ((TEST5.A OR TEST5.C) OR TEST5.F)) AND " + + "(((TEST5.A OR TEST5.D) OR TEST5.E) AND ((TEST5.A OR TEST5.D) OR TEST5.F))) AND " + + "((((TEST5.B OR TEST5.C) OR TEST5.E) AND ((TEST5.B OR TEST5.C) OR TEST5.F)) AND " + + "(((TEST5.B OR TEST5.D) OR TEST5.E) AND ((TEST5.B OR TEST5.D) OR TEST5.F))))"); + + assertCNF("TEST5", "(NOT A AND B) OR (C AND NOT D)", + "((((NOT TEST5.A) OR TEST5.C) AND ((NOT TEST5.A) OR (NOT TEST5.D))) AND " + + "((TEST5.B OR TEST5.C) AND (TEST5.B OR (NOT TEST5.D))))"); + assertCNF("TEST5", "NOT (A AND B) OR NOT (C AND D)", + "(((NOT TEST5.A) OR (NOT TEST5.B)) OR ((NOT TEST5.C) OR (NOT TEST5.D)))"); + } + + @Test + public void shouldBeDNF() { + assertDNF("TEST5", "A", "TEST5.A"); + assertDNF("TEST5", "A AND B", "(TEST5.A AND TEST5.B)"); + assertDNF("TEST5", "A OR B", "(TEST5.A OR TEST5.B)"); + assertDNF("TEST5", "(A OR B) AND (C OR D)", + "(((TEST5.A AND TEST5.C) OR (TEST5.A AND TEST5.D)) OR " + + "((TEST5.B AND TEST5.C) OR (TEST5.B AND TEST5.D)))"); + assertDNF("TEST5", "(A AND B) OR (C AND D)", + "((TEST5.A AND TEST5.B) OR (TEST5.C AND TEST5.D))"); + assertDNF("TEST5", "(A OR B) AND (C OR D) AND (E OR F)", + "(((((TEST5.A AND TEST5.C) AND TEST5.E) OR ((TEST5.A AND TEST5.C) AND TEST5.F)) OR " + + "(((TEST5.A AND TEST5.D) AND TEST5.E) OR ((TEST5.A AND TEST5.D) AND TEST5.F))) OR " + + "((((TEST5.B AND TEST5.C) AND TEST5.E) OR ((TEST5.B AND TEST5.C) AND TEST5.F)) OR " + + "(((TEST5.B AND TEST5.D) AND TEST5.E) OR ((TEST5.B AND TEST5.D) AND TEST5.F))))"); + + assertDNF("TEST5", "(NOT A OR B) AND (C OR NOT D)", + "((((NOT TEST5.A) AND TEST5.C) OR ((NOT TEST5.A) AND (NOT TEST5.D))) OR " + + "((TEST5.B AND TEST5.C) OR (TEST5.B AND (NOT TEST5.D))))"); + assertDNF("TEST5", "NOT (A OR B) AND NOT (C OR D)", + "(((NOT TEST5.A) AND (NOT TEST5.B)) AND ((NOT TEST5.C) AND (NOT TEST5.D)))"); + } + + @Test + public void shouldExtractDisjuncts() { + assertExtractDisjuncts("TEST5", "A", "TEST5.A"); + assertExtractDisjuncts("TEST5", "A AND B", "(TEST5.A AND TEST5.B)"); + assertExtractDisjuncts("TEST5", "A OR B", "TEST5.A", "TEST5.B"); + assertExtractDisjuncts("TEST5", "(A OR B) AND (C OR D)", "(TEST5.A AND TEST5.C)", + "(TEST5.A AND TEST5.D)", "(TEST5.B AND TEST5.C)", "(TEST5.B AND TEST5.D)"); + + assertExtractDisjuncts("ORDERS", "ORDERUNITS > 1", "(ORDERS.ORDERUNITS > 1)"); + assertExtractDisjuncts("ORDERS", "ORDERUNITS > 1 AND ITEMID = 'a'", + "((ORDERS.ORDERUNITS > 1) AND (ORDERS.ITEMID = 'a'))"); + assertExtractDisjuncts("ORDERS", "NOT (ORDERUNITS > 1 AND ITEMID = 'a')", + "(NOT (ORDERS.ORDERUNITS > 1))", "(NOT (ORDERS.ITEMID = 'a'))"); + } + + private void assertNeg(final String table, final String expressionStr, final String expectedStr) { + Expression expression = getWhereExpression(table, expressionStr); + Expression converted = LogicRewriter.rewriteNegations(expression); + + // When + assertThat(converted.toString(), is(expectedStr)); + } + + private void assertCNF(final String table, final String expressionStr, final String expectedStr) { + Expression expression = getWhereExpression(table, expressionStr); + Expression converted = LogicRewriter.rewriteCNF(expression); + + // When + assertThat(converted.toString(), is(expectedStr)); + } + + private void assertDNF(final String table, final String expressionStr, final String expectedStr) { + Expression expression = getWhereExpression(table, expressionStr); + Expression converted = LogicRewriter.rewriteDNF(expression); + + // When + assertThat(converted.toString(), is(expectedStr)); + } + + private void assertExtractDisjuncts(final String table, final String expressionStr, + final String... expectedStrs) { + Expression expression = getWhereExpression(table, expressionStr); + List disjuncts = LogicRewriter.extractDisjuncts(expression); + + assertThat(disjuncts.size(), is(expectedStrs.length)); + + // When + int i = 0; + for (Expression e : disjuncts) { + assertThat(e.toString(), is(expectedStrs[i++])); + } + } + + private Expression getWhereExpression(final String table, String expression) { + final Query statement = (Query) KsqlParserTestUtil + .buildSingleAst("SELECT * FROM " + table + " WHERE " + expression + ";", metaStore) + .getStatement(); + + assertThat(statement.getWhere().isPresent(), is(true)); + return statement.getWhere().get(); + } +} diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/PullFilterNodeTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/PullFilterNodeTest.java index 217050637220..c4d074c43c01 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/PullFilterNodeTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/PullFilterNodeTest.java @@ -108,12 +108,14 @@ public void shouldExtractKeyValueFromLiteralEquals() { ); // When: - final List keys = filterNode.getKeyValues(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(1)))); assertThat(filterNode.isWindowed(), is(false)); - assertThat(filterNode.getWindowBounds(), is(Optional.empty())); + assertThat(keys.size(), is(1)); + final KeyConstraint keyConstraint = (KeyConstraint) keys.get(0); + assertThat(keyConstraint.getKey(), is(GenericKey.genericKey(1))); + assertThat(keyConstraint.getWindowBounds(), is(Optional.empty())); } @Test @@ -160,12 +162,55 @@ public void shouldExtractKeyValueFromExpressionEquals() { ); // When: - final List keys = filterNode.getKeyValues(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(-1)))); assertThat(filterNode.isWindowed(), is(false)); - assertThat(filterNode.getWindowBounds(), is(Optional.empty())); + assertThat(keys.size(), is(1)); + final KeyConstraint keyConstraint = (KeyConstraint) keys.get(0); + assertThat(keyConstraint.getKey(), is(GenericKey.genericKey(-1))); + assertThat(keyConstraint.getWindowBounds(), is(Optional.empty())); + } + + @Test + public void shouldExtractKeyValueFromExpressionEquals_multipleDisjuncts() { + // Given: + final Expression keyExp1 = new ComparisonExpression( + Type.EQUAL, + new UnqualifiedColumnReferenceExp(ColumnName.of("K")), + new IntegerLiteral(1) + ); + final Expression keyExp2 = new ComparisonExpression( + Type.EQUAL, + new UnqualifiedColumnReferenceExp(ColumnName.of("K")), + new IntegerLiteral(2) + ); + final Expression expression = new LogicalBinaryExpression( + LogicalBinaryExpression.Type.OR, + keyExp1, + keyExp2 + ); + PullFilterNode filterNode = new PullFilterNode( + NODE_ID, + source, + expression, + metaStore, + ksqlConfig, + false + ); + + // When: + final List keys = filterNode.getLookupConstraints(); + + // Then: + assertThat(filterNode.isWindowed(), is(false)); + assertThat(keys.size(), is(2)); + final KeyConstraint keyConstraint1 = (KeyConstraint) keys.get(0); + assertThat(keyConstraint1.getKey(), is(GenericKey.genericKey(1))); + assertThat(keyConstraint1.getWindowBounds(), is(Optional.empty())); + final KeyConstraint keyConstraint2 = (KeyConstraint) keys.get(1); + assertThat(keyConstraint2.getKey(), is(GenericKey.genericKey(2))); + assertThat(keyConstraint2.getWindowBounds(), is(Optional.empty())); } @Test @@ -187,12 +232,17 @@ public void shouldExtractKeyValuesFromInExpression() { ); // When: - final List keys = filterNode.getKeyValues(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(1),GenericKey.genericKey(2)))); assertThat(filterNode.isWindowed(), is(false)); - assertThat(filterNode.getWindowBounds(), is(Optional.empty())); + assertThat(keys.size(), is(2)); + final KeyConstraint keyConstraint0 = (KeyConstraint) keys.get(0); + assertThat(keyConstraint0.getKey(), is(GenericKey.genericKey(1))); + assertThat(keyConstraint0.getWindowBounds(), is(Optional.empty())); + final KeyConstraint keyConstraint1 = (KeyConstraint) keys.get(1); + assertThat(keyConstraint1.getKey(), is(GenericKey.genericKey(2))); + assertThat(keyConstraint1.getWindowBounds(), is(Optional.empty())); } // We should refactor the WindowBounds class to encompass the functionality around @@ -215,14 +265,14 @@ public void shouldExtractKeyValueFromExpressionWithNoWindowBounds() { ); // When: - final List keys = filterNode.getKeyValues(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(1)))); assertThat(filterNode.isWindowed(), is(true)); - assertThat(filterNode.getWindowBounds(), is(Optional.of( - new WindowBounds() - ))); + assertThat(keys.size(), is(1)); + final KeyConstraint keyConstraint = (KeyConstraint) keys.get(0); + assertThat(keyConstraint.getKey(), is(GenericKey.genericKey(1))); + assertThat(keyConstraint.getWindowBounds(), is(Optional.of(new WindowBounds()))); } @Test @@ -253,13 +303,14 @@ public void shouldExtractKeyValueAndWindowBountsFromExpressionWithGTWindowStart( ); // When: - final List keys = filterNode.getKeyValues(); - final Optional windowBounds = filterNode.getWindowBounds(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(1)))); assertThat(filterNode.isWindowed(), is(true)); - assertThat(windowBounds, is(Optional.of( + assertThat(keys.size(), is(1)); + final KeyConstraint keyConstraint = (KeyConstraint) keys.get(0); + assertThat(keyConstraint.getKey(), is(GenericKey.genericKey(1))); + assertThat(keyConstraint.getWindowBounds(), is(Optional.of( new WindowBounds( new WindowRange( null, null, Range.downTo(Instant.ofEpochMilli(2), BoundType.OPEN)), @@ -296,13 +347,14 @@ public void shouldExtractKeyValueAndWindowBoundsFromExpressionWithGTEWindowStart ); // When: - final List keys = filterNode.getKeyValues(); - final Optional windowBounds = filterNode.getWindowBounds(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(1)))); assertThat(filterNode.isWindowed(), is(true)); - assertThat(windowBounds, is(Optional.of( + assertThat(keys.size(), is(1)); + final KeyConstraint keyConstraint = (KeyConstraint) keys.get(0); + assertThat(keyConstraint.getKey(), is(GenericKey.genericKey(1))); + assertThat(keyConstraint.getWindowBounds(), is(Optional.of( new WindowBounds( new WindowRange( null, null, Range.downTo(Instant.ofEpochMilli(2), BoundType.CLOSED)), @@ -339,13 +391,14 @@ public void shouldExtractKeyValueAndWindowBoundsFromExpressionWithLTWindowStart( ); // When: - final List keys = filterNode.getKeyValues(); - final Optional windowBounds = filterNode.getWindowBounds(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(1)))); assertThat(filterNode.isWindowed(), is(true)); - assertThat(windowBounds, is(Optional.of( + assertThat(keys.size(), is(1)); + final KeyConstraint keyConstraint = (KeyConstraint) keys.get(0); + assertThat(keyConstraint.getKey(), is(GenericKey.genericKey(1))); + assertThat(keyConstraint.getWindowBounds(), is(Optional.of( new WindowBounds( new WindowRange( null, Range.upTo(Instant.ofEpochMilli(2), BoundType.OPEN), null), @@ -382,13 +435,14 @@ public void shouldExtractKeyValueAndWindowBoundsFromExpressionWithLTEWindowStart ); // When: - final List keys = filterNode.getKeyValues(); - final Optional windowBounds = filterNode.getWindowBounds(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(1)))); assertThat(filterNode.isWindowed(), is(true)); - assertThat(windowBounds, is(Optional.of( + assertThat(keys.size(), is(1)); + final KeyConstraint keyConstraint = (KeyConstraint) keys.get(0); + assertThat(keyConstraint.getKey(), is(GenericKey.genericKey(1))); + assertThat(keyConstraint.getWindowBounds(), is(Optional.of( new WindowBounds( new WindowRange( null, Range.upTo(Instant.ofEpochMilli(2), BoundType.CLOSED), null), @@ -425,17 +479,18 @@ public void shouldExtractKeyValueAndWindowBoundsFromExpressionWithGTWindowEnd() ); // When: - final List keys = filterNode.getKeyValues(); - final Optional windowBounds = filterNode.getWindowBounds(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(1)))); assertThat(filterNode.isWindowed(), is(true)); - assertThat(windowBounds, is(Optional.of( + assertThat(keys.size(), is(1)); + final KeyConstraint keyConstraint = (KeyConstraint) keys.get(0); + assertThat(keyConstraint.getKey(), is(GenericKey.genericKey(1))); + assertThat(keyConstraint.getWindowBounds(), is(Optional.of( new WindowBounds( new WindowRange(), new WindowRange( - null, null, Range.downTo(Instant.ofEpochMilli(2), BoundType.OPEN)) + null, null, Range.downTo(Instant.ofEpochMilli(2), BoundType.OPEN)) ) ))); } @@ -468,13 +523,14 @@ public void shouldExtractKeyValueAndWindowBoundsFromExpressionWithGTEWindowEnd() ); // When: - final List keys = filterNode.getKeyValues(); - final Optional windowBounds = filterNode.getWindowBounds(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(1)))); assertThat(filterNode.isWindowed(), is(true)); - assertThat(windowBounds, is(Optional.of( + assertThat(keys.size(), is(1)); + final KeyConstraint keyConstraint = (KeyConstraint) keys.get(0); + assertThat(keyConstraint.getKey(), is(GenericKey.genericKey(1))); + assertThat(keyConstraint.getWindowBounds(), is(Optional.of( new WindowBounds( new WindowRange(), new WindowRange( @@ -511,13 +567,14 @@ public void shouldExtractKeyValueAndWindowBoundsFromExpressionWithLTWindowEnd() ); // When: - final List keys = filterNode.getKeyValues(); - final Optional windowBounds = filterNode.getWindowBounds(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(1)))); assertThat(filterNode.isWindowed(), is(true)); - assertThat(windowBounds, is(Optional.of( + assertThat(keys.size(), is(1)); + final KeyConstraint keyConstraint = (KeyConstraint) keys.get(0); + assertThat(keyConstraint.getKey(), is(GenericKey.genericKey(1))); + assertThat(keyConstraint.getWindowBounds(), is(Optional.of( new WindowBounds( new WindowRange(), new WindowRange( @@ -554,13 +611,14 @@ public void shouldExtractKeyValueAndWindowBoundsFromExpressionWithLTEWindowEnd() ); // When: - final List keys = filterNode.getKeyValues(); - final Optional windowBounds = filterNode.getWindowBounds(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(1)))); assertThat(filterNode.isWindowed(), is(true)); - assertThat(windowBounds, is(Optional.of( + assertThat(keys.size(), is(1)); + final KeyConstraint keyConstraint = (KeyConstraint) keys.get(0); + assertThat(keyConstraint.getKey(), is(GenericKey.genericKey(1))); + assertThat(keyConstraint.getWindowBounds(), is(Optional.of( new WindowBounds( new WindowRange(), new WindowRange( @@ -597,13 +655,14 @@ public void shouldExtractKeyValueAndWindowBoundsFromExpressionWithEQWindowEnd() ); // When: - final List keys = filterNode.getKeyValues(); - final Optional windowBounds = filterNode.getWindowBounds(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(1)))); assertThat(filterNode.isWindowed(), is(true)); - assertThat(windowBounds, is(Optional.of( + assertThat(keys.size(), is(1)); + final KeyConstraint keyConstraint = (KeyConstraint) keys.get(0); + assertThat(keyConstraint.getKey(), is(GenericKey.genericKey(1))); + assertThat(keyConstraint.getWindowBounds(), is(Optional.of( new WindowBounds( new WindowRange(), new WindowRange( @@ -651,13 +710,14 @@ public void shouldExtractKeyValueAndWindowBoundsFromExpressionWithBothWindowBoun ); // When: - final List keys = filterNode.getKeyValues(); - final Optional windowBounds = filterNode.getWindowBounds(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(1)))); assertThat(filterNode.isWindowed(), is(true)); - assertThat(windowBounds, is(Optional.of( + assertThat(keys.size(), is(1)); + final KeyConstraint keyConstraint = (KeyConstraint) keys.get(0); + assertThat(keyConstraint.getKey(), is(GenericKey.genericKey(1))); + assertThat(keyConstraint.getWindowBounds(), is(Optional.of( new WindowBounds( new WindowRange( null, @@ -699,13 +759,14 @@ public void shouldExtractKeyValueAndWindowBoundsFromExpressionWithGTWindowStartT ); // When: - final List keys = filterNode.getKeyValues(); - final Optional windowBounds = filterNode.getWindowBounds(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(1)))); assertThat(filterNode.isWindowed(), is(true)); - assertThat(windowBounds, is(Optional.of( + assertThat(keys.size(), is(1)); + final KeyConstraint keyConstraint = (KeyConstraint) keys.get(0); + assertThat(keyConstraint.getKey(), is(GenericKey.genericKey(1))); + assertThat(keyConstraint.getWindowBounds(), is(Optional.of( new WindowBounds( new WindowRange( null, null, Range.downTo(Instant.ofEpochMilli(1577836800_000L), BoundType.OPEN)), @@ -743,10 +804,76 @@ public void shouldSupportMultiKeyExpressions() { ); // When: - final List keys = filterNode.getKeyValues(); + final List keys = filterNode.getLookupConstraints(); // Then: - assertThat(keys, is(ImmutableList.of(GenericKey.genericKey(1, 2)))); + assertThat(keys.size(), is(1)); + final KeyConstraint keyConstraint = (KeyConstraint) keys.get(0); + assertThat(keyConstraint.getKey(), is(GenericKey.genericKey(1, 2))); + } + + @Test + public void shouldThrowKeyExpressionThatDoestCoverKey() { + // Given: + when(source.getSchema()).thenReturn(INPUT_SCHEMA); + final Expression expression = new ComparisonExpression( + Type.EQUAL, + new UnqualifiedColumnReferenceExp(ColumnName.of("WINDOWSTART")), + new IntegerLiteral(1234) + ); + + // When: + final KsqlException e = assertThrows( + KsqlException.class, + () -> new PullFilterNode( + NODE_ID, + source, + expression, + metaStore, + ksqlConfig, + true + )); + + // Then: + assertThat(e.getMessage(), containsString("WHERE clause missing key column for disjunct: " + + "(WINDOWSTART = 1234)")); + } + + @Test + public void shouldThrowKeyExpressionThatDoestCoverKey_multipleDisjuncts() { + // Given: + when(source.getSchema()).thenReturn(INPUT_SCHEMA); + final Expression keyExp1 = new ComparisonExpression( + Type.EQUAL, + new UnqualifiedColumnReferenceExp(ColumnName.of("WINDOWSTART")), + new IntegerLiteral(1) + ); + final Expression keyExp2 = new ComparisonExpression( + Type.EQUAL, + new UnqualifiedColumnReferenceExp(ColumnName.of("K")), + new IntegerLiteral(2) + ); + final Expression expression = new LogicalBinaryExpression( + LogicalBinaryExpression.Type.OR, + keyExp1, + keyExp2 + ); + + // When: + final KsqlException e = assertThrows( + KsqlException.class, + () -> new PullFilterNode( + NODE_ID, + source, + expression, + metaStore, + ksqlConfig, + true + )); + + // Then: + assertThat(e.getMessage(), containsString("WHERE clause missing key column for disjunct: " + + "(WINDOWSTART = 1)")); } @@ -884,7 +1011,8 @@ public void shouldThrowOnInAndComparisonExpression() { )); // Then: - assertThat(e.getMessage(), containsString("The IN predicate cannot be combined with other comparisons")); + assertThat(e.getMessage(), containsString("An equality condition on the key column cannot be " + + "combined with other comparisons")); } @Test @@ -1256,5 +1384,4 @@ public void shouldThrowOnUsageOfWindowBoundOnNonwindowedTable() { // Then: assertThat(e.getMessage(), containsString("Cannot use WINDOWSTART/WINDOWEND on non-windowed source.")); } - } \ No newline at end of file diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/PullQueryRewriterTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/PullQueryRewriterTest.java new file mode 100644 index 000000000000..533266e924f5 --- /dev/null +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/PullQueryRewriterTest.java @@ -0,0 +1,57 @@ +package io.confluent.ksql.planner.plan; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.mock; + +import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.metastore.MetaStore; +import io.confluent.ksql.parser.tree.Query; +import io.confluent.ksql.util.KsqlParserTestUtil; +import io.confluent.ksql.util.MetaStoreFixture; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class PullQueryRewriterTest { + + private MetaStore metaStore; + + @Before + public void init() { + metaStore = MetaStoreFixture.getNewMetaStore(mock(FunctionRegistry.class)); + } + + @Test + public void shouldRewriteInPredicate() { + assertRewrite("ORDERS", "ITEMID in ('a', 'b', 'c')", + "((ORDERS.ITEMID = 'a') OR ((ORDERS.ITEMID = 'b') OR (ORDERS.ITEMID = 'c')))"); + assertRewrite("ORDERS", "ORDERID > 2 AND ITEMID in ('a', 'b', 'c')", + "(((ORDERS.ORDERID > 2) AND (ORDERS.ITEMID = 'a')) OR (((ORDERS.ORDERID > 2) AND " + + "(ORDERS.ITEMID = 'b')) OR ((ORDERS.ORDERID > 2) AND (ORDERS.ITEMID = 'c'))))"); + assertRewrite("ORDERS", "ORDERID > 2 OR ITEMID in ('a', 'b', 'c')", + "((ORDERS.ORDERID > 2) OR ((ORDERS.ITEMID = 'a') OR ((ORDERS.ITEMID = 'b') OR " + + "(ORDERS.ITEMID = 'c'))))"); + } + + private void assertRewrite(final String table, final String expressionStr, + final String expectedStr) { + Expression expression = getWhereExpression(table, expressionStr); + Expression converted = PullQueryRewriter.rewrite(expression); + + // When + assertThat(converted.toString(), is(expectedStr)); + } + + private Expression getWhereExpression(final String table, String expression) { + final Query statement = (Query) KsqlParserTestUtil + .buildSingleAst("SELECT * FROM " + table + " WHERE " + expression + ";", metaStore) + .getStatement(); + + assertThat(statement.getWhere().isPresent(), is(true)); + return statement.getWhere().get(); + } +} diff --git a/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/pull-queries-against-materialized-aggregates.json b/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/pull-queries-against-materialized-aggregates.json index 5898278d20ec..f5ec2c313200 100644 --- a/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/pull-queries-against-materialized-aggregates.json +++ b/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/pull-queries-against-materialized-aggregates.json @@ -1032,19 +1032,6 @@ "status": 400 } }, - { - "name": "fail on unsupported query feature: where multiple keys", - "statements": [ - "CREATE STREAM INPUT (ID STRING KEY, IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", - "CREATE TABLE AGGREGATE AS SELECT ID, COUNT(1) AS COUNT FROM INPUT GROUP BY ID;", - "SELECT * FROM AGGREGATE WHERE ID='10' OR ID='11';" - ], - "expectedError": { - "type": "io.confluent.ksql.rest.entity.KsqlStatementErrorMessage", - "message": "Only AND expressions are supported: ((ID = '10') OR (ID = '11'))", - "status": 400 - } - }, { "name": "fail on unsupported query feature: where key range", "statements": [ @@ -1907,7 +1894,7 @@ ], "expectedError": { "type": "io.confluent.ksql.rest.entity.KsqlStatementErrorMessage", - "message": "Primary key columns can not be NULL: (ID IN (null))", + "message": "Primary key columns can not be NULL: (ID = null)", "status": 400 } }, @@ -2109,6 +2096,85 @@ {"row":{"columns":["11", 10, 14000, 15000, 2]}} ]} ] + }, + { + "name": "non-windowed General WHERE - OR lookup", + "statements": [ + "CREATE STREAM INPUT (ID STRING KEY, IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE AGGREGATE AS SELECT ID, COUNT(1) AS COUNT FROM INPUT GROUP BY ID;", + "SELECT * FROM AGGREGATE WHERE ID = '10' OR ID = '8';" + ], + "inputs": [ + {"topic": "test_topic", "timestamp": 12345, "key": "11", "value": {}}, + {"topic": "test_topic", "timestamp": 12365, "key": "10", "value": {}}, + {"topic": "test_topic", "timestamp": 12366, "key": "9", "value": {}}, + {"topic": "test_topic", "timestamp": 12367, "key": "8", "value": {}}, + {"topic": "test_topic", "timestamp": 12368, "key": "12", "value": {}} + ], + "responses": [ + {"admin": {"@type": "currentStatus"}}, + {"admin": {"@type": "currentStatus"}}, + {"query": [ + {"header":{"schema":"`ID` STRING KEY, `COUNT` BIGINT"}}, + {"row":{"columns":["10", 1]}}, + {"row":{"columns":["8", 1]}} + ]} + ] + }, + { + "name": "windowed General WHERE - window range", + "statements": [ + "CREATE STREAM INPUT (ID STRING KEY, IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE AGGREGATE AS SELECT ID, COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ID;", + "SELECT * FROM AGGREGATE WHERE WINDOWSTART >= 12000 AND (ID = '10' OR ID = '8' OR ID = '12');", + "SELECT * FROM AGGREGATE WHERE WINDOWSTART >= 12000 AND WINDOWEND <= 13000 AND (ID = '10' OR ID = '8' OR ID = '12');", + "SELECT * FROM AGGREGATE WHERE WINDOWSTART >= 12000 AND ID = '10' OR ID = '8' OR WINDOWEND <= 13000 AND ID = '12';", + "SELECT * FROM AGGREGATE WHERE WINDOWSTART >= 12000 AND WINDOWEND <= 13000 AND ID = '8' OR WINDOWSTART >= 14000 AND WINDOWEND <= 15000 AND ID = '8';" + ], + "inputs": [ + {"topic": "test_topic", "timestamp": 12346, "key": "11", "value": {"val": 1}}, + {"topic": "test_topic", "timestamp": 11345, "key": "10", "value": {"val": 2}}, + {"topic": "test_topic", "timestamp": 12345, "key": "10", "value": {"val": 2}}, + {"topic": "test_topic", "timestamp": 12366, "key": "9", "value": {"val": 3}}, + {"topic": "test_topic", "timestamp": 12367, "key": "8", "value": {"val": 4}}, + {"topic": "test_topic", "timestamp": 14367, "key": "8", "value": {"val": 4}}, + {"topic": "test_topic", "timestamp": 15367, "key": "8", "value": {"val": 4}}, + {"topic": "test_topic", "timestamp": 12368, "key": "12", "value": {"val": 5}}, + {"topic": "test_topic", "timestamp": 13368, "key": "12", "value": {"val": 5}}, + {"topic": "test_topic", "timestamp": 16369, "key": "13", "value": {"val": 6}} + ], + "responses": [ + {"admin": {"@type": "currentStatus"}}, + {"admin": {"@type": "currentStatus"}}, + {"query": [ + {"header":{"schema":"`ID` STRING KEY, `WINDOWSTART` BIGINT KEY, `WINDOWEND` BIGINT KEY, `COUNT` BIGINT"}}, + {"row":{"columns":["10", 12000, 13000, 1]}}, + {"row":{"columns":["12", 12000, 13000, 1]}}, + {"row":{"columns":["12", 13000, 14000, 1]}}, + {"row":{"columns":["8", 12000, 13000, 1]}}, + {"row":{"columns":["8", 14000, 15000, 1]}}, + {"row":{"columns":["8", 15000, 16000, 1]}} + ]}, + {"query": [ + {"header":{"schema":"`ID` STRING KEY, `WINDOWSTART` BIGINT KEY, `WINDOWEND` BIGINT KEY, `COUNT` BIGINT"}}, + {"row":{"columns":["10", 12000, 13000, 1]}}, + {"row":{"columns":["12", 12000, 13000, 1]}}, + {"row":{"columns":["8", 12000, 13000, 1]}} + ]}, + {"query": [ + {"header":{"schema":"`ID` STRING KEY, `WINDOWSTART` BIGINT KEY, `WINDOWEND` BIGINT KEY, `COUNT` BIGINT"}}, + {"row":{"columns":["10", 12000, 13000, 1]}}, + {"row":{"columns":["12", 12000, 13000, 1]}}, + {"row":{"columns":["8", 12000, 13000, 1]}}, + {"row":{"columns":["8", 14000, 15000, 1]}}, + {"row":{"columns":["8", 15000, 16000, 1]}} + ]}, + {"query": [ + {"header":{"schema":"`ID` STRING KEY, `WINDOWSTART` BIGINT KEY, `WINDOWEND` BIGINT KEY, `COUNT` BIGINT"}}, + {"row":{"columns":["8", 12000, 13000, 1]}}, + {"row":{"columns":["8", 14000, 15000, 1]}} + ]} + ] } ] } \ No newline at end of file diff --git a/ksqldb-metastore/src/test/java/io/confluent/ksql/util/MetaStoreFixture.java b/ksqldb-metastore/src/test/java/io/confluent/ksql/util/MetaStoreFixture.java index 0811edc6fba5..d7e52ad86337 100644 --- a/ksqldb-metastore/src/test/java/io/confluent/ksql/util/MetaStoreFixture.java +++ b/ksqldb-metastore/src/test/java/io/confluent/ksql/util/MetaStoreFixture.java @@ -269,6 +269,31 @@ public static MutableMetaStore getNewMetaStore( metaStore.putSource(ksqlStreamSensorReadings, false); + final LogicalSchema testTable5 = LogicalSchema.builder() + .keyColumn(ColumnName.of("A"), SqlTypes.BOOLEAN) + .valueColumn(ColumnName.of("B"), SqlTypes.BOOLEAN) + .valueColumn(ColumnName.of("C"), SqlTypes.BOOLEAN) + .valueColumn(ColumnName.of("D"), SqlTypes.BOOLEAN) + .valueColumn(ColumnName.of("E"), SqlTypes.BOOLEAN) + .valueColumn(ColumnName.of("F"), SqlTypes.BOOLEAN) + .valueColumn(ColumnName.of("G"), SqlTypes.BOOLEAN) + .build(); + + final KsqlTopic ksqlTopic5 = new KsqlTopic( + "test5", + keyFormat, + valueFormat + ); + final KsqlTable ksqlTable5 = new KsqlTable<>( + "sqlexpression", + SourceName.of("TEST5"), + testTable5, + Optional.empty(), + false, + ksqlTopic5 + ); + metaStore.putSource(ksqlTable5, false); + return metaStore; } } diff --git a/ksqldb-streams/src/main/java/io/confluent/ksql/execution/streams/materialization/Locator.java b/ksqldb-streams/src/main/java/io/confluent/ksql/execution/streams/materialization/Locator.java index 382b084a65da..c330983702af 100644 --- a/ksqldb-streams/src/main/java/io/confluent/ksql/execution/streams/materialization/Locator.java +++ b/ksqldb-streams/src/main/java/io/confluent/ksql/execution/streams/materialization/Locator.java @@ -42,7 +42,7 @@ public interface Locator { * @return the list of nodes, that can potentially serve the key. */ List locate( - List keys, + List keys, RoutingOptions routingOptions, RoutingFilterFactory routingFilterFactory ); @@ -76,6 +76,17 @@ interface KsqlPartitionLocation { * @return the keys associated with the data we want to access, if any. Keys may not be present * for queries which don't enumerate them up front, such as range queries. */ - Optional> getKeys(); + Optional> getKeys(); + } + + /** + * Wrapper around a GenericKey + */ + interface KsqlKey { + + /** + * Gets the key associated with this KsqlKey + */ + GenericKey getKey(); } } diff --git a/ksqldb-streams/src/main/java/io/confluent/ksql/execution/streams/materialization/ks/KsLocator.java b/ksqldb-streams/src/main/java/io/confluent/ksql/execution/streams/materialization/ks/KsLocator.java index 3b9be710f490..bfc4d0031c3d 100644 --- a/ksqldb-streams/src/main/java/io/confluent/ksql/execution/streams/materialization/ks/KsLocator.java +++ b/ksqldb-streams/src/main/java/io/confluent/ksql/execution/streams/materialization/ks/KsLocator.java @@ -82,7 +82,7 @@ public final class KsLocator implements Locator { @Override public List locate( - final List keys, + final List keys, final RoutingOptions routingOptions, final RoutingFilterFactory routingFilterFactory ) { @@ -102,7 +102,7 @@ public List locate( final HostInfo activeHost = partitionMetadata.getActiveHost(); final Set standByHosts = partitionMetadata.getStandbyHosts(); final int partition = partitionMetadata.getPartition(); - final Optional> partitionKeys = partitionMetadata.getKeys(); + final Optional> partitionKeys = partitionMetadata.getKeys(); // For a given partition, find the ordered, filtered list of hosts to consider final List filteredHosts = getFilteredHosts(routingOptions, routingFilterFactory, @@ -122,16 +122,16 @@ public List locate( * @return The metadata associated with the keys */ private List getMetadataForKeys( - final List keys, + final List keys, final Set filterPartitions ) { // Maintain request order for reproducibility by using a LinkedHashMap, even though it's // not a guarantee of the API. final Map metadataByPartition = new LinkedHashMap<>(); - final Map> keysByPartition = new HashMap<>(); - for (GenericKey key : keys) { + final Map> keysByPartition = new HashMap<>(); + for (KsqlKey key : keys) { final KeyQueryMetadata metadata = kafkaStreams - .queryMetadataForKey(stateStoreName, key, keySerializer); + .queryMetadataForKey(stateStoreName, key.getKey(), keySerializer); // Fail fast if Streams not ready. Let client handle it if (metadata == KeyQueryMetadata.NOT_AVAILABLE) { @@ -344,18 +344,18 @@ public int hashCode() { @VisibleForTesting public static final class PartitionLocation implements KsqlPartitionLocation { - private final Optional> keys; + private final Optional> keys; private final int partition; private final List nodes; - public PartitionLocation(final Optional> keys, final int partition, + public PartitionLocation(final Optional> keys, final int partition, final List nodes) { this.keys = keys; this.partition = partition; this.nodes = nodes; } - public Optional> getKeys() { + public Optional> getKeys() { return keys; } @@ -383,13 +383,13 @@ private static class PartitionMetadata { private final HostInfo activeHost; private final Set standbyHosts; private final int partition; - private final Optional> keys; + private final Optional> keys; PartitionMetadata( final HostInfo activeHost, final Set standbyHosts, final int partition, - final Optional> keys + final Optional> keys ) { this.activeHost = activeHost; this.standbyHosts = standbyHosts; @@ -421,7 +421,7 @@ public int getPartition() { /** * @return the set of keys associated with the partition, if they exist */ - public Optional> getKeys() { + public Optional> getKeys() { return keys; } } diff --git a/ksqldb-streams/src/test/java/io/confluent/ksql/execution/streams/materialization/ks/KsLocatorTest.java b/ksqldb-streams/src/test/java/io/confluent/ksql/execution/streams/materialization/ks/KsLocatorTest.java index 7d230518b86b..5d278a0ae927 100644 --- a/ksqldb-streams/src/test/java/io/confluent/ksql/execution/streams/materialization/ks/KsLocatorTest.java +++ b/ksqldb-streams/src/test/java/io/confluent/ksql/execution/streams/materialization/ks/KsLocatorTest.java @@ -35,6 +35,7 @@ import io.confluent.ksql.execution.streams.RoutingFilter.RoutingFilterFactory; import io.confluent.ksql.execution.streams.RoutingFilters; import io.confluent.ksql.execution.streams.RoutingOptions; +import io.confluent.ksql.execution.streams.materialization.Locator.KsqlKey; import io.confluent.ksql.execution.streams.materialization.Locator.KsqlNode; import io.confluent.ksql.execution.streams.materialization.Locator.KsqlPartitionLocation; import io.confluent.ksql.execution.streams.materialization.MaterializationException; @@ -71,6 +72,10 @@ public class KsLocatorTest { private static final GenericKey SOME_KEY1 = GenericKey.genericKey(2); private static final GenericKey SOME_KEY2 = GenericKey.genericKey(3); private static final GenericKey SOME_KEY3 = GenericKey.genericKey(4); + private static final KsqlKey KEY = new TestKey(SOME_KEY); + private static final KsqlKey KEY1 = new TestKey(SOME_KEY1); + private static final KsqlKey KEY2 = new TestKey(SOME_KEY2); + private static final KsqlKey KEY3 = new TestKey(SOME_KEY3); private static final KsqlHostInfo ACTIVE_HOST = new KsqlHostInfo("remoteHost", 2345); private static final KsqlHostInfo STANDBY_HOST1 = new KsqlHostInfo("standby1", 1234); private static final KsqlHostInfo STANDBY_HOST2 = new KsqlHostInfo("standby2", 5678); @@ -165,7 +170,7 @@ public void shouldThrowIfMetadataNotAvailable() { // When: final Exception e = assertThrows( MaterializationException.class, - () -> locator.locate(ImmutableList.of(SOME_KEY), routingOptions, routingFilterFactoryActive) + () -> locator.locate(ImmutableList.of(KEY), routingOptions, routingFilterFactoryActive) ); // Then: @@ -179,7 +184,7 @@ public void shouldReturnOwnerIfKnown() { getActiveAndStandbyMetadata(); // When: - final List result = locator.locate(ImmutableList.of(SOME_KEY), routingOptions, + final List result = locator.locate(ImmutableList.of(KEY), routingOptions, routingFilterFactoryActive); // Then: @@ -203,7 +208,7 @@ public void shouldReturnLocalOwnerIfSameAsSuppliedLocalHost() { .thenReturn(true); // When: - final List result = locator.locate(ImmutableList.of(SOME_KEY), + final List result = locator.locate(ImmutableList.of(KEY), routingOptions, routingFilterFactoryActive); // Then: @@ -223,7 +228,7 @@ public void shouldReturnLocalOwnerIfExplicitlyLocalHostOnSamePortAsSuppliedLocal .thenReturn(true); // When: - final List result = locator.locate(ImmutableList.of(SOME_KEY), routingOptions, + final List result = locator.locate(ImmutableList.of(KEY), routingOptions, routingFilterFactoryActive); // Then: @@ -243,7 +248,7 @@ public void shouldReturnRemoteOwnerForDifferentHost() { .thenReturn(true); // When: - final List result = locator.locate(ImmutableList.of(SOME_KEY), routingOptions, + final List result = locator.locate(ImmutableList.of(KEY), routingOptions, routingFilterFactoryActive); // Then: @@ -263,7 +268,7 @@ public void shouldReturnRemoteOwnerForDifferentPort() { .thenReturn(true); // When: - final List result = locator.locate(ImmutableList.of(SOME_KEY), routingOptions, + final List result = locator.locate(ImmutableList.of(KEY), routingOptions, routingFilterFactoryActive); // Then: @@ -284,7 +289,7 @@ public void shouldReturnRemoteOwnerForDifferentPortOnLocalHost() { .thenReturn(true); // When: - final List result = locator.locate(ImmutableList.of(SOME_KEY), routingOptions, + final List result = locator.locate(ImmutableList.of(KEY), routingOptions, routingFilterFactoryActive); // Then: @@ -299,7 +304,7 @@ public void shouldReturnActiveWhenRoutingStandbyNotEnabledHeartBeatNotEnabled() getActiveAndStandbyMetadata(); // When: - final List result = locator.locate(ImmutableList.of(SOME_KEY), routingOptions, + final List result = locator.locate(ImmutableList.of(KEY), routingOptions, routingFilterFactoryActive); // Then: @@ -314,7 +319,7 @@ public void shouldReturnActiveAndStandBysWhenRoutingStandbyEnabledHeartBeatNotEn getActiveAndStandbyMetadata(); // When: - final List result = locator.locate(ImmutableList.of(SOME_KEY), routingOptions, + final List result = locator.locate(ImmutableList.of(KEY), routingOptions, routingFilterFactoryStandby); // Then: @@ -332,7 +337,7 @@ public void shouldReturnStandBysWhenActiveDown() { .thenReturn(false); // When: - final List result = locator.locate(ImmutableList.of(SOME_KEY), routingOptions, + final List result = locator.locate(ImmutableList.of(KEY), routingOptions, routingFilterFactoryStandby); // Then: @@ -351,7 +356,7 @@ public void shouldReturnOneStandByWhenActiveAndOtherStandByDown() { .thenReturn(false); // When: - final List result = locator.locate(ImmutableList.of(SOME_KEY), routingOptions, + final List result = locator.locate(ImmutableList.of(KEY), routingOptions, routingFilterFactoryStandby); // Then: @@ -370,22 +375,22 @@ public void shouldGroupKeysByLocation() { // When: final List result = locator.locate( - ImmutableList.of(SOME_KEY, SOME_KEY1, SOME_KEY2, SOME_KEY3), routingOptions, + ImmutableList.of(KEY, KEY1, KEY2, KEY3), routingOptions, routingFilterFactoryStandby); // Then: assertThat(result.size(), is(3)); - assertThat(result.get(0).getKeys().get(), contains(SOME_KEY, SOME_KEY2)); + assertThat(result.get(0).getKeys().get(), contains(KEY, KEY2)); List nodeList = result.get(0).getNodes(); assertThat(nodeList.size(), is(2)); assertThat(nodeList.get(0), is(activeNode)); assertThat(nodeList.get(1), is(standByNode1)); - assertThat(result.get(1).getKeys().get(), contains(SOME_KEY1)); + assertThat(result.get(1).getKeys().get(), contains(KEY1)); nodeList = result.get(1).getNodes(); assertThat(nodeList.size(), is(2)); assertThat(nodeList.get(0), is(standByNode1)); assertThat(nodeList.get(1), is(activeNode)); - assertThat(result.get(2).getKeys().get(), contains(SOME_KEY3)); + assertThat(result.get(2).getKeys().get(), contains(KEY3)); nodeList = result.get(2).getNodes(); assertThat(nodeList.size(), is(2)); assertThat(nodeList.get(0), is(activeNode)); @@ -469,4 +474,22 @@ private static URL localHost() { throw new AssertionError("Failed to build URL", e); } } + + private static class TestKey implements KsqlKey { + + private final GenericKey genericKey; + + public TestKey(final GenericKey genericKey) { + this.genericKey = genericKey; + } + + @Override + public GenericKey getKey() { + return genericKey; + } + + public String toString() { + return genericKey.toString(); + } + } } \ No newline at end of file