Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: do not allow grouping sets #4942

Merged
merged 4 commits into from
Mar 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@
import io.confluent.ksql.planner.plan.JoinNode.JoinType;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.serde.SerdeOption;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand All @@ -63,7 +65,7 @@ public class Analysis implements ImmutableAnalysis {
private Optional<Expression> whereExpression = Optional.empty();
private final List<SelectItem> selectItems = new ArrayList<>();
private final Set<ColumnName> selectColumnNames = new HashSet<>();
private final List<Expression> groupByExpressions = new ArrayList<>();
private final Set<Expression> groupByExpressions = new LinkedHashSet<>();
private Optional<WindowExpression> windowExpression = Optional.empty();
private Optional<Expression> partitionBy = Optional.empty();
private Optional<Expression> havingExpression = Optional.empty();
Expand Down Expand Up @@ -129,8 +131,12 @@ public List<Expression> getGroupByExpressions() {
return ImmutableList.copyOf(groupByExpressions);
}

void addGroupByExpressions(final Set<Expression> expressions) {
groupByExpressions.addAll(expressions);
void setGroupByExpressions(final List<Expression> expressions) {
expressions.forEach(exp -> {
if (!groupByExpressions.add(exp)) {
throw new KsqlException("Duplicate GROUP BY expression: " + exp);
}
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import io.confluent.ksql.parser.tree.AllColumns;
import io.confluent.ksql.parser.tree.AstNode;
import io.confluent.ksql.parser.tree.GroupBy;
import io.confluent.ksql.parser.tree.GroupingElement;
import io.confluent.ksql.parser.tree.Join;
import io.confluent.ksql.parser.tree.JoinOn;
import io.confluent.ksql.parser.tree.JoinedSource;
Expand Down Expand Up @@ -482,22 +481,13 @@ protected AstNode visitSelect(final Select node, final Void context) {
return null;
}

@Override
protected AstNode visitGroupBy(final GroupBy node, final Void context) {
return null;
}

private void analyzeWhere(final Expression node) {
analysis.setWhereExpression(node);
}

private void analyzeGroupBy(final GroupBy groupBy) {
isGroupBy = true;

for (final GroupingElement groupingElement : groupBy.getGroupingElements()) {
final Set<Expression> groupingSet = groupingElement.enumerateGroupingSets().get(0);
analysis.addGroupByExpressions(groupingSet);
}
analysis.setGroupByExpressions(groupBy.getGroupingExpressions());
}

private void analyzePartitionBy(final Expression partitionBy) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import io.confluent.ksql.parser.tree.DropTable;
import io.confluent.ksql.parser.tree.Explain;
import io.confluent.ksql.parser.tree.GroupBy;
import io.confluent.ksql.parser.tree.GroupingElement;
import io.confluent.ksql.parser.tree.InsertInto;
import io.confluent.ksql.parser.tree.Join;
import io.confluent.ksql.parser.tree.JoinCriteria;
Expand All @@ -41,7 +40,6 @@
import io.confluent.ksql.parser.tree.Relation;
import io.confluent.ksql.parser.tree.Select;
import io.confluent.ksql.parser.tree.SelectItem;
import io.confluent.ksql.parser.tree.SimpleGroupBy;
import io.confluent.ksql.parser.tree.SingleColumn;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.parser.tree.Statements;
Expand Down Expand Up @@ -441,30 +439,13 @@ protected AstNode visitGroupBy(final GroupBy node, final C context) {
return result.get();
}

final List<GroupingElement> rewrittenGroupings = node.getGroupingElements().stream()
.map(groupingElement -> (GroupingElement) rewriter.apply(groupingElement, context))
final List<Expression> rewrittenGroupings = node.getGroupingExpressions().stream()
.map(exp -> processExpression(exp, context))
.collect(Collectors.toList());

return new GroupBy(node.getLocation(), rewrittenGroupings);
}

@Override
protected AstNode visitSimpleGroupBy(final SimpleGroupBy node, final C context) {
final Optional<AstNode> result = plugin.apply(node, new Context<>(context, this));
if (result.isPresent()) {
return result.get();
}

final List<Expression> columns = node.getColumns().stream()
.map(ce -> processExpression(ce, context))
.collect(Collectors.toList());

return new SimpleGroupBy(
node.getLocation(),
columns
);
}

@Override
public AstNode visitRegisterType(final RegisterType node, final C context) {
final Optional<AstNode> result = plugin.apply(node, new Context<>(context, this));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@

package io.confluent.ksql.analyzer;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.metastore.model.KsqlStream;
import io.confluent.ksql.model.WindowType;
import io.confluent.ksql.name.ColumnName;
Expand All @@ -32,7 +38,9 @@
import io.confluent.ksql.serde.FormatInfo;
import io.confluent.ksql.serde.KeyFormat;
import io.confluent.ksql.serde.WindowInfo;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
Expand Down Expand Up @@ -63,6 +71,10 @@ public class AnalysisTest {
private Function<Map<SourceName, LogicalSchema>, SourceSchemas> sourceSchemasFactory;
@Mock
private WindowExpression windowExpression;
@Mock(name = "anExpression")
private Expression exp1;
@Mock(name = "anotherExpression")
private Expression exp2;

private Analysis analysis;

Expand Down Expand Up @@ -177,6 +189,34 @@ public void shouldGetWindowedGroupBySourceSchemasPostAggregate() {
));
}

@Test
public void shouldMaintainGroupByOrder() {
// Given:
final List<Expression> original = ImmutableList.of(exp1, exp2);

analysis.setGroupByExpressions(original);

// When:
final List<Expression> result = analysis.getGroupByExpressions();

// Then:
assertThat(result, is(original));
}

@Test
public void shouldThrowOnDuplicateGroupBy() {
// Given:
final List<Expression> withDuplicate = ImmutableList.of(exp1, exp1);

// When:
final KsqlException e = assertThrows(
KsqlException.class,
() -> analysis.setGroupByExpressions(withDuplicate)
);

// Then:
assertThat(e.getMessage(), containsString("Duplicate GROUP BY expression: anExpression"));
}

private static void givenNoneWindowedSource(final KsqlStream<?> dataSource) {
final KsqlTopic topic = mock(KsqlTopic.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import io.confluent.ksql.parser.tree.CreateTableAsSelect;
import io.confluent.ksql.parser.tree.Explain;
import io.confluent.ksql.parser.tree.GroupBy;
import io.confluent.ksql.parser.tree.GroupingElement;
import io.confluent.ksql.parser.tree.InsertInto;
import io.confluent.ksql.parser.tree.Join;
import io.confluent.ksql.parser.tree.JoinCriteria;
Expand All @@ -52,7 +51,6 @@
import io.confluent.ksql.parser.tree.Relation;
import io.confluent.ksql.parser.tree.ResultMaterialization;
import io.confluent.ksql.parser.tree.Select;
import io.confluent.ksql.parser.tree.SimpleGroupBy;
import io.confluent.ksql.parser.tree.SingleColumn;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.parser.tree.Statements;
Expand Down Expand Up @@ -769,16 +767,16 @@ public void shouldRewriteInsertIntoUsingPlugin() {
@Test
public void shouldRewriteGroupBy() {
// Given:
final GroupingElement groupingElement1 = mock(GroupingElement.class);
final GroupingElement groupingElement2 = mock(GroupingElement.class);
final GroupingElement rewrittenGroupingElement1 = mock(GroupingElement.class);
final GroupingElement rewrittenGroupingElement2 = mock(GroupingElement.class);
final Expression exp1 = mock(Expression.class);
final Expression exp2 = mock(Expression.class);
final Expression rewrittenExp1 = mock(Expression.class);
final Expression rewrittenExp2 = mock(Expression.class);
final GroupBy groupBy = new GroupBy(
location,
ImmutableList.of(groupingElement1, groupingElement2)
ImmutableList.of(exp1, exp2)
);
when(mockRewriter.apply(groupingElement1, context)).thenReturn(rewrittenGroupingElement1);
when(mockRewriter.apply(groupingElement2, context)).thenReturn(rewrittenGroupingElement2);
when(expressionRewriter.apply(exp1, context)).thenReturn(rewrittenExp1);
when(expressionRewriter.apply(exp2, context)).thenReturn(rewrittenExp2);

// When:
final AstNode rewritten = rewriter.rewrite(groupBy, context);
Expand All @@ -789,32 +787,7 @@ public void shouldRewriteGroupBy() {
equalTo(
new GroupBy(
location,
ImmutableList.of(rewrittenGroupingElement1, rewrittenGroupingElement2)
)
)
);
}

@Test
public void shouldRewriteSimpleGroupBy() {
// Given:
final Expression expression2 = mock(Expression.class);
final Expression rewrittenExpression2 = mock(Expression.class);
final SimpleGroupBy groupBy =
new SimpleGroupBy(location, ImmutableList.of(expression, expression2));
when(expressionRewriter.apply(expression, context)).thenReturn(rewrittenExpression);
when(expressionRewriter.apply(expression2, context)).thenReturn(rewrittenExpression2);

// When:
final AstNode rewritten = rewriter.rewrite(groupBy, context);

// Then:
assertThat(
rewritten,
equalTo(
new SimpleGroupBy(
location,
ImmutableList.of(rewrittenExpression, rewrittenExpression2)
ImmutableList.of(rewrittenExp1, rewrittenExp2)
)
)
);
Expand Down
Loading