diff --git a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java index becfc912eedf..64634ad658b8 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java @@ -219,20 +219,14 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { .getKsqlTopic() .getValueFormat(); - final Serde genericRowSerde = builder.buildValueSerde( - valueFormat.getFormatInfo(), - PhysicalSchema.from(prepareSchema, SerdeOption.none()), - groupByContext.getQueryContext() - ); - final List internalGroupByColumns = internalSchema.getInternalExpressionList( getGroupByExpressions()); final SchemaKGroupedStream schemaKGroupedStream = aggregateArgExpanded.groupBy( valueFormat, - genericRowSerde, internalGroupByColumns, - groupByContext + groupByContext, + builder ); // Aggregate computations diff --git a/ksql-engine/src/main/java/io/confluent/ksql/streams/StreamsFactories.java b/ksql-engine/src/main/java/io/confluent/ksql/streams/StreamsFactories.java index 1d31351e5826..6616f1cabd46 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/streams/StreamsFactories.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/streams/StreamsFactories.java @@ -15,6 +15,7 @@ package io.confluent.ksql.streams; +import io.confluent.ksql.execution.streams.GroupedFactory; import io.confluent.ksql.execution.streams.MaterializedFactory; import io.confluent.ksql.util.KsqlConfig; import java.util.Objects; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java index 9d1f76fb0b3f..103ca4df07e0 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java @@ -23,8 +23,6 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; -import io.confluent.ksql.execution.codegen.CodeGenRunner; -import io.confluent.ksql.execution.codegen.ExpressionMetadata; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.context.QueryLoggerUtil; import io.confluent.ksql.execution.ddl.commands.KsqlTopic; @@ -37,14 +35,18 @@ import io.confluent.ksql.execution.plan.JoinType; import io.confluent.ksql.execution.plan.LogicalSchemaWithMetaAndKeyFields; import io.confluent.ksql.execution.plan.SelectExpression; +import io.confluent.ksql.execution.plan.StreamGroupBy; +import io.confluent.ksql.execution.plan.StreamGroupByKey; import io.confluent.ksql.execution.plan.StreamMapValues; import io.confluent.ksql.execution.plan.StreamSource; import io.confluent.ksql.execution.plan.StreamToTable; import io.confluent.ksql.execution.streams.ExecutionStepFactory; +import io.confluent.ksql.execution.streams.StreamGroupByBuilder; import io.confluent.ksql.execution.streams.StreamMapValuesBuilder; import io.confluent.ksql.execution.streams.StreamSourceBuilder; import io.confluent.ksql.execution.streams.StreamToTableBuilder; import io.confluent.ksql.execution.streams.StreamsUtil; +import io.confluent.ksql.execution.util.StructKeyUtil; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.logging.processing.ProcessingLogContext; import io.confluent.ksql.metastore.model.DataSource; @@ -68,10 +70,10 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.connect.data.Struct; import org.apache.kafka.streams.Topology.AutoOffsetReset; -import org.apache.kafka.streams.kstream.Grouped; import org.apache.kafka.streams.kstream.JoinWindows; import org.apache.kafka.streams.kstream.KGroupedStream; import org.apache.kafka.streams.kstream.KStream; @@ -88,6 +90,8 @@ public class SchemaKStream { private static final FormatOptions FORMAT_OPTIONS = FormatOptions.of(IdentifierUtil::needsQuotes); + static final String GROUP_BY_COLUMN_SEPARATOR = "|+|"; + public enum Type { SOURCE, PROJECT, FILTER, AGGREGATE, SINK, REKEY, JOIN } final KStream kstream; @@ -809,36 +813,30 @@ private boolean rekeyRequired(final List groupByExpressions) { @SuppressWarnings("unchecked") public SchemaKGroupedStream groupBy( final ValueFormat valueFormat, - final Serde valSerde, final List groupByExpressions, - final QueryContext.Stacker contextStacker + final QueryContext.Stacker contextStacker, + final KsqlQueryBuilder queryBuilder ) { final boolean rekey = rekeyRequired(groupByExpressions); final KeyFormat rekeyedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo()); if (!rekey) { - final Grouped grouped = streamsFactories.getGroupedFactory() - .create( - StreamsUtil.buildOpName(contextStacker.getQueryContext()), - keySerde, - valSerde - ); - - final KGroupedStream kgroupedStream = kstream.groupByKey(grouped); - if (keySerde.isWindowed()) { throw new UnsupportedOperationException("Group by on windowed should always require rekey"); } - final KeySerde structKeySerde = (KeySerde) keySerde; - final ExecutionStep> step = - ExecutionStepFactory.streamGroupBy( + final StreamGroupByKey, KGroupedStream> step = + ExecutionStepFactory.streamGroupByKey( contextStacker, sourceStep, - Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()), - groupByExpressions + Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()) ); return new SchemaKGroupedStream( - kgroupedStream, + StreamGroupByBuilder.build( + (KStream) kstream, + step, + queryBuilder, + streamsFactories.getGroupedFactory() + ), step, keyFormat, structKeySerde, @@ -849,28 +847,16 @@ public SchemaKGroupedStream groupBy( ); } - final GroupBy groupBy = new GroupBy(groupByExpressions); - final KeySerde groupedKeySerde = keySerde .rebind(StructKeyUtil.ROWKEY_SERIALIZED_SCHEMA); - final Grouped grouped = streamsFactories.getGroupedFactory() - .create( - StreamsUtil.buildOpName(contextStacker.getQueryContext()), - groupedKeySerde, - valSerde - ); - - final KGroupedStream kgroupedStream = kstream - .filter((key, value) -> value != null) - .groupBy(groupBy.mapper, grouped); - + final String aggregateKeyName = groupedKeyNameFor(groupByExpressions); final LegacyField legacyKeyField = LegacyField - .notInSchema(groupBy.aggregateKeyName, SqlTypes.STRING); - - final Optional newKeyCol = getSchema().findValueColumn(groupBy.aggregateKeyName) + .notInSchema(aggregateKeyName, SqlTypes.STRING); + final Optional newKeyCol = getSchema().findValueColumn(aggregateKeyName) .map(Column::name); - final ExecutionStep> source = + + final StreamGroupBy, KGroupedStream> source = ExecutionStepFactory.streamGroupBy( contextStacker, sourceStep, @@ -878,7 +864,12 @@ public SchemaKGroupedStream groupBy( groupByExpressions ); return new SchemaKGroupedStream( - kgroupedStream, + StreamGroupByBuilder.build( + (KStream) kstream, + source, + queryBuilder, + streamsFactories.getGroupedFactory() + ), source, rekeyedKeyFormat, groupedKeySerde, @@ -946,18 +937,10 @@ public FunctionRegistry getFunctionRegistry() { return functionRegistry; } - class GroupBy { - - final String aggregateKeyName; - final GroupByMapper mapper; - - GroupBy(final List expressions) { - final List groupBy = CodeGenRunner.compileExpressions( - expressions.stream(), "Group By", getSchema(), ksqlConfig, functionRegistry); - - this.mapper = new GroupByMapper(groupBy); - this.aggregateKeyName = GroupByMapper.keyNameFor(expressions); - } + String groupedKeyNameFor(final List groupByExpressions) { + return groupByExpressions.stream() + .map(Expression::toString) + .collect(Collectors.joining(GROUP_BY_COLUMN_SEPARATOR)); } protected static class KsqlValueJoiner diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKTable.java b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKTable.java index d8e7a37565d7..fc60495a92ca 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKTable.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKTable.java @@ -25,10 +25,12 @@ import io.confluent.ksql.execution.plan.Formats; import io.confluent.ksql.execution.plan.JoinType; import io.confluent.ksql.execution.plan.SelectExpression; +import io.confluent.ksql.execution.plan.TableGroupBy; import io.confluent.ksql.execution.plan.TableMapValues; import io.confluent.ksql.execution.streams.ExecutionStepFactory; -import io.confluent.ksql.execution.streams.StreamsUtil; +import io.confluent.ksql.execution.streams.TableGroupByBuilder; import io.confluent.ksql.execution.streams.TableMapValuesBuilder; +import io.confluent.ksql.execution.util.StructKeyUtil; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.logging.processing.ProcessingLogContext; import io.confluent.ksql.metastore.model.KeyField; @@ -50,8 +52,6 @@ import java.util.Set; import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.KeyValue; -import org.apache.kafka.streams.kstream.Grouped; import org.apache.kafka.streams.kstream.KGroupedTable; import org.apache.kafka.streams.kstream.KStream; import org.apache.kafka.streams.kstream.KTable; @@ -241,40 +241,25 @@ public ExecutionStep> getSourceTableStep() { } @Override + @SuppressWarnings("unchecked") public SchemaKGroupedStream groupBy( final ValueFormat valueFormat, - final Serde valSerde, final List groupByExpressions, - final QueryContext.Stacker contextStacker + final QueryContext.Stacker contextStacker, + final KsqlQueryBuilder queryBuilder ) { - final GroupBy groupBy = new GroupBy(groupByExpressions); final KeyFormat groupedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo()); final KeySerde groupedKeySerde = keySerde .rebind(StructKeyUtil.ROWKEY_SERIALIZED_SCHEMA); - final Grouped grouped = streamsFactories.getGroupedFactory() - .create( - StreamsUtil.buildOpName(contextStacker.getQueryContext()), - groupedKeySerde, - valSerde - ); - - final KGroupedTable kgroupedTable = ktable - .filter((key, value) -> value != null) - .groupBy( - (key, value) -> new KeyValue<>(groupBy.mapper.apply(key, value), value), - grouped - ); - - final LegacyField legacyKeyField = LegacyField - .notInSchema(groupBy.aggregateKeyName, SqlTypes.STRING); - - final Optional newKeyField = getSchema().findValueColumn(groupBy.aggregateKeyName) - .map(Column::fullName); + final String aggregateKeyName = groupedKeyNameFor(groupByExpressions); + final LegacyField legacyKeyField = LegacyField.notInSchema(aggregateKeyName, SqlTypes.STRING); + final Optional newKeyField = + getSchema().findValueColumn(aggregateKeyName).map(Column::fullName); - final ExecutionStep> step = + final TableGroupBy, KGroupedTable> step = ExecutionStepFactory.tableGroupBy( contextStacker, sourceTableStep, @@ -282,7 +267,12 @@ public SchemaKGroupedStream groupBy( groupByExpressions ); return new SchemaKGroupedTable( - kgroupedTable, + TableGroupByBuilder.build( + (KTable) ktable, + step, + queryBuilder, + streamsFactories.getGroupedFactory() + ), step, groupedKeyFormat, groupedKeySerde, diff --git a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/AggregateNodeTest.java b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/AggregateNodeTest.java index 0d98075fe00e..7e5bc7aceba7 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/AggregateNodeTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/AggregateNodeTest.java @@ -52,6 +52,7 @@ import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.PersistenceSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.serde.FormatInfo; import io.confluent.ksql.serde.KeySerde; import io.confluent.ksql.serde.WindowInfo; import io.confluent.ksql.structured.SchemaKStream; @@ -401,7 +402,7 @@ private SchemaKStream buildQuery(final AggregateNode aggregateNode, final KsqlCo when(ksqlStreamBuilder.buildNodeContext(any())).thenAnswer(inv -> new QueryContext.Stacker(queryId) .push(inv.getArgument(0).toString())); - when(ksqlStreamBuilder.buildKeySerde(any(), any(), any())).thenReturn(keySerde); + when(ksqlStreamBuilder.buildKeySerde(any(FormatInfo.class), any(), any())).thenReturn(keySerde); return aggregateNode.buildStream(ksqlStreamBuilder); } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/DataSourceNodeTest.java b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/DataSourceNodeTest.java index e639af2a279b..87d18f86ef40 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/DataSourceNodeTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/DataSourceNodeTest.java @@ -190,7 +190,8 @@ public void before() { new QueryContext.Stacker(queryId) .push(inv.getArgument(0).toString())); - when(ksqlStreamBuilder.buildKeySerde(any(), any(), any())).thenReturn((KeySerde)keySerde); + when(ksqlStreamBuilder.buildKeySerde(any(FormatInfo.class), any(), any())) + .thenReturn((KeySerde)keySerde); when(ksqlStreamBuilder.buildKeySerde(any(), any(), any(), any())).thenReturn((KeySerde)keySerde); when(ksqlStreamBuilder.buildValueSerde(any(), any(), any())).thenReturn(rowSerde); when(ksqlStreamBuilder.getFunctionRegistry()).thenReturn(functionRegistry); diff --git a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java index 8041d0846338..7c4a6a6cc949 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java @@ -188,7 +188,8 @@ public void setUp() { when(ksqlStreamBuilder.buildNodeContext(any())).thenAnswer(inv -> new QueryContext.Stacker(queryId) .push(inv.getArgument(0).toString())); - when(ksqlStreamBuilder.buildKeySerde(any(), any(), any())).thenReturn(keySerde); + when(ksqlStreamBuilder.buildKeySerde(any(FormatInfo.class), any(), any())) + .thenReturn(keySerde); when(keySerde.rebind(any(PersistenceSchema.class))).thenReturn(reboundKeySerde); diff --git a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlBareOutputNodeTest.java b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlBareOutputNodeTest.java index 8a00da999e71..162553cec172 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlBareOutputNodeTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlBareOutputNodeTest.java @@ -35,6 +35,7 @@ import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.serde.FormatInfo; import io.confluent.ksql.serde.KeySerde; import io.confluent.ksql.structured.SchemaKStream; import io.confluent.ksql.testutils.AnalysisTestUtil; @@ -89,7 +90,8 @@ public void before() { when(ksqlStreamBuilder.buildNodeContext(any())).thenAnswer(inv -> new QueryContext.Stacker(queryId) .push(inv.getArgument(0).toString())); - when(ksqlStreamBuilder.buildKeySerde(any(), any(), any())).thenReturn(keySerde); + when(ksqlStreamBuilder.buildKeySerde(any(FormatInfo.class), any(), any())) + .thenReturn(keySerde); final KsqlBareOutputNode planNode = (KsqlBareOutputNode) AnalysisTestUtil .buildLogicalPlan(ksqlConfig, SIMPLE_SELECT_WITH_FILTER, metaStore); diff --git a/ksql-engine/src/test/java/io/confluent/ksql/streams/GroupedFactoryTest.java b/ksql-engine/src/test/java/io/confluent/ksql/streams/GroupedFactoryTest.java index 8654470bb195..e72bf4183615 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/streams/GroupedFactoryTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/streams/GroupedFactoryTest.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableMap; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.streams.GroupedFactory; import io.confluent.ksql.util.KsqlConfig; import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.streams.StreamsConfig; diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java index b463b749688e..90a95ae9f178 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java @@ -32,6 +32,7 @@ import com.google.common.collect.ImmutableMap; import io.confluent.kafka.schemaregistry.client.MockSchemaRegistryClient; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.expression.tree.DereferenceExpression; import io.confluent.ksql.execution.expression.tree.Expression; @@ -138,6 +139,8 @@ public class SchemaKGroupedTableTest { private KsqlAggregateFunction otherFunc; @Mock private TableAggregationFunction tableFunc; + @Mock + private KsqlQueryBuilder queryBuilder; private KTable kTable; private KsqlTable ksqlTable; @@ -161,6 +164,9 @@ public void init() { Consumed.with(Serdes.String(), rowSerde) ); + when(queryBuilder.getFunctionRegistry()).thenReturn(functionRegistry); + when(queryBuilder.getKsqlConfig()).thenReturn(ksqlConfig); + when(aggregateSchema.findValueColumn("GROUPING_COLUMN")) .thenReturn(Optional.of(Column.of("GROUPING_COLUMN", SqlTypes.STRING))); @@ -210,7 +216,7 @@ private SchemaKGroupedTable buildSchemaKGroupedTableFromQuery( ); final SchemaKGroupedStream groupedSchemaKTable = initialSchemaKTable.groupBy( - valueFormat, rowSerde, groupByExpressions, queryContext); + valueFormat, groupByExpressions, queryContext, queryBuilder); Assert.assertThat(groupedSchemaKTable, instanceOf(SchemaKGroupedTable.class)); return (SchemaKGroupedTable)groupedSchemaKTable; } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java index fbf173874d66..50a4bc251435 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java @@ -68,6 +68,7 @@ import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.PersistenceSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.serde.Format; import io.confluent.ksql.serde.FormatInfo; @@ -76,7 +77,7 @@ import io.confluent.ksql.serde.KeySerde; import io.confluent.ksql.serde.SerdeOption; import io.confluent.ksql.serde.ValueFormat; -import io.confluent.ksql.streams.GroupedFactory; +import io.confluent.ksql.execution.streams.GroupedFactory; import io.confluent.ksql.streams.JoinedFactory; import io.confluent.ksql.execution.streams.MaterializedFactory; import io.confluent.ksql.streams.StreamsFactories; @@ -153,7 +154,7 @@ public class SchemaKStreamTest { private Serde rightSerde; private LogicalSchema joinSchema; private Serde rowSerde; - private KeyFormat keyFormat = KeyFormat.nonWindowed(FormatInfo.of(Format.JSON)); + private KeyFormat keyFormat = KeyFormat.nonWindowed(FormatInfo.of(Format.KAFKA)); private ValueFormat valueFormat = ValueFormat.of(FormatInfo.of(Format.JSON)); private ValueFormat rightFormat = ValueFormat.of(FormatInfo.of(Format.DELIMITED)); private final LogicalSchema simpleSchema = LogicalSchema.builder() @@ -581,9 +582,9 @@ public void testGroupByKey() { // When: final SchemaKGroupedStream groupedSchemaKStream = initialSchemaKStream.groupBy( valueFormat, - rowSerde, groupBy, - childContextStacker); + childContextStacker, + queryBuilder); // Then: assertThat(groupedSchemaKStream.getKeyField().name(), is(Optional.of("TEST1.COL0"))); @@ -591,7 +592,7 @@ public void testGroupByKey() { } @Test - public void shouldBuildStepForGroupBy() { + public void shouldBuildStepForGroupByKey() { // Given: givenInitialKStreamOf("SELECT col0, col1 FROM test1 WHERE col0 > 100 EMIT CHANGES;"); final List groupBy = Collections.singletonList( @@ -602,9 +603,39 @@ public void shouldBuildStepForGroupBy() { // When: final SchemaKGroupedStream groupedSchemaKStream = initialSchemaKStream.groupBy( valueFormat, - rowSerde, groupBy, - childContextStacker); + childContextStacker, + queryBuilder); + + // Then: + final KeyFormat expectedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo()); + assertThat( + groupedSchemaKStream.getSourceStep(), + equalTo( + ExecutionStepFactory.streamGroupByKey( + childContextStacker, + initialSchemaKStream.getSourceStep(), + Formats.of(expectedKeyFormat, valueFormat, SerdeOption.none()) + ) + ) + ); + } + + @Test + public void shouldBuildStepForGroupBy() { + // Given: + givenInitialKStreamOf("SELECT col0, col1 FROM test1 WHERE col0 > 100 EMIT CHANGES;"); + final List groupBy = Collections.singletonList( + new DereferenceExpression( + new QualifiedNameReference(QualifiedName.of("TEST1")), "COL1") + ); + + // When: + final SchemaKGroupedStream groupedSchemaKStream = initialSchemaKStream.groupBy( + valueFormat, + groupBy, + childContextStacker, + queryBuilder); // Then: final KeyFormat expectedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo()); @@ -636,9 +667,9 @@ public void testGroupByMultipleColumns() { // When: final SchemaKGroupedStream groupedSchemaKStream = initialSchemaKStream.groupBy( valueFormat, - rowSerde, groupBy, - childContextStacker); + childContextStacker, + queryBuilder); // Then: assertThat(groupedSchemaKStream.getKeyField().name(), is(Optional.empty())); @@ -655,9 +686,9 @@ public void testGroupByMoreComplexExpression() { // When: final SchemaKGroupedStream groupedSchemaKStream = initialSchemaKStream.groupBy( valueFormat, - rowSerde, ImmutableList.of(groupBy), - childContextStacker); + childContextStacker, + queryBuilder); // Then: assertThat(groupedSchemaKStream.getKeyField().name(), is(Optional.empty())); @@ -674,13 +705,15 @@ public void shouldUseFactoryForGroupedWithoutRekey() { ksqlStream.getKeyField().name().get()); final List groupByExpressions = Collections.singletonList(keyExpression); givenInitialSchemaKStreamUsesMocks(); + when(queryBuilder.buildKeySerde(any(KeyFormat.class), any(), any())).thenReturn(keySerde); + when(queryBuilder.buildValueSerde(any(), any(), any())).thenReturn(leftSerde); // When: initialSchemaKStream.groupBy( valueFormat, - leftSerde, groupByExpressions, - childContextStacker); + childContextStacker, + queryBuilder); // Then: verify(mockGroupedFactory).create( @@ -689,6 +722,17 @@ public void shouldUseFactoryForGroupedWithoutRekey() { same(leftSerde) ); verify(mockKStream).groupByKey(same(grouped)); + final LogicalSchema logicalSchema = ksqlStream.getSchema().withAlias(ksqlStream.getName()); + verify(queryBuilder).buildKeySerde( + KeyFormat.nonWindowed(FormatInfo.of(Format.KAFKA)), + PhysicalSchema.from(logicalSchema, SerdeOption.none()), + childContextStacker.getQueryContext() + ); + verify(queryBuilder).buildValueSerde( + valueFormat.getFormatInfo(), + PhysicalSchema.from(logicalSchema, SerdeOption.none()), + childContextStacker.getQueryContext() + ); } @Test @@ -704,13 +748,15 @@ public void shouldUseFactoryForGrouped() { new QualifiedNameReference(QualifiedName.of(ksqlStream.getName())), "COL1"); final List groupByExpressions = Arrays.asList(col1Expression, col0Expression); givenInitialSchemaKStreamUsesMocks(); + when(queryBuilder.buildKeySerde(any(KeyFormat.class), any(), any())).thenReturn(reboundKeySerde); + when(queryBuilder.buildValueSerde(any(), any(), any())).thenReturn(leftSerde); // When: initialSchemaKStream.groupBy( valueFormat, - leftSerde, groupByExpressions, - childContextStacker); + childContextStacker, + queryBuilder); // Then: verify(mockGroupedFactory).create( @@ -718,6 +764,17 @@ public void shouldUseFactoryForGrouped() { same(reboundKeySerde), same(leftSerde)); verify(mockKStream).groupBy(any(KeyValueMapper.class), same(grouped)); + final LogicalSchema logicalSchema = ksqlStream.getSchema().withAlias(ksqlStream.getName()); + verify(queryBuilder).buildKeySerde( + KeyFormat.nonWindowed(FormatInfo.of(Format.KAFKA)), + PhysicalSchema.from(logicalSchema, SerdeOption.none()), + childContextStacker.getQueryContext() + ); + verify(queryBuilder).buildValueSerde( + valueFormat.getFormatInfo(), + PhysicalSchema.from(logicalSchema, SerdeOption.none()), + childContextStacker.getQueryContext() + ); } @Test diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKTableTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKTableTest.java index db36fe8e405f..d63d5d32ca84 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKTableTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKTableTest.java @@ -50,6 +50,7 @@ import io.confluent.ksql.execution.plan.JoinType; import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.execution.streams.ExecutionStepFactory; +import io.confluent.ksql.execution.util.StructKeyUtil; import io.confluent.ksql.function.InternalFunctionRegistry; import io.confluent.ksql.logging.processing.ProcessingLogContext; import io.confluent.ksql.metastore.MetaStore; @@ -72,7 +73,7 @@ import io.confluent.ksql.serde.KeySerde; import io.confluent.ksql.serde.SerdeOption; import io.confluent.ksql.serde.ValueFormat; -import io.confluent.ksql.streams.GroupedFactory; +import io.confluent.ksql.execution.streams.GroupedFactory; import io.confluent.ksql.streams.JoinedFactory; import io.confluent.ksql.execution.streams.MaterializedFactory; import io.confluent.ksql.streams.StreamsFactories; @@ -403,15 +404,14 @@ public void testGroupBy() { final String selectQuery = "SELECT col0, col1, col2 FROM test2 EMIT CHANGES;"; final PlanNode logicalPlan = buildLogicalPlan(selectQuery); initialSchemaKTable = buildSchemaKTableFromPlan(logicalPlan); - final Serde rowSerde = mock(Serde.class); final List groupByExpressions = Arrays.asList(TEST_2_COL_2, TEST_2_COL_1); // When: final SchemaKGroupedStream groupedSchemaKTable = initialSchemaKTable.groupBy( valueFormat, - rowSerde, groupByExpressions, - childContextStacker); + childContextStacker, + queryBuilder); // Then: assertThat(groupedSchemaKTable, instanceOf(SchemaKGroupedTable.class)); @@ -426,15 +426,14 @@ public void shouldBuildStepForGroupBy() { final String selectQuery = "SELECT col0, col1, col2 FROM test2 EMIT CHANGES;"; final PlanNode logicalPlan = buildLogicalPlan(selectQuery); initialSchemaKTable = buildSchemaKTableFromPlan(logicalPlan); - final Serde rowSerde = mock(Serde.class); final List groupByExpressions = Arrays.asList(TEST_2_COL_2, TEST_2_COL_1); // When: final SchemaKGroupedStream groupedSchemaKTable = initialSchemaKTable.groupBy( valueFormat, - rowSerde, groupByExpressions, - childContextStacker); + childContextStacker, + queryBuilder); // Then: assertThat( @@ -455,6 +454,7 @@ public void shouldUseOpNameForGrouped() { // Given: final Serde valSerde = getRowSerde(ksqlTable.getKsqlTopic(), ksqlTable.getSchema().valueConnectSchema()); + when(queryBuilder.buildValueSerde(any(), any(), any())).thenReturn(valSerde); expect( groupedFactory.create( eq(StreamsUtil.buildOpName(childContextStacker.getQueryContext())), @@ -470,7 +470,7 @@ public void shouldUseOpNameForGrouped() { final SchemaKTable schemaKTable = buildSchemaKTable(ksqlTable, mockKTable, groupedFactory); // When: - schemaKTable.groupBy(valueFormat, valSerde, groupByExpressions, childContextStacker); + schemaKTable.groupBy(valueFormat, groupByExpressions, childContextStacker, queryBuilder); // Then: verify(mockKTable, groupedFactory); @@ -505,16 +505,9 @@ public void shouldGroupKeysCorrectly() { ); final List groupByExpressions = Arrays.asList(TEST_2_COL_2, TEST_2_COL_1); - final Serde rowSerde = GenericRowSerDe.from( - FormatInfo.of(Format.JSON, Optional.empty()), - PersistenceSchema.from(initialSchemaKTable.getSchema().valueConnectSchema(), false), - null, - () -> null, - "test", - processingLogContext); // Call groupBy and extract the captured mapper - initialSchemaKTable.groupBy(valueFormat, rowSerde, groupByExpressions, childContextStacker); + initialSchemaKTable.groupBy(valueFormat, groupByExpressions, childContextStacker, queryBuilder); verify(mockKTable, mockKGroupedTable); final KeyValueMapper keySelector = capturedKeySelector.getValue(); final GenericRow value = new GenericRow(Arrays.asList("key", 0, 100, "foo", "bar")); @@ -760,7 +753,7 @@ public void shouldSetKeyOnGroupBySingleExpressionThatIsInProjection() { // When: final SchemaKGroupedStream result = initialSchemaKTable - .groupBy(valueFormat, rowSerde, groupByExprs, childContextStacker); + .groupBy(valueFormat, groupByExprs, childContextStacker, queryBuilder); // Then: assertThat(result.getKeyField(), diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/builder/KsqlQueryBuilder.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/builder/KsqlQueryBuilder.java index 722409dd927d..4738c4439f27 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/builder/KsqlQueryBuilder.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/builder/KsqlQueryBuilder.java @@ -29,6 +29,7 @@ import io.confluent.ksql.serde.FormatInfo; import io.confluent.ksql.serde.GenericKeySerDe; import io.confluent.ksql.serde.GenericRowSerDe; +import io.confluent.ksql.serde.KeyFormat; import io.confluent.ksql.serde.KeySerde; import io.confluent.ksql.serde.KeySerdeFactory; import io.confluent.ksql.serde.ValueSerdeFactory; @@ -152,6 +153,28 @@ public KeySerde buildKeySerde( ); } + @SuppressWarnings("unchecked") + public KeySerde buildKeySerde( + final KeyFormat keyFormat, + final PhysicalSchema physicalSchema, + final QueryContext queryContext + ) { + if (keyFormat.isWindowed()) { + return (KeySerde) buildKeySerde( + keyFormat.getFormatInfo(), + keyFormat.getWindowInfo().get(), + physicalSchema, + queryContext + ); + } else { + return (KeySerde) buildKeySerde( + keyFormat.getFormatInfo(), + physicalSchema, + queryContext + ); + } + } + public KeySerde> buildKeySerde( final FormatInfo format, final WindowInfo window, diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamGroupBy.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamGroupBy.java index 3a2ec496926c..4de8257bad77 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamGroupBy.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamGroupBy.java @@ -53,6 +53,10 @@ public List> getSources() { return Collections.singletonList(source); } + public Formats getFormats() { + return formats; + } + @Override public G build(final KsqlQueryBuilder streamsBuilder) { throw new UnsupportedOperationException(); diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamGroupByKey.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamGroupByKey.java new file mode 100644 index 000000000000..c5ddd17a9e9e --- /dev/null +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamGroupByKey.java @@ -0,0 +1,80 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License; you may not use this file + * except in compliance with the License. You may obtain a copy of the License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.plan; + +import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +@Immutable +public class StreamGroupByKey implements ExecutionStep { + private final ExecutionStepProperties properties; + private final ExecutionStep source; + private final Formats formats; + + public StreamGroupByKey( + final ExecutionStepProperties properties, + final ExecutionStep source, + final Formats formats) { + this.properties = Objects.requireNonNull(properties, "properties"); + this.formats = Objects.requireNonNull(formats, "formats"); + this.source = Objects.requireNonNull(source, "source"); + } + + @Override + public ExecutionStepProperties getProperties() { + return properties; + } + + @Override + public List> getSources() { + return Collections.singletonList(source); + } + + public ExecutionStep getSource() { + return source; + } + + public Formats getFormats() { + return formats; + } + + @Override + public G build(final KsqlQueryBuilder streamsBuilder) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final StreamGroupByKey that = (StreamGroupByKey) o; + return Objects.equals(properties, that.properties) + && Objects.equals(source, that.source) + && Objects.equals(formats, that.formats); + } + + @Override + public int hashCode() { + + return Objects.hash(properties, source, formats); + } +} diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableGroupBy.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableGroupBy.java index fb031952272b..21207a4aaf40 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableGroupBy.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableGroupBy.java @@ -50,6 +50,14 @@ public List> getSources() { return Collections.singletonList(source); } + public Formats getFormats() { + return formats; + } + + public List getGroupByExpressions() { + return groupByExpressions; + } + @Override public G build(final KsqlQueryBuilder builder) { throw new UnsupportedOperationException(); diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/StructKeyUtil.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/util/StructKeyUtil.java similarity index 86% rename from ksql-engine/src/main/java/io/confluent/ksql/structured/StructKeyUtil.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/util/StructKeyUtil.java index 4d23b1d2dc24..100dd68c4a17 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/StructKeyUtil.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/util/StructKeyUtil.java @@ -13,7 +13,7 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.structured; +package io.confluent.ksql.execution.util; import io.confluent.ksql.schema.ksql.PersistenceSchema; import io.confluent.ksql.util.SchemaUtil; @@ -25,7 +25,7 @@ /** * Helper for dealing with Struct keys. */ -final class StructKeyUtil { +public final class StructKeyUtil { private static final Schema ROWKEY_STRUCT_SCHEMA = SchemaBuilder .struct() @@ -35,7 +35,7 @@ final class StructKeyUtil { private static final org.apache.kafka.connect.data.Field ROWKEY_FIELD = ROWKEY_STRUCT_SCHEMA.fields().get(0); - static final PersistenceSchema ROWKEY_SERIALIZED_SCHEMA = PersistenceSchema.from( + public static final PersistenceSchema ROWKEY_SERIALIZED_SCHEMA = PersistenceSchema.from( (ConnectSchema) ROWKEY_STRUCT_SCHEMA, false ); @@ -43,7 +43,7 @@ final class StructKeyUtil { private StructKeyUtil() { } - static Struct asStructKey(final String rowKey) { + public static Struct asStructKey(final String rowKey) { final Struct keyStruct = new Struct(ROWKEY_STRUCT_SCHEMA); keyStruct.put(ROWKEY_FIELD, rowKey); return keyStruct; diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java index 9d7d01d3845a..8bf170eeab68 100644 --- a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java @@ -28,6 +28,7 @@ import io.confluent.ksql.execution.plan.StreamAggregate; import io.confluent.ksql.execution.plan.StreamFilter; import io.confluent.ksql.execution.plan.StreamGroupBy; +import io.confluent.ksql.execution.plan.StreamGroupByKey; import io.confluent.ksql.execution.plan.StreamMapValues; import io.confluent.ksql.execution.plan.StreamSelectKey; import io.confluent.ksql.execution.plan.StreamSink; @@ -335,6 +336,20 @@ public static TableTableJoin> tableTableJoin( ); } + public static StreamGroupByKey, KGroupedStream> + streamGroupByKey( + final QueryContext.Stacker stacker, + final ExecutionStep> sourceStep, + final Formats formats + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new StreamGroupByKey<>( + sourceStep.getProperties().withQueryContext(queryContext), + sourceStep, + formats + ); + } + public static TableAggregate, KGroupedTable> tableAggregate( final QueryContext.Stacker stacker, diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/GroupByMapper.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/GroupByMapper.java similarity index 83% rename from ksql-engine/src/main/java/io/confluent/ksql/structured/GroupByMapper.java rename to ksql-streams/src/main/java/io/confluent/ksql/execution/streams/GroupByMapper.java index c2d8e5e25fe9..1b0c0bd3dbce 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/GroupByMapper.java +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/GroupByMapper.java @@ -13,22 +13,21 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.structured; +package io.confluent.ksql.execution.streams; import com.google.common.collect.ImmutableList; import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.codegen.ExpressionMetadata; -import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.util.StructKeyUtil; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; import java.util.stream.IntStream; -import org.apache.kafka.connect.data.Struct; import org.apache.kafka.streams.kstream.KeyValueMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class GroupByMapper implements KeyValueMapper { +class GroupByMapper implements KeyValueMapper { private static final Logger LOG = LoggerFactory.getLogger(GroupByMapper.class); @@ -44,7 +43,7 @@ class GroupByMapper implements KeyValueMapper { } @Override - public Struct apply(final Object key, final GenericRow row) { + public Object apply(final Object key, final GenericRow row) { final String stringRowKey = IntStream.range(0, expressions.size()) .mapToObj(idx -> processColumn(idx, expressions.get(idx), row)) .collect(Collectors.joining(GROUP_BY_COLUMN_SEPARATOR)); @@ -52,12 +51,6 @@ public Struct apply(final Object key, final GenericRow row) { return StructKeyUtil.asStructKey(stringRowKey); } - static String keyNameFor(final List groupByExpressions) { - return groupByExpressions.stream() - .map(Expression::toString) - .collect(Collectors.joining(GROUP_BY_COLUMN_SEPARATOR)); - } - private static String processColumn( final int index, final ExpressionMetadata exp, @@ -70,4 +63,8 @@ private static String processColumn( return "null"; } } + + List getExpressionMetadata() { + return expressions; + } } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/streams/GroupedFactory.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/GroupedFactory.java similarity index 94% rename from ksql-engine/src/main/java/io/confluent/ksql/streams/GroupedFactory.java rename to ksql-streams/src/main/java/io/confluent/ksql/execution/streams/GroupedFactory.java index 0b4a6abd56d9..7c72e0c38ca6 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/streams/GroupedFactory.java +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/GroupedFactory.java @@ -13,9 +13,8 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.streams; +package io.confluent.ksql.execution.streams; -import io.confluent.ksql.execution.streams.StreamsUtil; import io.confluent.ksql.util.KsqlConfig; import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.streams.kstream.Grouped; diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamGroupByBuilder.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamGroupByBuilder.java new file mode 100644 index 000000000000..a2323a95e327 --- /dev/null +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamGroupByBuilder.java @@ -0,0 +1,108 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.codegen.CodeGenRunner; +import io.confluent.ksql.execution.codegen.ExpressionMetadata; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.StreamGroupBy; +import io.confluent.ksql.execution.plan.StreamGroupByKey; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.serde.KeySerde; +import java.util.List; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; + +public final class StreamGroupByBuilder { + private StreamGroupByBuilder() { + } + + public static KGroupedStream build( + final KStream kstream, + final StreamGroupByKey step, + final KsqlQueryBuilder queryBuilder, + final GroupedFactory groupedFactory + ) { + final LogicalSchema sourceSchema = step.getSource().getProperties().getSchema(); + final QueryContext queryContext = step.getProperties().getQueryContext(); + final Formats formats = step.getFormats(); + final Grouped grouped = buildGrouped( + formats, + sourceSchema, + queryContext, + queryBuilder, + groupedFactory + ); + return kstream.groupByKey(grouped); + } + + public static KGroupedStream build( + final KStream kstream, + final StreamGroupBy step, + final KsqlQueryBuilder queryBuilder, + final GroupedFactory groupedFactory + ) { + final LogicalSchema sourceSchema = step.getSources().get(0).getProperties().getSchema(); + final QueryContext queryContext = step.getProperties().getQueryContext(); + final Formats formats = step.getFormats(); + final Grouped grouped = buildGrouped( + formats, + sourceSchema, + queryContext, + queryBuilder, + groupedFactory + ); + final List groupBy = CodeGenRunner.compileExpressions( + step.getGroupByExpressions().stream(), + "Group By", + sourceSchema, + queryBuilder.getKsqlConfig(), + queryBuilder.getFunctionRegistry() + ); + final GroupByMapper mapper = new GroupByMapper(groupBy); + return kstream.filter((key, value) -> value != null).groupBy(mapper, grouped); + } + + private static Grouped buildGrouped( + final Formats formats, + final LogicalSchema schema, + final QueryContext queryContext, + final KsqlQueryBuilder queryBuilder, + final GroupedFactory groupedFactory + ) { + final PhysicalSchema physicalSchema = PhysicalSchema.from( + schema, + formats.getOptions() + ); + final KeySerde keySerde = queryBuilder.buildKeySerde( + formats.getKeyFormat(), + physicalSchema, + queryContext + ); + final Serde valSerde = queryBuilder.buildValueSerde( + formats.getValueFormat().getFormatInfo(), + physicalSchema, + queryContext + ); + return groupedFactory.create(StreamsUtil.buildOpName(queryContext), keySerde, valSerde); + } +} diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamToTableBuilder.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamToTableBuilder.java index 895fdcf62f0c..6683def8df22 100644 --- a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamToTableBuilder.java +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamToTableBuilder.java @@ -54,9 +54,8 @@ public static KTable build( queryContext ); final KeyFormat keyFormat = streamToTable.getFormats().getKeyFormat(); - final KeySerde keySerde = buildKeySerde( + final KeySerde keySerde = queryBuilder.buildKeySerde( keyFormat, - queryBuilder, physicalSchema, queryContext ); @@ -84,27 +83,4 @@ public static KTable build( (k, value, oldValue) -> value.orElse(null), materialized); } - - @SuppressWarnings("unchecked") - private static KeySerde buildKeySerde( - final KeyFormat keyFormat, - final KsqlQueryBuilder queryBuilder, - final PhysicalSchema physicalSchema, - final QueryContext queryContext - ) { - if (keyFormat.isWindowed()) { - return (KeySerde) queryBuilder.buildKeySerde( - keyFormat.getFormatInfo(), - keyFormat.getWindowInfo().get(), - physicalSchema, - queryContext - ); - } else { - return (KeySerde) queryBuilder.buildKeySerde( - keyFormat.getFormatInfo(), - physicalSchema, - queryContext - ); - } - } } diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/TableGroupByBuilder.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/TableGroupByBuilder.java new file mode 100644 index 000000000000..2df715695912 --- /dev/null +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/TableGroupByBuilder.java @@ -0,0 +1,98 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.codegen.CodeGenRunner; +import io.confluent.ksql.execution.codegen.ExpressionMetadata; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.TableGroupBy; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.serde.KeySerde; +import java.util.List; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedTable; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.KeyValueMapper; + +public final class TableGroupByBuilder { + private TableGroupByBuilder() { + } + + public static KGroupedTable build( + final KTable ktable, + final TableGroupBy step, + final KsqlQueryBuilder queryBuilder, + final GroupedFactory groupedFactory + ) { + final LogicalSchema sourceSchema = step.getSources().get(0).getProperties().getSchema(); + final QueryContext queryContext = step.getProperties().getQueryContext(); + final Formats formats = step.getFormats(); + final PhysicalSchema physicalSchema = PhysicalSchema.from( + sourceSchema, + formats.getOptions() + ); + final KeySerde keySerde = queryBuilder.buildKeySerde( + formats.getKeyFormat(), + physicalSchema, + queryContext + ); + final Serde valSerde = queryBuilder.buildValueSerde( + formats.getValueFormat().getFormatInfo(), + physicalSchema, + queryContext + ); + final Grouped grouped = groupedFactory.create( + StreamsUtil.buildOpName(queryContext), + keySerde, + valSerde + ); + final List groupBy = CodeGenRunner.compileExpressions( + step.getGroupByExpressions().stream(), + "Group By", + sourceSchema, + queryBuilder.getKsqlConfig(), + queryBuilder.getFunctionRegistry() + ); + final GroupByMapper mapper = new GroupByMapper(groupBy); + return ktable + .filter((key, value) -> value != null) + .groupBy(new TableKeyValueMapper(mapper), grouped); + } + + public static final class TableKeyValueMapper + implements KeyValueMapper> { + private final GroupByMapper groupByMapper; + + private TableKeyValueMapper(final GroupByMapper groupByMapper) { + this.groupByMapper = groupByMapper; + } + + @Override + public KeyValue apply(final Object key, final GenericRow value) { + return new KeyValue<>(groupByMapper.apply(key, value), value); + } + + GroupByMapper getGroupByMapper() { + return groupByMapper; + } + } +} diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/GroupByMapperTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/GroupByMapperTest.java similarity index 74% rename from ksql-engine/src/test/java/io/confluent/ksql/structured/GroupByMapperTest.java rename to ksql-streams/src/test/java/io/confluent/ksql/execution/streams/GroupByMapperTest.java index f1bf48b5f29b..c8c0d3f91c56 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/GroupByMapperTest.java +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/GroupByMapperTest.java @@ -13,7 +13,7 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.structured; +package io.confluent.ksql.execution.streams; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; @@ -21,10 +21,7 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.codegen.ExpressionMetadata; -import io.confluent.ksql.execution.expression.tree.DereferenceExpression; -import io.confluent.ksql.execution.expression.tree.Expression; -import io.confluent.ksql.execution.expression.tree.QualifiedName; -import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; +import io.confluent.ksql.execution.util.StructKeyUtil; import java.util.Collections; import org.apache.kafka.connect.data.Struct; import org.easymock.EasyMock; @@ -72,7 +69,7 @@ public void shouldGenerateGroupByKey() { EasyMock.replay(groupBy0, groupBy1); // When: - final Struct result = mapper.apply("key", row); + final Struct result = (Struct) mapper.apply("key", row); // Then: assertThat(result, is(StructKeyUtil.asStructKey("result0|+|result1"))); @@ -86,7 +83,7 @@ public void shouldSupportNullValues() { EasyMock.replay(groupBy0, groupBy1); // When: - final Struct result = mapper.apply("key", row); + final Struct result = (Struct) mapper.apply("key", row); // Then: assertThat(result, is(StructKeyUtil.asStructKey("null|+|result1"))); @@ -100,24 +97,9 @@ public void shouldUseNullIfExpressionThrows() { EasyMock.replay(groupBy0, groupBy1); // When: - final Struct result = mapper.apply("key", row); + final Struct result = (Struct) mapper.apply("key", row); // Then: assertThat(result, is(StructKeyUtil.asStructKey("null|+|result1"))); } - - @Test - public void shouldGetKeyName() { - // Given: - final Expression exp0 = new DereferenceExpression( - new QualifiedNameReference(QualifiedName.of("Fred")), "f1"); - final Expression exp1 = new DereferenceExpression( - new QualifiedNameReference(QualifiedName.of("Bob")), "b1"); - - // When: - final String result = GroupByMapper.keyNameFor(ImmutableList.of(exp0, exp1)); - - // Then: - assertThat(result, is("Fred.f1|+|Bob.b1")); - } -} \ No newline at end of file +} diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamGroupByBuilderTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamGroupByBuilderTest.java new file mode 100644 index 000000000000..df801b5af30f --- /dev/null +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamGroupByBuilderTest.java @@ -0,0 +1,264 @@ +package io.confluent.ksql.execution.streams; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.expression.tree.DereferenceExpression; +import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.QualifiedName; +import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; +import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.ExecutionStepProperties; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.StreamGroupBy; +import io.confluent.ksql.execution.plan.StreamGroupByKey; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.query.QueryId; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.serde.Format; +import io.confluent.ksql.serde.FormatInfo; +import io.confluent.ksql.serde.KeyFormat; +import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; +import io.confluent.ksql.util.KsqlConfig; +import java.util.List; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Predicate; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +public class StreamGroupByBuilderTest { + private static final String ALIAS = "SOURCE"; + private static final LogicalSchema SCHEMA = LogicalSchema.builder() + .valueColumn("PAC", SqlTypes.BIGINT) + .valueColumn("MAN", SqlTypes.STRING) + .build() + .withAlias(ALIAS) + .withMetaAndKeyColsInValue(); + private static final PhysicalSchema PHYSICAL_SCHEMA = PhysicalSchema.from(SCHEMA, SerdeOption.none()); + + private final List groupByExpressions = ImmutableList.of( + dereference("PAC"), + dereference("MAN") + ); + private final QueryContext sourceContext = + new QueryContext.Stacker(new QueryId("qid")).push("foo").push("source").getQueryContext(); + private final QueryContext stepContext = + new QueryContext.Stacker(new QueryId("qid")).push("foo").push("groupby").getQueryContext(); + private final ExecutionStepProperties sourceProperties = new DefaultExecutionStepProperties( + SCHEMA, + sourceContext + ); + private final ExecutionStepProperties properties = new DefaultExecutionStepProperties( + SCHEMA, + stepContext + ); + private final Formats formats = Formats.of( + KeyFormat.nonWindowed(FormatInfo.of(Format.KAFKA)), + ValueFormat.of(FormatInfo.of(Format.JSON)), + SerdeOption.none() + ); + + @Mock + private KsqlQueryBuilder queryBuilder; + @Mock + private KsqlConfig ksqlConfig; + @Mock + private FunctionRegistry functionRegistry; + @Mock + private GroupedFactory groupedFactory; + @Mock + private ExecutionStep sourceStep; + @Mock + private KeySerde keySerde; + @Mock + private Serde valueSerde; + @Mock + private Grouped grouped; + @Mock + private KStream sourceStream; + @Mock + private KStream filteredStream; + @Mock + private KGroupedStream groupedStream; + @Captor + private ArgumentCaptor mapperCaptor; + @Captor + private ArgumentCaptor predicateCaptor; + + private StreamGroupBy streamGroupBy; + private StreamGroupByKey streamGroupByKey; + + @Rule + public final MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Before + @SuppressWarnings("unchecked") + public void init() { + when(queryBuilder.getKsqlConfig()).thenReturn(ksqlConfig); + when(queryBuilder.getFunctionRegistry()).thenReturn(functionRegistry); + when(queryBuilder.buildKeySerde(any(KeyFormat.class), any(), any())).thenReturn(keySerde); + when(queryBuilder.buildValueSerde(any(), any(), any())).thenReturn(valueSerde); + when(groupedFactory.create(any(), any(), any())).thenReturn(grouped); + when(sourceStream.groupByKey(any(Grouped.class))).thenReturn(groupedStream); + when(sourceStream.filter(any())).thenReturn(filteredStream); + when(filteredStream.groupBy(any(KeyValueMapper.class), any(Grouped.class))) + .thenReturn(groupedStream); + when(sourceStep.getProperties()).thenReturn(sourceProperties); + streamGroupBy = new StreamGroupBy<>( + properties, + sourceStep, + formats, + groupByExpressions + ); + streamGroupByKey = new StreamGroupByKey<>(properties, sourceStep, formats); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldPerformGroupByCorrectly() { + // When: + final KGroupedStream result = + StreamGroupByBuilder.build(sourceStream, streamGroupBy, queryBuilder, groupedFactory); + + // Then: + assertThat(result, is(groupedStream)); + verify(sourceStream).filter(any()); + verify(filteredStream).groupBy(mapperCaptor.capture(), same(grouped)); + verifyNoMoreInteractions(filteredStream, sourceStream); + final GroupByMapper mapper = mapperCaptor.getValue(); + assertThat(mapper.getExpressionMetadata(), hasSize(2)); + assertThat( + mapper.getExpressionMetadata().get(0).getExpression(), + equalTo(groupByExpressions.get(0)) + ); + assertThat( + mapper.getExpressionMetadata().get(1).getExpression(), + equalTo(groupByExpressions.get(1)) + ); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldFilterNullRowsBeforeGroupBy() { + // When: + StreamGroupByBuilder.build(sourceStream, streamGroupBy, queryBuilder, groupedFactory); + + // Then: + verify(sourceStream).filter(predicateCaptor.capture()); + final Predicate predicate = predicateCaptor.getValue(); + assertThat(predicate.test(new Object(), new GenericRow()), is(true)); + assertThat(predicate.test(new Object(), null), is(false)); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldBuildGroupedCorrectlyForGroupBy() { + // When: + StreamGroupByBuilder.build(sourceStream, streamGroupBy, queryBuilder, groupedFactory); + + // Then: + verify(groupedFactory).create("foo-groupby", keySerde, valueSerde); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldBuildKeySerdeCorrectlyForGroupBy() { + // When: + StreamGroupByBuilder.build(sourceStream, streamGroupBy, queryBuilder, groupedFactory); + + // Then: + verify(queryBuilder).buildKeySerde(formats.getKeyFormat(), PHYSICAL_SCHEMA, stepContext); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldBuildValueSerdeCorrectlyForGroupBy() { + // When: + StreamGroupByBuilder.build(sourceStream, streamGroupBy, queryBuilder, groupedFactory); + + // Then: + verify(queryBuilder).buildValueSerde( + formats.getValueFormat().getFormatInfo(), + PHYSICAL_SCHEMA, + stepContext + ); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldPerformGroupByKeyCorrectly() { + // When: + final KGroupedStream result = + StreamGroupByBuilder.build(sourceStream, streamGroupByKey, queryBuilder, groupedFactory); + + // Then: + assertThat(result, is(groupedStream)); + verify(sourceStream).groupByKey(grouped); + verifyNoMoreInteractions(sourceStream); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldBuildGroupedCorrectlyForGroupByKey() { + // When: + StreamGroupByBuilder.build(sourceStream, streamGroupByKey, queryBuilder, groupedFactory); + + // Then: + verify(groupedFactory).create("foo-groupby", keySerde, valueSerde); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldBuildKeySerdeCorrectlyForGroupByKey() { + // When: + StreamGroupByBuilder.build(sourceStream, streamGroupByKey, queryBuilder, groupedFactory); + + // Then: + verify(queryBuilder).buildKeySerde(formats.getKeyFormat(), PHYSICAL_SCHEMA, stepContext); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldBuildValueSerdeCorrectlyForGroupByKey() { + // When: + StreamGroupByBuilder.build(sourceStream, streamGroupByKey, queryBuilder, groupedFactory); + + // Then: + verify(queryBuilder).buildValueSerde( + formats.getValueFormat().getFormatInfo(), + PHYSICAL_SCHEMA, + stepContext + ); + } + + private static Expression dereference(final String column) { + return new DereferenceExpression(new QualifiedNameReference(QualifiedName.of(ALIAS)), column); + } +} \ No newline at end of file diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamSourceBuilderTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamSourceBuilderTest.java index 6fd76cf317b7..7d6cea32a16a 100644 --- a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamSourceBuilderTest.java +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamSourceBuilderTest.java @@ -162,7 +162,7 @@ public void setup() { when(kStream.mapValues(any(ValueMapperWithKey.class))).thenReturn(kStream); when(kStream.transformValues(any(ValueTransformerSupplier.class))).thenReturn(kStream); when(queryBuilder.buildKeySerde(any(), any(), any(), any())).thenReturn(keySerde); - when(queryBuilder.buildKeySerde(any(), any(), any())).thenReturn(keySerde); + when(queryBuilder.buildKeySerde(any(FormatInfo.class), any(), any())).thenReturn(keySerde); when(queryBuilder.buildValueSerde(any(), any(), any())).thenReturn(valueSerde); when(valueFormat.getFormatInfo()).thenReturn(valueFormatInfo); when(physicalSchemaFactory.apply(any(), any())).thenReturn(PHYSICAL_SCHEMA); diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamToTableBuilderTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamToTableBuilderTest.java index 5952dbb4c25b..0e0c754adb79 100644 --- a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamToTableBuilderTest.java +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamToTableBuilderTest.java @@ -143,7 +143,7 @@ private void givenUnwindowed() { Formats.of(keyFormat, valueFormat, SerdeOption.none()), new DefaultExecutionStepProperties(SCHEMA, queryContext) ); - when(ksqlQueryBuilder.buildKeySerde(any(), any(), any())).thenReturn(keySerde); + when(ksqlQueryBuilder.buildKeySerde(any(FormatInfo.class), any(), any())).thenReturn(keySerde); } @Test diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableGroupByBuilderTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableGroupByBuilderTest.java new file mode 100644 index 000000000000..d03b04b4b8b2 --- /dev/null +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableGroupByBuilderTest.java @@ -0,0 +1,213 @@ +package io.confluent.ksql.execution.streams; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.expression.tree.DereferenceExpression; +import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.QualifiedName; +import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; +import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.ExecutionStepProperties; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.TableGroupBy; +import io.confluent.ksql.execution.streams.TableGroupByBuilder.TableKeyValueMapper; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.query.QueryId; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.serde.Format; +import io.confluent.ksql.serde.FormatInfo; +import io.confluent.ksql.serde.KeyFormat; +import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; +import io.confluent.ksql.util.KsqlConfig; +import java.util.List; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedTable; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Predicate; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +public class TableGroupByBuilderTest { + private static final String ALIAS = "SOURCE"; + private static final LogicalSchema SCHEMA = LogicalSchema.builder() + .valueColumn("PAC", SqlTypes.BIGINT) + .valueColumn("MAN", SqlTypes.STRING) + .build() + .withAlias(ALIAS) + .withMetaAndKeyColsInValue(); + private static final PhysicalSchema PHYSICAL_SCHEMA = PhysicalSchema.from(SCHEMA, SerdeOption.none()); + + private final List groupByExpressions = ImmutableList.of( + dereference("PAC"), + dereference("MAN") + ); + private final QueryContext sourceContext = + new QueryContext.Stacker(new QueryId("qid")).push("foo").push("source").getQueryContext(); + private final QueryContext stepContext = + new QueryContext.Stacker(new QueryId("qid")).push("foo").push("groupby").getQueryContext(); + private final ExecutionStepProperties sourceProperties = new DefaultExecutionStepProperties( + SCHEMA, + sourceContext + ); + private final ExecutionStepProperties properties = new DefaultExecutionStepProperties( + SCHEMA, + stepContext + ); + private final Formats formats = Formats.of( + KeyFormat.nonWindowed(FormatInfo.of(Format.KAFKA)), + ValueFormat.of(FormatInfo.of(Format.JSON)), + SerdeOption.none() + ); + + @Mock + private KsqlQueryBuilder queryBuilder; + @Mock + private KsqlConfig ksqlConfig; + @Mock + private FunctionRegistry functionRegistry; + @Mock + private GroupedFactory groupedFactory; + @Mock + private ExecutionStep sourceStep; + @Mock + private KeySerde keySerde; + @Mock + private Serde valueSerde; + @Mock + private Grouped grouped; + @Mock + private KTable sourceTable; + @Mock + private KTable filteredTable; + @Mock + private KGroupedTable groupedTable; + @Captor + private ArgumentCaptor mapperCaptor; + @Captor + private ArgumentCaptor predicateCaptor; + + private TableGroupBy groupBy; + + @Rule + public final MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Before + @SuppressWarnings("unchecked") + public void init() { + when(queryBuilder.getKsqlConfig()).thenReturn(ksqlConfig); + when(queryBuilder.getFunctionRegistry()).thenReturn(functionRegistry); + when(queryBuilder.buildKeySerde(any(KeyFormat.class), any(), any())).thenReturn(keySerde); + when(queryBuilder.buildValueSerde(any(), any(), any())).thenReturn(valueSerde); + when(groupedFactory.create(any(), any(), any())).thenReturn(grouped); + when(sourceTable.filter(any())).thenReturn(filteredTable); + when(filteredTable.groupBy(any(KeyValueMapper.class), any(Grouped.class))) + .thenReturn(groupedTable); + when(sourceStep.getProperties()).thenReturn(sourceProperties); + groupBy = new TableGroupBy<>( + properties, + sourceStep, + formats, + groupByExpressions + ); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldPerformGroupByCorrectly() { + // When: + final KGroupedTable result = + TableGroupByBuilder.build(sourceTable, groupBy, queryBuilder, groupedFactory); + + // Then: + assertThat(result, is(groupedTable)); + verify(sourceTable).filter(any()); + verify(filteredTable).groupBy(mapperCaptor.capture(), same(grouped)); + verifyNoMoreInteractions(filteredTable, sourceTable); + final GroupByMapper mapper = mapperCaptor.getValue().getGroupByMapper(); + assertThat(mapper.getExpressionMetadata(), hasSize(2)); + assertThat( + mapper.getExpressionMetadata().get(0).getExpression(), + equalTo(groupByExpressions.get(0)) + ); + assertThat( + mapper.getExpressionMetadata().get(1).getExpression(), + equalTo(groupByExpressions.get(1)) + ); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldFilterNullRowsBeforeGroupBy() { + // When: + TableGroupByBuilder.build(sourceTable, groupBy, queryBuilder, groupedFactory); + + // Then: + verify(sourceTable).filter(predicateCaptor.capture()); + final Predicate predicate = predicateCaptor.getValue(); + assertThat(predicate.test(new Object(), new GenericRow()), is(true)); + assertThat(predicate.test(new Object(), null), is(false)); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldBuildGroupedCorrectlyForGroupBy() { + // When: + TableGroupByBuilder.build(sourceTable, groupBy, queryBuilder, groupedFactory); + + // Then: + verify(groupedFactory).create("foo-groupby", keySerde, valueSerde); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldBuildKeySerdeCorrectlyForGroupBy() { + // When: + TableGroupByBuilder.build(sourceTable, groupBy, queryBuilder, groupedFactory); + + // Then: + verify(queryBuilder).buildKeySerde(formats.getKeyFormat(), PHYSICAL_SCHEMA, stepContext); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldBuildValueSerdeCorrectlyForGroupBy() { + // When: + TableGroupByBuilder.build(sourceTable, groupBy, queryBuilder, groupedFactory); + + // Then: + verify(queryBuilder).buildValueSerde( + formats.getValueFormat().getFormatInfo(), + PHYSICAL_SCHEMA, + stepContext + ); + } + + private static Expression dereference(final String column) { + return new DereferenceExpression(new QualifiedNameReference(QualifiedName.of(ALIAS)), column); + } +} \ No newline at end of file