Skip to content

Commit

Permalink
feat: move groupBy into plan builders
Browse files Browse the repository at this point in the history
This patch moves the code for regrouping streams/tables into plan
builders. This also required adding a new execution step for
groupByKey, which we missed the first go-round.
  • Loading branch information
rodesai committed Sep 16, 2019
1 parent 06aa252 commit f6212f9
Show file tree
Hide file tree
Showing 28 changed files with 989 additions and 192 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -219,20 +219,14 @@ public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {
.getKsqlTopic()
.getValueFormat();

final Serde<GenericRow> genericRowSerde = builder.buildValueSerde(
valueFormat.getFormatInfo(),
PhysicalSchema.from(prepareSchema, SerdeOption.none()),
groupByContext.getQueryContext()
);

final List<Expression> internalGroupByColumns = internalSchema.getInternalExpressionList(
getGroupByExpressions());

final SchemaKGroupedStream schemaKGroupedStream = aggregateArgExpanded.groupBy(
valueFormat,
genericRowSerde,
internalGroupByColumns,
groupByContext
groupByContext,
builder
);

// Aggregate computations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package io.confluent.ksql.streams;

import io.confluent.ksql.execution.streams.GroupedFactory;
import io.confluent.ksql.execution.streams.MaterializedFactory;
import io.confluent.ksql.util.KsqlConfig;
import java.util.Objects;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import com.google.common.collect.ImmutableList;
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.codegen.CodeGenRunner;
import io.confluent.ksql.execution.codegen.ExpressionMetadata;
import io.confluent.ksql.execution.context.QueryContext;
import io.confluent.ksql.execution.context.QueryLoggerUtil;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
Expand All @@ -37,14 +35,18 @@
import io.confluent.ksql.execution.plan.JoinType;
import io.confluent.ksql.execution.plan.LogicalSchemaWithMetaAndKeyFields;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.plan.StreamGroupBy;
import io.confluent.ksql.execution.plan.StreamGroupByKey;
import io.confluent.ksql.execution.plan.StreamMapValues;
import io.confluent.ksql.execution.plan.StreamSource;
import io.confluent.ksql.execution.plan.StreamToTable;
import io.confluent.ksql.execution.streams.ExecutionStepFactory;
import io.confluent.ksql.execution.streams.StreamGroupByBuilder;
import io.confluent.ksql.execution.streams.StreamMapValuesBuilder;
import io.confluent.ksql.execution.streams.StreamSourceBuilder;
import io.confluent.ksql.execution.streams.StreamToTableBuilder;
import io.confluent.ksql.execution.streams.StreamsUtil;
import io.confluent.ksql.execution.util.StructKeyUtil;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.logging.processing.ProcessingLogContext;
import io.confluent.ksql.metastore.model.DataSource;
Expand All @@ -68,10 +70,10 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.streams.Topology.AutoOffsetReset;
import org.apache.kafka.streams.kstream.Grouped;
import org.apache.kafka.streams.kstream.JoinWindows;
import org.apache.kafka.streams.kstream.KGroupedStream;
import org.apache.kafka.streams.kstream.KStream;
Expand All @@ -88,6 +90,8 @@ public class SchemaKStream<K> {
private static final FormatOptions FORMAT_OPTIONS =
FormatOptions.of(IdentifierUtil::needsQuotes);

static final String GROUP_BY_COLUMN_SEPARATOR = "|+|";

public enum Type { SOURCE, PROJECT, FILTER, AGGREGATE, SINK, REKEY, JOIN }

final KStream<K, GenericRow> kstream;
Expand Down Expand Up @@ -809,36 +813,30 @@ private boolean rekeyRequired(final List<Expression> groupByExpressions) {
@SuppressWarnings("unchecked")
public SchemaKGroupedStream groupBy(
final ValueFormat valueFormat,
final Serde<GenericRow> valSerde,
final List<Expression> groupByExpressions,
final QueryContext.Stacker contextStacker
final QueryContext.Stacker contextStacker,
final KsqlQueryBuilder queryBuilder
) {
final boolean rekey = rekeyRequired(groupByExpressions);
final KeyFormat rekeyedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo());
if (!rekey) {
final Grouped<K, GenericRow> grouped = streamsFactories.getGroupedFactory()
.create(
StreamsUtil.buildOpName(contextStacker.getQueryContext()),
keySerde,
valSerde
);

final KGroupedStream kgroupedStream = kstream.groupByKey(grouped);

if (keySerde.isWindowed()) {
throw new UnsupportedOperationException("Group by on windowed should always require rekey");
}

final KeySerde<Struct> structKeySerde = (KeySerde) keySerde;
final ExecutionStep<KGroupedStream<Struct, GenericRow>> step =
ExecutionStepFactory.streamGroupBy(
final StreamGroupByKey<KStream<K, GenericRow>, KGroupedStream<Struct, GenericRow>> step =
ExecutionStepFactory.streamGroupByKey(
contextStacker,
sourceStep,
Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()),
groupByExpressions
Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none())
);
return new SchemaKGroupedStream(
kgroupedStream,
StreamGroupByBuilder.build(
(KStream<Object, GenericRow>) kstream,
step,
queryBuilder,
streamsFactories.getGroupedFactory()
),
step,
keyFormat,
structKeySerde,
Expand All @@ -849,36 +847,29 @@ public SchemaKGroupedStream groupBy(
);
}

final GroupBy groupBy = new GroupBy(groupByExpressions);

final KeySerde<Struct> groupedKeySerde = keySerde
.rebind(StructKeyUtil.ROWKEY_SERIALIZED_SCHEMA);

final Grouped<Struct, GenericRow> grouped = streamsFactories.getGroupedFactory()
.create(
StreamsUtil.buildOpName(contextStacker.getQueryContext()),
groupedKeySerde,
valSerde
);

final KGroupedStream kgroupedStream = kstream
.filter((key, value) -> value != null)
.groupBy(groupBy.mapper, grouped);

final String aggregateKeyName = groupedKeyNameFor(groupByExpressions);
final LegacyField legacyKeyField = LegacyField
.notInSchema(groupBy.aggregateKeyName, SqlTypes.STRING);

final Optional<String> newKeyCol = getSchema().findValueColumn(groupBy.aggregateKeyName)
.notInSchema(aggregateKeyName, SqlTypes.STRING);
final Optional<String> newKeyCol = getSchema().findValueColumn(aggregateKeyName)
.map(Column::name);
final ExecutionStep<KGroupedStream<Struct, GenericRow>> source =

final StreamGroupBy<KStream<K, GenericRow>, KGroupedStream<Struct, GenericRow>> source =
ExecutionStepFactory.streamGroupBy(
contextStacker,
sourceStep,
Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()),
groupByExpressions
);
return new SchemaKGroupedStream(
kgroupedStream,
StreamGroupByBuilder.build(
(KStream<Object, GenericRow>) kstream,
source,
queryBuilder,
streamsFactories.getGroupedFactory()
),
source,
rekeyedKeyFormat,
groupedKeySerde,
Expand Down Expand Up @@ -946,18 +937,10 @@ public FunctionRegistry getFunctionRegistry() {
return functionRegistry;
}

class GroupBy {

final String aggregateKeyName;
final GroupByMapper mapper;

GroupBy(final List<Expression> expressions) {
final List<ExpressionMetadata> groupBy = CodeGenRunner.compileExpressions(
expressions.stream(), "Group By", getSchema(), ksqlConfig, functionRegistry);

this.mapper = new GroupByMapper(groupBy);
this.aggregateKeyName = GroupByMapper.keyNameFor(expressions);
}
String groupedKeyNameFor(final List<Expression> groupByExpressions) {
return groupByExpressions.stream()
.map(Expression::toString)
.collect(Collectors.joining(GROUP_BY_COLUMN_SEPARATOR));
}

protected static class KsqlValueJoiner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
import io.confluent.ksql.execution.plan.Formats;
import io.confluent.ksql.execution.plan.JoinType;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.plan.TableGroupBy;
import io.confluent.ksql.execution.plan.TableMapValues;
import io.confluent.ksql.execution.streams.ExecutionStepFactory;
import io.confluent.ksql.execution.streams.StreamsUtil;
import io.confluent.ksql.execution.streams.TableGroupByBuilder;
import io.confluent.ksql.execution.streams.TableMapValuesBuilder;
import io.confluent.ksql.execution.util.StructKeyUtil;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.logging.processing.ProcessingLogContext;
import io.confluent.ksql.metastore.model.KeyField;
Expand All @@ -50,8 +52,6 @@
import java.util.Set;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.kstream.Grouped;
import org.apache.kafka.streams.kstream.KGroupedTable;
import org.apache.kafka.streams.kstream.KStream;
import org.apache.kafka.streams.kstream.KTable;
Expand Down Expand Up @@ -241,48 +241,38 @@ public ExecutionStep<KTable<K, GenericRow>> getSourceTableStep() {
}

@Override
@SuppressWarnings("unchecked")
public SchemaKGroupedStream groupBy(
final ValueFormat valueFormat,
final Serde<GenericRow> valSerde,
final List<Expression> groupByExpressions,
final QueryContext.Stacker contextStacker
final QueryContext.Stacker contextStacker,
final KsqlQueryBuilder queryBuilder
) {

final GroupBy groupBy = new GroupBy(groupByExpressions);
final KeyFormat groupedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo());

final KeySerde<Struct> groupedKeySerde = keySerde
.rebind(StructKeyUtil.ROWKEY_SERIALIZED_SCHEMA);

final Grouped<Struct, GenericRow> grouped = streamsFactories.getGroupedFactory()
.create(
StreamsUtil.buildOpName(contextStacker.getQueryContext()),
groupedKeySerde,
valSerde
);

final KGroupedTable kgroupedTable = ktable
.filter((key, value) -> value != null)
.groupBy(
(key, value) -> new KeyValue<>(groupBy.mapper.apply(key, value), value),
grouped
);

final LegacyField legacyKeyField = LegacyField
.notInSchema(groupBy.aggregateKeyName, SqlTypes.STRING);

final Optional<String> newKeyField = getSchema().findValueColumn(groupBy.aggregateKeyName)
.map(Column::fullName);
final String aggregateKeyName = groupedKeyNameFor(groupByExpressions);
final LegacyField legacyKeyField = LegacyField.notInSchema(aggregateKeyName, SqlTypes.STRING);
final Optional<String> newKeyField =
getSchema().findValueColumn(aggregateKeyName).map(Column::fullName);

final ExecutionStep<KGroupedTable<Struct, GenericRow>> step =
final TableGroupBy<KTable<K, GenericRow>, KGroupedTable<Struct, GenericRow>> step =
ExecutionStepFactory.tableGroupBy(
contextStacker,
sourceTableStep,
Formats.of(groupedKeyFormat, valueFormat, SerdeOption.none()),
groupByExpressions
);
return new SchemaKGroupedTable(
kgroupedTable,
TableGroupByBuilder.build(
(KTable<Object, GenericRow>) ktable,
step,
queryBuilder,
streamsFactories.getGroupedFactory()
),
step,
groupedKeyFormat,
groupedKeySerde,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.PersistenceSchema;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.serde.FormatInfo;
import io.confluent.ksql.serde.KeySerde;
import io.confluent.ksql.serde.WindowInfo;
import io.confluent.ksql.structured.SchemaKStream;
Expand Down Expand Up @@ -401,7 +402,7 @@ private SchemaKStream buildQuery(final AggregateNode aggregateNode, final KsqlCo
when(ksqlStreamBuilder.buildNodeContext(any())).thenAnswer(inv ->
new QueryContext.Stacker(queryId)
.push(inv.getArgument(0).toString()));
when(ksqlStreamBuilder.buildKeySerde(any(), any(), any())).thenReturn(keySerde);
when(ksqlStreamBuilder.buildKeySerde(any(FormatInfo.class), any(), any())).thenReturn(keySerde);

return aggregateNode.buildStream(ksqlStreamBuilder);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ public void before() {
new QueryContext.Stacker(queryId)
.push(inv.getArgument(0).toString()));

when(ksqlStreamBuilder.buildKeySerde(any(), any(), any())).thenReturn((KeySerde)keySerde);
when(ksqlStreamBuilder.buildKeySerde(any(FormatInfo.class), any(), any()))
.thenReturn((KeySerde)keySerde);
when(ksqlStreamBuilder.buildKeySerde(any(), any(), any(), any())).thenReturn((KeySerde)keySerde);
when(ksqlStreamBuilder.buildValueSerde(any(), any(), any())).thenReturn(rowSerde);
when(ksqlStreamBuilder.getFunctionRegistry()).thenReturn(functionRegistry);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ public void setUp() {
when(ksqlStreamBuilder.buildNodeContext(any())).thenAnswer(inv ->
new QueryContext.Stacker(queryId)
.push(inv.getArgument(0).toString()));
when(ksqlStreamBuilder.buildKeySerde(any(), any(), any())).thenReturn(keySerde);
when(ksqlStreamBuilder.buildKeySerde(any(FormatInfo.class), any(), any()))
.thenReturn(keySerde);

when(keySerde.rebind(any(PersistenceSchema.class))).thenReturn(reboundKeySerde);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.serde.FormatInfo;
import io.confluent.ksql.serde.KeySerde;
import io.confluent.ksql.structured.SchemaKStream;
import io.confluent.ksql.testutils.AnalysisTestUtil;
Expand Down Expand Up @@ -89,7 +90,8 @@ public void before() {
when(ksqlStreamBuilder.buildNodeContext(any())).thenAnswer(inv ->
new QueryContext.Stacker(queryId)
.push(inv.getArgument(0).toString()));
when(ksqlStreamBuilder.buildKeySerde(any(), any(), any())).thenReturn(keySerde);
when(ksqlStreamBuilder.buildKeySerde(any(FormatInfo.class), any(), any()))
.thenReturn(keySerde);

final KsqlBareOutputNode planNode = (KsqlBareOutputNode) AnalysisTestUtil
.buildLogicalPlan(ksqlConfig, SIMPLE_SELECT_WITH_FILTER, metaStore);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.execution.streams.GroupedFactory;
import io.confluent.ksql.util.KsqlConfig;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.streams.StreamsConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.google.common.collect.ImmutableMap;
import io.confluent.kafka.schemaregistry.client.MockSchemaRegistryClient;
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.context.QueryContext;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
Expand Down Expand Up @@ -138,6 +139,8 @@ public class SchemaKGroupedTableTest {
private KsqlAggregateFunction otherFunc;
@Mock
private TableAggregationFunction tableFunc;
@Mock
private KsqlQueryBuilder queryBuilder;

private KTable kTable;
private KsqlTable<?> ksqlTable;
Expand All @@ -161,6 +164,9 @@ public void init() {
Consumed.with(Serdes.String(), rowSerde)
);

when(queryBuilder.getFunctionRegistry()).thenReturn(functionRegistry);
when(queryBuilder.getKsqlConfig()).thenReturn(ksqlConfig);

when(aggregateSchema.findValueColumn("GROUPING_COLUMN"))
.thenReturn(Optional.of(Column.of("GROUPING_COLUMN", SqlTypes.STRING)));

Expand Down Expand Up @@ -210,7 +216,7 @@ private SchemaKGroupedTable buildSchemaKGroupedTableFromQuery(
);

final SchemaKGroupedStream groupedSchemaKTable = initialSchemaKTable.groupBy(
valueFormat, rowSerde, groupByExpressions, queryContext);
valueFormat, groupByExpressions, queryContext, queryBuilder);
Assert.assertThat(groupedSchemaKTable, instanceOf(SchemaKGroupedTable.class));
return (SchemaKGroupedTable)groupedSchemaKTable;
}
Expand Down
Loading

0 comments on commit f6212f9

Please sign in to comment.