Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: pick up key name from field name on GROUP BY and PARTITON BY #4902

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter;
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp;
Expand Down Expand Up @@ -73,6 +74,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
Expand Down Expand Up @@ -511,30 +513,48 @@ private LogicalSchema buildAggregateSchema(
) {
final LogicalSchema sourceSchema = sourcePlanNode.getSchema();

final LogicalSchema projectionSchema = buildProjectionSchema(
sourceSchema
.withMetaAndKeyColsInValue(analysis.getWindowExpression().isPresent()),
projectionExpressions
);

final Supplier<ColumnName> keyColNameGen = ColumnNames
.columnAliasGenerator(Stream.of(sourceSchema, projectionSchema));

final ColumnName keyName;
final SqlType keyType;

if (groupByExps.size() != 1) {
keyName = SchemaUtil.ROWKEY_NAME;
if (ksqlConfig.getBoolean(KsqlConfig.KSQL_ANY_KEY_NAME_ENABLED)) {
keyName = keyColNameGen.get();
} else {
keyName = SchemaUtil.ROWKEY_NAME;
}
keyType = SqlTypes.STRING;
} else {
final Expression expression = groupByExps.get(0);

keyName = exactlyMatchesKeyColumns(expression, sourceSchema)
? ((ColumnReferenceExp) expression).getColumnName()
: SchemaUtil.ROWKEY_NAME;
if (ksqlConfig.getBoolean(KsqlConfig.KSQL_ANY_KEY_NAME_ENABLED)) {
if (expression instanceof ColumnReferenceExp) {
keyName = ((ColumnReferenceExp) expression).getColumnName();
} else if (expression instanceof DereferenceExpression) {
keyName = ColumnName.of(((DereferenceExpression) expression).getFieldName());
} else {
keyName = keyColNameGen.get();
}
} else {
keyName = exactlyMatchesKeyColumns(expression, sourceSchema)
? ((ColumnReferenceExp) expression).getColumnName()
: SchemaUtil.ROWKEY_NAME;
}

final ExpressionTypeManager typeManager =
new ExpressionTypeManager(sourceSchema, functionRegistry);

keyType = typeManager.getExpressionSqlType(expression);
}

final LogicalSchema projectionSchema = buildProjectionSchema(
sourceSchema
.withMetaAndKeyColsInValue(analysis.getWindowExpression().isPresent()),
projectionExpressions
);

return LogicalSchema.builder()
.withRowTime()
.keyColumn(keyName, keyType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context;
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.context.QueryContext;
import io.confluent.ksql.execution.context.QueryContext.Stacker;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
Expand All @@ -50,12 +51,9 @@
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;


Expand All @@ -78,6 +76,7 @@ public class AggregateNode extends PlanNode {
private final ImmutableList<ColumnReferenceExp> requiredColumns;
private final Optional<Expression> havingExpressions;
private final ImmutableList<SelectExpression> finalSelectExpressions;
private final ValueFormat valueFormat;

public AggregateNode(
final PlanNodeId id,
Expand Down Expand Up @@ -117,6 +116,10 @@ public AggregateNode(
.map(exp -> ExpressionTreeRewriter.rewriteWith(aggregateExpressionRewriter::process, exp));
this.keyField = KeyField.of(requireNonNull(keyFieldName, "keyFieldName"))
.validateKeyExistsIn(schema);
this.valueFormat = getTheSourceNode()
.getDataSource()
.getKsqlTopic()
.getValueFormat();
}

@Override
Expand Down Expand Up @@ -157,158 +160,129 @@ public <C, R> R accept(final PlanVisitor<C, R> visitor, final C context) {
@Override
public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {
final QueryContext.Stacker contextStacker = builder.buildNodeContext(getId().toString());
final DataSourceNode streamSourceNode = getTheSourceNode();
final SchemaKStream<?> sourceSchemaKStream = getSource().buildStream(builder);

// Pre aggregate computations
final InternalSchema internalSchema = new InternalSchema(
getSchema(),
requiredColumns,
aggregateFunctionArguments
);
final InternalSchema internalSchema =
new InternalSchema(requiredColumns, aggregateFunctionArguments);

final SchemaKStream<?> aggregateArgExpanded = sourceSchemaKStream.select(
internalSchema.getAggArgExpansionList(),
contextStacker.push(PREPARE_OP_NAME),
builder
);
final SchemaKStream<?> preSelected =
selectRequiredInputColumns(sourceSchemaKStream, internalSchema, contextStacker, builder);

final QueryContext.Stacker groupByContext = contextStacker.push(GROUP_BY_OP_NAME);
final SchemaKGroupedStream grouped = groupBy(contextStacker, preSelected);

final ValueFormat valueFormat = streamSourceNode
.getDataSource()
.getKsqlTopic()
.getValueFormat();
SchemaKTable<?> aggregated = aggregate(grouped, internalSchema, contextStacker);

final List<Expression> internalGroupByColumns = internalSchema.resolveGroupByExpressions(
groupByExpressions,
aggregateArgExpanded
);
aggregated = applyHavingFilter(aggregated, contextStacker);

final SchemaKGroupedStream schemaKGroupedStream = aggregateArgExpanded.groupBy(
valueFormat,
internalGroupByColumns,
groupByContext
return selectRequiredOutputColumns(aggregated, contextStacker, builder);
}

protected int getPartitions(final KafkaTopicClient kafkaTopicClient) {
return source.getPartitions(kafkaTopicClient);
}

private static SchemaKStream<?> selectRequiredInputColumns(
final SchemaKStream<?> sourceSchemaKStream,
final InternalSchema internalSchema,
final Stacker contextStacker,
final KsqlQueryBuilder builder
) {
return sourceSchemaKStream.select(
internalSchema.getAggArgExpansionList(),
contextStacker.push(PREPARE_OP_NAME),
builder
);
}

final List<FunctionCall> functionsWithInternalIdentifiers = functionList.stream()
.map(
fc -> new FunctionCall(
fc.getName(),
internalSchema.getInternalArgsExpressionList(fc.getArguments())
)
)
.collect(Collectors.toList());
private SchemaKTable<?> aggregate(
final SchemaKGroupedStream grouped,
final InternalSchema internalSchema,
final Stacker contextStacker
) {
final List<FunctionCall> functions = internalSchema.updateFunctionList(functionList);

final QueryContext.Stacker aggregationContext = contextStacker.push(AGGREGATION_OP_NAME);
final Stacker aggregationContext = contextStacker.push(AGGREGATION_OP_NAME);

final List<ColumnName> requiredColumnNames = requiredColumns.stream()
.map(e -> (UnqualifiedColumnReferenceExp) internalSchema.resolveToInternal(e))
.map(UnqualifiedColumnReferenceExp::getColumnName)
.collect(Collectors.toList());

SchemaKTable<?> aggregated = schemaKGroupedStream.aggregate(
return grouped.aggregate(
requiredColumnNames,
functionsWithInternalIdentifiers,
functions,
windowExpression,
valueFormat,
aggregationContext
);
}

final Optional<Expression> havingExpression = havingExpressions
.map(internalSchema::resolveToInternal);

if (havingExpression.isPresent()) {
aggregated = aggregated.filter(
havingExpression.get(),
contextStacker.push(HAVING_FILTER_OP_NAME)
);
}

final List<SelectExpression> finalSelects = internalSchema
.updateFinalSelectExpressions(finalSelectExpressions);
private SchemaKTable<?> applyHavingFilter(
final SchemaKTable<?> aggregated,
final Stacker contextStacker
) {
return havingExpressions.isPresent()
? aggregated.filter(havingExpressions.get(), contextStacker.push(HAVING_FILTER_OP_NAME))
: aggregated;
}

return aggregated.select(
finalSelects,
contextStacker.push(PROJECT_OP_NAME),
builder
);
private SchemaKStream<?> selectRequiredOutputColumns(
final SchemaKTable<?> aggregated,
final Stacker contextStacker,
final KsqlQueryBuilder builder
) {
return aggregated
.select(finalSelectExpressions, contextStacker.push(PROJECT_OP_NAME), builder);
}

protected int getPartitions(final KafkaTopicClient kafkaTopicClient) {
return source.getPartitions(kafkaTopicClient);
private SchemaKGroupedStream groupBy(
final Stacker contextStacker,
final SchemaKStream<?> preSelected
) {
return preSelected
.groupBy(valueFormat, groupByExpressions, contextStacker.push(GROUP_BY_OP_NAME));
}

private static class InternalSchema {

private final Optional<ColumnName> singleKeyColumn;
private final List<SelectExpression> aggArgExpansions = new ArrayList<>();
private final Map<String, ColumnName> expressionToInternalColumnName = new HashMap<>();

InternalSchema(
final LogicalSchema schema,
final List<ColumnReferenceExp> requiredColumns,
final List<Expression> aggregateFunctionArguments
) {
this.singleKeyColumn = schema.key().size() == 1
? Optional.of(schema.key().get(0).name())
: Optional.empty();

final Set<String> seen = new HashSet<>();
collectAggregateArgExpressions(requiredColumns, seen);
collectAggregateArgExpressions(aggregateFunctionArguments, seen);
collectAggregateArgExpressions(requiredColumns);
collectAggregateArgExpressions(aggregateFunctionArguments);
}

private void collectAggregateArgExpressions(
final Collection<? extends Expression> expressions,
final Set<String> seen
final Collection<? extends Expression> expressions
) {
for (final Expression expression : expressions) {
if (seen.contains(expression.toString())) {
final String sql = expression.toString();
if (expressionToInternalColumnName.containsKey(sql)) {
continue;
}

seen.add(expression.toString());
final ColumnName internalName = expression instanceof ColumnReferenceExp
? ((ColumnReferenceExp) expression).getColumnName()
: ColumnName.of(INTERNAL_COLUMN_NAME_PREFIX + aggArgExpansions.size());

final String internalName = INTERNAL_COLUMN_NAME_PREFIX + aggArgExpansions.size();

aggArgExpansions.add(SelectExpression.of(ColumnName.of(internalName), expression));
expressionToInternalColumnName
.putIfAbsent(expression.toString(), ColumnName.of(internalName));
aggArgExpansions.add(SelectExpression.of(internalName, expression));
expressionToInternalColumnName.put(sql, internalName);
}
}

List<Expression> resolveGroupByExpressions(
final List<Expression> expressionList,
final SchemaKStream<?> aggregateArgExpanded
) {
final boolean specialRowTimeHandling = !(aggregateArgExpanded instanceof SchemaKTable);

final Function<Expression, Expression> mapper = e -> {
final boolean rowKey = singleKeyColumn.isPresent()
&& e instanceof UnqualifiedColumnReferenceExp
&& ((UnqualifiedColumnReferenceExp) e).getColumnName().equals(singleKeyColumn.get());

if (!rowKey || !specialRowTimeHandling) {
return resolveToInternal(e);
}

return e;
};

return expressionList.stream()
.map(mapper)
.collect(Collectors.toList());
}

/**
* Return the aggregate function arguments based on the internal expressions.
* Currently we support aggregate functions with at most two arguments where
* the second argument should be a literal.
* Return the aggregate function arguments based on the internal expressions. Currently we
* support aggregate functions with at most two arguments where the second argument should be a
* literal.
*
* @param argExpressionList The list of parameters for the aggregate fuunction.
* @return The list of arguments based on the internal expressions for the aggregate function.
*/
List<Expression> getInternalArgsExpressionList(final List<Expression> argExpressionList) {
List<Expression> updateArgsExpressionList(final List<Expression> argExpressionList) {
// Currently we only support aggregations on one column only
if (argExpressionList.size() > 2) {
throw new KsqlException("Currently, KSQL UDAFs can only have two arguments.");
Expand All @@ -319,7 +293,7 @@ List<Expression> getInternalArgsExpressionList(final List<Expression> argExpress
final List<Expression> internalExpressionList = new ArrayList<>();
internalExpressionList.add(resolveToInternal(argExpressionList.get(0)));
if (argExpressionList.size() == 2) {
if (! (argExpressionList.get(1) instanceof Literal)) {
if (!(argExpressionList.get(1) instanceof Literal)) {
throw new KsqlException("Currently, second argument in UDAF should be literal.");
}
internalExpressionList.add(argExpressionList.get(1));
Expand All @@ -328,14 +302,9 @@ List<Expression> getInternalArgsExpressionList(final List<Expression> argExpress

}

List<SelectExpression> updateFinalSelectExpressions(
final List<SelectExpression> finalSelectExpressions
) {
return finalSelectExpressions.stream()
.map(finalSelectExpression -> {
final Expression internal = resolveToInternal(finalSelectExpression.getExpression());
return SelectExpression.of(finalSelectExpression.getAlias(), internal);
})
List<FunctionCall> updateFunctionList(final ImmutableList<FunctionCall> functions) {
return functions.stream()
.map(fc -> new FunctionCall(fc.getName(), updateArgsExpressionList(fc.getArguments())))
.collect(Collectors.toList());
}

Expand All @@ -356,6 +325,7 @@ private Expression resolveToInternal(final Expression exp) {

private final class ResolveToInternalRewriter
extends VisitParentExpressionVisitor<Optional<Expression>, Context<Void>> {

private ResolveToInternalRewriter() {
super(Optional.empty());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
@SuppressWarnings({"unused", "MethodMayBeStatic"})
public class BadUdf {

private int i;

@Udf(description = "throws")
public String blowUp(final int arg1) {
throw new RuntimeException("boom!");
Expand All @@ -32,4 +34,9 @@ public int mightThrow(final boolean arg) {

return 0;
}

@Udf(description = "returns null every other invocation")
public String returnNull(final String arg) {
return i++ % 2 == 0 ? null : arg;
}
}
Loading