Skip to content

Commit

Permalink
fix: improve escaping of identifiers (#3295)
Browse files Browse the repository at this point in the history
  • Loading branch information
agavra authored Sep 3, 2019
1 parent af779dc commit 04435d7
Show file tree
Hide file tree
Showing 15 changed files with 158 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,13 @@ public String toString() {
}

public String toString(final FormatOptions formatOptions) {
final String formattedName = formatOptions.isReservedWord(fullName)
? "`" + fullName + "`"
: fullName;
final Optional<String> base = source.map(val -> escape(val, formatOptions));
final String escaped = escape(name, formatOptions);
final String field = base.isPresent() ? base.get() + "." + escaped : escaped;
return field + " " + type.toString(formatOptions);
}

return formattedName + " " + type.toString(formatOptions);
private static String escape(final String string, final FormatOptions formatOptions) {
return formatOptions.isReservedWord(string) ? "`" + string + "`" : string;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public void shouldToString() {
is("`SomeName` BOOLEAN"));

assertThat(Field.of("SomeSource", "SomeName", SqlTypes.INTEGER).toString(),
is("`SomeSource.SomeName` INTEGER"));
is("`SomeSource`.`SomeName` INTEGER"));
}

@Test
Expand All @@ -131,17 +131,17 @@ public void shouldToStringWithReservedWords() {
is("`reserved` BIGINT"));

assertThat(Field.of("reserved", "word", SqlTypes.DOUBLE).toString(options),
is("reserved.word DOUBLE"));
is("`reserved`.`word` DOUBLE"));

assertThat(Field.of("source", "word", SqlTypes.STRING).toString(options),
is("source.word STRING"));
is("source.`word` STRING"));

final SqlStruct struct = SqlTypes.struct()
.field("reserved", SqlTypes.BIGINT)
.field("other", SqlTypes.BIGINT)
.build();

