Skip to content

Commit

Permalink
fix(3356): struct rewritter missed EXPLAIN (#3398)
Browse files Browse the repository at this point in the history
* fix(3356): struct rewritter missed EXPLAIN

KSQL uses a statement rewritter to convert STRUCT field dereferences into `FETCH_FIELD_FROM_STRUCT` function calls.  The rewrite is done for `Query`, `C*AS` and `INSERT INTO` statements, but was missing `EXPLAIN`.

It is valid to execute `EXPLAIN` on a query string, e.g.  `EXPLAIN SELECT address->street FROM Y;`.  It is therefore important that the query within the `EXPLAIN` statement has the same rewrites applied to it as would be applied should the statement be executed directly.
  • Loading branch information
big-andy-coates authored Sep 23, 2019
1 parent bf51808 commit daf974b
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,8 @@ public Void visitDereferenceExpression(
final DereferenceExpression node,
final ExpressionTypeContext expressionTypeContext
) {
final Column schemaColumn = schema.findValueColumn(node.toString())
.orElseThrow(() ->
new KsqlException(String.format("Invalid Expression %s.", node.toString())));

final Schema schema = SQL_TO_CONNECT_SCHEMA_CONVERTER.toConnectSchema(schemaColumn.type());
expressionTypeContext.setSchema(schemaColumn.type(), schema);
return null;
throw new IllegalArgumentException("Dereferenced expressions should have been rewritten to "
+ FetchFieldFromStruct.FUNCTION_NAME + " by this point");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import io.confluent.ksql.execution.expression.tree.BooleanLiteral;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression.Type;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.InListExpression;
Expand All @@ -45,6 +46,7 @@
import io.confluent.ksql.execution.expression.tree.LikePredicate;
import io.confluent.ksql.execution.expression.tree.NotExpression;
import io.confluent.ksql.execution.expression.tree.QualifiedName;
import io.confluent.ksql.execution.expression.tree.QualifiedNameReference;
import io.confluent.ksql.execution.expression.tree.SearchedCaseExpression;
import io.confluent.ksql.execution.expression.tree.SimpleCaseExpression;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
Expand Down Expand Up @@ -274,7 +276,23 @@ public void shouldHandleNestedUdfs() {
}

@Test
public void shouldHandleStruct() {
public void shouldThrowOnStructFieldDereference() {
// Given:
final Expression expression = new DereferenceExpression(
Optional.empty(),
new QualifiedNameReference(QualifiedName.of("TEST1", "COL6")),
"STREET"
);

// Then:
expectedException.expect(IllegalArgumentException.class);

// When:
expressionTypeManager.getExpressionSqlType(expression);
}

@Test
public void shouldHandleRewrittenStruct() {
final Expression expression = new FunctionCall(
QualifiedName.of(FetchFieldFromStruct.FUNCTION_NAME),
ImmutableList.of(ADDRESS, new StringLiteral("NUMBER"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor;
import io.confluent.ksql.parser.rewrite.ExpressionTreeRewriter.Context;
import io.confluent.ksql.parser.tree.Explain;
import io.confluent.ksql.parser.tree.Query;
import io.confluent.ksql.parser.tree.QueryContainer;
import io.confluent.ksql.parser.tree.Statement;
Expand All @@ -46,7 +47,8 @@ public Statement rewriteForStruct() {

public static boolean requiresRewrite(final Statement statement) {
return statement instanceof Query
|| statement instanceof QueryContainer;
|| statement instanceof QueryContainer
|| statement instanceof Explain;
}

private static final class Plugin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.confluent.ksql.parser.tree.CreateTable;
import io.confluent.ksql.parser.tree.CreateTableAsSelect;
import io.confluent.ksql.parser.tree.DropTable;
import io.confluent.ksql.parser.tree.Explain;
import io.confluent.ksql.parser.tree.GroupBy;
import io.confluent.ksql.parser.tree.GroupingElement;
import io.confluent.ksql.parser.tree.InsertInto;
Expand Down Expand Up @@ -105,6 +106,11 @@ private Expression processExpression(final Expression node, final C context) {
return expressionRewriter.apply(node, context);
}

@Override
protected AstNode visitNode(final AstNode node, final C context) {
return node;
}

@Override
protected AstNode visitStatements(final Statements node, final C context) {
final List<Statement> rewrittenStatements = node.getStatements()
Expand Down Expand Up @@ -150,6 +156,22 @@ protected Query visitQuery(final Query node, final C context) {
);
}

@Override
protected AstNode visitExplain(final Explain node, final C context) {
if (!node.getStatement().isPresent()) {
return node;
}

final Statement original = node.getStatement().get();
final Statement rewritten = (Statement) rewriter.apply(original, context);

return new Explain(
node.getLocation(),
node.getQueryId(),
Optional.of(rewritten)
);
}

@Override
protected AstNode visitSelect(final Select node, final C context) {
final List<SelectItem> rewrittenItems = node.getSelectItems()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.core.IsInstanceOf.instanceOf;
import static org.mockito.Mockito.mock;

Expand All @@ -29,15 +31,16 @@
import io.confluent.ksql.parser.tree.CreateStreamAsSelect;
import io.confluent.ksql.parser.tree.CreateTable;
import io.confluent.ksql.parser.tree.CreateTableAsSelect;
import io.confluent.ksql.parser.tree.Explain;
import io.confluent.ksql.parser.tree.InsertInto;
import io.confluent.ksql.parser.tree.Query;
import io.confluent.ksql.parser.tree.SingleColumn;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.util.MetaStoreFixture;
import org.easymock.EasyMock;
import org.junit.Before;
import org.junit.Test;

@SuppressWarnings("OptionalGetWithoutIsPresent")
public class StatementRewriteForStructTest {

private MetaStore metaStore;
Expand All @@ -54,7 +57,7 @@ public void shouldCreateCorrectFunctionCallExpression() {
.getStatement();

final Query query = getQuery(statement);
assertThat(query.getSelect().getSelectItems().size(), equalTo(2));
assertThat(query.getSelect().getSelectItems(), hasSize(2));
final Expression col0 = ((SingleColumn) query.getSelect().getSelectItems().get(0))
.getExpression();
final Expression col1 = ((SingleColumn) query.getSelect().getSelectItems().get(1))
Expand All @@ -66,7 +69,6 @@ public void shouldCreateCorrectFunctionCallExpression() {
assertThat(col0.toString(), equalTo(
"FETCH_FIELD_FROM_STRUCT(FETCH_FIELD_FROM_STRUCT(ORDERS.ITEMINFO, 'CATEGORY'), 'NAME')"));
assertThat(col1.toString(), equalTo("FETCH_FIELD_FROM_STRUCT(ORDERS.ADDRESS, 'STATE')"));

}

@Test
Expand All @@ -76,7 +78,7 @@ public void shouldNotCreateFunctionCallIfNotNeeded() {
.getStatement();

final Query query = getQuery(statement);
assertThat(query.getSelect().getSelectItems().size(), equalTo(1));
assertThat(query.getSelect().getSelectItems(), hasSize(1));
final Expression col0 = ((SingleColumn) query.getSelect().getSelectItems().get(0))
.getExpression();

Expand All @@ -91,7 +93,7 @@ public void shouldCreateCorrectFunctionCallExpressionWithSubscript() {
.getStatement();

final Query query = getQuery(statement);
assertThat(query.getSelect().getSelectItems().size(), equalTo(2));
assertThat(query.getSelect().getSelectItems(), hasSize(2));
final Expression col0 = ((SingleColumn) query.getSelect().getSelectItems().get(0))
.getExpression();
final Expression col1 = ((SingleColumn) query.getSelect().getSelectItems().get(1))
Expand All @@ -113,7 +115,7 @@ public void shouldCreateCorrectFunctionCallExpressionWithSubscriptWithExpression
.getStatement();

final Query query = getQuery(statement);
assertThat(query.getSelect().getSelectItems().size(), equalTo(2));
assertThat(query.getSelect().getSelectItems(), hasSize(2));
final Expression col0 = ((SingleColumn) query.getSelect().getSelectItems().get(0))
.getExpression();
final Expression col1 = ((SingleColumn) query.getSelect().getSelectItems().get(1))
Expand All @@ -128,17 +130,38 @@ public void shouldCreateCorrectFunctionCallExpressionWithSubscriptWithExpression
equalTo("FETCH_FIELD_FROM_STRUCT(NESTED_STREAM.MAPCOL['key'], 'NAME')"));
}

@Test
public void shouldRewriteExplainQuery() {
// When:
final Explain statement = KsqlParserTestUtil.<Explain>buildSingleAst(
"EXPLAIN SELECT address->state FROM orders;", metaStore)
.getStatement();

// Then:
assertThat(
statement.getStatement().toString(),
containsString("FETCH_FIELD_FROM_STRUCT(ORDERS.ADDRESS, 'STATE')")
);
}

@Test
public void shouldEnsureRewriteRequirementCorrectly() {
assertThat("Query should be valid for rewrite for struct.", StatementRewriteForStruct.requiresRewrite(EasyMock.mock(Query.class)));
assertThat("CSAS should be valid for rewrite for struct.", StatementRewriteForStruct.requiresRewrite(EasyMock.mock(CreateStreamAsSelect.class)));
assertThat("CTAS should be valid for rewrite for struct.", StatementRewriteForStruct.requiresRewrite(EasyMock.mock(CreateTableAsSelect.class)));
assertThat("Insert Into should be valid for rewrite for struct.", StatementRewriteForStruct.requiresRewrite(EasyMock.mock(InsertInto.class)));
assertThat("Query should be valid for rewrite for struct.",
StatementRewriteForStruct.requiresRewrite(mock(Query.class)));
assertThat("CSAS should be valid for rewrite for struct.",
StatementRewriteForStruct.requiresRewrite(mock(CreateStreamAsSelect.class)));
assertThat("CTAS should be valid for rewrite for struct.",
StatementRewriteForStruct.requiresRewrite(mock(CreateTableAsSelect.class)));
assertThat("Insert Into should be valid for rewrite for struct.",
StatementRewriteForStruct.requiresRewrite(mock(InsertInto.class)));
assertThat("Explain should be valid for rewrite for struct.",
StatementRewriteForStruct.requiresRewrite(mock(Explain.class)));
}

@Test
public void shouldFailTestIfStatementShouldBeRewritten() {
assertThat("Incorrect rewrite requirement enforcement.", !StatementRewriteForStruct.requiresRewrite(EasyMock.mock(CreateTable.class)));
assertThat("Incorrect rewrite requirement enforcement.",
!StatementRewriteForStruct.requiresRewrite(mock(CreateTable.class)));
}

private static Query getQuery(final Statement statement) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package io.confluent.ksql.parser.rewrite;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.sameInstance;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand All @@ -17,6 +19,7 @@
import io.confluent.ksql.parser.tree.CreateStreamAsSelect;
import io.confluent.ksql.parser.tree.CreateTable;
import io.confluent.ksql.parser.tree.CreateTableAsSelect;
import io.confluent.ksql.parser.tree.Explain;
import io.confluent.ksql.parser.tree.GroupBy;
import io.confluent.ksql.parser.tree.GroupingElement;
import io.confluent.ksql.parser.tree.InsertInto;
Expand Down Expand Up @@ -48,6 +51,7 @@
import org.mockito.junit.MockitoRule;

public class StatementRewriterTest {

@Mock
private BiFunction<Expression, Object, Expression> expressionRewriter;
@Mock
Expand Down Expand Up @@ -88,7 +92,6 @@ public class StatementRewriterTest {
private CreateSourceAsProperties csasProperties;
@Mock
private ResultMaterialization resultMaterialization;
private boolean staticQuery;

private StatementRewriter<Object> rewriter;

Expand Down Expand Up @@ -139,7 +142,7 @@ private Query givenQuery(
groupBy,
having,
resultMaterialization,
staticQuery,
false,
optionalInt
);
}
Expand All @@ -163,7 +166,7 @@ public void shouldRewriteQuery() {
Optional.empty(),
Optional.empty(),
resultMaterialization,
staticQuery,
false,
optionalInt))
);
}
Expand All @@ -188,7 +191,7 @@ public void shouldRewriteQueryWithFilter() {
Optional.empty(),
Optional.empty(),
resultMaterialization,
staticQuery,
false,
optionalInt))
);
}
Expand All @@ -215,7 +218,7 @@ public void shouldRewriteQueryWithGroupBy() {
Optional.of(rewrittenGroupBy),
Optional.empty(),
resultMaterialization,
staticQuery,
false,
optionalInt))
);
}
Expand All @@ -242,7 +245,7 @@ public void shouldRewriteQueryWithWindow() {
Optional.empty(),
Optional.empty(),
resultMaterialization,
staticQuery,
false,
optionalInt))
);
}
Expand All @@ -267,7 +270,7 @@ public void shouldRewriteQueryWithHaving() {
Optional.empty(),
Optional.of(rewrittenExpression),
resultMaterialization,
staticQuery,
false,
optionalInt))
);
}
Expand Down Expand Up @@ -632,4 +635,33 @@ public void shouldRewriteSimpleGroupBy() {
)
);
}

