Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partition by prep #4781

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2018 Confluent Inc.
* Copyright 2020 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"); you may not use
* this file except in compliance with the License. You may obtain a copy of the
Expand All @@ -15,47 +15,74 @@

package io.confluent.ksql.analyzer;

import static java.util.Objects.requireNonNull;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.google.common.collect.Sets.SetView;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.TraversalExpressionVisitor;
import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.HashSet;
import java.util.Objects;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;

class AggregateAnalyzer {
public class AggregateAnalyzer {

private final MutableAggregateAnalysis aggregateAnalysis;
private final QualifiedColumnReferenceExp defaultArgument;
private final FunctionRegistry functionRegistry;
private final boolean hasWindowExpression;

AggregateAnalyzer(
final MutableAggregateAnalysis aggregateAnalysis,
final QualifiedColumnReferenceExp defaultArgument,
final boolean hasWindowExpression,
public AggregateAnalyzer(
final FunctionRegistry functionRegistry
) {
this.aggregateAnalysis = Objects.requireNonNull(aggregateAnalysis, "aggregateAnalysis");
this.defaultArgument = Objects.requireNonNull(defaultArgument, "defaultArgument");
this.functionRegistry = Objects.requireNonNull(functionRegistry, "functionRegistry");
this.hasWindowExpression = hasWindowExpression;
this.functionRegistry = requireNonNull(functionRegistry, "functionRegistry");
}

void processSelect(final Expression expression) {
public AggregateAnalysisResult analyze(
final ImmutableAnalysis analysis,
final List<SelectExpression> finalProjection
) {
if (analysis.getGroupByExpressions().isEmpty()) {
throw new IllegalArgumentException("Not an aggregate query");
}

final Context context = new Context(analysis);

finalProjection.stream()
.map(SelectExpression::getExpression)
.forEach(exp -> processSelect(exp, context));

analysis.getWhereExpression()
.ifPresent(exp -> processWhere(exp, context));

analysis.getGroupByExpressions()
.forEach(exp -> processGroupBy(exp, context));

analysis.getHavingExpression()
.ifPresent(exp -> processHaving(exp, context));

enforceAggregateRules(context);

return context.aggregateAnalysis;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically moved from QueryAnalyzer.


private void processSelect(final Expression expression, final Context context) {
final Set<ColumnReferenceExp> nonAggParams = new HashSet<>();
final AggregateVisitor visitor = new AggregateVisitor((aggFuncName, node) -> {
final AggregateVisitor visitor = new AggregateVisitor(context, (aggFuncName, node) -> {
if (aggFuncName.isPresent()) {
throwOnWindowBoundColumnIfWindowedAggregate(node);
throwOnWindowBoundColumnIfWindowedAggregate(node, context);
} else {
nonAggParams.add(node);
}
Expand All @@ -64,45 +91,51 @@ void processSelect(final Expression expression) {
visitor.process(expression, null);

if (visitor.visitedAggFunction) {
aggregateAnalysis.addAggregateSelectField(nonAggParams);
context.aggregateAnalysis.addAggregateSelectField(nonAggParams);
} else {
aggregateAnalysis.addNonAggregateSelectExpression(expression, nonAggParams);
context.aggregateAnalysis.addNonAggregateSelectExpression(expression, nonAggParams);
}

context.aggregateAnalysis.addFinalSelectExpression(expression);
}

void processGroupBy(final Expression expression) {
final AggregateVisitor visitor = new AggregateVisitor((aggFuncName, node) -> {
private void processGroupBy(final Expression expression, final Context context) {
final AggregateVisitor visitor = new AggregateVisitor(context, (aggFuncName, node) -> {
if (aggFuncName.isPresent()) {
throw new KsqlException("GROUP BY does not support aggregate functions: "
+ aggFuncName.get().text() + " is an aggregate function.");
}
throwOnWindowBoundColumnIfWindowedAggregate(node);
throwOnWindowBoundColumnIfWindowedAggregate(node, context);
});

visitor.process(expression, null);
}

void processWhere(final Expression expression) {
final AggregateVisitor visitor = new AggregateVisitor((aggFuncName, node) -> {
throwOnWindowBoundColumnIfWindowedAggregate(node);
});
private void processWhere(final Expression expression, final Context context) {
final AggregateVisitor visitor = new AggregateVisitor(context, (aggFuncName, node) ->
throwOnWindowBoundColumnIfWindowedAggregate(node, context));

visitor.process(expression, null);
}

void processHaving(final Expression expression) {
final AggregateVisitor visitor = new AggregateVisitor((aggFuncName, node) -> {
throwOnWindowBoundColumnIfWindowedAggregate(node);
private void processHaving(final Expression expression, final Context context) {
final AggregateVisitor visitor = new AggregateVisitor(context, (aggFuncName, node) -> {
throwOnWindowBoundColumnIfWindowedAggregate(node, context);
if (!aggFuncName.isPresent()) {
aggregateAnalysis.addNonAggregateHavingField(node);
context.aggregateAnalysis.addNonAggregateHavingField(node);
}
});
visitor.process(expression, null);

context.aggregateAnalysis.setHavingExpression(expression);
}

private void throwOnWindowBoundColumnIfWindowedAggregate(final ColumnReferenceExp node) {
private static void throwOnWindowBoundColumnIfWindowedAggregate(
final ColumnReferenceExp node,
final Context context
) {
// Window bounds are supported for operations on windowed sources
if (!hasWindowExpression) {
if (!context.analysis.getWindowExpression().isPresent()) {
return;
}

Expand All @@ -117,18 +150,91 @@ private void throwOnWindowBoundColumnIfWindowedAggregate(final ColumnReferenceEx
}
}

private static void enforceAggregateRules(
final Context context
) {
if (context.aggregateAnalysis.getAggregateFunctions().isEmpty()) {
throw new KsqlException(
"GROUP BY requires columns using aggregate functions in SELECT clause.");
}

final Set<Expression> groupByExprs = getGroupByExpressions(context.analysis);

final List<String> unmatchedSelects = context.aggregateAnalysis
.getNonAggregateSelectExpressions()
.entrySet()
.stream()
// Remove any that exactly match a group by expression:
.filter(e -> !groupByExprs.contains(e.getKey()))
// Remove any that are constants,
// or expressions where all params exactly match a group by expression:
.filter(e -> !Sets.difference(e.getValue(), groupByExprs).isEmpty())
.map(Map.Entry::getKey)
.map(Expression::toString)
.sorted()
.collect(Collectors.toList());

if (!unmatchedSelects.isEmpty()) {
throw new KsqlException(
"Non-aggregate SELECT expression(s) not part of GROUP BY: " + unmatchedSelects);
}

final SetView<ColumnReferenceExp> unmatchedSelectsAgg = Sets
.difference(context.aggregateAnalysis.getAggregateSelectFields(), groupByExprs);
if (!unmatchedSelectsAgg.isEmpty()) {
throw new KsqlException(
"Column used in aggregate SELECT expression(s) "
+ "outside of aggregate functions not part of GROUP BY: " + unmatchedSelectsAgg);
}

final Set<ColumnReferenceExp> havingColumns = context.aggregateAnalysis
.getNonAggregateHavingFields().stream()
.map(ref -> new UnqualifiedColumnReferenceExp(ref.getColumnName()))
.collect(Collectors.toSet());

final Set<ColumnReferenceExp> havingOnly = Sets.difference(havingColumns, groupByExprs);
if (!havingOnly.isEmpty()) {
throw new KsqlException(
"Non-aggregate HAVING expression not part of GROUP BY: " + havingOnly);
}
}
Comment on lines +151 to +198
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved from QueryAnalyzer


private static Set<Expression> getGroupByExpressions(
final ImmutableAnalysis analysis
) {
if (!analysis.getWindowExpression().isPresent()) {
return ImmutableSet.copyOf(analysis.getGroupByExpressions());
}

// Add in window bounds columns as implicit group by columns:
final Set<UnqualifiedColumnReferenceExp> windowBoundColumnRefs =
SchemaUtil.windowBoundsColumnNames().stream()
.map(UnqualifiedColumnReferenceExp::new)
.collect(Collectors.toSet());

return ImmutableSet.<Expression>builder()
.addAll(analysis.getGroupByExpressions())
.addAll(windowBoundColumnRefs)
.build();
}
Comment on lines +200 to +217
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved from QueryAnalyzer


private final class AggregateVisitor extends TraversalExpressionVisitor<Void> {

private final BiConsumer<Optional<FunctionName>, ColumnReferenceExp>
dereferenceCollector;
private final ColumnReferenceExp defaultArgument;
private final MutableAggregateAnalysis aggregateAnalysis;

private Optional<FunctionName> aggFunctionName = Optional.empty();
private boolean visitedAggFunction = false;

private AggregateVisitor(
final Context context,
final BiConsumer<Optional<FunctionName>, ColumnReferenceExp> dereferenceCollector
) {
this.dereferenceCollector =
Objects.requireNonNull(dereferenceCollector, "dereferenceCollector");
this.defaultArgument = context.analysis.getDefaultArgument();
this.aggregateAnalysis = context.aggregateAnalysis;
this.dereferenceCollector = requireNonNull(dereferenceCollector, "dereferenceCollector");
}

@Override
Expand Down Expand Up @@ -180,12 +286,17 @@ public Void visitQualifiedColumnReference(
final QualifiedColumnReferenceExp node,
final Void context
) {
dereferenceCollector.accept(aggFunctionName, node);
throw new UnsupportedOperationException("Should of been converted to unqualified");
}
}

if (!SchemaUtil.isWindowBound(node.getColumnName())) {
aggregateAnalysis.addRequiredColumn(node);
}
return null;
private static final class Context {

final ImmutableAnalysis analysis;
final MutableAggregateAnalysis aggregateAnalysis = new MutableAggregateAnalysis();

Context(final ImmutableAnalysis analysis) {
this.analysis = requireNonNull(analysis, "analysis");
}
}
}
Loading