Skip to content

Commit

Permalink
feat: expression support for PARTITION BY (#4032)
Browse files Browse the repository at this point in the history
* feat: expression support for PARTITION BY

* chore: fix qtt test

* test: add a false-positive test for the join pb rowkey issue
  • Loading branch information
agavra authored Dec 6, 2019
1 parent f90ac26 commit 0f31f8e
Show file tree
Hide file tree
Showing 23 changed files with 253 additions and 249 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public class Analysis implements ImmutableAnalysis {
private final Set<ColumnRef> selectColumnRefs = new HashSet<>();
private final List<Expression> groupByExpressions = new ArrayList<>();
private Optional<WindowExpression> windowExpression = Optional.empty();
private Optional<ColumnRef> partitionBy = Optional.empty();
private Optional<Expression> partitionBy = Optional.empty();
private ImmutableSet<SerdeOption> serdeOptions = ImmutableSet.of();
private Optional<Expression> havingExpression = Optional.empty();
private OptionalInt limitClause = OptionalInt.empty();
Expand Down Expand Up @@ -134,11 +134,11 @@ void setHavingExpression(final Expression havingExpression) {
this.havingExpression = Optional.of(havingExpression);
}

public Optional<ColumnRef> getPartitionBy() {
public Optional<Expression> getPartitionBy() {
return partitionBy;
}

void setPartitionBy(final ColumnRef partitionBy) {
void setPartitionBy(final Expression partitionBy) {
this.partitionBy = Optional.of(partitionBy);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -509,13 +509,7 @@ private void analyzeGroupBy(final GroupBy groupBy) {
}

private void analyzePartitionBy(final Expression partitionBy) {
if (partitionBy instanceof ColumnReferenceExp) {
analysis.setPartitionBy(((ColumnReferenceExp) partitionBy).getReference());
return;
}

throw new KsqlException(
"Expected partition by to be a valid column but got " + partitionBy);
analysis.setPartitionBy(partitionBy);
}

private void analyzeWindowExpression(final WindowExpression windowExpression) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,8 @@ protected AstNode visitQuery(final Query node, final C context) {
final Optional<GroupBy> groupBy = node.getGroupBy()
.map(exp -> ((GroupBy) rewriter.apply(exp, context)));

// don't rewrite the partitionBy because we expect it to be
// exactly as it was (a single, un-aliased, column reference)
final Optional<Expression> partitionBy = node.getPartitionBy();
final Optional<Expression> partitionBy = node.getPartitionBy()
.map(exp -> processExpression(exp, context));

final Optional<Expression> having = node.getHaving()
.map(exp -> (processExpression(exp, context)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import io.confluent.ksql.schema.ksql.LogicalSchema.Builder;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -204,28 +203,34 @@ private static FilterNode buildFilterNode(

private static RepartitionNode buildRepartitionNode(
final PlanNode sourceNode,
final ColumnRef partitionBy
final Expression partitionBy
) {
if (!sourceNode.getSchema().withoutAlias().findValueColumn(partitionBy).isPresent()) {
throw new KsqlException("Invalid identifier for PARTITION BY clause: '" + partitionBy
+ "'. Only columns from the source schema can be referenced in the PARTITION BY clause.");
if (!(partitionBy instanceof ColumnReferenceExp)) {
return new RepartitionNode(
new PlanNodeId("PartitionBy"),
sourceNode,
partitionBy,
KeyField.none());
}

final KeyField keyField;
final ColumnRef partitionColumn = ((ColumnReferenceExp) partitionBy).getReference();
final LogicalSchema schema = sourceNode.getSchema();
if (schema.isMetaColumn(partitionBy.name())) {

final KeyField keyField;
if (schema.isMetaColumn(partitionColumn.name())) {
keyField = KeyField.none();
} else if (schema.isKeyColumn(partitionBy.name())) {
} else if (schema.isKeyColumn(partitionColumn.name())) {
keyField = sourceNode.getKeyField();
} else {
keyField = KeyField.of(partitionBy);
keyField = KeyField.of(partitionColumn);
}

return new RepartitionNode(
new PlanNodeId("PartitionBy"),
sourceNode,
partitionBy,
keyField);

}

private FlatMapNode buildFlatMapNode(final PlanNode sourcePlanNode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.context.QueryContext;
import io.confluent.ksql.execution.context.QueryContext.Stacker;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.streams.JoinParamsFactory;
import io.confluent.ksql.metastore.model.DataSource.DataSourceType;
Expand Down Expand Up @@ -285,7 +286,7 @@ static <K> SchemaKStream<K> maybeRePartitionByKey(
final ColumnRef joinFieldName,
final Stacker contextStacker
) {
return stream.selectKey(joinFieldName, contextStacker);
return stream.selectKey(new ColumnReferenceExp(joinFieldName), contextStacker);
}

static ValueFormat getFormatForSource(final DataSourceNode sourceNode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

package io.confluent.ksql.planner.plan;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.Immutable;
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.metastore.model.KeyField;
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.services.KafkaTopicClient;
import io.confluent.ksql.structured.SchemaKStream;
Expand All @@ -32,14 +32,18 @@
public class RepartitionNode extends PlanNode {

private final PlanNode source;
private final ColumnRef partitionBy;
private final Expression partitionBy;
private final KeyField keyField;

public RepartitionNode(PlanNodeId id, PlanNode source, ColumnRef partitionBy, KeyField keyField) {
public RepartitionNode(
PlanNodeId id,
PlanNode source,
Expression partitionBy,
KeyField keyField
) {
super(id, source.getNodeOutputType());
final SourceName alias = source.getTheSourceNode().getAlias();
this.source = Objects.requireNonNull(source, "source");
this.partitionBy = Objects.requireNonNull(partitionBy, "partitionBy").withSource(alias);
this.partitionBy = Objects.requireNonNull(partitionBy, "partitionBy");
this.keyField = Objects.requireNonNull(keyField, "keyField");
}

Expand Down Expand Up @@ -73,4 +77,9 @@ public SchemaKStream<?> buildStream(KsqlQueryBuilder builder) {
return source.buildStream(builder)
.selectKey(partitionBy, builder.buildNodeContext(getId().toString()));
}

@VisibleForTesting
public Expression getPartitionBy() {
return partitionBy;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.schema.ksql.FormatOptions;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.serde.KeyFormat;
import io.confluent.ksql.serde.SerdeOption;
import io.confluent.ksql.serde.ValueFormat;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -321,56 +323,65 @@ public SchemaKStream<K> outerJoin(

@SuppressWarnings("unchecked")
public SchemaKStream<Struct> selectKey(
final ColumnRef columnRef,
final Expression keyExpression,
final QueryContext.Stacker contextStacker
) {
if (keyFormat.isWindowed()) {
throw new UnsupportedOperationException("Can not selectKey of windowed stream");
}

final Optional<Column> existingKey = keyField.resolve(getSchema());

final Column proposedKey = getSchema().findValueColumn(columnRef)
.orElseThrow(IllegalArgumentException::new);

final KeyField resultantKeyField = isRowKey(columnRef)
? keyField
: KeyField.of(columnRef);

final boolean namesMatch = existingKey
.map(kf -> kf.ref().equals(proposedKey.ref()))
.orElse(false);

if (namesMatch || isRowKey(proposedKey.ref())) {
return (SchemaKStream<Struct>) new SchemaKStream<>(
sourceStep,
schema,
keyFormat,
resultantKeyField,
ksqlConfig,
functionRegistry
);
if (!needsRepartition(keyExpression)) {
return (SchemaKStream<Struct>) this;
}

final KeyField newKeyField = getSchema().isMetaColumn(columnRef.name())
? KeyField.none()
: resultantKeyField;

final StreamSelectKey step = ExecutionStepFactory.streamSelectKey(
contextStacker,
sourceStep,
columnRef
keyExpression
);

return new SchemaKStream<>(
step,
resolveSchema(step),
keyFormat,
newKeyField,
getNewKeyField(keyExpression),
ksqlConfig,
functionRegistry
);
}

private KeyField getNewKeyField(final Expression expression) {
if (!(expression instanceof ColumnReferenceExp)) {
return KeyField.none();
}

final ColumnRef columnRef = ((ColumnReferenceExp) expression).getReference();
final KeyField newKeyField = isRowKey(columnRef) ? keyField : KeyField.of(columnRef);
return getSchema().isMetaColumn(columnRef.name()) ? KeyField.none() : newKeyField;
}

private boolean needsRepartition(final Expression expression) {
if (!(expression instanceof ColumnReferenceExp)) {
return true;
}

final ColumnRef columnRef = ((ColumnReferenceExp) expression).getReference();
final Optional<Column> existingKey = keyField.resolve(getSchema());

final Column proposedKey = getSchema()
.findValueColumn(columnRef)
.orElseThrow(() -> new KsqlException("Invalid identifier for PARTITION BY clause: '"
+ columnRef.name().toString(FormatOptions.noEscape()) + "' Only columns from the "
+ "source schema can be referenced in the PARTITION BY clause."));


final boolean namesMatch = existingKey
.map(kf -> kf.ref().equals(proposedKey.ref()))
.orElse(false);

return !namesMatch && !isRowKey(columnRef);
}

private static boolean isRowKey(final ColumnRef fieldName) {
return fieldName.name().equals(SchemaUtil.ROWKEY_NAME);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.analyzer.Analysis.Into;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.parser.tree.ResultMaterialization;
Expand Down Expand Up @@ -126,7 +127,7 @@ public void shouldThrowOnGroupBy() {
public void shouldThrowOnPartitionBy() {
// Given:
when(analysis.getPartitionBy())
.thenReturn(Optional.of(ColumnRef.withoutSource(ColumnName.of("Something"))));
.thenReturn(Optional.of(new ColumnReferenceExp(ColumnRef.withoutSource(ColumnName.of("Something")))));

// Then:
expectedException.expect(KsqlException.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,11 @@ public void shouldRewriteQueryWithGroupBy() {
}

@Test
public void shouldNotRewriteQueryWithPartitionBy() {
public void shouldRewriteQueryWithPartitionBy() {
// Given:
final Query query =
givenQuery(Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(expression), Optional.empty());
when(expressionRewriter.apply(expression, context)).thenReturn(rewrittenExpression);

// When:
final AstNode rewritten = rewriter.rewrite(query, context);
Expand All @@ -293,7 +294,7 @@ public void shouldNotRewriteQueryWithPartitionBy() {
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.of(expression),
Optional.of(rewrittenExpression),
Optional.empty(),
resultMaterialization,
false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.context.QueryContext;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.streams.KSPlanBuilder;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.function.InternalFunctionRegistry;
Expand Down Expand Up @@ -878,7 +879,7 @@ public void shouldSelectLeftKeyField() {

// Then:
verify(leftSchemaKStream).selectKey(
eq(LEFT_JOIN_FIELD_REF),
eq(new ColumnReferenceExp(LEFT_JOIN_FIELD_REF)),
any()
);
}
Expand Down
Loading

0 comments on commit 0f31f8e

Please sign in to comment.