Skip to content

Commit

Permalink
fix: do not allow grouping sets (#4942)
Browse files Browse the repository at this point in the history
* fix: do not allow grouping sets

fixes: #4941

Change the ksqlDB syntax to now allow grouping sets, which aren't supported yet.

Co-authored-by: Andy Coates <[email protected]>
  • Loading branch information
big-andy-coates and big-andy-coates authored Mar 31, 2020
1 parent 9bce85f commit 51bb9f6
Show file tree
Hide file tree
Showing 19 changed files with 383 additions and 375 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@
import io.confluent.ksql.planner.plan.JoinNode.JoinType;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.serde.SerdeOption;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand All @@ -63,7 +65,7 @@ public class Analysis implements ImmutableAnalysis {
private Optional<Expression> whereExpression = Optional.empty();
private final List<SelectItem> selectItems = new ArrayList<>();
private final Set<ColumnName> selectColumnNames = new HashSet<>();
private final List<Expression> groupByExpressions = new ArrayList<>();
private final Set<Expression> groupByExpressions = new LinkedHashSet<>();
private Optional<WindowExpression> windowExpression = Optional.empty();
private Optional<Expression> partitionBy = Optional.empty();
private Optional<Expression> havingExpression = Optional.empty();
Expand Down Expand Up @@ -129,8 +131,12 @@ public List<Expression> getGroupByExpressions() {
return ImmutableList.copyOf(groupByExpressions);
}

void addGroupByExpressions(final Set<Expression> expressions) {
groupByExpressions.addAll(expressions);
void setGroupByExpressions(final List<Expression> expressions) {
expressions.forEach(exp -> {
if (!groupByExpressions.add(exp)) {
throw new KsqlException("Duplicate GROUP BY expression: " + exp);
}
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import io.confluent.ksql.parser.tree.AllColumns;
import io.confluent.ksql.parser.tree.AstNode;
import io.confluent.ksql.parser.tree.GroupBy;
import io.confluent.ksql.parser.tree.GroupingElement;
import io.confluent.ksql.parser.tree.Join;
import io.confluent.ksql.parser.tree.JoinOn;
import io.confluent.ksql.parser.tree.JoinedSource;
Expand Down Expand Up @@ -482,22 +481,13 @@ protected AstNode visitSelect(final Select node, final Void context) {
return null;
}

@Override
protected AstNode visitGroupBy(final GroupBy node, final Void context) {
return null;
}

private void analyzeWhere(final Expression node) {
analysis.setWhereExpression(node);
}

private void analyzeGroupBy(final GroupBy groupBy) {
isGroupBy = true;

for (final GroupingElement groupingElement : groupBy.getGroupingElements()) {
final Set<Expression> groupingSet = groupingElement.enumerateGroupingSets().get(0);
analysis.addGroupByExpressions(groupingSet);
}
analysis.setGroupByExpressions(groupBy.getGroupingExpressions());
}

private void analyzePartitionBy(final Expression partitionBy) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import io.confluent.ksql.parser.tree.DropTable;
import io.confluent.ksql.parser.tree.Explain;
import io.confluent.ksql.parser.tree.GroupBy;
import io.confluent.ksql.parser.tree.GroupingElement;
import io.confluent.ksql.parser.tree.InsertInto;
import io.confluent.ksql.parser.tree.Join;
import io.confluent.ksql.parser.tree.JoinCriteria;
Expand All @@ -41,7 +40,6 @@
import io.confluent.ksql.parser.tree.Relation;
import io.confluent.ksql.parser.tree.Select;
import io.confluent.ksql.parser.tree.SelectItem;
import io.confluent.ksql.parser.tree.SimpleGroupBy;
import io.confluent.ksql.parser.tree.SingleColumn;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.parser.tree.Statements;
Expand Down Expand Up @@ -441,30 +439,13 @@ protected AstNode visitGroupBy(final GroupBy node, final C context) {
return result.get();
}

final List<GroupingElement> rewrittenGroupings = node.getGroupingElements().stream()
.map(groupingElement -> (GroupingElement) rewriter.apply(groupingElement, context))
final List<Expression> rewrittenGroupings = node.getGroupingExpressions().stream()
.map(exp -> processExpression(exp, context))
.collect(Collectors.toList());

return new GroupBy(node.getLocation(), rewrittenGroupings);
}

@Override
protected AstNode visitSimpleGroupBy(final SimpleGroupBy node, final C context) {
final Optional<AstNode> result = plugin.apply(node, new Context<>(context, this));
if (result.isPresent()) {
return result.get();
}

final List<Expression> columns = node.getColumns().stream()
.map(ce -> processExpression(ce, context))
.collect(Collectors.toList());

return new SimpleGroupBy(
node.getLocation(),
columns
);
}

@Override
public AstNode visitRegisterType(final RegisterType node, final C context) {
final Optional<AstNode> result = plugin.apply(node, new Context<>(context, this));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@

package io.confluent.ksql.analyzer;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.metastore.model.KsqlStream;
import io.confluent.ksql.model.WindowType;
import io.confluent.ksql.name.ColumnName;
Expand All @@ -32,7 +38,9 @@
import io.confluent.ksql.serde.FormatInfo;
import io.confluent.ksql.serde.KeyFormat;
import io.confluent.ksql.serde.WindowInfo;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
Expand Down Expand Up @@ -63,6 +71,10 @@ public class AnalysisTest {
private Function<Map<SourceName, LogicalSchema>, SourceSchemas> sourceSchemasFactory;
@Mock
private WindowExpression windowExpression;
@Mock(name = "anExpression")
private Expression exp1;
@Mock(name = "anotherExpression")
private Expression exp2;

private Analysis analysis;

Expand Down Expand Up @@ -177,6 +189,34 @@ public void shouldGetWindowedGroupBySourceSchemasPostAggregate() {
));
}

@Test
public void shouldMaintainGroupByOrder() {
// Given:
final List<Expression> original = ImmutableList.of(exp1, exp2);

analysis.setGroupByExpressions(original);

// When:
final List<Expression> result = analysis.getGroupByExpressions();

// Then:
assertThat(result, is(original));
}

@Test
public void shouldThrowOnDuplicateGroupBy() {
// Given:
final List<Expression> withDuplicate = ImmutableList.of(exp1, exp1);

// When:
final KsqlException e = assertThrows(
KsqlException.class,
() -> analysis.setGroupByExpressions(withDuplicate)
);

// Then:
assertThat(e.getMessage(), containsString("Duplicate GROUP BY expression: anExpression"));
}

private static void givenNoneWindowedSource(final KsqlStream<?> dataSource) {
final KsqlTopic topic = mock(KsqlTopic.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import io.confluent.ksql.parser.tree.CreateTableAsSelect;
import io.confluent.ksql.parser.tree.Explain;
import io.confluent.ksql.parser.tree.GroupBy;
import io.confluent.ksql.parser.tree.GroupingElement;
import io.confluent.ksql.parser.tree.InsertInto;
import io.confluent.ksql.parser.tree.Join;
import io.confluent.ksql.parser.tree.JoinCriteria;
Expand All @@ -52,7 +51,6 @@
import io.confluent.ksql.parser.tree.Relation;
import io.confluent.ksql.parser.tree.ResultMaterialization;
import io.confluent.ksql.parser.tree.Select;
import io.confluent.ksql.parser.tree.SimpleGroupBy;
import io.confluent.ksql.parser.tree.SingleColumn;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.parser.tree.Statements;
Expand Down Expand Up @@ -769,16 +767,16 @@ public void shouldRewriteInsertIntoUsingPlugin() {
@Test
public void shouldRewriteGroupBy() {
// Given:
final GroupingElement groupingElement1 = mock(GroupingElement.class);
final GroupingElement groupingElement2 = mock(GroupingElement.class);
final GroupingElement rewrittenGroupingElement1 = mock(GroupingElement.class);
final GroupingElement rewrittenGroupingElement2 = mock(GroupingElement.class);
final Expression exp1 = mock(Expression.class);
final Expression exp2 = mock(Expression.class);
final Expression rewrittenExp1 = mock(Expression.class);
final Expression rewrittenExp2 = mock(Expression.class);
final GroupBy groupBy = new GroupBy(
location,
ImmutableList.of(groupingElement1, groupingElement2)
ImmutableList.of(exp1, exp2)
);
when(mockRewriter.apply(groupingElement1, context)).thenReturn(rewrittenGroupingElement1);
when(mockRewriter.apply(groupingElement2, context)).thenReturn(rewrittenGroupingElement2);
when(expressionRewriter.apply(exp1, context)).thenReturn(rewrittenExp1);
when(expressionRewriter.apply(exp2, context)).thenReturn(rewrittenExp2);

// When:
final AstNode rewritten = rewriter.rewrite(groupBy, context);
Expand All @@ -789,32 +787,7 @@ public void shouldRewriteGroupBy() {
equalTo(
new GroupBy(
location,
ImmutableList.of(rewrittenGroupingElement1, rewrittenGroupingElement2)
)
)
);
}

@Test
public void shouldRewriteSimpleGroupBy() {
// Given:
final Expression expression2 = mock(Expression.class);
final Expression rewrittenExpression2 = mock(Expression.class);
final SimpleGroupBy groupBy =
new SimpleGroupBy(location, ImmutableList.of(expression, expression2));
when(expressionRewriter.apply(expression, context)).thenReturn(rewrittenExpression);
when(expressionRewriter.apply(expression2, context)).thenReturn(rewrittenExpression2);

// When:
final AstNode rewritten = rewriter.rewrite(groupBy, context);

// Then:
assertThat(
rewritten,
equalTo(
new SimpleGroupBy(
location,
ImmutableList.of(rewrittenExpression, rewrittenExpression2)
ImmutableList.of(rewrittenExp1, rewrittenExp2)
)
)
);
Expand Down
Loading

0 comments on commit 51bb9f6

Please sign in to comment.