assertThat(Field.of("reserved", "name", struct).toString(options),
is("`reserved.name` STRUCT<`reserved` BIGINT, other BIGINT>"));
is("`reserved`.name STRUCT<`reserved` BIGINT, other BIGINT>"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -616,8 +616,8 @@ public void shouldConvertAliasedSchemaToString() {
// Then:
assertThat(s, is(
"["
+ "`t.ROWKEY` STRING KEY, "
+ "`t.f0` BOOLEAN"
+ "`t`.`ROWKEY` STRING KEY, "
+ "`t`.`f0` BOOLEAN"
+ "]"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
import io.confluent.ksql.serde.Format;
import io.confluent.ksql.statement.ConfiguredStatement;
import io.confluent.ksql.statement.Injector;
import io.confluent.ksql.util.IdentifierUtil;
import io.confluent.ksql.util.KsqlStatementException;
import io.confluent.ksql.util.ParserUtil;
import java.util.Objects;
import java.util.Optional;
import org.apache.kafka.connect.data.Schema;
Expand All @@ -52,7 +52,7 @@
public class DefaultSchemaInjector implements Injector {

private static final SqlSchemaFormatter FORMATTER = new SqlSchemaFormatter(
ParserUtil::isReservedIdentifier, Option.AS_COLUMN_LIST);
IdentifierUtil::needsQuotes, Option.AS_COLUMN_LIST);

private final TopicSchemaSupplier schemaSupplier;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
import io.confluent.ksql.streams.StreamsUtil;
import io.confluent.ksql.structured.SelectValueMapper.SelectInfo;
import io.confluent.ksql.util.ExpressionMetadata;
import io.confluent.ksql.util.IdentifierUtil;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.ParserUtil;
import io.confluent.ksql.util.SchemaUtil;
import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -79,7 +79,7 @@ public class SchemaKStream<K> {
// CHECKSTYLE_RULES.ON: ClassDataAbstractionCoupling

private static final FormatOptions FORMAT_OPTIONS =
FormatOptions.of(ParserUtil::isReservedIdentifier);
FormatOptions.of(IdentifierUtil::needsQuotes);

public enum Type { SOURCE, PROJECT, FILTER, AGGREGATE, SINK, REKEY, JOIN }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
import io.confluent.ksql.schema.ksql.SchemaConverters;
import io.confluent.ksql.statement.ConfiguredStatement;
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.IdentifierUtil;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.ParserUtil;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.SchemaBuilder;
import org.junit.Assert;
Expand All @@ -50,7 +50,7 @@
public class DefaultSchemaInjectorFunctionalTest {

private static final SqlSchemaFormatter FORMATTER =
new SqlSchemaFormatter(ParserUtil::isReservedIdentifier);
new SqlSchemaFormatter(IdentifierUtil::needsQuotes);

@Rule
public final ExpectedException expectedException = ExpectedException.none();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.execution.expression.tree.Literal;
import io.confluent.ksql.execution.expression.tree.QualifiedName;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.execution.expression.tree.Type;
import io.confluent.ksql.parser.KsqlParser.PreparedStatement;
import io.confluent.ksql.parser.properties.with.CreateSourceProperties;
import io.confluent.ksql.parser.tree.CreateSource;
import io.confluent.ksql.parser.tree.CreateStream;
import io.confluent.ksql.parser.tree.CreateTable;
import io.confluent.ksql.execution.expression.tree.Literal;
import io.confluent.ksql.execution.expression.tree.QualifiedName;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.parser.tree.TableElement;
import io.confluent.ksql.parser.tree.TableElement.Namespace;
import io.confluent.ksql.parser.tree.TableElements;
import io.confluent.ksql.execution.expression.tree.Type;
import io.confluent.ksql.schema.ksql.inference.TopicSchemaSupplier.SchemaResult;
import io.confluent.ksql.schema.ksql.types.SqlStruct;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
Expand Down Expand Up @@ -437,22 +437,6 @@ public void shouldEscapeAvroSchemaThatHasReservedColumnName() {
assertThat(inject.getStatementText(), containsString("`CREATE`"));
}

@Test
public void shouldFailIfAvroSchemaHasInvalidColumnName() {
// Given:
when(schemaSupplier.getValueSchema(any(), any()))
.thenReturn(SchemaResult.success(schemaAndId(
SchemaBuilder.struct().field("foo-bar", Schema.INT64_SCHEMA).build(),
SCHEMA_ID)));

// Expect:
expectedException.expect(KsqlException.class);
expectedException.expectMessage("Failed to convert schema to KSQL model");

// When:
injector.inject(ctStatement);
}

@Test
public void shouldThrowIfSchemaSupplierThrows() {
// Given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import io.confluent.ksql.execution.expression.formatter.ExpressionFormatter;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.parser.tree.GroupingElement;
import io.confluent.ksql.util.ParserUtil;
import io.confluent.ksql.util.IdentifierUtil;
import java.util.List;
import java.util.Set;

Expand All @@ -38,7 +38,7 @@ public static String formatExpression(final Expression expression, final boolean
return ExpressionFormatter.formatExpression(
expression,
unmangleNames,
ParserUtil::isReservedIdentifier
IdentifierUtil::needsQuotes
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
import io.confluent.ksql.parser.tree.TableElement.Namespace;
import io.confluent.ksql.parser.tree.TerminateQuery;
import io.confluent.ksql.parser.tree.UnsetProperty;
import io.confluent.ksql.util.ParserUtil;
import io.confluent.ksql.util.IdentifierUtil;
import java.util.List;
import java.util.Optional;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -213,7 +213,7 @@ protected Void visitAliasedRelation(final AliasedRelation node, final Integer in
process(node.getRelation(), indent);

builder.append(' ')
.append(ParserUtil.escapeIfReservedIdentifier(node.getAlias()));
.append(IdentifierUtil.escape(node.getAlias()));

return null;
}
Expand Down Expand Up @@ -382,7 +382,7 @@ private void visitExtended() {
@Override
public Void visitRegisterType(final RegisterType node, final Integer context) {
builder.append("CREATE TYPE ");
builder.append(ParserUtil.escapeIfReservedIdentifier(node.getName()));
builder.append(IdentifierUtil.escape(node.getName()));
builder.append(" AS ");
builder.append(ExpressionFormatterUtil.formatExpression(node.getType()));
builder.append(";");
Expand Down Expand Up @@ -481,18 +481,17 @@ private void formatCreateAs(final CreateAsSelect node, final Integer indent) {
}

private static String formatTableElement(final TableElement e) {
return ParserUtil.escapeIfReservedIdentifier(e.getName())
return IdentifierUtil.escape(e.getName())
+ " "
+ ExpressionFormatter.formatExpression(
e.getType(), true, ParserUtil::isReservedIdentifier)
+ ExpressionFormatter.formatExpression(e.getType(), true, IdentifierUtil::needsQuotes)
+ (e.getNamespace() == Namespace.KEY ? " KEY" : "");
}
}

private static String escapedName(final QualifiedName name) {
return name.getParts()
.stream()
.map(ParserUtil::escapeIfReservedIdentifier)
.map(IdentifierUtil::escape)
.collect(Collectors.joining("."));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright 2019 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"; you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.util;

import io.confluent.ksql.parser.CaseInsensitiveStream;
import io.confluent.ksql.parser.SqlBaseLexer;
import io.confluent.ksql.parser.SqlBaseParser;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;

public final class IdentifierUtil {

private IdentifierUtil() { }

/**
* @param identifier the identifier
* @return whether or not {@code identifier} is a valid identifier without quotes
*/
public static boolean needsQuotes(final String identifier) {
final SqlBaseLexer sqlBaseLexer = new SqlBaseLexer(
new CaseInsensitiveStream(CharStreams.fromString(identifier)));
final CommonTokenStream tokenStream = new CommonTokenStream(sqlBaseLexer);
final SqlBaseParser sqlBaseParser = new SqlBaseParser(tokenStream);

// don't log or print anything in the case of error since this is expected
// for this method
sqlBaseLexer.removeErrorListeners();
sqlBaseParser.removeErrorListeners();

sqlBaseParser.identifier();

// needs quotes if the `identifier` was not able to read the entire line
return sqlBaseParser.getNumberOfSyntaxErrors() != 0
|| sqlBaseParser.getCurrentToken().getCharPositionInLine() != identifier.length();
}

/**
* @param identifier the identifier to escape
* @return wraps the {@code identifier} in back quotes (`) if {@link #needsQuotes(String)}
*/
public static String escape(final String identifier) {
return needsQuotes(identifier) ? '`' + identifier + '`' : identifier;
}

}
45 changes: 0 additions & 45 deletions ksql-parser/src/main/java/io/confluent/ksql/util/ParserUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,18 @@
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;

import com.google.common.collect.ImmutableSet;
import io.confluent.ksql.execution.expression.tree.DoubleLiteral;
import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
import io.confluent.ksql.execution.expression.tree.Literal;
import io.confluent.ksql.execution.expression.tree.LongLiteral;
import io.confluent.ksql.execution.expression.tree.QualifiedName;
import io.confluent.ksql.parser.DefaultKsqlParser;
import io.confluent.ksql.parser.KsqlParser;
import io.confluent.ksql.parser.NodeLocation;
import io.confluent.ksql.parser.ParsingException;
import io.confluent.ksql.parser.SqlBaseLexer;
import io.confluent.ksql.parser.SqlBaseParser;
import io.confluent.ksql.parser.SqlBaseParser.IntegerLiteralContext;
import io.confluent.ksql.parser.SqlBaseParser.NumberContext;
import io.confluent.ksql.parser.exception.ParseFailedException;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.tree.TerminalNode;
Expand All @@ -50,41 +40,6 @@ public final class ParserUtil {
private ParserUtil() {
}

private static final Set<String> RESERVED_WORDS;

static {
final KsqlParser parser = new DefaultKsqlParser();

final Predicate<String> isReservedWord = columnName -> {
try {
parser.parse(
"CREATE STREAM x (" + columnName + " INT) "
+ "WITH(KAFKA_TOPIC='x', VALUE_FORMAT='JSON');");
return false;
} catch (final ParseFailedException e) {
return true;
}
};

final Set<String> reserved = IntStream.range(0, SqlBaseLexer.VOCABULARY.getMaxTokenType())
.mapToObj(SqlBaseLexer.VOCABULARY::getLiteralName)
.filter(Objects::nonNull)
.map(l -> l.substring(1, l.length() - 1)) // literals start and end with ' - remove them
.map(String::toUpperCase)
.filter(isReservedWord)
.collect(Collectors.toSet());

RESERVED_WORDS = ImmutableSet.copyOf(reserved);
}

public static boolean isReservedIdentifier(final String name) {
return RESERVED_WORDS.contains(name.toUpperCase());
}

public static String escapeIfReservedIdentifier(final String name) {
return isReservedIdentifier(name) ? "`" + name + "`" : name;
}

public static String getIdentifierText(final SqlBaseParser.IdentifierContext context) {
if (context instanceof SqlBaseParser.QuotedIdentifierAlternativeContext) {
return unquote(context.getText(), "\"");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,6 @@ public void setUp() {
mockLocation(decimalLiteralContext, 1, 2);
}

@Test
public void shouldEscapeStringIfLiteral() {
assertThat(ParserUtil.escapeIfReservedIdentifier("END"), equalTo("`END`"));
}

@Test
public void shouldNotEscapeStringIfNotLiteral() {
assertThat(ParserUtil.escapeIfReservedIdentifier("NOT_A_LITERAL"), equalTo("NOT_A_LITERAL"));
}

@Test
public void shouldThrowWhenParsingDecimalIfNaN() {
// Given:
Expand Down Expand Up @@ -97,17 +87,6 @@ public void shouldThrowWhenParsingDecimalIfOverflowsDouble() {
ParserUtil.parseDecimalLiteral(decimalLiteralContext);
}

@Test
public void shouldHaveReservedLiteralInReservedSet() {
assertThat(ParserUtil.isReservedIdentifier("FROM"), is(true));
}

@Test
public void shouldExcludeNonReservedLiteralsFromReservedSet() {
// i.e. those in the "nonReserved" rule in SqlBase.g4
assertThat(ParserUtil.isReservedIdentifier("SHOW"), is(false));
}

private static void mockLocation(final ParserRuleContext ctx, final int line, final int col) {
final Token token = mock(Token.class);
when(token.getLine()).thenReturn(line);
Expand Down
Loading

0 comments on commit 04435d7

Please sign in to comment.