Skip to content

Commit

Permalink
feat: move aggregation to plan builder (#3391)
Browse files Browse the repository at this point in the history
This patch moves the code for building aggregations into a plan builder:

To express windowed aggregates, I've added a new execution step type
called StreamWindowedAggregate. Adding a new type ensures that when we
implement the visitor that builds the streams app, the handler for windowed
aggregations is type-safe.

I've moved the windowing pojos into KSQL execution, and split them off
from the AST. This way we can serialize these into the aggregation plan
nodes to express windows. I've also removed the code that builds
aggregations from these pojos and moved it into a visitor inside
StreamAggregateBuilder.

To implement window start/end projections, I had to move
WindowSelectMapper to ksql-execution. This also requires having
WindowSelectMapper own the start/end udaf names.

This patch also includes a refactor of AggregateNode to pass down the
aggregation function call expressions rather than the resolved aggregation
functions. The code for resolving the function call expressions against the
internal schema and building the aggregators, initializers, and undo
aggregators has been moved into a class called AggregateParams, which
the aggregate builders use to make the aggregation calls in streams.

The rest of the patch implements the actual aggregation from step builders.
  • Loading branch information
rodesai authored Sep 25, 2019
1 parent fd411a1 commit 3aaeb73
Show file tree
Hide file tree
Showing 53 changed files with 2,351 additions and 1,203 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.windows.KsqlWindowExpression;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.metastore.model.DataSource;
import io.confluent.ksql.name.ColumnName;
Expand All @@ -40,7 +41,6 @@
import io.confluent.ksql.parser.tree.GroupingElement;
import io.confluent.ksql.parser.tree.Join;
import io.confluent.ksql.parser.tree.JoinOn;
import io.confluent.ksql.parser.tree.KsqlWindowExpression;
import io.confluent.ksql.parser.tree.Query;
import io.confluent.ksql.parser.tree.Select;
import io.confluent.ksql.parser.tree.SelectItem;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import io.confluent.ksql.parser.tree.GroupingElement;
import io.confluent.ksql.parser.tree.InsertInto;
import io.confluent.ksql.parser.tree.Join;
import io.confluent.ksql.parser.tree.KsqlWindowExpression;
import io.confluent.ksql.parser.tree.Query;
import io.confluent.ksql.parser.tree.RegisterType;
import io.confluent.ksql.parser.tree.Relation;
Expand Down Expand Up @@ -236,14 +235,8 @@ protected AstNode visitWindowExpression(final WindowExpression node, final C con
return new WindowExpression(
node.getLocation(),
node.getWindowName(),
(KsqlWindowExpression) rewriter.apply(node.getKsqlWindowExpression(), context));
}

@Override
protected AstNode visitKsqlWindowExpression(
final KsqlWindowExpression node,
final C context) {
return node;
node.getKsqlWindowExpression()
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package io.confluent.ksql.function;

import io.confluent.ksql.execution.function.TableAggregationFunction;
import io.confluent.ksql.function.udaf.TableUdaf;
import java.util.List;
import java.util.Optional;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

package io.confluent.ksql.function.udaf.count;

import io.confluent.ksql.execution.function.TableAggregationFunction;
import io.confluent.ksql.function.AggregateFunctionArguments;
import io.confluent.ksql.function.BaseAggregateFunction;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.TableAggregationFunction;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

package io.confluent.ksql.function.udaf.sum;

import io.confluent.ksql.execution.function.TableAggregationFunction;
import io.confluent.ksql.function.AggregateFunctionArguments;
import io.confluent.ksql.function.BaseAggregateFunction;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.TableAggregationFunction;
import io.confluent.ksql.util.DecimalUtil;
import java.math.BigDecimal;
import java.math.MathContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

package io.confluent.ksql.function.udaf.sum;

import io.confluent.ksql.execution.function.TableAggregationFunction;
import io.confluent.ksql.function.AggregateFunctionArguments;
import io.confluent.ksql.function.BaseAggregateFunction;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.TableAggregationFunction;
import java.util.Collections;
import java.util.function.Function;
import org.apache.kafka.connect.data.Schema;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

package io.confluent.ksql.function.udaf.sum;

import io.confluent.ksql.execution.function.TableAggregationFunction;
import io.confluent.ksql.function.AggregateFunctionArguments;
import io.confluent.ksql.function.BaseAggregateFunction;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.TableAggregationFunction;
import java.util.Collections;
import java.util.function.Function;
import org.apache.kafka.connect.data.Schema;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

package io.confluent.ksql.function.udaf.sum;

import io.confluent.ksql.execution.function.TableAggregationFunction;
import io.confluent.ksql.function.AggregateFunctionArguments;
import io.confluent.ksql.function.BaseAggregateFunction;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.TableAggregationFunction;
import java.util.Collections;
import java.util.function.Function;
import org.apache.kafka.connect.data.Schema;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package io.confluent.ksql.function.udaf.window;

import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper;
import io.confluent.ksql.function.udaf.TableUdaf;
import io.confluent.ksql.function.udaf.UdafDescription;
import io.confluent.ksql.function.udaf.UdafFactory;
Expand All @@ -37,7 +38,7 @@ private WindowEndKudaf() {
}

static String getFunctionName() {
return "WindowEnd";
return WindowSelectMapper.WINDOW_END_NAME;
}

@UdafFactory(description = "Extracts the window end time")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package io.confluent.ksql.function.udaf.window;

import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper;
import io.confluent.ksql.function.udaf.TableUdaf;
import io.confluent.ksql.function.udaf.UdafDescription;
import io.confluent.ksql.function.udaf.UdafFactory;
Expand All @@ -37,7 +38,7 @@ private WindowStartKudaf() {
}

static String getFunctionName() {
return "WindowStart";
return WindowSelectMapper.WINDOW_START_NAME;
}

@UdafFactory(description = "Extracts the window start time")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import static java.util.Objects.requireNonNull;

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter;
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context;
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
Expand All @@ -29,24 +28,20 @@
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.Literal;
import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor;
import io.confluent.ksql.execution.function.UdafUtil;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.util.ExpressionTypeManager;
import io.confluent.ksql.function.AggregateFunctionArguments;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.udaf.KudafInitializer;
import io.confluent.ksql.materialization.MaterializationInfo;
import io.confluent.ksql.metastore.model.KeyField;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.parser.tree.WindowExpression;
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.PhysicalSchema;
import io.confluent.ksql.schema.ksql.SchemaConverters;
import io.confluent.ksql.schema.ksql.SchemaConverters.ConnectToSqlTypeConverter;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.serde.SerdeOption;
import io.confluent.ksql.serde.ValueFormat;
import io.confluent.ksql.services.KafkaTopicClient;
import io.confluent.ksql.structured.SchemaKGroupedStream;
Expand All @@ -66,8 +61,6 @@
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.connect.data.Schema;


public class AggregateNode extends PlanNode {
Expand Down Expand Up @@ -234,54 +227,43 @@ public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {
builder
);

// Aggregate computations
final KudafInitializer initializer = new KudafInitializer(requiredColumns.size());

final Map<Integer, KsqlAggregateFunction> aggValToFunctionMap = createAggValToFunctionMap(
aggregateArgExpanded,
initializer,
requiredColumns.size(),
builder.getFunctionRegistry(),
internalSchema
);
final List<FunctionCall> functionsWithInternalIdentifiers = functionList.stream()
.map(
fc -> new FunctionCall(
fc.getName(),
internalSchema.getInternalArgsExpressionList(fc.getArguments())
)
)
.collect(Collectors.toList());

// This is the schema of the aggregation change log topic and associated state store.
// It contains all columns from prepareSchema and columns for any aggregating functions
// It uses internal column names, e.g. KSQL_INTERNAL_COL_0 and KSQL_AGG_VARIABLE_0
final LogicalSchema aggregationSchema = buildLogicalSchema(
prepareSchema,
aggValToFunctionMap,
functionsWithInternalIdentifiers,
builder.getFunctionRegistry(),
true
);

final QueryContext.Stacker aggregationContext = contextStacker.push(AGGREGATION_OP_NAME);

final Serde<GenericRow> aggValueGenericRowSerde = builder.buildValueSerde(
valueFormat.getFormatInfo(),
PhysicalSchema.from(aggregationSchema, SerdeOption.none()),
aggregationContext.getQueryContext()
);

final List<FunctionCall> functionsWithInternalIdentifiers = functionList.stream()
.map(internalSchema::resolveToInternal)
.map(FunctionCall.class::cast)
.collect(Collectors.toList());

final LogicalSchema outputSchema = buildLogicalSchema(
prepareSchema,
aggValToFunctionMap,
false);
functionsWithInternalIdentifiers,
builder.getFunctionRegistry(),
false
);

SchemaKTable<?> aggregated = schemaKGroupedStream.aggregate(
aggregationSchema,
outputSchema,
initializer,
requiredColumns.size(),
functionsWithInternalIdentifiers,
aggValToFunctionMap,
windowExpression,
valueFormat,
aggValueGenericRowSerde,
aggregationContext
aggregationContext,
builder
);

final Optional<Expression> havingExpression = Optional.ofNullable(havingExpressions)
Expand Down Expand Up @@ -316,61 +298,12 @@ protected int getPartitions(final KafkaTopicClient kafkaTopicClient) {
return source.getPartitions(kafkaTopicClient);
}

private Map<Integer, KsqlAggregateFunction> createAggValToFunctionMap(
final SchemaKStream aggregateArgExpanded,
final KudafInitializer initializer,
final int initialUdafIndex,
final FunctionRegistry functionRegistry,
final InternalSchema internalSchema
) {
int udafIndexInAggSchema = initialUdafIndex;
final Map<Integer, KsqlAggregateFunction> aggValToAggFunctionMap = new HashMap<>();
for (final FunctionCall functionCall : functionList) {
final KsqlAggregateFunction aggregateFunction = getAggregateFunction(
functionRegistry,
internalSchema,
functionCall, aggregateArgExpanded.getSchema());

aggValToAggFunctionMap.put(udafIndexInAggSchema++, aggregateFunction);
initializer.addAggregateIntializer(aggregateFunction.getInitialValueSupplier());
}
return aggValToAggFunctionMap;
}

@SuppressWarnings("deprecation") // Need to migrate away from Connect Schema use.
private static KsqlAggregateFunction getAggregateFunction(
final FunctionRegistry functionRegistry,
final InternalSchema internalSchema,
final FunctionCall functionCall,
final LogicalSchema schema
) {
try {
final ExpressionTypeManager expressionTypeManager =
new ExpressionTypeManager(schema, functionRegistry);
final List<Expression> functionArgs = internalSchema.getInternalArgsExpressionList(
functionCall.getArguments());
final Schema expressionType = expressionTypeManager.getExpressionSchema(functionArgs.get(0));
final KsqlAggregateFunction aggregateFunctionInfo = functionRegistry
.getAggregate(functionCall.getName().name(), expressionType);

final List<String> args = functionArgs.stream()
.map(Expression::toString)
.collect(Collectors.toList());

final int udafIndex = Integer
.parseInt(args.get(0).substring(INTERNAL_COLUMN_NAME_PREFIX.length()));

return aggregateFunctionInfo.getInstance(new AggregateFunctionArguments(udafIndex, args));
} catch (final Exception e) {
throw new KsqlException("Failed to create aggregate function: " + functionCall, e);
}
}

private LogicalSchema buildLogicalSchema(
final LogicalSchema inputSchema,
final Map<Integer, KsqlAggregateFunction> aggregateFunctions,
final boolean useAggregate) {

final List<FunctionCall> aggregations,
final FunctionRegistry functionRegistry,
final boolean useAggregate
) {
final LogicalSchema.Builder schemaBuilder = LogicalSchema.builder();
final List<Column> cols = inputSchema.value();

Expand All @@ -382,18 +315,13 @@ private LogicalSchema buildLogicalSchema(

final ConnectToSqlTypeConverter converter = SchemaConverters.connectToSqlConverter();

for (int idx = 0; idx < aggregateFunctions.size(); idx++) {

final KsqlAggregateFunction aggregateFunction = aggregateFunctions
.get(requiredColumns.size() + idx);

final ColumnName colName = ColumnName.aggregate(idx);
SqlType fieldType = null;
if (useAggregate) {
fieldType = converter.toSqlType(aggregateFunction.getAggregateType());
} else {
fieldType = converter.toSqlType(aggregateFunction.getReturnType());
}
for (int i = 0; i < aggregations.size(); i++) {
final KsqlAggregateFunction aggregateFunction =
UdafUtil.resolveAggregateFunction(functionRegistry, aggregations.get(i), inputSchema);
final ColumnName colName = ColumnName.aggregate(i);
final SqlType fieldType = converter.toSqlType(
useAggregate ? aggregateFunction.getAggregateType() : aggregateFunction.getReturnType()
);
schemaBuilder.valueColumn(colName, fieldType);
}

Expand Down
Loading

0 comments on commit 3aaeb73

Please sign in to comment.