diff --git a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java index 9f3d372ccf98..3521d22cd0d9 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java @@ -63,7 +63,7 @@ public class Analysis implements ImmutableAnalysis { private final Set selectColumnRefs = new HashSet<>(); private final List groupByExpressions = new ArrayList<>(); private Optional windowExpression = Optional.empty(); - private Optional partitionBy = Optional.empty(); + private Optional partitionBy = Optional.empty(); private ImmutableSet serdeOptions = ImmutableSet.of(); private Optional havingExpression = Optional.empty(); private OptionalInt limitClause = OptionalInt.empty(); @@ -134,11 +134,11 @@ void setHavingExpression(final Expression havingExpression) { this.havingExpression = Optional.of(havingExpression); } - public Optional getPartitionBy() { + public Optional getPartitionBy() { return partitionBy; } - void setPartitionBy(final ColumnRef partitionBy) { + void setPartitionBy(final Expression partitionBy) { this.partitionBy = Optional.of(partitionBy); } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java index c1bba9e9b244..8d152e8d2028 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java @@ -509,13 +509,7 @@ private void analyzeGroupBy(final GroupBy groupBy) { } private void analyzePartitionBy(final Expression partitionBy) { - if (partitionBy instanceof ColumnReferenceExp) { - analysis.setPartitionBy(((ColumnReferenceExp) partitionBy).getReference()); - return; - } - - throw new KsqlException( - "Expected partition by to be a valid column but got " + partitionBy); + analysis.setPartitionBy(partitionBy); } private void analyzeWindowExpression(final WindowExpression windowExpression) { diff --git a/ksql-engine/src/main/java/io/confluent/ksql/engine/rewrite/StatementRewriter.java b/ksql-engine/src/main/java/io/confluent/ksql/engine/rewrite/StatementRewriter.java index 626a49b2c69a..b1b8244a16d6 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/engine/rewrite/StatementRewriter.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/engine/rewrite/StatementRewriter.java @@ -188,9 +188,8 @@ protected AstNode visitQuery(final Query node, final C context) { final Optional groupBy = node.getGroupBy() .map(exp -> ((GroupBy) rewriter.apply(exp, context))); - // don't rewrite the partitionBy because we expect it to be - // exactly as it was (a single, un-aliased, column reference) - final Optional partitionBy = node.getPartitionBy(); + final Optional partitionBy = node.getPartitionBy() + .map(exp -> processExpression(exp, context)); final Optional having = node.getHaving() .map(exp -> (processExpression(exp, context))); diff --git a/ksql-engine/src/main/java/io/confluent/ksql/planner/LogicalPlanner.java b/ksql-engine/src/main/java/io/confluent/ksql/planner/LogicalPlanner.java index f76b4f68459f..7dfdf4ba9f47 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/planner/LogicalPlanner.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/planner/LogicalPlanner.java @@ -47,7 +47,6 @@ import io.confluent.ksql.schema.ksql.LogicalSchema.Builder; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.util.KsqlConfig; -import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.SchemaUtil; import java.util.List; import java.util.Optional; @@ -204,21 +203,26 @@ private static FilterNode buildFilterNode( private static RepartitionNode buildRepartitionNode( final PlanNode sourceNode, - final ColumnRef partitionBy + final Expression partitionBy ) { - if (!sourceNode.getSchema().withoutAlias().findValueColumn(partitionBy).isPresent()) { - throw new KsqlException("Invalid identifier for PARTITION BY clause: '" + partitionBy - + "'. Only columns from the source schema can be referenced in the PARTITION BY clause."); + if (!(partitionBy instanceof ColumnReferenceExp)) { + return new RepartitionNode( + new PlanNodeId("PartitionBy"), + sourceNode, + partitionBy, + KeyField.none()); } - final KeyField keyField; + final ColumnRef partitionColumn = ((ColumnReferenceExp) partitionBy).getReference(); final LogicalSchema schema = sourceNode.getSchema(); - if (schema.isMetaColumn(partitionBy.name())) { + + final KeyField keyField; + if (schema.isMetaColumn(partitionColumn.name())) { keyField = KeyField.none(); - } else if (schema.isKeyColumn(partitionBy.name())) { + } else if (schema.isKeyColumn(partitionColumn.name())) { keyField = sourceNode.getKeyField(); } else { - keyField = KeyField.of(partitionBy); + keyField = KeyField.of(partitionColumn); } return new RepartitionNode( @@ -226,6 +230,7 @@ private static RepartitionNode buildRepartitionNode( sourceNode, partitionBy, keyField); + } private FlatMapNode buildFlatMapNode(final PlanNode sourcePlanNode) { diff --git a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/JoinNode.java b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/JoinNode.java index ce10c85369c7..897479977e7b 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/JoinNode.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/JoinNode.java @@ -20,6 +20,7 @@ import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.context.QueryContext.Stacker; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.execution.streams.JoinParamsFactory; import io.confluent.ksql.metastore.model.DataSource.DataSourceType; @@ -285,7 +286,7 @@ static SchemaKStream maybeRePartitionByKey( final ColumnRef joinFieldName, final Stacker contextStacker ) { - return stream.selectKey(joinFieldName, contextStacker); + return stream.selectKey(new ColumnReferenceExp(joinFieldName), contextStacker); } static ValueFormat getFormatForSource(final DataSourceNode sourceNode) { diff --git a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/RepartitionNode.java b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/RepartitionNode.java index 767faf0719c3..d80bba1dd8e3 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/RepartitionNode.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/RepartitionNode.java @@ -15,13 +15,13 @@ package io.confluent.ksql.planner.plan; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.metastore.model.KeyField; -import io.confluent.ksql.name.SourceName; -import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.services.KafkaTopicClient; import io.confluent.ksql.structured.SchemaKStream; @@ -32,14 +32,18 @@ public class RepartitionNode extends PlanNode { private final PlanNode source; - private final ColumnRef partitionBy; + private final Expression partitionBy; private final KeyField keyField; - public RepartitionNode(PlanNodeId id, PlanNode source, ColumnRef partitionBy, KeyField keyField) { + public RepartitionNode( + PlanNodeId id, + PlanNode source, + Expression partitionBy, + KeyField keyField + ) { super(id, source.getNodeOutputType()); - final SourceName alias = source.getTheSourceNode().getAlias(); this.source = Objects.requireNonNull(source, "source"); - this.partitionBy = Objects.requireNonNull(partitionBy, "partitionBy").withSource(alias); + this.partitionBy = Objects.requireNonNull(partitionBy, "partitionBy"); this.keyField = Objects.requireNonNull(keyField, "keyField"); } @@ -73,4 +77,9 @@ public SchemaKStream buildStream(KsqlQueryBuilder builder) { return source.buildStream(builder) .selectKey(partitionBy, builder.buildNodeContext(getId().toString())); } + + @VisibleForTesting + public Expression getPartitionBy() { + return partitionBy; + } } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java index 097c0670b15a..ffb3d8897bbc 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java @@ -45,12 +45,14 @@ import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.ColumnRef; +import io.confluent.ksql.schema.ksql.FormatOptions; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.serde.KeyFormat; import io.confluent.ksql.serde.SerdeOption; import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.util.KsqlConfig; +import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.SchemaUtil; import java.util.List; import java.util.Objects; @@ -321,56 +323,65 @@ public SchemaKStream outerJoin( @SuppressWarnings("unchecked") public SchemaKStream selectKey( - final ColumnRef columnRef, + final Expression keyExpression, final QueryContext.Stacker contextStacker ) { if (keyFormat.isWindowed()) { throw new UnsupportedOperationException("Can not selectKey of windowed stream"); } - final Optional existingKey = keyField.resolve(getSchema()); - - final Column proposedKey = getSchema().findValueColumn(columnRef) - .orElseThrow(IllegalArgumentException::new); - - final KeyField resultantKeyField = isRowKey(columnRef) - ? keyField - : KeyField.of(columnRef); - - final boolean namesMatch = existingKey - .map(kf -> kf.ref().equals(proposedKey.ref())) - .orElse(false); - - if (namesMatch || isRowKey(proposedKey.ref())) { - return (SchemaKStream) new SchemaKStream<>( - sourceStep, - schema, - keyFormat, - resultantKeyField, - ksqlConfig, - functionRegistry - ); + if (!needsRepartition(keyExpression)) { + return (SchemaKStream) this; } - final KeyField newKeyField = getSchema().isMetaColumn(columnRef.name()) - ? KeyField.none() - : resultantKeyField; - final StreamSelectKey step = ExecutionStepFactory.streamSelectKey( contextStacker, sourceStep, - columnRef + keyExpression ); + return new SchemaKStream<>( step, resolveSchema(step), keyFormat, - newKeyField, + getNewKeyField(keyExpression), ksqlConfig, functionRegistry ); } + private KeyField getNewKeyField(final Expression expression) { + if (!(expression instanceof ColumnReferenceExp)) { + return KeyField.none(); + } + + final ColumnRef columnRef = ((ColumnReferenceExp) expression).getReference(); + final KeyField newKeyField = isRowKey(columnRef) ? keyField : KeyField.of(columnRef); + return getSchema().isMetaColumn(columnRef.name()) ? KeyField.none() : newKeyField; + } + + private boolean needsRepartition(final Expression expression) { + if (!(expression instanceof ColumnReferenceExp)) { + return true; + } + + final ColumnRef columnRef = ((ColumnReferenceExp) expression).getReference(); + final Optional existingKey = keyField.resolve(getSchema()); + + final Column proposedKey = getSchema() + .findValueColumn(columnRef) + .orElseThrow(() -> new KsqlException("Invalid identifier for PARTITION BY clause: '" + + columnRef.name().toString(FormatOptions.noEscape()) + "' Only columns from the " + + "source schema can be referenced in the PARTITION BY clause.")); + + + final boolean namesMatch = existingKey + .map(kf -> kf.ref().equals(proposedKey.ref())) + .orElse(false); + + return !namesMatch && !isRowKey(columnRef); + } + private static boolean isRowKey(final ColumnRef fieldName) { return fieldName.name().equals(SchemaUtil.ROWKEY_NAME); } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/analyzer/PullQueryValidatorTest.java b/ksql-engine/src/test/java/io/confluent/ksql/analyzer/PullQueryValidatorTest.java index 5c4e78d702e0..37db1c22f29a 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/analyzer/PullQueryValidatorTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/analyzer/PullQueryValidatorTest.java @@ -20,6 +20,7 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.analyzer.Analysis.Into; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.parser.tree.ResultMaterialization; @@ -126,7 +127,7 @@ public void shouldThrowOnGroupBy() { public void shouldThrowOnPartitionBy() { // Given: when(analysis.getPartitionBy()) - .thenReturn(Optional.of(ColumnRef.withoutSource(ColumnName.of("Something")))); + .thenReturn(Optional.of(new ColumnReferenceExp(ColumnRef.withoutSource(ColumnName.of("Something"))))); // Then: expectedException.expect(KsqlException.class); diff --git a/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/StatementRewriterTest.java b/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/StatementRewriterTest.java index 869f30f9e4ef..879207de8f15 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/StatementRewriterTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/StatementRewriterTest.java @@ -277,10 +277,11 @@ public void shouldRewriteQueryWithGroupBy() { } @Test - public void shouldNotRewriteQueryWithPartitionBy() { + public void shouldRewriteQueryWithPartitionBy() { // Given: final Query query = givenQuery(Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(expression), Optional.empty()); + when(expressionRewriter.apply(expression, context)).thenReturn(rewrittenExpression); // When: final AstNode rewritten = rewriter.rewrite(query, context); @@ -293,7 +294,7 @@ public void shouldNotRewriteQueryWithPartitionBy() { Optional.empty(), Optional.empty(), Optional.empty(), - Optional.of(expression), + Optional.of(rewrittenExpression), Optional.empty(), resultMaterialization, false, diff --git a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java index e2bc3c93c53f..b10278e272c9 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java @@ -34,6 +34,7 @@ import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.ddl.commands.KsqlTopic; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.streams.KSPlanBuilder; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.InternalFunctionRegistry; @@ -878,7 +879,7 @@ public void shouldSelectLeftKeyField() { // Then: verify(leftSchemaKStream).selectKey( - eq(LEFT_JOIN_FIELD_REF), + eq(new ColumnReferenceExp(LEFT_JOIN_FIELD_REF)), any() ); } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java index e83a44daef71..d3682df93790 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java @@ -49,6 +49,7 @@ import io.confluent.ksql.planner.plan.FilterNode; import io.confluent.ksql.planner.plan.PlanNode; import io.confluent.ksql.planner.plan.ProjectNode; +import io.confluent.ksql.planner.plan.RepartitionNode; import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; @@ -59,6 +60,7 @@ import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.testutils.AnalysisTestUtil; import io.confluent.ksql.util.KsqlConfig; +import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.MetaStoreFixture; import io.confluent.ksql.util.Pair; import java.time.Duration; @@ -198,6 +200,93 @@ public void shouldBuildSchemaForSelect() { ); } + @Test + public void shouldNotRepartitionIfSameKeyField() { + // Given: + final PlanNode logicalPlan = givenInitialKStreamOf( + "SELECT col0, col2, col3 FROM test1 PARTITION BY col0 EMIT CHANGES;"); + final RepartitionNode repartitionNode = (RepartitionNode) logicalPlan.getSources().get(0).getSources().get(0); + + // When: + final SchemaKStream result = initialSchemaKStream + .selectKey(repartitionNode.getPartitionBy(), childContextStacker); + + // Then: + assertThat(result, is(initialSchemaKStream)); + } + + @Test + public void shouldNotRepartitionIfRowkey() { + // Given: + final PlanNode logicalPlan = givenInitialKStreamOf( + "SELECT col0, col2, col3 FROM test1 PARTITION BY ROWKEY EMIT CHANGES;"); + final RepartitionNode repartitionNode = (RepartitionNode) logicalPlan.getSources().get(0).getSources().get(0); + + // When: + final SchemaKStream result = initialSchemaKStream + .selectKey(repartitionNode.getPartitionBy(), childContextStacker); + + // Then: + assertThat(result, is(initialSchemaKStream)); + } + + @Test + public void shouldUpdateKeyOnPartitionByColumn() { + // Given: + final PlanNode logicalPlan = givenInitialKStreamOf( + "SELECT col0, col2, col3 FROM test1 PARTITION BY col2 EMIT CHANGES;"); + final RepartitionNode repartitionNode = (RepartitionNode) logicalPlan.getSources().get(0).getSources().get(0); + + // When: + final SchemaKStream result = initialSchemaKStream + .selectKey(repartitionNode.getPartitionBy(), childContextStacker); + + // Then: + assertThat(result.getKeyField(), + is(KeyField.of(ColumnRef.of(SourceName.of("TEST1"), ColumnName.of("COL2"))))); + } + + @Test + public void shouldUpdateKeyToNoneOnPartitionByMetaColumn() { + // Given: + final PlanNode logicalPlan = givenInitialKStreamOf( + "SELECT col0, col2, col3 FROM test1 PARTITION BY ROWTIME EMIT CHANGES;"); + final RepartitionNode repartitionNode = (RepartitionNode) logicalPlan.getSources().get(0).getSources().get(0); + + // When: + final SchemaKStream result = initialSchemaKStream + .selectKey(repartitionNode.getPartitionBy(), childContextStacker); + + // Then: + assertThat(result.getKeyField(), is(KeyField.none())); + } + + @Test + public void shouldUpdateKeyToNoneOnPartitionByExpression() { + // Given: + final PlanNode logicalPlan = givenInitialKStreamOf( + "SELECT col0, col2, col3 FROM test1 PARTITION BY col2 + 'foo' EMIT CHANGES;"); + final RepartitionNode repartitionNode = (RepartitionNode) logicalPlan.getSources().get(0).getSources().get(0); + + // When: + final SchemaKStream result = initialSchemaKStream + .selectKey(repartitionNode.getPartitionBy(), childContextStacker); + + // Then: + assertThat(result.getKeyField(), is(KeyField.none())); + } + + @Test(expected = KsqlException.class) + public void shouldThrowOnRepartitionByMissingField() { + // Given: + final PlanNode logicalPlan = givenInitialKStreamOf( + "SELECT col0, col2, col3 FROM test1 PARTITION BY not_here EMIT CHANGES;"); + final RepartitionNode repartitionNode = (RepartitionNode) logicalPlan.getSources().get(0).getSources().get(0); + + // When: + initialSchemaKStream.selectKey(repartitionNode.getPartitionBy(), childContextStacker); + } + @Test public void shouldUpdateKeyIfRenamed() { // Given: @@ -423,7 +512,7 @@ public void shouldSelectKey() { // When: final SchemaKStream rekeyedSchemaKStream = initialSchemaKStream.selectKey( - ColumnRef.of(SourceName.of("TEST1"), ColumnName.of("COL1")), + new ColumnReferenceExp(ColumnRef.of(SourceName.of("TEST1"), ColumnName.of("COL1"))), childContextStacker); // Then: @@ -437,7 +526,7 @@ public void shouldBuildStepForSelectKey() { // When: final SchemaKStream rekeyedSchemaKStream = initialSchemaKStream.selectKey( - ColumnRef.of(SourceName.of("TEST1"), ColumnName.of("COL1")), + new ColumnReferenceExp(ColumnRef.of(SourceName.of("TEST1"), ColumnName.of("COL1"))), childContextStacker); // Then: @@ -447,7 +536,7 @@ public void shouldBuildStepForSelectKey() { ExecutionStepFactory.streamSelectKey( childContextStacker, initialSchemaKStream.getSourceStep(), - ColumnRef.of(SourceName.of("TEST1"), ColumnName.of("COL1")) + new ColumnReferenceExp(ColumnRef.of(SourceName.of("TEST1"), ColumnName.of("COL1"))) ) ) ); @@ -460,7 +549,7 @@ public void shouldBuildSchemaForSelectKey() { // When: final SchemaKStream rekeyedSchemaKStream = initialSchemaKStream.selectKey( - ColumnRef.of(SourceName.of("TEST1"), ColumnName.of("COL1")), + new ColumnReferenceExp(ColumnRef.of(SourceName.of("TEST1"), ColumnName.of("COL1"))), childContextStacker); // Then: @@ -471,12 +560,12 @@ public void shouldBuildSchemaForSelectKey() { ); } - @Test(expected = IllegalArgumentException.class) + @Test(expected = KsqlException.class) public void shouldThrowOnSelectKeyIfKeyNotInSchema() { givenInitialKStreamOf("SELECT col0, col2, col3 FROM test1 WHERE col0 > 100 EMIT CHANGES;"); final SchemaKStream rekeyedSchemaKStream = initialSchemaKStream.selectKey( - ColumnRef.withoutSource(ColumnName.of("won't find me")), + new ColumnReferenceExp(ColumnRef.withoutSource(ColumnName.of("won't find me"))), childContextStacker); assertThat(rekeyedSchemaKStream.getKeyField(), is(validJoinKeyField)); diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamSelectKey.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamSelectKey.java index 729ce6ac025a..0a4c70711da2 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamSelectKey.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamSelectKey.java @@ -17,7 +17,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.errorprone.annotations.Immutable; -import io.confluent.ksql.schema.ksql.ColumnRef; +import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.testing.EffectivelyImmutable; import java.util.Collections; import java.util.List; @@ -28,7 +28,7 @@ public class StreamSelectKey implements ExecutionStep> { private final ExecutionStepPropertiesV1 properties; - private final ColumnRef fieldName; + private final Expression keyExpression; @EffectivelyImmutable private final ExecutionStep> source; @@ -36,10 +36,11 @@ public StreamSelectKey( @JsonProperty(value = "properties", required = true) ExecutionStepPropertiesV1 properties, @JsonProperty(value = "source", required = true) ExecutionStep> source, - @JsonProperty(value = "fieldName", required = true) ColumnRef fieldName) { + @JsonProperty(value = "keyExpression", required = true) Expression keyExpression + ) { this.properties = Objects.requireNonNull(properties, "properties"); this.source = Objects.requireNonNull(source, "source"); - this.fieldName = Objects.requireNonNull(fieldName, "fieldName"); + this.keyExpression = Objects.requireNonNull(keyExpression, "keyExpression"); } @Override @@ -53,8 +54,8 @@ public List> getSources() { return Collections.singletonList(source); } - public ColumnRef getFieldName() { - return fieldName; + public Expression getKeyExpression() { + return keyExpression; } public ExecutionStep> getSource() { diff --git a/ksql-functional-tests/src/test/resources/expected_topology/0_6_0-pre/key-field_-_stream___initially_set___partition_by_(different)___key_in_value___no_aliasing b/ksql-functional-tests/src/test/resources/expected_topology/0_6_0-pre/key-field_-_stream___initially_set___partition_by_(different)___key_in_value___no_aliasing deleted file mode 100644 index 5f8e067ef5ba..000000000000 --- a/ksql-functional-tests/src/test/resources/expected_topology/0_6_0-pre/key-field_-_stream___initially_set___partition_by_(different)___key_in_value___no_aliasing +++ /dev/null @@ -1,63 +0,0 @@ -{ - "ksql.extension.dir" : "ext", - "ksql.streams.cache.max.bytes.buffering" : "0", - "ksql.security.extension.class" : null, - "ksql.transient.prefix" : "transient_", - "ksql.persistence.wrap.single.values" : "true", - "ksql.schema.registry.url" : "http://localhost:8081", - "ksql.streams.default.deserialization.exception.handler" : "io.confluent.ksql.errors.LogMetricAndContinueExceptionHandler", - "ksql.output.topic.name.prefix" : "", - "ksql.streams.auto.offset.reset" : "earliest", - "ksql.connect.url" : "http://localhost:8083", - "ksql.service.id" : "some.ksql.service.id", - "ksql.internal.topic.min.insync.replicas" : "1", - "ksql.internal.topic.replicas" : "1", - "ksql.insert.into.values.enabled" : "true", - "ksql.query.pull.enable" : "true", - "ksql.streams.default.production.exception.handler" : "io.confluent.ksql.errors.ProductionExceptionHandlerUtil$LogAndFailProductionExceptionHandler", - "ksql.access.validator.enable" : "auto", - "ksql.streams.bootstrap.servers" : "localhost:0", - "ksql.streams.commit.interval.ms" : "2000", - "ksql.metric.reporters" : "", - "ksql.streams.auto.commit.interval.ms" : "0", - "ksql.metrics.extension" : null, - "ksql.streams.topology.optimization" : "all", - "ksql.query.pull.streamsstore.rebalancing.timeout.ms" : "10000", - "ksql.streams.num.stream.threads" : "4", - "ksql.metrics.tags.custom" : "", - "ksql.udfs.enabled" : "true", - "ksql.udf.enable.security.manager" : "true", - "ksql.query.pull.skip.access.validator" : "false", - "ksql.connect.worker.config" : "", - "ksql.query.pull.routing.timeout.ms" : "30000", - "ksql.sink.window.change.log.additional.retention" : "1000000", - "ksql.udf.collect.metrics" : "false", - "ksql.persistent.prefix" : "query_", - "ksql.query.persistent.active.limit" : "2147483647" -} -CONFIGS_END -CSAS_OUTPUT_0.KsqlTopic.Source = STRUCT NOT NULL -CSAS_OUTPUT_0.OUTPUT = STRUCT NOT NULL -SCHEMAS_END -Topologies: - Sub-topology: 0 - Source: KSTREAM-SOURCE-0000000000 (topics: [input_topic]) - --> KSTREAM-TRANSFORMVALUES-0000000001 - Processor: KSTREAM-TRANSFORMVALUES-0000000001 (stores: []) - --> KSTREAM-FILTER-0000000002 - <-- KSTREAM-SOURCE-0000000000 - Processor: KSTREAM-FILTER-0000000002 (stores: []) - --> KSTREAM-KEY-SELECT-0000000003 - <-- KSTREAM-TRANSFORMVALUES-0000000001 - Processor: KSTREAM-KEY-SELECT-0000000003 (stores: []) - --> Project - <-- KSTREAM-FILTER-0000000002 - Processor: Project (stores: []) - --> KSTREAM-MAPVALUES-0000000005 - <-- KSTREAM-KEY-SELECT-0000000003 - Processor: KSTREAM-MAPVALUES-0000000005 (stores: []) - --> KSTREAM-SINK-0000000006 - <-- Project - Sink: KSTREAM-SINK-0000000006 (topic: OUTPUT) - <-- KSTREAM-MAPVALUES-0000000005 - diff --git a/ksql-functional-tests/src/test/resources/expected_topology/0_6_0-pre/key-field_-_stream___initially_set___partition_by_(same)___key_in_value___no_aliasing b/ksql-functional-tests/src/test/resources/expected_topology/0_6_0-pre/key-field_-_stream___initially_set___partition_by_(same)___key_in_value___no_aliasing deleted file mode 100644 index 84cb68985e53..000000000000 --- a/ksql-functional-tests/src/test/resources/expected_topology/0_6_0-pre/key-field_-_stream___initially_set___partition_by_(same)___key_in_value___no_aliasing +++ /dev/null @@ -1,57 +0,0 @@ -{ - "ksql.extension.dir" : "ext", - "ksql.streams.cache.max.bytes.buffering" : "0", - "ksql.security.extension.class" : null, - "ksql.transient.prefix" : "transient_", - "ksql.persistence.wrap.single.values" : "true", - "ksql.schema.registry.url" : "http://localhost:8081", - "ksql.streams.default.deserialization.exception.handler" : "io.confluent.ksql.errors.LogMetricAndContinueExceptionHandler", - "ksql.output.topic.name.prefix" : "", - "ksql.streams.auto.offset.reset" : "earliest", - "ksql.connect.url" : "http://localhost:8083", - "ksql.service.id" : "some.ksql.service.id", - "ksql.internal.topic.min.insync.replicas" : "1", - "ksql.internal.topic.replicas" : "1", - "ksql.insert.into.values.enabled" : "true", - "ksql.query.pull.enable" : "true", - "ksql.streams.default.production.exception.handler" : "io.confluent.ksql.errors.ProductionExceptionHandlerUtil$LogAndFailProductionExceptionHandler", - "ksql.access.validator.enable" : "auto", - "ksql.streams.bootstrap.servers" : "localhost:0", - "ksql.streams.commit.interval.ms" : "2000", - "ksql.metric.reporters" : "", - "ksql.streams.auto.commit.interval.ms" : "0", - "ksql.metrics.extension" : null, - "ksql.streams.topology.optimization" : "all", - "ksql.query.pull.streamsstore.rebalancing.timeout.ms" : "10000", - "ksql.streams.num.stream.threads" : "4", - "ksql.metrics.tags.custom" : "", - "ksql.udfs.enabled" : "true", - "ksql.udf.enable.security.manager" : "true", - "ksql.query.pull.skip.access.validator" : "false", - "ksql.connect.worker.config" : "", - "ksql.query.pull.routing.timeout.ms" : "30000", - "ksql.sink.window.change.log.additional.retention" : "1000000", - "ksql.udf.collect.metrics" : "false", - "ksql.persistent.prefix" : "query_", - "ksql.query.persistent.active.limit" : "2147483647" -} -CONFIGS_END -CSAS_OUTPUT_0.KsqlTopic.Source = STRUCT NOT NULL -CSAS_OUTPUT_0.OUTPUT = STRUCT NOT NULL -SCHEMAS_END -Topologies: - Sub-topology: 0 - Source: KSTREAM-SOURCE-0000000000 (topics: [input_topic]) - --> KSTREAM-TRANSFORMVALUES-0000000001 - Processor: KSTREAM-TRANSFORMVALUES-0000000001 (stores: []) - --> Project - <-- KSTREAM-SOURCE-0000000000 - Processor: Project (stores: []) - --> KSTREAM-MAPVALUES-0000000003 - <-- KSTREAM-TRANSFORMVALUES-0000000001 - Processor: KSTREAM-MAPVALUES-0000000003 (stores: []) - --> KSTREAM-SINK-0000000004 - <-- Project - Sink: KSTREAM-SINK-0000000004 (topic: OUTPUT) - <-- KSTREAM-MAPVALUES-0000000003 - diff --git a/ksql-functional-tests/src/test/resources/query-validation-tests/key-field.json b/ksql-functional-tests/src/test/resources/query-validation-tests/key-field.json index 67908749854c..a28442df5e57 100644 --- a/ksql-functional-tests/src/test/resources/query-validation-tests/key-field.json +++ b/ksql-functional-tests/src/test/resources/query-validation-tests/key-field.json @@ -705,12 +705,16 @@ "CREATE STREAM INPUT (foo INT, bar INT) WITH (kafka_topic='input_topic', key='foo', value_format='JSON');", "CREATE STREAM OUTPUT AS SELECT foo + bar FROM INPUT PARTITION BY foo + bar;" ], - "comment": [ - "This test is present so that it fails if/when we support PARTITION BY on multiple fields.", - "If/when we do, this test will fail to remind us to add tests to cover keyFields for new functionality"], - "expectedException": { - "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "line 2:70: mismatched input '+' expecting ';'" + "inputs": [ + {"topic": "input_topic", "key": "1", "value": {"foo": 1, "bar": 2}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": "3", "value": {"KSQL_COL_0": 3}} + ], + "post": { + "sources": [ + {"name": "OUTPUT", "type": "stream", "keyField": null} + ] } }, { diff --git a/ksql-functional-tests/src/test/resources/query-validation-tests/partition-by.json b/ksql-functional-tests/src/test/resources/query-validation-tests/partition-by.json index 568df00e5aa7..286470ccb07e 100644 --- a/ksql-functional-tests/src/test/resources/query-validation-tests/partition-by.json +++ b/ksql-functional-tests/src/test/resources/query-validation-tests/partition-by.json @@ -131,6 +131,31 @@ ], "inputs": [{"topic": "input", "value": {"ID": 22}, "timestamp": 10}], "outputs": [{"topic": "OUTPUT", "key": "10", "value": {"ID": 22}, "timestamp": 10}] + }, + { + "name": "partition by ROWKEY in join on non-ROWKEY", + "statements": [ + "CREATE STREAM L (A STRING, B STRING) WITH (kafka_topic='LEFT', value_format='JSON', KEY='A');", + "CREATE STREAM R (C STRING, D STRING) WITH (kafka_topic='RIGHT', value_format='JSON', KEY='C');", + "CREATE STREAM OUTPUT AS SELECT L.A, L.B, R.C, R.D, L.ROWKEY, R.ROWKEY FROM L JOIN R WITHIN 10 SECONDS ON L.B = R.D PARTITION BY L.ROWKEY;" + ], + "comments": [ + "This test demonstrates a problem when we JOIN on a non-ROWKEY field and then PARTITION BY ", + "a ROWKEY field. Note that the key is 'join' when it should be 'a' and the key-field is 'B' ", + "when it should be 'L_ROWKEY'" + ], + "inputs": [ + {"topic": "LEFT", "key": "a", "value": {"A": "a", "B": "join"}}, + {"topic": "RIGHT", "key": "c", "value": {"C": "c", "D": "join"}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": "join", "value": {"A": "a", "B": "join", "C": "c", "D": "join", "L_ROWKEY": "a", "R_ROWKEY": "c"}} + ], + "post": { + "sources": [ + {"name": "OUTPUT", "type": "stream", "keyField": "B"} + ] + } } ] } diff --git a/ksql-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 b/ksql-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 index 11dd95284393..aa85babceef3 100644 --- a/ksql-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 +++ b/ksql-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 @@ -77,7 +77,7 @@ query (WINDOW windowExpression)? (WHERE where=booleanExpression)? (GROUP BY groupBy)? - (PARTITION BY partitionBy=identifier)? + (PARTITION BY partitionBy=booleanExpression)? (HAVING having=booleanExpression)? (EMIT resultMaterialization)? limitClause? diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java b/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java index e802934d7c82..8500c005f021 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java @@ -404,7 +404,7 @@ public Query visitQuery(final SqlBaseParser.QueryContext context) { visitIfPresent(context.windowExpression(), WindowExpression.class), visitIfPresent(context.where, Expression.class), visitIfPresent(context.groupBy(), GroupBy.class), - getPartitionBy(context.partitionBy), + visitIfPresent(context.partitionBy, Expression.class), visitIfPresent(context.having, Expression.class), resultMaterialization, pullQuery, @@ -1121,19 +1121,6 @@ private List visit(final List contexts, .collect(toList()); } - private static Optional getPartitionBy( - final SqlBaseParser.IdentifierContext identifier - ) { - if (identifier == null) { - return Optional.empty(); - } - - final Optional location = getLocation(identifier); - final ColumnRef name = ColumnRef.withoutSource( - ColumnName.of(ParserUtil.getIdentifierText(identifier))); - return Optional.of(new ColumnReferenceExp(location, name)); - } - private static Operator getArithmeticBinaryOperator( final Token operator) { switch (operator.getType()) { diff --git a/ksql-rest-app/src/test/resources/ksql-plan-schema/schema.json b/ksql-rest-app/src/test/resources/ksql-plan-schema/schema.json index 9ab9c777b3fb..36add88927be 100644 --- a/ksql-rest-app/src/test/resources/ksql-plan-schema/schema.json +++ b/ksql-rest-app/src/test/resources/ksql-plan-schema/schema.json @@ -420,12 +420,12 @@ "source" : { "$ref" : "#/definitions/ExecutionStep" }, - "fieldName" : { + "keyExpression" : { "type" : "string" } }, "title" : "streamSelectKeyV1", - "required" : [ "@type", "properties", "source", "fieldName" ] + "required" : [ "@type", "properties", "source", "keyExpression" ] }, "StreamSink" : { "type" : "object", diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java index 8e8d1cad4c1c..c7ea805c519f 100644 --- a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java @@ -51,7 +51,6 @@ import io.confluent.ksql.execution.timestamp.TimestampColumn; import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.name.SourceName; -import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.serde.WindowInfo; import java.time.Duration; @@ -252,7 +251,7 @@ public static StreamStreamJoin streamStreamJoin( public static StreamSelectKey streamSelectKey( final QueryContext.Stacker stacker, final ExecutionStep> source, - final ColumnRef fieldName + final Expression fieldName ) { final QueryContext queryContext = stacker.getQueryContext(); return new StreamSelectKey(new ExecutionStepPropertiesV1(queryContext), source, fieldName); diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamSelectKeyBuilder.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamSelectKeyBuilder.java index fdd91a2c38e0..e3d8b2ca1a38 100644 --- a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamSelectKeyBuilder.java +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamSelectKeyBuilder.java @@ -17,58 +17,46 @@ import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.codegen.CodeGenRunner; +import io.confluent.ksql.execution.codegen.ExpressionMetadata; import io.confluent.ksql.execution.plan.KStreamHolder; import io.confluent.ksql.execution.plan.KeySerdeFactory; import io.confluent.ksql.execution.plan.StreamSelectKey; import io.confluent.ksql.execution.util.StructKeyUtil; -import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.LogicalSchema; import org.apache.kafka.connect.data.Struct; import org.apache.kafka.streams.kstream.KStream; public final class StreamSelectKeyBuilder { + + private static final String EXP_TYPE = "SelectKey"; + private StreamSelectKeyBuilder() { } public static KStreamHolder build( final KStreamHolder stream, final StreamSelectKey selectKey, - final KsqlQueryBuilder queryBuilder) { + final KsqlQueryBuilder queryBuilder + ) { final LogicalSchema sourceSchema = stream.getSchema(); + final CodeGenRunner codeGen = new CodeGenRunner( + sourceSchema, + queryBuilder.getKsqlConfig(), + queryBuilder.getFunctionRegistry()); - final Column keyColumn = sourceSchema.findValueColumn(selectKey.getFieldName()) - .orElseThrow(IllegalArgumentException::new); - - final int keyIndexInValue = keyColumn.index(); + final ExpressionMetadata expression = + codeGen.buildCodeGenFromParseTree(selectKey.getKeyExpression(), EXP_TYPE); final KStream kstream = stream.getStream(); final KStream rekeyed = kstream - .filter((key, value) -> - value != null && extractColumn(sourceSchema, keyIndexInValue, value) != null - ).selectKey((key, value) -> - StructKeyUtil.asStructKey( - extractColumn(sourceSchema, keyIndexInValue, value).toString() - ) - ); + .filter((key, val) -> val != null && expression.evaluate(val) != null) + .selectKey((key, val) -> StructKeyUtil.asStructKey(expression.evaluate(val).toString())); + return new KStreamHolder<>( rekeyed, stream.getSchema(), KeySerdeFactory.unwindowed(queryBuilder) ); } - - private static Object extractColumn( - final LogicalSchema schema, - final int keyIndexInValue, - final GenericRow value - ) { - if (value.getColumns().size() != schema.value().size()) { - throw new IllegalStateException("Field count mismatch. " - + "Schema fields: " + schema - + ", row:" + value); - } - return value - .getColumns() - .get(keyIndexInValue); - } } diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StepSchemaResolverTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StepSchemaResolverTest.java index b2ac747c732f..f71d60f25031 100644 --- a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StepSchemaResolverTest.java +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StepSchemaResolverTest.java @@ -264,7 +264,7 @@ public void shouldResolveSchemaForStreamSelectKey() { final StreamSelectKey step = new StreamSelectKey( PROPERTIES, streamSource, - mock(ColumnRef.class) + mock(ColumnReferenceExp.class) ); // When: diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamSelectKeyBuilderTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamSelectKeyBuilderTest.java index d6fd9c762f22..dc98fa3a354a 100644 --- a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamSelectKeyBuilderTest.java +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamSelectKeyBuilderTest.java @@ -23,15 +23,18 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableMap; import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.plan.ExecutionStep; import io.confluent.ksql.execution.plan.ExecutionStepPropertiesV1; import io.confluent.ksql.execution.plan.KStreamHolder; import io.confluent.ksql.execution.plan.KeySerdeFactory; import io.confluent.ksql.execution.plan.PlanBuilder; import io.confluent.ksql.execution.plan.StreamSelectKey; +import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.SourceName; import io.confluent.ksql.query.QueryId; @@ -42,6 +45,7 @@ import io.confluent.ksql.serde.Format; import io.confluent.ksql.serde.FormatInfo; import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.util.KsqlConfig; import org.apache.kafka.connect.data.Struct; import org.apache.kafka.streams.kstream.KStream; import org.apache.kafka.streams.kstream.KeyValueMapper; @@ -65,7 +69,8 @@ public class StreamSelectKeyBuilderTest { .build() .withAlias(ALIAS) .withMetaAndKeyColsInValue(); - private static final ColumnRef KEY = ColumnRef.of(SourceName.of("ATL"), ColumnName.of("BOI")); + private static final ColumnReferenceExp KEY = + new ColumnReferenceExp(ColumnRef.of(SourceName.of("ATL"), ColumnName.of("BOI"))); @Mock private KStream kstream; @@ -77,6 +82,8 @@ public class StreamSelectKeyBuilderTest { private ExecutionStep> sourceStep; @Mock private KsqlQueryBuilder queryBuilder; + @Mock + private FunctionRegistry functionRegistry; @Captor private ArgumentCaptor> predicateCaptor; @Captor @@ -96,6 +103,8 @@ public class StreamSelectKeyBuilderTest { @SuppressWarnings("unchecked") public void init() { when(queryBuilder.getQueryId()).thenReturn(new QueryId("hey")); + when(queryBuilder.getFunctionRegistry()).thenReturn(functionRegistry); + when(queryBuilder.getKsqlConfig()).thenReturn(new KsqlConfig(ImmutableMap.of())); when(sourceStep.getProperties()).thenReturn(properties); when(kstream.filter(any())).thenReturn(filteredKStream); when(filteredKStream.selectKey(any(KeyValueMapper.class))).thenReturn(rekeyedKstream);