Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: group-by primitive key support #4108

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.LogicalSchema.Builder;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
Expand Down Expand Up @@ -147,24 +148,26 @@ private Optional<TimestampColumn> getTimestampColumn(
}

private AggregateNode buildAggregateNode(final PlanNode sourcePlanNode) {
final Expression groupBy = analysis.getGroupByExpressions().size() == 1
? analysis.getGroupByExpressions().get(0)
: null;
final List<Expression> groupByExps = analysis.getGroupByExpressions();

final LogicalSchema schema = buildProjectionSchema(sourcePlanNode);
final LogicalSchema schema = buildAggregateSchema(sourcePlanNode, groupByExps);

final Expression groupBy = groupByExps.size() == 1
? groupByExps.get(0)
: null;

final Optional<ColumnName> keyFieldName = getSelectAliasMatching((expression, alias) ->
expression.equals(groupBy)
&& !SchemaUtil.isFieldName(alias.name(), SchemaUtil.ROWTIME_NAME.name())
&& !SchemaUtil.isFieldName(alias.name(), SchemaUtil.ROWKEY_NAME.name()),
expression.equals(groupBy)
&& !SchemaUtil.isFieldName(alias.name(), SchemaUtil.ROWTIME_NAME.name())
&& !SchemaUtil.isFieldName(alias.name(), SchemaUtil.ROWKEY_NAME.name()),
sourcePlanNode);

return new AggregateNode(
new PlanNodeId("Aggregate"),
sourcePlanNode,
schema,
keyFieldName.map(ColumnRef::withoutSource),
analysis.getGroupByExpressions(),
groupByExps,
analysis.getWindowExpression(),
aggregateAnalysis.getAggregateFunctionArguments(),
aggregateAnalysis.getAggregateFunctions(),
Expand Down Expand Up @@ -344,6 +347,28 @@ private LogicalSchema buildProjectionSchema(final PlanNode sourcePlanNode) {
return builder.build();
}

private LogicalSchema buildAggregateSchema(
final PlanNode sourcePlanNode,
final List<Expression> groupByExps
) {
final SqlType keyType;
if (groupByExps.size() != 1) {
keyType = SqlTypes.STRING;
} else {
final ExpressionTypeManager typeManager =
new ExpressionTypeManager(sourcePlanNode.getSchema(), functionRegistry);

keyType = typeManager.getExpressionSqlType(groupByExps.get(0));
}

final LogicalSchema sourceSchema = buildProjectionSchema(sourcePlanNode);

return LogicalSchema.builder()
.keyColumn(SchemaUtil.ROWKEY_NAME, keyType)
.valueColumns(sourceSchema.value())
.build();
}

private LogicalSchema buildRepartitionedSchema(
final PlanNode sourceNode,
final Expression partitionBy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import io.confluent.ksql.parser.tree.WindowExpression;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.SqlBaseType;
import io.confluent.ksql.serde.ValueFormat;
import io.confluent.ksql.services.KafkaTopicClient;
import io.confluent.ksql.structured.SchemaKGroupedStream;
Expand Down Expand Up @@ -109,10 +108,6 @@ public AggregateNode(
this.havingExpressions = havingExpressions;
this.keyField = KeyField.of(requireNonNull(keyFieldName, "keyFieldName"))
.validateKeyExistsIn(schema);

if (schema.key().get(0).type().baseType() != SqlBaseType.STRING) {
throw new KsqlException("GROUP BY is not supported with non-STRING keys");
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright 2019 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"); you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.function.udf;

@UdfDescription(name="bad_udf", description = "throws exceptions when called")
@SuppressWarnings("unused")
public class BadUdf {

@Udf(description = "throws")
public String blowUp(final int arg1) {
throw new RuntimeException("boom!");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.confluent.ksql.planner.plan.PlanNode;
import io.confluent.ksql.planner.plan.ProjectNode;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.testutils.AnalysisTestUtil;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.MetaStoreFixture;
Expand All @@ -45,7 +46,8 @@

public class SelectValueMapperIntegrationTest {

private static final Struct NON_WINDOWED_KEY = StructKeyUtil.asStructKey("someKey");
private static final Struct NON_WINDOWED_KEY = StructKeyUtil.keyBuilder(SqlTypes.STRING)
.build("someKey");

private final MetaStore metaStore = MetaStoreFixture
.getNewMetaStore(TestFunctionRegistry.INSTANCE.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,14 @@ public void shouldCreateExecutionPlan() {
final String[] lines = planText.split("\n");

assertThat(lines[0], startsWith(
" > [ PROJECT ] | Schema: [ROWKEY STRING KEY, COL0 BIGINT, KSQL_COL_1 DOUBLE, "
" > [ PROJECT ] | Schema: [ROWKEY BIGINT KEY, COL0 BIGINT, KSQL_COL_1 DOUBLE, "
+ "KSQL_COL_2 BIGINT] |"));
assertThat(lines[1], startsWith(
"\t\t > [ AGGREGATE ] | Schema: [ROWKEY STRING KEY, KSQL_INTERNAL_COL_0 BIGINT, "
"\t\t > [ AGGREGATE ] | Schema: [ROWKEY BIGINT KEY, KSQL_INTERNAL_COL_0 BIGINT, "
+ "KSQL_INTERNAL_COL_1 DOUBLE, KSQL_AGG_VARIABLE_0 DOUBLE, "
+ "KSQL_AGG_VARIABLE_1 BIGINT] |"));
assertThat(lines[2], startsWith(
"\t\t\t\t > [ GROUP_BY ] | Schema: [ROWKEY STRING KEY, KSQL_INTERNAL_COL_0 BIGINT, "
"\t\t\t\t > [ GROUP_BY ] | Schema: [ROWKEY BIGINT KEY, KSQL_INTERNAL_COL_0 BIGINT, "
+ "KSQL_INTERNAL_COL_1 DOUBLE] |"
));
assertThat(lines[3], startsWith(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.execution.context.QueryContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import io.confluent.ksql.execution.streams.StreamsFactories;
import io.confluent.ksql.execution.streams.StreamsUtil;
import io.confluent.ksql.execution.util.StructKeyUtil;
import io.confluent.ksql.execution.util.StructKeyUtil.KeyBuilder;
import io.confluent.ksql.function.InternalFunctionRegistry;
import io.confluent.ksql.logging.processing.ProcessingLogContext;
import io.confluent.ksql.metastore.MetaStore;
Expand Down Expand Up @@ -121,6 +122,8 @@
@RunWith(MockitoJUnitRunner.class)
public class SchemaKTableTest {

private static final KeyBuilder STRING_KEY_BUILDER = StructKeyUtil.keyBuilder(SqlTypes.STRING);

private final KsqlConfig ksqlConfig = new KsqlConfig(Collections.emptyMap());
private final MetaStore metaStore = MetaStoreFixture.getNewMetaStore(new InternalFunctionRegistry());
private final GroupedFactory groupedFactory = mock(GroupedFactory.class);
Expand Down Expand Up @@ -584,7 +587,7 @@ public void shouldGroupKeysCorrectly() {
(KeyValue<String, GenericRow>) keySelector.apply("key", value);

// Validate that the captured mapper produces the correct key
assertThat(keyValue.key, equalTo(StructKeyUtil.asStructKey("bar|+|foo")));
assertThat(keyValue.key, equalTo(STRING_KEY_BUILDER.build("bar|+|foo")));
assertThat(keyValue.value, equalTo(value));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,17 @@ public final class StructKeyUtil {
private StructKeyUtil() {
}

public static Struct asStructKey(String rowKey) {
Struct keyStruct = new Struct(ROWKEY_STRUCT_SCHEMA);
keyStruct.put(ROWKEY_FIELD, rowKey);
return keyStruct;
}

public static KeyBuilder keySchema(final LogicalSchema schema) {
public static KeyBuilder keyBuilder(final LogicalSchema schema) {
final List<Column> keyCols = schema.key();
if (keyCols.size() != 1) {
throw new UnsupportedOperationException("Only single keys supported");
}

final SqlType sqlType = keyCols.get(0).type();
return keyBuilder(sqlType);
}

public static KeyBuilder keyBuilder(final SqlType sqlType) {
final Schema connectSchema = SchemaConverters.sqlToConnectConverter().toConnectSchema(sqlType);

return new KeyBuilder(SchemaBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ public class StructKeyUtilTest {

@Before
public void setUp() {
builder = StructKeyUtil.keySchema(LOGICAL_SCHEMA);
builder = StructKeyUtil.keyBuilder(LOGICAL_SCHEMA);
}

@Test(expected = UnsupportedOperationException.class)
public void shouldThrowOnMultipleKeyColumns() {
// Only single key columns initially supported
StructKeyUtil.keySchema(LogicalSchema.builder()
StructKeyUtil.keyBuilder(LogicalSchema.builder()
.keyColumn(ColumnName.of("BOB"), SqlTypes.STRING)
.keyColumn(ColumnName.of("JOHN"), SqlTypes.STRING)
.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.SchemaConverters;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.test.TestFrameworkException;
import io.confluent.ksql.test.serde.SerdeSupplier;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -105,17 +106,24 @@ public byte[] serialize(final String topic, final Object value) {
private final class RowDeserializer implements Deserializer<Object> {

private Deserializer<Object> delegate;
private String type;

@Override
public void configure(final Map<String, ?> configs, final boolean isKey) {
this.type = isKey ? "key" : "value";
final SqlType sqlType = getColumnType(isKey);
delegate = getSerde(sqlType).deserializer();
delegate.configure(configs, isKey);
}

@Override
public Object deserialize(final String topic, final byte[] bytes) {
return delegate.deserialize(topic, bytes);
try {
return delegate.deserialize(topic, bytes);
} catch (final Exception e) {
throw new TestFrameworkException("Failed to deserialize " + type + ". "
+ e.getMessage(), e);
}
}
}
}
Loading