Skip to content

Commit

Permalink
chore: add GROUP BY support for any key names
Browse files Browse the repository at this point in the history
fixes: #4898

This commit sees the result of a GROUP BY on a single column reference have a schema with a key column matching the name of the column, e.g.

```sql
-- source schema: A -> B, C
CREATE STREAM OUTPUT AS SELECT COUNT(1) AS COUNT FROM INPUT GROUP BY B;
-- output schema: B -> COUNT
```

If the GROUP BY is on anything other than a single column reference then the key column will be a unique generated column name, e.g.

```sql
-- source schema: A -> B, C
CREATE STREAM OUTPUT AS SELECT COUNT(1) FROM INPUT GROUP BY B+1;
-- output schema: KSQL_COL_1 -> KSQL_COL_0  (Both names are generated)
```

BREAKING CHANGE: Existing queries that reference a single GROUP BY column in the projection would fail if they were resubmitted, due to a duplicate column. The same existing queries will continue to run if already running, i.e. this is only a change for newly submitted queries. Existing queries will use the old query semantics.
  • Loading branch information
big-andy-coates committed Mar 26, 2020
1 parent 258d0b0 commit 1b5312d
Show file tree
Hide file tree
Showing 612 changed files with 61,362 additions and 279 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
Expand Down Expand Up @@ -511,30 +512,46 @@ private LogicalSchema buildAggregateSchema(
) {
final LogicalSchema sourceSchema = sourcePlanNode.getSchema();

final LogicalSchema projectionSchema = buildProjectionSchema(
sourceSchema
.withMetaAndKeyColsInValue(analysis.getWindowExpression().isPresent()),
projectionExpressions
);

final Supplier<ColumnName> keyColNameGen = ColumnNames
.columnAliasGenerator(Stream.of(sourceSchema, projectionSchema));

final ColumnName keyName;
final SqlType keyType;

if (groupByExps.size() != 1) {
keyName = SchemaUtil.ROWKEY_NAME;
if (ksqlConfig.getBoolean(KsqlConfig.KSQL_ANY_KEY_NAME_ENABLED)) {
keyName = keyColNameGen.get();
} else {
keyName = SchemaUtil.ROWKEY_NAME;
}
keyType = SqlTypes.STRING;
} else {
final Expression expression = groupByExps.get(0);

keyName = exactlyMatchesKeyColumns(expression, sourceSchema)
? ((ColumnReferenceExp) expression).getColumnName()
: SchemaUtil.ROWKEY_NAME;
if (ksqlConfig.getBoolean(KsqlConfig.KSQL_ANY_KEY_NAME_ENABLED)) {
if (expression instanceof ColumnReferenceExp) {
keyName = ((ColumnReferenceExp) expression).getColumnName();
} else {
keyName = keyColNameGen.get();
}
} else {
keyName = exactlyMatchesKeyColumns(expression, sourceSchema)
? ((ColumnReferenceExp) expression).getColumnName()
: SchemaUtil.ROWKEY_NAME;
}

final ExpressionTypeManager typeManager =
new ExpressionTypeManager(sourceSchema, functionRegistry);

keyType = typeManager.getExpressionSqlType(expression);
}

final LogicalSchema projectionSchema = buildProjectionSchema(
sourceSchema
.withMetaAndKeyColsInValue(analysis.getWindowExpression().isPresent()),
projectionExpressions
);

return LogicalSchema.builder()
.withRowTime()
.keyColumn(keyName, keyType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context;
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.context.QueryContext;
import io.confluent.ksql.execution.context.QueryContext.Stacker;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
Expand All @@ -50,12 +51,9 @@
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;


Expand All @@ -78,6 +76,7 @@ public class AggregateNode extends PlanNode {
private final ImmutableList<ColumnReferenceExp> requiredColumns;
private final Optional<Expression> havingExpressions;
private final ImmutableList<SelectExpression> finalSelectExpressions;
private final ValueFormat valueFormat;

public AggregateNode(
final PlanNodeId id,
Expand Down Expand Up @@ -117,6 +116,10 @@ public AggregateNode(
.map(exp -> ExpressionTreeRewriter.rewriteWith(aggregateExpressionRewriter::process, exp));
this.keyField = KeyField.of(requireNonNull(keyFieldName, "keyFieldName"))
.validateKeyExistsIn(schema);
this.valueFormat = getTheSourceNode()
.getDataSource()
.getKsqlTopic()
.getValueFormat();
}

@Override
Expand Down Expand Up @@ -157,158 +160,129 @@ public <C, R> R accept(final PlanVisitor<C, R> visitor, final C context) {
@Override
public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {
final QueryContext.Stacker contextStacker = builder.buildNodeContext(getId().toString());
final DataSourceNode streamSourceNode = getTheSourceNode();
final SchemaKStream<?> sourceSchemaKStream = getSource().buildStream(builder);

// Pre aggregate computations
final InternalSchema internalSchema = new InternalSchema(
getSchema(),
requiredColumns,
aggregateFunctionArguments
);
final InternalSchema internalSchema =
new InternalSchema(requiredColumns, aggregateFunctionArguments);

final SchemaKStream<?> aggregateArgExpanded = sourceSchemaKStream.select(
internalSchema.getAggArgExpansionList(),
contextStacker.push(PREPARE_OP_NAME),
builder
);
final SchemaKStream<?> preSelected =
selectRequiredInputColumns(sourceSchemaKStream, internalSchema, contextStacker, builder);

final QueryContext.Stacker groupByContext = contextStacker.push(GROUP_BY_OP_NAME);
final SchemaKGroupedStream grouped = groupBy(contextStacker, preSelected);

final ValueFormat valueFormat = streamSourceNode
.getDataSource()
.getKsqlTopic()
.getValueFormat();
SchemaKTable<?> aggregated = aggregate(grouped, internalSchema, contextStacker);

final List<Expression> internalGroupByColumns = internalSchema.resolveGroupByExpressions(
groupByExpressions,
aggregateArgExpanded
);
aggregated = applyHavingFilter(aggregated, contextStacker);

final SchemaKGroupedStream schemaKGroupedStream = aggregateArgExpanded.groupBy(
valueFormat,
internalGroupByColumns,
groupByContext
return selectRequiredOutputColumns(aggregated, contextStacker, builder);
}

protected int getPartitions(final KafkaTopicClient kafkaTopicClient) {
return source.getPartitions(kafkaTopicClient);
}

private static SchemaKStream<?> selectRequiredInputColumns(
final SchemaKStream<?> sourceSchemaKStream,
final InternalSchema internalSchema,
final Stacker contextStacker,
final KsqlQueryBuilder builder
) {
return sourceSchemaKStream.select(
internalSchema.getAggArgExpansionList(),
contextStacker.push(PREPARE_OP_NAME),
builder
);
}

final List<FunctionCall> functionsWithInternalIdentifiers = functionList.stream()
.map(
fc -> new FunctionCall(
fc.getName(),
internalSchema.getInternalArgsExpressionList(fc.getArguments())
)
)
.collect(Collectors.toList());
private SchemaKTable<?> aggregate(
final SchemaKGroupedStream grouped,
final InternalSchema internalSchema,
final Stacker contextStacker
) {
final List<FunctionCall> functions = internalSchema.updateFunctionList(functionList);

final QueryContext.Stacker aggregationContext = contextStacker.push(AGGREGATION_OP_NAME);
final Stacker aggregationContext = contextStacker.push(AGGREGATION_OP_NAME);

final List<ColumnName> requiredColumnNames = requiredColumns.stream()
.map(e -> (UnqualifiedColumnReferenceExp) internalSchema.resolveToInternal(e))
.map(UnqualifiedColumnReferenceExp::getColumnName)
.collect(Collectors.toList());

SchemaKTable<?> aggregated = schemaKGroupedStream.aggregate(
return grouped.aggregate(
requiredColumnNames,
functionsWithInternalIdentifiers,
functions,
windowExpression,
valueFormat,
aggregationContext
);
}

final Optional<Expression> havingExpression = havingExpressions
.map(internalSchema::resolveToInternal);

if (havingExpression.isPresent()) {
aggregated = aggregated.filter(
havingExpression.get(),
contextStacker.push(HAVING_FILTER_OP_NAME)
);
}

final List<SelectExpression> finalSelects = internalSchema
.updateFinalSelectExpressions(finalSelectExpressions);
private SchemaKTable<?> applyHavingFilter(
final SchemaKTable<?> aggregated,
final Stacker contextStacker
) {
return havingExpressions.isPresent()
? aggregated.filter(havingExpressions.get(), contextStacker.push(HAVING_FILTER_OP_NAME))
: aggregated;
}

return aggregated.select(
finalSelects,
contextStacker.push(PROJECT_OP_NAME),
builder
);
private SchemaKStream<?> selectRequiredOutputColumns(
final SchemaKTable<?> aggregated,
final Stacker contextStacker,
final KsqlQueryBuilder builder
) {
return aggregated
.select(finalSelectExpressions, contextStacker.push(PROJECT_OP_NAME), builder);
}

protected int getPartitions(final KafkaTopicClient kafkaTopicClient) {
return source.getPartitions(kafkaTopicClient);
private SchemaKGroupedStream groupBy(
final Stacker contextStacker,
final SchemaKStream<?> preSelected
) {
return preSelected
.groupBy(valueFormat, groupByExpressions, contextStacker.push(GROUP_BY_OP_NAME));
}

private static class InternalSchema {

private final Optional<ColumnName> singleKeyColumn;
private final List<SelectExpression> aggArgExpansions = new ArrayList<>();
private final Map<String, ColumnName> expressionToInternalColumnName = new HashMap<>();

InternalSchema(
final LogicalSchema schema,
final List<ColumnReferenceExp> requiredColumns,
final List<Expression> aggregateFunctionArguments
) {
this.singleKeyColumn = schema.key().size() == 1
? Optional.of(schema.key().get(0).name())
: Optional.empty();

final Set<String> seen = new HashSet<>();
collectAggregateArgExpressions(requiredColumns, seen);
collectAggregateArgExpressions(aggregateFunctionArguments, seen);
collectAggregateArgExpressions(requiredColumns);
collectAggregateArgExpressions(aggregateFunctionArguments);
}

private void collectAggregateArgExpressions(
final Collection<? extends Expression> expressions,
final Set<String> seen
final Collection<? extends Expression> expressions
) {
for (final Expression expression : expressions) {
if (seen.contains(expression.toString())) {
final String sql = expression.toString();
if (expressionToInternalColumnName.containsKey(sql)) {
continue;
}

seen.add(expression.toString());
final ColumnName internalName = expression instanceof ColumnReferenceExp
? ((ColumnReferenceExp) expression).getColumnName()
: ColumnName.of(INTERNAL_COLUMN_NAME_PREFIX + aggArgExpansions.size());

final String internalName = INTERNAL_COLUMN_NAME_PREFIX + aggArgExpansions.size();

aggArgExpansions.add(SelectExpression.of(ColumnName.of(internalName), expression));
expressionToInternalColumnName
.putIfAbsent(expression.toString(), ColumnName.of(internalName));
aggArgExpansions.add(SelectExpression.of(internalName, expression));
expressionToInternalColumnName.put(sql, internalName);
}
}

List<Expression> resolveGroupByExpressions(
final List<Expression> expressionList,
final SchemaKStream<?> aggregateArgExpanded
) {
final boolean specialRowTimeHandling = !(aggregateArgExpanded instanceof SchemaKTable);

final Function<Expression, Expression> mapper = e -> {
final boolean rowKey = singleKeyColumn.isPresent()
&& e instanceof UnqualifiedColumnReferenceExp
&& ((UnqualifiedColumnReferenceExp) e).getColumnName().equals(singleKeyColumn.get());

if (!rowKey || !specialRowTimeHandling) {
return resolveToInternal(e);
}

return e;
};

return expressionList.stream()
.map(mapper)
.collect(Collectors.toList());
}

/**
* Return the aggregate function arguments based on the internal expressions.
* Currently we support aggregate functions with at most two arguments where
* the second argument should be a literal.
* Return the aggregate function arguments based on the internal expressions. Currently we
* support aggregate functions with at most two arguments where the second argument should be a
* literal.
*
* @param argExpressionList The list of parameters for the aggregate fuunction.
* @return The list of arguments based on the internal expressions for the aggregate function.
*/
List<Expression> getInternalArgsExpressionList(final List<Expression> argExpressionList) {
List<Expression> updateArgsExpressionList(final List<Expression> argExpressionList) {
// Currently we only support aggregations on one column only
if (argExpressionList.size() > 2) {
throw new KsqlException("Currently, KSQL UDAFs can only have two arguments.");
Expand All @@ -319,7 +293,7 @@ List<Expression> getInternalArgsExpressionList(final List<Expression> argExpress
final List<Expression> internalExpressionList = new ArrayList<>();
internalExpressionList.add(resolveToInternal(argExpressionList.get(0)));
if (argExpressionList.size() == 2) {
if (! (argExpressionList.get(1) instanceof Literal)) {
if (!(argExpressionList.get(1) instanceof Literal)) {
throw new KsqlException("Currently, second argument in UDAF should be literal.");
}
internalExpressionList.add(argExpressionList.get(1));
Expand All @@ -328,14 +302,9 @@ List<Expression> getInternalArgsExpressionList(final List<Expression> argExpress

}

List<SelectExpression> updateFinalSelectExpressions(
final List<SelectExpression> finalSelectExpressions
) {
return finalSelectExpressions.stream()
.map(finalSelectExpression -> {
final Expression internal = resolveToInternal(finalSelectExpression.getExpression());
return SelectExpression.of(finalSelectExpression.getAlias(), internal);
})
List<FunctionCall> updateFunctionList(final ImmutableList<FunctionCall> functions) {
return functions.stream()
.map(fc -> new FunctionCall(fc.getName(), updateArgsExpressionList(fc.getArguments())))
.collect(Collectors.toList());
}

Expand All @@ -356,6 +325,7 @@ private Expression resolveToInternal(final Expression exp) {

private final class ResolveToInternalRewriter
extends VisitParentExpressionVisitor<Optional<Expression>, Context<Void>> {

private ResolveToInternalRewriter() {
super(Optional.empty());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
@SuppressWarnings({"unused", "MethodMayBeStatic"})
public class BadUdf {

private int i;

@Udf(description = "throws")
public String blowUp(final int arg1) {
throw new RuntimeException("boom!");
Expand All @@ -32,4 +34,9 @@ public int mightThrow(final boolean arg) {

return 0;
}

@Udf(description = "returns null every other invocation")
public String returnNull(final String arg) {
return i++ % 2 == 0 ? null : arg;
}
}
Loading

0 comments on commit 1b5312d

Please sign in to comment.