Skip to content

Commit

Permalink
feat: Support subscript and nested functions in grouping queries (#5998)
Browse files Browse the repository at this point in the history
* fixed handling subscript and nested functions

* remove intermediate topics from test file

* addressed comments, handle struct and arithmetic, added plans

* fixed window bounds error

* adding historic plans

* addressed Almog's comments

* minor fixes
  • Loading branch information
vpapavas authored Aug 21, 2020
1 parent 1b58c9b commit 8d383db
Show file tree
Hide file tree
Showing 29 changed files with 3,694 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,13 @@ private static final class AggAnalyzer {
private final FunctionRegistry functionRegistry;
private final Set<Expression> groupBy;

// The list of columns from the source schema that are used in aggregate columns, but not as
// parameters to the aggregate functions and which are not part of the GROUP BY clause:
private final List<ColumnReferenceExp> aggSelectsNotPartOfGroupBy = new ArrayList<>();

// The list of non-aggregate select expression which are not part of the GROUP BY clause:
// The list of expressions that appear in the SELECT clause outside of aggregate functions.
// Used for throwing an error if these columns are not part of the GROUP BY clause.
private final List<Expression> nonAggSelectsNotPartOfGroupBy = new ArrayList<>();

// The list of columns from the source schema that are used in the HAVING clause outside
// of aggregate functions which are not part of the GROUP BY clause:
private final List<ColumnReferenceExp> nonAggHavingNotPartOfGroupBy = new ArrayList<>();
private final List<Expression> nonAggHavingNotPartOfGroupBy = new ArrayList<>();

AggAnalyzer(
final ImmutableAnalysis analysis,
Expand Down Expand Up @@ -107,69 +104,78 @@ public void process(final List<SelectExpression> finalProjection) {
}

private void processSelect(final Expression expression) {
final Set<ColumnReferenceExp> nonAggParams = new HashSet<>();
final AggregateVisitor visitor = new AggregateVisitor(this, (aggFuncName, node) -> {
if (aggFuncName.isPresent()) {
throwOnWindowBoundColumnIfWindowedAggregate(node);
} else {
nonAggParams.add(node);
}
});
final Set<Expression> nonAggParams = new HashSet<>();
final AggregateVisitor visitor = new AggregateVisitor(
this,
groupBy,
(aggFuncName, node) -> {
if (aggFuncName.isPresent()) {
throwOnWindowBoundColumnIfWindowedAggregate(node);
} else {
if (!groupBy.contains(node)) {
nonAggParams.add(node);
}
}
});

visitor.process(expression, null);

if (visitor.visitedAggFunction) {
captureAggregateSelectNotPartOfGroupBy(nonAggParams);
} else {
captureNonAggregateSelectNotPartOfGroupBy(expression, nonAggParams);
}

captureNonAggregateSelectNotPartOfGroupBy(expression, nonAggParams);
aggregateAnalysis.addFinalSelectExpression(expression);
}

private void processGroupBy(final Expression expression) {
final AggregateVisitor visitor = new AggregateVisitor(this, (aggFuncName, node) -> {
if (aggFuncName.isPresent()) {
throw new KsqlException("GROUP BY does not support aggregate functions: "
+ aggFuncName.get().text() + " is an aggregate function.");
}
throwOnWindowBoundColumnIfWindowedAggregate(node);
});
final AggregateVisitor visitor = new AggregateVisitor(
this,
groupBy,
(aggFuncName, node) -> {
if (aggFuncName.isPresent()) {
throw new KsqlException("GROUP BY does not support aggregate functions: "
+ aggFuncName.get().text() + " is an aggregate function.");
}
throwOnWindowBoundColumnIfWindowedAggregate(node);
});

visitor.process(expression, null);
}

private void processWhere(final Expression expression) {
final AggregateVisitor visitor = new AggregateVisitor(this, (aggFuncName, node) ->
throwOnWindowBoundColumnIfWindowedAggregate(node));
final AggregateVisitor visitor = new AggregateVisitor(
this,
groupBy,
(aggFuncName, node) ->
throwOnWindowBoundColumnIfWindowedAggregate(node));

visitor.process(expression, null);
}

private void processHaving(final Expression expression) {
final AggregateVisitor visitor = new AggregateVisitor(this, (aggFuncName, node) -> {
throwOnWindowBoundColumnIfWindowedAggregate(node);
final AggregateVisitor visitor = new AggregateVisitor(
this,
groupBy,
(aggFuncName, node) -> {
throwOnWindowBoundColumnIfWindowedAggregate(node);

if (!aggFuncName.isPresent()) {
captureNoneAggregateHavingNotPartOfGroupBy(node);
}
});
if (!aggFuncName.isPresent()) {
captureNonAggregateHavingNotPartOfGroupBy(node);
}
});

visitor.process(expression, null);

aggregateAnalysis.setHavingExpression(expression);
}

private void throwOnWindowBoundColumnIfWindowedAggregate(
final ColumnReferenceExp node
) {
private void throwOnWindowBoundColumnIfWindowedAggregate(final Expression node) {
// Window bounds are supported for operations on windowed sources
if (!analysis.getWindowExpression().isPresent()) {
return;
}

if (!(node instanceof ColumnReferenceExp)) {
return;
}
// For non-windowed sources, with a windowed GROUP BY, they are only supported in selects:
if (SystemColumns.isWindowBound(node.getColumnName())) {
if (SystemColumns.isWindowBound(((ColumnReferenceExp)node).getColumnName())) {
throw new KsqlException(
"Window bounds column " + node + " can only be used in the SELECT clause of "
+ "windowed aggregations and can not be passed to aggregate functions."
Expand Down Expand Up @@ -204,7 +210,7 @@ private static Set<Expression> getGroupByExpressions(

private void captureNonAggregateSelectNotPartOfGroupBy(
final Expression expression,
final Set<ColumnReferenceExp> nonAggParams
final Set<Expression> nonAggParams
) {
final boolean matchesGroupBy = groupBy.contains(expression);
if (matchesGroupBy) {
Expand All @@ -220,16 +226,8 @@ private void captureNonAggregateSelectNotPartOfGroupBy(
nonAggSelectsNotPartOfGroupBy.add(expression);
}

private void captureAggregateSelectNotPartOfGroupBy(
final Set<ColumnReferenceExp> nonAggParams
) {
nonAggParams.stream()
.filter(param -> !groupBy.contains(param))
.forEach(aggSelectsNotPartOfGroupBy::add);
}

private void captureNoneAggregateHavingNotPartOfGroupBy(final ColumnReferenceExp nonAggColumn) {
if (groupBy.contains(new UnqualifiedColumnReferenceExp(nonAggColumn.getColumnName()))) {
private void captureNonAggregateHavingNotPartOfGroupBy(final Expression nonAggColumn) {
if (groupBy.contains(nonAggColumn)) {
return;
}

Expand All @@ -256,21 +254,7 @@ private void enforceAggregateRules() {
"Non-aggregate SELECT expression(s) not part of GROUP BY: "
+ unmatchedSelects
+ System.lineSeparator()
+ "Either add the column to the GROUP BY or remove it from the SELECT."
);
}

final String unmatchedSelectsAgg = aggSelectsNotPartOfGroupBy.stream()
.map(Objects::toString)
.collect(Collectors.joining(", "));

if (!unmatchedSelectsAgg.isEmpty()) {
throw new KsqlException(
"Column used in aggregate SELECT expression(s) outside of aggregate functions "
+ "not part of GROUP BY: "
+ unmatchedSelectsAgg
+ System.lineSeparator()
+ "Either add the column to the GROUP BY or remove it from the SELECT."
+ "Either add the column(s) to the GROUP BY or remove them from the SELECT."
);
}

Expand All @@ -288,26 +272,60 @@ private void enforceAggregateRules() {
}
}

/**
* This visitor performs two tasks: Create the input schema to the AggregateNode and validations.
*
* <p>For creating the input schema, it checks if any expression along the path from the root
* expression to the leaf (UnqualifiedColumnReference) is part of the groupBy. If at least one is,
* then the UnqualifiedColumnReference is added to the schema.
*
* <p>For validation, the visitor checks that:
* <ol>
* <li> expressions not in aggregate functions are part of the grouping clause </li>
* <li> aggregate functions are not nested </li>
* <li> window clauses (windowstart, windowend) don't appear in aggregate functions or
* groupBy </li>
* <li> aggregate functions don't appear in the groupBy clause </li>
* <li> expressions in the having clause are either aggregate functions or grouping keys </li>
* </ol>
*/
private static final class AggregateVisitor extends TraversalExpressionVisitor<Void> {

private final BiConsumer<Optional<FunctionName>, ColumnReferenceExp> dereferenceCollector;
private final BiConsumer<Optional<FunctionName>, Expression> dereferenceCollector;
private final ColumnReferenceExp defaultArgument;
private final MutableAggregateAnalysis aggregateAnalysis;
private final FunctionRegistry functionRegistry;
private final Set<Expression> groupBy;
private Expression currentlyInExpressionPartOfGroupBy;

private Optional<FunctionName> aggFunctionName = Optional.empty();
private boolean visitedAggFunction = false;
private boolean currentlyInAggregateFunction = false;

private AggregateVisitor(
final AggAnalyzer aggAnalyzer,
final BiConsumer<Optional<FunctionName>, ColumnReferenceExp> dereferenceCollector
final Set<Expression> groupBy,
final BiConsumer<Optional<FunctionName>, Expression> dereferenceCollector
) {
this.defaultArgument = aggAnalyzer.analysis.getDefaultArgument();
this.aggregateAnalysis = aggAnalyzer.aggregateAnalysis;
this.functionRegistry = aggAnalyzer.functionRegistry;
this.groupBy = groupBy;
this.dereferenceCollector = requireNonNull(dereferenceCollector, "dereferenceCollector");
}

@Override
public Void process(final Expression node, final Void context) {
if (groupBy.contains(node) && currentlyInExpressionPartOfGroupBy == null) {
currentlyInExpressionPartOfGroupBy = node;
}
super.process(node, context);
if (currentlyInExpressionPartOfGroupBy != null
&& currentlyInExpressionPartOfGroupBy == node) {
currentlyInExpressionPartOfGroupBy = null;
}
return null;
}

@Override
public Void visitFunctionCall(final FunctionCall node, final Void context) {
final FunctionName functionName = node.getName();
Expand All @@ -323,7 +341,7 @@ public Void visitFunctionCall(final FunctionCall node, final Void context) {
+ aggFunctionName.get().text() + "(" + functionName.text() + "())");
}

visitedAggFunction = true;
currentlyInAggregateFunction = true;
aggFunctionName = Optional.of(functionName);

functionCall.getArguments().forEach(aggregateAnalysis::addAggregateFunctionArgument);
Expand All @@ -344,9 +362,13 @@ public Void visitUnqualifiedColumnReference(
final UnqualifiedColumnReferenceExp node,
final Void context
) {
dereferenceCollector.accept(aggFunctionName, node);

if (currentlyInExpressionPartOfGroupBy == null
|| currentlyInAggregateFunction
|| SystemColumns.isWindowBound(node.getColumnName())) {
dereferenceCollector.accept(aggFunctionName, node);
}
if (!SystemColumns.isWindowBound(node.getColumnName())) {
// Used to infer the required columns in the INPUT schema of the aggregate node
aggregateAnalysis.addRequiredColumn(node);
}
return null;
Expand Down
Loading

0 comments on commit 8d383db

Please sign in to comment.