Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support subscript and nested functions in grouping queries #5998

Merged
merged 7 commits into from
Aug 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,13 @@ private static final class AggAnalyzer {
private final FunctionRegistry functionRegistry;
private final Set<Expression> groupBy;

// The list of columns from the source schema that are used in aggregate columns, but not as
// parameters to the aggregate functions and which are not part of the GROUP BY clause:
private final List<ColumnReferenceExp> aggSelectsNotPartOfGroupBy = new ArrayList<>();

// The list of non-aggregate select expression which are not part of the GROUP BY clause:
// The list of expressions that appear in the SELECT clause outside of aggregate functions.
// Used for throwing an error if these columns are not part of the GROUP BY clause.
private final List<Expression> nonAggSelectsNotPartOfGroupBy = new ArrayList<>();

// The list of columns from the source schema that are used in the HAVING clause outside
// of aggregate functions which are not part of the GROUP BY clause:
private final List<ColumnReferenceExp> nonAggHavingNotPartOfGroupBy = new ArrayList<>();
private final List<Expression> nonAggHavingNotPartOfGroupBy = new ArrayList<>();

AggAnalyzer(
final ImmutableAnalysis analysis,
Expand Down Expand Up @@ -107,69 +104,78 @@ public void process(final List<SelectExpression> finalProjection) {
}

private void processSelect(final Expression expression) {
final Set<ColumnReferenceExp> nonAggParams = new HashSet<>();
final AggregateVisitor visitor = new AggregateVisitor(this, (aggFuncName, node) -> {
if (aggFuncName.isPresent()) {
throwOnWindowBoundColumnIfWindowedAggregate(node);
} else {
nonAggParams.add(node);
}
});
final Set<Expression> nonAggParams = new HashSet<>();
final AggregateVisitor visitor = new AggregateVisitor(
this,
groupBy,
(aggFuncName, node) -> {
if (aggFuncName.isPresent()) {
throwOnWindowBoundColumnIfWindowedAggregate(node);
} else {
if (!groupBy.contains(node)) {
nonAggParams.add(node);
}
}
});

visitor.process(expression, null);

if (visitor.visitedAggFunction) {
captureAggregateSelectNotPartOfGroupBy(nonAggParams);
} else {
captureNonAggregateSelectNotPartOfGroupBy(expression, nonAggParams);
}

captureNonAggregateSelectNotPartOfGroupBy(expression, nonAggParams);
aggregateAnalysis.addFinalSelectExpression(expression);
}

private void processGroupBy(final Expression expression) {
final AggregateVisitor visitor = new AggregateVisitor(this, (aggFuncName, node) -> {
if (aggFuncName.isPresent()) {
throw new KsqlException("GROUP BY does not support aggregate functions: "
+ aggFuncName.get().text() + " is an aggregate function.");
}
throwOnWindowBoundColumnIfWindowedAggregate(node);
});
final AggregateVisitor visitor = new AggregateVisitor(
this,
groupBy,
(aggFuncName, node) -> {
if (aggFuncName.isPresent()) {
throw new KsqlException("GROUP BY does not support aggregate functions: "
+ aggFuncName.get().text() + " is an aggregate function.");
}
throwOnWindowBoundColumnIfWindowedAggregate(node);
});

visitor.process(expression, null);
}

private void processWhere(final Expression expression) {
final AggregateVisitor visitor = new AggregateVisitor(this, (aggFuncName, node) ->
throwOnWindowBoundColumnIfWindowedAggregate(node));
final AggregateVisitor visitor = new AggregateVisitor(
this,
groupBy,
(aggFuncName, node) ->
throwOnWindowBoundColumnIfWindowedAggregate(node));

visitor.process(expression, null);
}

private void processHaving(final Expression expression) {
final AggregateVisitor visitor = new AggregateVisitor(this, (aggFuncName, node) -> {
throwOnWindowBoundColumnIfWindowedAggregate(node);
final AggregateVisitor visitor = new AggregateVisitor(
this,
groupBy,
(aggFuncName, node) -> {
throwOnWindowBoundColumnIfWindowedAggregate(node);

if (!aggFuncName.isPresent()) {
captureNoneAggregateHavingNotPartOfGroupBy(node);
}
});
if (!aggFuncName.isPresent()) {
captureNonAggregateHavingNotPartOfGroupBy(node);
}
});

visitor.process(expression, null);

aggregateAnalysis.setHavingExpression(expression);
}

private void throwOnWindowBoundColumnIfWindowedAggregate(
final ColumnReferenceExp node
) {
private void throwOnWindowBoundColumnIfWindowedAggregate(final Expression node) {
// Window bounds are supported for operations on windowed sources
if (!analysis.getWindowExpression().isPresent()) {
return;
}

if (!(node instanceof ColumnReferenceExp)) {
return;
}
vpapavas marked this conversation as resolved.
Show resolved Hide resolved
// For non-windowed sources, with a windowed GROUP BY, they are only supported in selects:
if (SystemColumns.isWindowBound(node.getColumnName())) {
if (SystemColumns.isWindowBound(((ColumnReferenceExp)node).getColumnName())) {
throw new KsqlException(
"Window bounds column " + node + " can only be used in the SELECT clause of "
+ "windowed aggregations and can not be passed to aggregate functions."
Expand Down Expand Up @@ -204,7 +210,7 @@ private static Set<Expression> getGroupByExpressions(

private void captureNonAggregateSelectNotPartOfGroupBy(
final Expression expression,
final Set<ColumnReferenceExp> nonAggParams
final Set<Expression> nonAggParams
) {
final boolean matchesGroupBy = groupBy.contains(expression);
if (matchesGroupBy) {
Expand All @@ -220,16 +226,8 @@ private void captureNonAggregateSelectNotPartOfGroupBy(
nonAggSelectsNotPartOfGroupBy.add(expression);
}

private void captureAggregateSelectNotPartOfGroupBy(
final Set<ColumnReferenceExp> nonAggParams
) {
nonAggParams.stream()
.filter(param -> !groupBy.contains(param))
.forEach(aggSelectsNotPartOfGroupBy::add);
}

private void captureNoneAggregateHavingNotPartOfGroupBy(final ColumnReferenceExp nonAggColumn) {
if (groupBy.contains(new UnqualifiedColumnReferenceExp(nonAggColumn.getColumnName()))) {
private void captureNonAggregateHavingNotPartOfGroupBy(final Expression nonAggColumn) {
if (groupBy.contains(nonAggColumn)) {
return;
}

Expand All @@ -256,21 +254,7 @@ private void enforceAggregateRules() {
"Non-aggregate SELECT expression(s) not part of GROUP BY: "
+ unmatchedSelects
+ System.lineSeparator()
+ "Either add the column to the GROUP BY or remove it from the SELECT."
);
}

final String unmatchedSelectsAgg = aggSelectsNotPartOfGroupBy.stream()
.map(Objects::toString)
.collect(Collectors.joining(", "));

if (!unmatchedSelectsAgg.isEmpty()) {
throw new KsqlException(
"Column used in aggregate SELECT expression(s) outside of aggregate functions "
+ "not part of GROUP BY: "
+ unmatchedSelectsAgg
+ System.lineSeparator()
+ "Either add the column to the GROUP BY or remove it from the SELECT."
+ "Either add the column(s) to the GROUP BY or remove them from the SELECT."
);
}

Expand All @@ -288,26 +272,60 @@ private void enforceAggregateRules() {
}
}

/**
* This visitor performs two tasks: Create the input schema to the AggregateNode and validations.
*
* <p>For creating the input schema, it checks if any expression along the path from the root
* expression to the leaf (UnqualifiedColumnReference) is part of the groupBy. If at least one is,
* then the UnqualifiedColumnReference is added to the schema.
*
* <p>For validation, the visitor checks that:
* <ol>
* <li> expressions not in aggregate functions are part of the grouping clause </li>
* <li> aggregate functions are not nested </li>
* <li> window clauses (windowstart, windowend) don't appear in aggregate functions or
* groupBy </li>
* <li> aggregate functions don't appear in the groupBy clause </li>
* <li> expressions in the having clause are either aggregate functions or grouping keys </li>
* </ol>
*/
private static final class AggregateVisitor extends TraversalExpressionVisitor<Void> {

private final BiConsumer<Optional<FunctionName>, ColumnReferenceExp> dereferenceCollector;
private final BiConsumer<Optional<FunctionName>, Expression> dereferenceCollector;
private final ColumnReferenceExp defaultArgument;
private final MutableAggregateAnalysis aggregateAnalysis;
private final FunctionRegistry functionRegistry;
private final Set<Expression> groupBy;
private Expression currentlyInExpressionPartOfGroupBy;

private Optional<FunctionName> aggFunctionName = Optional.empty();
private boolean visitedAggFunction = false;
private boolean currentlyInAggregateFunction = false;

private AggregateVisitor(
final AggAnalyzer aggAnalyzer,
final BiConsumer<Optional<FunctionName>, ColumnReferenceExp> dereferenceCollector
final Set<Expression> groupBy,
final BiConsumer<Optional<FunctionName>, Expression> dereferenceCollector
) {
this.defaultArgument = aggAnalyzer.analysis.getDefaultArgument();
this.aggregateAnalysis = aggAnalyzer.aggregateAnalysis;
this.functionRegistry = aggAnalyzer.functionRegistry;
this.groupBy = groupBy;
this.dereferenceCollector = requireNonNull(dereferenceCollector, "dereferenceCollector");
}

@Override
public Void process(final Expression node, final Void context) {
if (groupBy.contains(node) && currentlyInExpressionPartOfGroupBy == null) {
currentlyInExpressionPartOfGroupBy = node;
}
super.process(node, context);
if (currentlyInExpressionPartOfGroupBy != null
&& currentlyInExpressionPartOfGroupBy == node) {
currentlyInExpressionPartOfGroupBy = null;
}
return null;
}

@Override
public Void visitFunctionCall(final FunctionCall node, final Void context) {
final FunctionName functionName = node.getName();
Expand All @@ -323,7 +341,7 @@ public Void visitFunctionCall(final FunctionCall node, final Void context) {
+ aggFunctionName.get().text() + "(" + functionName.text() + "())");
}

visitedAggFunction = true;
currentlyInAggregateFunction = true;
aggFunctionName = Optional.of(functionName);

functionCall.getArguments().forEach(aggregateAnalysis::addAggregateFunctionArgument);
Expand All @@ -344,9 +362,13 @@ public Void visitUnqualifiedColumnReference(
final UnqualifiedColumnReferenceExp node,
final Void context
) {
dereferenceCollector.accept(aggFunctionName, node);

if (currentlyInExpressionPartOfGroupBy == null
|| currentlyInAggregateFunction
|| SystemColumns.isWindowBound(node.getColumnName())) {
dereferenceCollector.accept(aggFunctionName, node);
}
if (!SystemColumns.isWindowBound(node.getColumnName())) {
// Used to infer the required columns in the INPUT schema of the aggregate node
aggregateAnalysis.addRequiredColumn(node);
}
return null;
Expand Down
Loading