diff --git a/ksqldb-api-client/src/test/java/io/confluent/ksql/api/client/integration/ClientIntegrationTest.java b/ksqldb-api-client/src/test/java/io/confluent/ksql/api/client/integration/ClientIntegrationTest.java index f41d62758b62..467e1a0f8787 100644 --- a/ksqldb-api-client/src/test/java/io/confluent/ksql/api/client/integration/ClientIntegrationTest.java +++ b/ksqldb-api-client/src/test/java/io/confluent/ksql/api/client/integration/ClientIntegrationTest.java @@ -434,14 +434,14 @@ public void shouldHandleErrorResponseFromTerminatePushQuery() { public void shouldInsertInto() throws Exception { // Given final KsqlObject insertRow = new KsqlObject() - .put("STR", "HELLO") - .put("LONG", 100L) + .put("str", "HELLO") // Column names are case-insensitive + .put("`LONG`", 100L) // Quotes may be used to preserve case-sensitivity .put("DEC", new BigDecimal("13.31")) .put("ARRAY", new KsqlArray().add("v1").add("v2")) .put("MAP", new KsqlObject().put("some_key", "a_value").put("another_key", "")); // When - client.insertInto(EMPTY_TEST_STREAM, insertRow).get(); + client.insertInto(EMPTY_TEST_STREAM.toLowerCase(), insertRow).get(); // Stream name is case-insensitive // Then: should receive new row final String query = "SELECT * FROM " + EMPTY_TEST_STREAM + " EMIT CHANGES LIMIT 1;"; diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/InsertsStreamEndpoint.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/InsertsStreamEndpoint.java index 9e1fe5c00e98..fe76755938f1 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/InsertsStreamEndpoint.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/InsertsStreamEndpoint.java @@ -21,6 +21,7 @@ import io.confluent.ksql.api.server.InsertResult; import io.confluent.ksql.api.server.InsertsStreamSubscriber; import io.confluent.ksql.api.server.KsqlApiException; +import io.confluent.ksql.api.server.ServerUtils; import io.confluent.ksql.engine.KsqlEngine; import io.confluent.ksql.metastore.MetaStore; import io.confluent.ksql.metastore.model.DataSource; @@ -48,18 +49,28 @@ public InsertsStreamEndpoint(final KsqlEngine ksqlEngine, final KsqlConfig ksqlC this.reservedInternalTopics = reservedInternalTopics; } - public InsertsStreamSubscriber createInsertsSubscriber(final String target, + public InsertsStreamSubscriber createInsertsSubscriber(final String caseInsensitiveTarget, final JsonObject properties, final Subscriber acksSubscriber, final Context context, final WorkerExecutor workerExecutor, final ServiceContext serviceContext) { VertxUtils.checkIsWorker(); + if (!ksqlConfig.getBoolean(KsqlConfig.KSQL_INSERT_INTO_VALUES_ENABLED)) { throw new KsqlApiException("The server has disabled INSERT INTO ... VALUES functionality. " + "To enable it, restart your ksqlDB server " + "with 'ksql.insert.into.values.enabled'=true", ERROR_CODE_BAD_REQUEST); } + + final String target; + try { + target = ServerUtils.getIdentifierText(caseInsensitiveTarget); + } catch (IllegalArgumentException e) { + throw new KsqlApiException( + "Invalid target name: " + e.getMessage(), ERROR_CODE_BAD_STATEMENT); + } + final DataSource dataSource = getDataSource(ksqlEngine.getMetaStore(), SourceName.of(target)); if (dataSource.getDataSourceType() == DataSourceType.KTABLE) { diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/InsertsSubscriber.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/InsertsSubscriber.java index 9190c9567f30..63c15ba66a3d 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/InsertsSubscriber.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/InsertsSubscriber.java @@ -15,6 +15,8 @@ package io.confluent.ksql.api.impl; +import static io.confluent.ksql.api.impl.KeyValueExtractor.convertColumnNameCase; + import io.confluent.ksql.GenericRow; import io.confluent.ksql.api.server.InsertResult; import io.confluent.ksql.api.server.InsertsStreamSubscriber; @@ -132,9 +134,11 @@ protected void afterSubscribe(final Subscription subscription) { } @Override - protected void handleValue(final JsonObject jsonObject) { + protected void handleValue(final JsonObject jsonObjectWithCaseInsensitiveFields) { try { + final JsonObject jsonObject = convertColumnNameCase(jsonObjectWithCaseInsensitiveFields); + final Struct key = extractKey(jsonObject); final GenericRow values = extractValues(jsonObject); diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/KeyValueExtractor.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/KeyValueExtractor.java index 1eb45679d10b..e3534243e310 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/KeyValueExtractor.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/KeyValueExtractor.java @@ -17,6 +17,7 @@ import io.confluent.ksql.GenericRow; import io.confluent.ksql.api.server.KsqlApiException; +import io.confluent.ksql.api.server.ServerUtils; import io.confluent.ksql.rest.Errors; import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.LogicalSchema; @@ -30,6 +31,7 @@ import java.math.RoundingMode; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.apache.kafka.connect.data.Field; import org.apache.kafka.connect.data.Struct; @@ -69,6 +71,28 @@ public static GenericRow extractValues(final JsonObject values, final LogicalSch return GenericRow.fromList(vals); } + static JsonObject convertColumnNameCase(final JsonObject jsonObjectWithCaseInsensitiveFields) { + final JsonObject jsonObject = new JsonObject(); + + for (Map.Entry entry : + jsonObjectWithCaseInsensitiveFields.getMap().entrySet()) { + final String key; + try { + key = ServerUtils.getIdentifierText(entry.getKey()); + } catch (IllegalArgumentException e) { + throw new KsqlApiException( + String.format("Invalid column name. Column: %s. Reason: %s", + entry.getKey(), e.getMessage()), + Errors.ERROR_CODE_BAD_REQUEST + ); + } + + jsonObject.put(key, entry.getValue()); + } + + return jsonObject; + } + private static Object coerceObject( final Object value, final SqlType sqlType, diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/ServerUtils.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/ServerUtils.java index b880492b7983..7f526572e462 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/ServerUtils.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/ServerUtils.java @@ -97,6 +97,20 @@ public static String convertCommaSeparatedWilcardsToRegex(final String csv) { return out.toString(); } + // See ParserUtil#getIdentifierText() + public static String getIdentifierText(final String text) { + if (text.isEmpty()) { + return ""; + } + + final char firstChar = text.charAt(0); + if (firstChar == '`' || firstChar == '"') { + return unquote(text, firstChar); + } + + return text.toUpperCase(); + } + public static boolean checkHttp2(final RoutingContext routingContext) { if (routingContext.request().version() != HttpVersion.HTTP_2) { routingContext.fail(BAD_REQUEST.code(), @@ -131,4 +145,28 @@ static Void handleEndpointException( ERROR_CODE_SERVER_ERROR)); return null; } + + private static String unquote(final String value, final char quote) { + if (value.charAt(0) != quote) { + throw new IllegalStateException("Value must begin with quote"); + } + if (value.charAt(value.length() - 1) != quote || value.length() < 2) { + throw new IllegalArgumentException("Expected matching quote at end of value"); + } + + int i = 1; + while (i < value.length() - 1) { + if (value.charAt(i) == quote) { + if (value.charAt(i + 1) != quote || i + 1 == value.length() - 1) { + throw new IllegalArgumentException("Un-escaped quote in middle of value at index " + i); + } + i += 2; + } else { + i++; + } + } + + return value.substring(1, value.length() - 1) + .replace("" + quote + quote, "" + quote); + } } diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/integration/ApiIntegrationTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/integration/ApiIntegrationTest.java index 95179fd15325..8e651648bdb8 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/integration/ApiIntegrationTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/integration/ApiIntegrationTest.java @@ -21,6 +21,8 @@ import static io.confluent.ksql.test.util.EmbeddedSingleNodeKafkaCluster.VALID_USER2; import static io.confluent.ksql.util.KsqlConfig.KSQL_STREAMS_PREFIX; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; @@ -221,7 +223,7 @@ public void shouldExecutePushQueryNoLimit() throws Exception { } catch (Throwable t) { return Integer.MAX_VALUE; } - }, is(6)); + }, greaterThanOrEqualTo(6)); // The response shouldn't have ended yet assertThat(writeStream.isEnded(), is(false)); @@ -416,6 +418,79 @@ public void shouldInsertWithMissingValueField() { shouldInsert(row); } + @Test + public void shouldInsertWithCaseInsensitivity() { + + // Given: lowercase fields names and stream name + String target = TEST_STREAM.toLowerCase(); + JsonObject row = new JsonObject() + .put("str", "HELLO") + .put("dec", 12.21) // JsonObject does not accept BigDecimal + .put("array", new JsonArray().add("a").add("b")) + .put("map", new JsonObject().put("k1", "v1").put("k2", "v2")); + + // Then: + shouldInsert(target, row); + } + + @Test + public void shouldTreatInsertTargetAsCaseSensitiveIfQuotedWithBackticks() { + // Given: + String target = "`" + TEST_STREAM.toLowerCase() + "`"; + JsonObject row = new JsonObject() + .put("STR", "HELLO") + .put("LONG", 1000L) + .put("DEC", 12.21) // JsonObject does not accept BigDecimal + .put("ARRAY", new JsonArray().add("a").add("b")) + .put("MAP", new JsonObject().put("k1", "v1").put("k2", "v2")); + + // Then: request fails because stream name is invalid + shouldRejectInsertRequest(target, row, "Cannot insert values into an unknown stream: " + target); + } + + @Test + public void shouldTreatInsertTargetAsCaseSensitiveIfQuotedWithDoubleQuotes() { + // Given: + String target = "\"" + TEST_STREAM.toLowerCase() + "\""; + JsonObject row = new JsonObject() + .put("STR", "HELLO") + .put("LONG", 1000L) + .put("DEC", 12.21) // JsonObject does not accept BigDecimal + .put("ARRAY", new JsonArray().add("a").add("b")) + .put("MAP", new JsonObject().put("k1", "v1").put("k2", "v2")); + + // Then: request fails because stream name is invalid + shouldRejectInsertRequest(target, row, "Cannot insert values into an unknown stream: `" + TEST_STREAM.toLowerCase() + "`"); + } + + @Test + public void shouldTreatInsertColumnNamesAsCaseSensitiveIfQuotedWithBackticks() { + // Given: + JsonObject row = new JsonObject() + .put("`str`", "HELLO") + .put("LONG", 1000L) + .put("DEC", 12.21) // JsonObject does not accept BigDecimal + .put("ARRAY", new JsonArray().add("a").add("b")) + .put("MAP", new JsonObject().put("k1", "v1").put("k2", "v2")); + + // Then: request fails because column name is incorrect + shouldFailToInsert(row, ERROR_CODE_BAD_REQUEST, "Key field must be specified: STR"); + } + + @Test + public void shouldTreatInsertColumnNamesAsCaseSensitiveIfQuotedWithDoubleQuotes() { + // Given: + JsonObject row = new JsonObject() + .put("\"str\"", "HELLO") + .put("LONG", 1000L) + .put("DEC", 12.21) // JsonObject does not accept BigDecimal + .put("ARRAY", new JsonArray().add("a").add("b")) + .put("MAP", new JsonObject().put("k1", "v1").put("k2", "v2")); + + // Then: request fails because column name is incorrect + shouldFailToInsert(row, ERROR_CODE_BAD_REQUEST, "Key field must be specified: STR"); + } + @Test public void shouldExecutePushQueryFromLatestOffset() { @@ -494,15 +569,7 @@ private QueryResponse executeQuery(final String sql) { } private void shouldFailToInsert(final JsonObject row, final int errorCode, final String message) { - JsonObject properties = new JsonObject(); - JsonObject requestBody = new JsonObject() - .put("target", TEST_STREAM).put("properties", properties); - Buffer bodyBuffer = requestBody.toBuffer(); - bodyBuffer.appendString("\n"); - - bodyBuffer.appendBuffer(row.toBuffer()).appendString("\n"); - - HttpResponse response = sendRequest("/inserts-stream", bodyBuffer); + final HttpResponse response = makeInsertsRequest(TEST_STREAM, row); assertThat(response.statusCode(), is(200)); @@ -515,15 +582,11 @@ private void shouldFailToInsert(final JsonObject row, final int errorCode, final } private void shouldInsert(final JsonObject row) { - JsonObject properties = new JsonObject(); - JsonObject requestBody = new JsonObject() - .put("target", TEST_STREAM).put("properties", properties); - Buffer bodyBuffer = requestBody.toBuffer(); - bodyBuffer.appendString("\n"); - - bodyBuffer.appendBuffer(row.toBuffer()).appendString("\n"); + shouldInsert(TEST_STREAM, row); + } - HttpResponse response = sendRequest("/inserts-stream", bodyBuffer); + private void shouldInsert(final String target, final JsonObject row) { + HttpResponse response = makeInsertsRequest(target, row); assertThat(response.statusCode(), is(200)); @@ -532,6 +595,29 @@ private void shouldInsert(final JsonObject row) { assertThat(insertsResponse.error, is(nullValue())); } + private void shouldRejectInsertRequest(final String target, final JsonObject row, final String message) { + HttpResponse response = makeInsertsRequest(target, row); + + assertThat(response.statusCode(), is(400)); + assertThat(response.statusMessage(), is("Bad Request")); + + QueryResponse queryResponse = new QueryResponse(response.bodyAsString()); + assertThat(queryResponse.responseObject.getInteger("error_code"), is(ERROR_CODE_BAD_STATEMENT)); + assertThat(queryResponse.responseObject.getString("message"), containsString(message)); + } + + private HttpResponse makeInsertsRequest(final String target, final JsonObject row) { + JsonObject properties = new JsonObject(); + JsonObject requestBody = new JsonObject() + .put("target", target).put("properties", properties); + Buffer bodyBuffer = requestBody.toBuffer(); + bodyBuffer.appendString("\n"); + + bodyBuffer.appendBuffer(row.toBuffer()).appendString("\n"); + + return sendRequest("/inserts-stream", bodyBuffer); + } + private WebClient createClient() { WebClientOptions options = new WebClientOptions(). setProtocolVersion(HttpVersion.HTTP_2).setHttp2ClearTextUpgrade(false)