Skip to content

Commit

Permalink
Partition by prep (#4781)
Browse files Browse the repository at this point in the history
* refactor: projection expression handling

Prep for #4749.

This commit changes the way the engine resolves '*' in a projection, e.g. `SELECT * FROM X;`.

Previously, the `Analyzer` was responsible for expanding the `*` into the set of columns of each source. However, this code was getting complicated and would be much more complicated once the key column can have any name, (#3536). The complexity comes about because the `Analyzer` would need to determine the presence of joins, group bys, partition bys, etc, which can effect how `*` is resolved.  This logic duplicates the logic in the `LogicalPlanner` and `PlanNode` sub-classes.

With this commit sees the logical plan and planner being responsible for resolving any `*` in the projection. This is achieved by asking the parent of the projection node to resolve the `*` into the set of columns. Parent node types that do not know how to resolve the `*`, e.g. `FilterNode`, forward requests to their parents. In this way, the resolution request ripples up the logical plan until it reaches a `DataSourceNode`, which can resolve the `*` into a list of columns. `JoinNode` knows how forward `*`, `left.*` and `right.*` appropriately.

Previously, the list of `SelectExpressions` was passed down from parent `PlanNode` to child, allowing some nodes to rewrite the expressions. For example, `FlatMapNode` would rewrite any expression involving a TableFunction to use the internal names like `KSQL_SYNTH_0`.

With this commit this is no longer necessary. Instead, when building a projection node the planner asks it's parent node to resolve any selects, allowing the parent to perform any rewrite.

At the moment, the planner is still responsible for much of this work. In the future, this logic may move into the plan itself. However, such a change would increase the complexity of this commit.

Co-authored-by: Andy Coates <[email protected]>
  • Loading branch information
big-andy-coates and big-andy-coates authored Mar 18, 2020
1 parent 2ed4dc8 commit 5a332a0
Show file tree
Hide file tree
Showing 117 changed files with 6,104 additions and 2,735 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2018 Confluent Inc.
* Copyright 2020 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"); you may not use
* this file except in compliance with the License. You may obtain a copy of the
Expand All @@ -15,47 +15,72 @@

package io.confluent.ksql.analyzer;

import static java.util.Objects.requireNonNull;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.google.common.collect.Sets.SetView;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.TraversalExpressionVisitor;
import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.HashSet;
import java.util.Objects;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;

class AggregateAnalyzer {
public class AggregateAnalyzer {

private final MutableAggregateAnalysis aggregateAnalysis;
private final QualifiedColumnReferenceExp defaultArgument;
private final FunctionRegistry functionRegistry;
private final boolean hasWindowExpression;

AggregateAnalyzer(
final MutableAggregateAnalysis aggregateAnalysis,
final QualifiedColumnReferenceExp defaultArgument,
final boolean hasWindowExpression,
final FunctionRegistry functionRegistry
public AggregateAnalyzer(final FunctionRegistry functionRegistry) {
this.functionRegistry = requireNonNull(functionRegistry, "functionRegistry");
}

public AggregateAnalysisResult analyze(
final ImmutableAnalysis analysis,
final List<SelectExpression> finalProjection
) {
this.aggregateAnalysis = Objects.requireNonNull(aggregateAnalysis, "aggregateAnalysis");
this.defaultArgument = Objects.requireNonNull(defaultArgument, "defaultArgument");
this.functionRegistry = Objects.requireNonNull(functionRegistry, "functionRegistry");
this.hasWindowExpression = hasWindowExpression;
if (analysis.getGroupByExpressions().isEmpty()) {
throw new IllegalArgumentException("Not an aggregate query");
}

final Context context = new Context(analysis);

finalProjection.stream()
.map(SelectExpression::getExpression)
.forEach(exp -> processSelect(exp, context));

analysis.getWhereExpression()
.ifPresent(exp -> processWhere(exp, context));

analysis.getGroupByExpressions()
.forEach(exp -> processGroupBy(exp, context));

analysis.getHavingExpression()
.ifPresent(exp -> processHaving(exp, context));

enforceAggregateRules(context);

return context.aggregateAnalysis;
}

void processSelect(final Expression expression) {
private void processSelect(final Expression expression, final Context context) {
final Set<ColumnReferenceExp> nonAggParams = new HashSet<>();
final AggregateVisitor visitor = new AggregateVisitor((aggFuncName, node) -> {
final AggregateVisitor visitor = new AggregateVisitor(context, (aggFuncName, node) -> {
if (aggFuncName.isPresent()) {
throwOnWindowBoundColumnIfWindowedAggregate(node);
throwOnWindowBoundColumnIfWindowedAggregate(node, context);
} else {
nonAggParams.add(node);
}
Expand All @@ -64,45 +89,51 @@ void processSelect(final Expression expression) {
visitor.process(expression, null);

if (visitor.visitedAggFunction) {
aggregateAnalysis.addAggregateSelectField(nonAggParams);
context.aggregateAnalysis.addAggregateSelectField(nonAggParams);
} else {
aggregateAnalysis.addNonAggregateSelectExpression(expression, nonAggParams);
context.aggregateAnalysis.addNonAggregateSelectExpression(expression, nonAggParams);
}

context.aggregateAnalysis.addFinalSelectExpression(expression);
}

void processGroupBy(final Expression expression) {
final AggregateVisitor visitor = new AggregateVisitor((aggFuncName, node) -> {
private void processGroupBy(final Expression expression, final Context context) {
final AggregateVisitor visitor = new AggregateVisitor(context, (aggFuncName, node) -> {
if (aggFuncName.isPresent()) {
throw new KsqlException("GROUP BY does not support aggregate functions: "
+ aggFuncName.get().text() + " is an aggregate function.");
}
throwOnWindowBoundColumnIfWindowedAggregate(node);
throwOnWindowBoundColumnIfWindowedAggregate(node, context);
});

visitor.process(expression, null);
}

void processWhere(final Expression expression) {
final AggregateVisitor visitor = new AggregateVisitor((aggFuncName, node) -> {
throwOnWindowBoundColumnIfWindowedAggregate(node);
});
private void processWhere(final Expression expression, final Context context) {
final AggregateVisitor visitor = new AggregateVisitor(context, (aggFuncName, node) ->
throwOnWindowBoundColumnIfWindowedAggregate(node, context));

visitor.process(expression, null);
}

void processHaving(final Expression expression) {
final AggregateVisitor visitor = new AggregateVisitor((aggFuncName, node) -> {
throwOnWindowBoundColumnIfWindowedAggregate(node);
private void processHaving(final Expression expression, final Context context) {
final AggregateVisitor visitor = new AggregateVisitor(context, (aggFuncName, node) -> {
throwOnWindowBoundColumnIfWindowedAggregate(node, context);
if (!aggFuncName.isPresent()) {
aggregateAnalysis.addNonAggregateHavingField(node);
context.aggregateAnalysis.addNonAggregateHavingField(node);
}
});
visitor.process(expression, null);

context.aggregateAnalysis.setHavingExpression(expression);
}

private void throwOnWindowBoundColumnIfWindowedAggregate(final ColumnReferenceExp node) {
private static void throwOnWindowBoundColumnIfWindowedAggregate(
final ColumnReferenceExp node,
final Context context
) {
// Window bounds are supported for operations on windowed sources
if (!hasWindowExpression) {
if (!context.analysis.getWindowExpression().isPresent()) {
return;
}

Expand All @@ -117,18 +148,91 @@ private void throwOnWindowBoundColumnIfWindowedAggregate(final ColumnReferenceEx
}
}

private static void enforceAggregateRules(
final Context context
) {
if (context.aggregateAnalysis.getAggregateFunctions().isEmpty()) {
throw new KsqlException(
"GROUP BY requires columns using aggregate functions in SELECT clause.");
}

final Set<Expression> groupByExprs = getGroupByExpressions(context.analysis);

final List<String> unmatchedSelects = context.aggregateAnalysis
.getNonAggregateSelectExpressions()
.entrySet()
.stream()
// Remove any that exactly match a group by expression:
.filter(e -> !groupByExprs.contains(e.getKey()))
// Remove any that are constants,
// or expressions where all params exactly match a group by expression:
.filter(e -> !Sets.difference(e.getValue(), groupByExprs).isEmpty())
.map(Map.Entry::getKey)
.map(Expression::toString)
.sorted()
.collect(Collectors.toList());

if (!unmatchedSelects.isEmpty()) {
throw new KsqlException(
"Non-aggregate SELECT expression(s) not part of GROUP BY: " + unmatchedSelects);
}

final SetView<ColumnReferenceExp> unmatchedSelectsAgg = Sets
.difference(context.aggregateAnalysis.getAggregateSelectFields(), groupByExprs);
if (!unmatchedSelectsAgg.isEmpty()) {
throw new KsqlException(
"Column used in aggregate SELECT expression(s) "
+ "outside of aggregate functions not part of GROUP BY: " + unmatchedSelectsAgg);
}

final Set<ColumnReferenceExp> havingColumns = context.aggregateAnalysis
.getNonAggregateHavingFields().stream()
.map(ref -> new UnqualifiedColumnReferenceExp(ref.getColumnName()))
.collect(Collectors.toSet());

final Set<ColumnReferenceExp> havingOnly = Sets.difference(havingColumns, groupByExprs);
if (!havingOnly.isEmpty()) {
throw new KsqlException(
"Non-aggregate HAVING expression not part of GROUP BY: " + havingOnly);
}
}

private static Set<Expression> getGroupByExpressions(
final ImmutableAnalysis analysis
) {
if (!analysis.getWindowExpression().isPresent()) {
return ImmutableSet.copyOf(analysis.getGroupByExpressions());
}

// Add in window bounds columns as implicit group by columns:
final Set<UnqualifiedColumnReferenceExp> windowBoundColumnRefs =
SchemaUtil.windowBoundsColumnNames().stream()
.map(UnqualifiedColumnReferenceExp::new)
.collect(Collectors.toSet());

return ImmutableSet.<Expression>builder()
.addAll(analysis.getGroupByExpressions())
.addAll(windowBoundColumnRefs)
.build();
}

private final class AggregateVisitor extends TraversalExpressionVisitor<Void> {

private final BiConsumer<Optional<FunctionName>, ColumnReferenceExp>
dereferenceCollector;
private final ColumnReferenceExp defaultArgument;
private final MutableAggregateAnalysis aggregateAnalysis;

private Optional<FunctionName> aggFunctionName = Optional.empty();
private boolean visitedAggFunction = false;

private AggregateVisitor(
final Context context,
final BiConsumer<Optional<FunctionName>, ColumnReferenceExp> dereferenceCollector
) {
this.dereferenceCollector =
Objects.requireNonNull(dereferenceCollector, "dereferenceCollector");
this.defaultArgument = context.analysis.getDefaultArgument();
this.aggregateAnalysis = context.aggregateAnalysis;
this.dereferenceCollector = requireNonNull(dereferenceCollector, "dereferenceCollector");
}

@Override
Expand Down Expand Up @@ -180,12 +284,17 @@ public Void visitQualifiedColumnReference(
final QualifiedColumnReferenceExp node,
final Void context
) {
dereferenceCollector.accept(aggFunctionName, node);
throw new UnsupportedOperationException("Should of been converted to unqualified");
}
}

if (!SchemaUtil.isWindowBound(node.getColumnName())) {
aggregateAnalysis.addRequiredColumn(node);
}
return null;
private static final class Context {

final ImmutableAnalysis analysis;
final MutableAggregateAnalysis aggregateAnalysis = new MutableAggregateAnalysis();

Context(final ImmutableAnalysis analysis) {
this.analysis = requireNonNull(analysis, "analysis");
}
}
}
Loading

0 comments on commit 5a332a0

Please sign in to comment.