@Test
public void shouldRewriteExplainWithQuery() {
// Given:
final Explain explain = new Explain(location, Optional.empty(), Optional.of(query));
when(mockRewriter.apply(query, context)).thenReturn(rewrittenQuery);

// When:
final AstNode rewritten = rewriter.rewrite(explain, context);

// Then:
assertThat(rewritten, is(new Explain(
location,
Optional.empty(),
Optional.of(rewrittenQuery)
)));
}

@Test
public void shouldNotRewriteExplainWithId() {
// Given:
final Explain explain = new Explain(location, Optional.of("id"), Optional.empty());

// When:
final AstNode rewritten = rewriter.rewrite(explain, context);

// Then:
assertThat(rewritten, is(sameInstance(explain)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ public class TemporaryEngine extends ExternalResource {
public static final LogicalSchema SCHEMA = LogicalSchema.builder()
.valueColumn("val", SqlTypes.STRING)
.valueColumn("val2", SqlTypes.decimal(2, 1))
.valueColumn("ADDRESS", SqlTypes.struct()
.field("STREET", SqlTypes.STRING)
.field("STATE", SqlTypes.STRING)
.build())
.build();

private MutableMetaStore metaStore;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,25 @@ public void shouldExplainStatement() {
assertThat(query.getQueryDescription().getSources(), containsInAnyOrder("Y"));
}

@Test
public void shouldExplainStatementWithStructFieldDereference() {
// Given:
engine.givenSource(DataSourceType.KSTREAM, "Y");
final String statementText = "SELECT address->street FROM Y EMIT CHANGES;";
final ConfiguredStatement<?> explain = engine.configure("EXPLAIN " + statementText);

// When:
final QueryDescriptionEntity query = (QueryDescriptionEntity) CustomExecutors.EXPLAIN.execute(
explain,
engine.getEngine(),
engine.getServiceContext()
).orElseThrow(IllegalStateException::new);

// Then:
assertThat(query.getQueryDescription().getStatementText(), equalTo(statementText));
assertThat(query.getQueryDescription().getSources(), containsInAnyOrder("Y"));
}

@Test
public void shouldFailOnNonQueryExplain() {
// Expect:
Expand Down

0 comments on commit daf974b

Please sign in to comment.