diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml index 8cc57987638b..d702256067f4 100644 --- a/checkstyle/suppressions.xml +++ b/checkstyle/suppressions.xml @@ -20,10 +20,10 @@ - diff --git a/ksql-metastore/src/test/java/io/confluent/ksql/util/MetaStoreFixture.java b/ksql-metastore/src/test/java/io/confluent/ksql/util/MetaStoreFixture.java index dc64cb2bac6d..2b75bbb1403b 100644 --- a/ksql-metastore/src/test/java/io/confluent/ksql/util/MetaStoreFixture.java +++ b/ksql-metastore/src/test/java/io/confluent/ksql/util/MetaStoreFixture.java @@ -16,6 +16,7 @@ package io.confluent.ksql.util; +import org.apache.kafka.connect.data.Schema; import org.apache.kafka.connect.data.SchemaBuilder; import io.confluent.ksql.function.FunctionRegistry; @@ -84,11 +85,37 @@ public static MetaStore getNewMetaStore(final FunctionRegistry functionRegistry) metaStore.putTopic(ksqlTopic2); metaStore.putSource(ksqlTable); - SchemaBuilder schemaBuilderOrders = SchemaBuilder.struct() - .field("ORDERTIME", SchemaBuilder.OPTIONAL_INT64_SCHEMA) - .field("ORDERID", SchemaBuilder.OPTIONAL_STRING_SCHEMA) - .field("ITEMID", SchemaBuilder.OPTIONAL_STRING_SCHEMA) - .field("ORDERUNITS", SchemaBuilder.OPTIONAL_FLOAT64_SCHEMA); + + final Schema addressSchema = SchemaBuilder.struct() + .field("NUMBER", Schema.OPTIONAL_INT64_SCHEMA) + .field("STREET", Schema.OPTIONAL_STRING_SCHEMA) + .field("CITY", Schema.OPTIONAL_STRING_SCHEMA) + .field("STATE", Schema.OPTIONAL_STRING_SCHEMA) + .field("ZIPCODE", Schema.OPTIONAL_INT64_SCHEMA) + .optional().build(); + + final Schema categorySchema = SchemaBuilder.struct() + .field("ID", Schema.OPTIONAL_INT64_SCHEMA) + .field("NAME", Schema.OPTIONAL_STRING_SCHEMA) + .optional().build(); + + final Schema itemInfoSchema = SchemaBuilder.struct() + .field("ITEMID", Schema.INT64_SCHEMA) + .field("NAME", Schema.STRING_SCHEMA) + .field("CATEGORY", categorySchema) + .optional().build(); + + final SchemaBuilder schemaBuilder = SchemaBuilder.struct(); + final Schema schemaBuilderOrders = schemaBuilder + .field("ORDERTIME", Schema.INT64_SCHEMA) + .field("ORDERID", Schema.OPTIONAL_INT64_SCHEMA) + .field("ITEMID", Schema.OPTIONAL_STRING_SCHEMA) + .field("ITEMINFO", itemInfoSchema) + .field("ORDERUNITS", Schema.INT32_SCHEMA) + .field("ARRAYCOL",schemaBuilder.array(Schema.FLOAT64_SCHEMA).optional().build()) + .field("MAPCOL", schemaBuilder.map(Schema.STRING_SCHEMA, Schema.FLOAT64_SCHEMA).optional().build()) + .field("ADDRESS", addressSchema) + .build(); KsqlTopic ksqlTopicOrders = diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java b/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java index 42c2e1c35673..e8916665df63 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java @@ -1129,11 +1129,7 @@ public Node visitSubqueryExpression(SqlBaseParser.SubqueryExpressionContext cont @Override public Node visitDereference(SqlBaseParser.DereferenceContext context) { String fieldName = getIdentifierText(context.identifier()); - Expression baseExpression; - QualifiedName tableName = QualifiedName.of( - context.primaryExpression().getText().toUpperCase()); - baseExpression = new QualifiedNameReference( - getLocation(context.primaryExpression()), tableName); + Expression baseExpression = (Expression) visit(context.base); DereferenceExpression dereferenceExpression = new DereferenceExpression(getLocation(context), baseExpression, fieldName); return dereferenceExpression; @@ -1144,6 +1140,15 @@ public Node visitColumnReference(SqlBaseParser.ColumnReferenceContext context) { String columnName = getIdentifierText(context.identifier()); // If this is join. if (dataSourceExtractor.getJoinLeftSchema() != null) { + if (columnName.equalsIgnoreCase(dataSourceExtractor.getLeftAlias()) + || columnName.equalsIgnoreCase(dataSourceExtractor.getLeftName()) + || columnName.equalsIgnoreCase(dataSourceExtractor.getRightAlias()) + || columnName.equalsIgnoreCase(dataSourceExtractor.getRightName())) { + return new QualifiedNameReference( + getLocation(context), + QualifiedName.of(columnName) + ); + } if (dataSourceExtractor.getCommonFieldNames().contains(columnName)) { throw new KsqlException("Field " + columnName + " is ambiguous."); } else if (dataSourceExtractor.getLeftFieldNames().contains(columnName)) { @@ -1164,12 +1169,20 @@ public Node visitColumnReference(SqlBaseParser.ColumnReferenceContext context) { throw new InvalidColumnReferenceException("Field " + columnName + " is ambiguous."); } } else { - Expression baseExpression = - new QualifiedNameReference( - getLocation(context), - QualifiedName.of(dataSourceExtractor.getFromAlias()) - ); - return new DereferenceExpression(getLocation(context), baseExpression, columnName); + if (columnName.equalsIgnoreCase(dataSourceExtractor.getFromAlias()) + || columnName.equalsIgnoreCase(dataSourceExtractor.getFromName())) { + return new QualifiedNameReference( + getLocation(context), + QualifiedName.of(columnName) + ); + } else { + Expression baseExpression = + new QualifiedNameReference( + getLocation(context), + QualifiedName.of(dataSourceExtractor.getFromAlias()) + ); + return new DereferenceExpression(getLocation(context), baseExpression, columnName); + } } } diff --git a/ksql-parser/src/main/java/io/confluent/ksql/util/DataSourceExtractor.java b/ksql-parser/src/main/java/io/confluent/ksql/util/DataSourceExtractor.java index 0e5b5a4fa6cf..b46e55e95bcc 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/util/DataSourceExtractor.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/util/DataSourceExtractor.java @@ -26,6 +26,7 @@ import io.confluent.ksql.parser.tree.NodeLocation; import io.confluent.ksql.parser.tree.QualifiedName; import io.confluent.ksql.parser.tree.Table; +import java.util.HashMap; import org.antlr.v4.runtime.ParserRuleContext; import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.ParseTree; @@ -47,8 +48,11 @@ public class DataSourceExtractor extends SqlBaseBaseVisitor { private Schema joinRightSchema; private String fromAlias; + private String fromName; private String leftAlias; + private String leftName; private String rightAlias; + private String rightName; private Set commonFieldNames = new HashSet<>(); private Set leftFieldNames = new HashSet<>(); @@ -61,6 +65,17 @@ public DataSourceExtractor(final MetaStore metaStore) { this.metaStore = metaStore; } + public java.util.Map getAliasToNameMap() { + java.util.Map aliasToNameMap = new HashMap<>(); + if (rightName != null && rightAlias != null) { + aliasToNameMap.put(leftAlias, leftName); + aliasToNameMap.put(rightAlias, rightName); + } else { + aliasToNameMap.put(fromAlias, fromName); + } + return aliasToNameMap; + } + @Override public Node visitQuerySpecification(final SqlBaseParser.QuerySpecificationContext ctx) { visit(ctx.from); @@ -86,6 +101,7 @@ public Node visitAliasedRelation(final SqlBaseParser.AliasedRelationContext cont if (!isJoin) { this.fromAlias = alias; + this.fromName = table.getName().getSuffix().toUpperCase(); StructuredDataSource fromDataSource = metaStore.getSource(table.getName().getSuffix()); @@ -117,6 +133,7 @@ public Node visitJoinRelation(final SqlBaseParser.JoinRelationContext context) { } this.leftAlias = left.getAlias(); + this.leftName = ((Table) left.getRelation()).getName().getSuffix(); StructuredDataSource leftDataSource = metaStore.getSource(((Table) left.getRelation()).getName().getSuffix()); @@ -127,6 +144,7 @@ public Node visitJoinRelation(final SqlBaseParser.JoinRelationContext context) { this.joinLeftSchema = leftDataSource.getSchema(); this.rightAlias = right.getAlias(); + this.rightName = ((Table) right.getRelation()).getName().getSuffix(); StructuredDataSource rightDataSource = metaStore.getSource(((Table) right.getRelation()).getName().getSuffix()); @@ -187,6 +205,26 @@ public Set getRightFieldNames() { return rightFieldNames; } + public Schema getJoinRightSchema() { + return joinRightSchema; + } + + public String getFromName() { + return fromName; + } + + public String getLeftName() { + return leftName; + } + + public String getRightName() { + return rightName; + } + + public boolean isJoin() { + return isJoin; + } + private static QualifiedName getQualifiedName(SqlBaseParser.QualifiedNameContext context) { List parts = context .identifier().stream() diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/KsqlParserTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/KsqlParserTest.java index 066b63170ecf..029125248cf4 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/KsqlParserTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/KsqlParserTest.java @@ -25,6 +25,7 @@ import io.confluent.ksql.parser.tree.CreateStream; import io.confluent.ksql.parser.tree.CreateStreamAsSelect; import io.confluent.ksql.parser.tree.CreateTable; +import io.confluent.ksql.parser.tree.DereferenceExpression; import io.confluent.ksql.parser.tree.DropStream; import io.confluent.ksql.parser.tree.DropTable; import io.confluent.ksql.parser.tree.InsertInto; @@ -225,6 +226,31 @@ public void testBooleanLogicalExpression() { } + @Test + public void shouldParseStructFieldAccessCorrectly() { + String simpleQuery = "SELECT iteminfo.category.name, address.street FROM orders WHERE address.state = 'CA';"; + Statement statement = KSQL_PARSER.buildAst(simpleQuery, metaStore).get(0); + + + Assert.assertTrue("testSimpleQuery fails", statement instanceof Query); + Query query = (Query) statement; + assertThat("testSimpleQuery fails", query.getQueryBody(), instanceOf(QuerySpecification.class)); + QuerySpecification querySpecification = (QuerySpecification)query.getQueryBody(); + assertThat("testSimpleQuery fails", querySpecification.getSelect().getSelectItems().size(), equalTo(2)); + SingleColumn singleColumn0 = (SingleColumn) querySpecification.getSelect().getSelectItems().get(0); + SingleColumn singleColumn1 = (SingleColumn) querySpecification.getSelect().getSelectItems().get(1); + assertThat(singleColumn0.getExpression(), instanceOf(DereferenceExpression.class)); + assertThat(singleColumn0.getExpression().toString(), equalTo("ORDERS.ITEMINFO.CATEGORY.NAME")); + DereferenceExpression dereferenceExpression0 = (DereferenceExpression) singleColumn0.getExpression(); + assertThat(dereferenceExpression0.getBase().toString(), equalTo("ORDERS.ITEMINFO.CATEGORY")); + assertThat(dereferenceExpression0.getFieldName(), equalTo("NAME")); + + DereferenceExpression dereferenceExpression1 = (DereferenceExpression) singleColumn1.getExpression(); + assertThat(dereferenceExpression1.getBase().toString(), equalTo("ORDERS.ADDRESS")); + assertThat(dereferenceExpression1.getFieldName(), equalTo("STREET")); + + } + @Test public void testSimpleLeftJoin() { String @@ -434,7 +460,7 @@ public void testCreateStreamAsSelect() { Assert.assertTrue("testCreateTable failed.", createStreamAsSelect.getName().toString().equalsIgnoreCase("bigorders_json")); Assert.assertTrue("testCreateTable failed.", createStreamAsSelect.getQuery().getQueryBody() instanceof QuerySpecification); QuerySpecification querySpecification = (QuerySpecification) createStreamAsSelect.getQuery().getQueryBody(); - Assert.assertTrue("testCreateTable failed.", querySpecification.getSelect().getSelectItems().size() == 4); + Assert.assertTrue("testCreateTable failed.", querySpecification.getSelect().getSelectItems().size() == 8); Assert.assertTrue("testCreateTable failed.", querySpecification.getWhere().get().toString().equalsIgnoreCase("(ORDERS.ORDERUNITS > 5)")); Assert.assertTrue("testCreateTable failed.", ((AliasedRelation)querySpecification.getFrom()).getAlias().equalsIgnoreCase("ORDERS")); } diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/rewrite/StatementRewriterTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/rewrite/StatementRewriterTest.java index a5b649f37064..f8df66c4f366 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/rewrite/StatementRewriterTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/rewrite/StatementRewriterTest.java @@ -479,7 +479,7 @@ public void testCreateStreamAsSelect() { assertThat("testCreateTable failed.", createStreamAsSelect.getName().toString(), equalTo("BIGORDERS_JSON")); assertThat("testCreateTable failed.", createStreamAsSelect.getQuery().getQueryBody(), instanceOf(QuerySpecification.class)); QuerySpecification querySpecification = (QuerySpecification) createStreamAsSelect.getQuery().getQueryBody(); - assertThat("testCreateTable failed.", querySpecification.getSelect().getSelectItems().size() == 4); + assertThat("testCreateTable failed.", querySpecification.getSelect().getSelectItems().size() == 8); assertThat("testCreateTable failed.", querySpecification.getWhere().get().toString(), equalTo("(ORDERS.ORDERUNITS > 5)")); assertThat("testCreateTable failed.", ((AliasedRelation)querySpecification.getFrom()).getAlias(), equalTo("ORDERS")); }