Skip to content

Commit

Permalink
chore: simplify group by schema resolving code (#4154)
Browse files Browse the repository at this point in the history
  • Loading branch information
big-andy-coates authored Dec 18, 2019
1 parent e92d2f3 commit 56ac607
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ public static GroupByParams build(
return new GroupByParams(schema, mapper);
}

static LogicalSchema multiExpressionSchema(
private static LogicalSchema multiExpressionSchema(
final LogicalSchema sourceSchema
) {
return buildSchema(sourceSchema, SqlTypes.STRING);
}

static LogicalSchema singleExpressionSchema(
private static LogicalSchema singleExpressionSchema(
final LogicalSchema sourceSchema,
final SqlType rowKeyType
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

package io.confluent.ksql.execution.streams;

import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.codegen.CodeGenRunner;
import io.confluent.ksql.execution.codegen.ExpressionMetadata;
import io.confluent.ksql.execution.plan.AbstractStreamSource;
import io.confluent.ksql.execution.plan.ExecutionStep;
import io.confluent.ksql.execution.plan.StreamAggregate;
Expand Down Expand Up @@ -54,6 +55,7 @@
/**
* Computes the schema produced by an execution step, given the schema(s) going into the step.
*/
@SuppressWarnings("MethodMayBeStatic") // Methods can not be used in HANDLERS is static.
public final class StepSchemaResolver {
private static final HandlerMaps.ClassHandlerMapR2
<ExecutionStep, StepSchemaResolver, LogicalSchema, LogicalSchema> HANDLERS
Expand Down Expand Up @@ -166,16 +168,15 @@ private LogicalSchema handleGroupBy(
final LogicalSchema sourceSchema,
final StreamGroupBy<?> streamGroupBy
) {
final List<Expression> groupBy = streamGroupBy.getGroupByExpressions();

if (groupBy.size() != 1) {
return GroupByParamsFactory.multiExpressionSchema(sourceSchema);
}

final SqlType rowKeyType = new ExpressionTypeManager(sourceSchema, functionRegistry)
.getExpressionSqlType(groupBy.get(0));
final List<ExpressionMetadata> compiledGroupBy = CodeGenRunner.compileExpressions(
streamGroupBy.getGroupByExpressions().stream(),
"Group By",
sourceSchema,
ksqlConfig,
functionRegistry
);

return GroupByParamsFactory.singleExpressionSchema(sourceSchema, rowKeyType);
return GroupByParamsFactory.build(sourceSchema, compiledGroupBy).getSchema();
}

private LogicalSchema handleStreamSelect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,17 @@

@RunWith(MockitoJUnitRunner.class)
public class StepSchemaResolverTest {

private static final KsqlConfig CONFIG = new KsqlConfig(Collections.emptyMap());

private static final LogicalSchema SCHEMA = LogicalSchema.builder()
.valueColumn(ColumnName.of("ORANGE"), SqlTypes.INTEGER)
.valueColumn(ColumnName.of("APPLE"), SqlTypes.BIGINT)
.valueColumn(ColumnName.of("BANANA"), SqlTypes.STRING)
.build();

private static final ColumnRef ORANGE_COL_REF = ColumnRef.withoutSource(ColumnName.of("ORANGE"));

private static final ExecutionStepPropertiesV1 PROPERTIES = new ExecutionStepPropertiesV1(
new QueryContext.Stacker().getQueryContext()
);
Expand Down Expand Up @@ -233,14 +238,17 @@ public void shouldResolveSchemaForStreamGroupBy() {
PROPERTIES,
streamSource,
formats,
Collections.emptyList()
ImmutableList.of(new ColumnReferenceExp(Optional.empty(), ORANGE_COL_REF))
);

// When:
final LogicalSchema result = resolver.resolve(step, SCHEMA);

// Then:
assertThat(result, is(SCHEMA));
assertThat(result, is(LogicalSchema.builder()
.keyColumn(SchemaUtil.ROWKEY_NAME, SqlTypes.INTEGER)
.valueColumns(SCHEMA.value())
.build()));
}

@Test
Expand Down

0 comments on commit 56ac607

Please sign in to comment.