diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/QueryEndpoint.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/QueryEndpoint.java index 99aab9c52714..7c4b1a6f4cb7 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/QueryEndpoint.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/QueryEndpoint.java @@ -97,6 +97,7 @@ public QueryEndpoint( public QueryPublisher createQueryPublisher( final String sql, final JsonObject properties, + final JsonObject sessionVariables, final Context context, final WorkerExecutor workerExecutor, final ServiceContext serviceContext, @@ -104,7 +105,8 @@ public QueryPublisher createQueryPublisher( // Must be run on worker as all this stuff is slow VertxUtils.checkIsWorker(); - final ConfiguredStatement statement = createStatement(sql, properties.getMap()); + final ConfiguredStatement statement = createStatement( + sql, properties.getMap(), sessionVariables.getMap()); if (statement.getStatement().isPullQuery()) { return createPullQueryPublisher( @@ -199,7 +201,7 @@ private QueryPublisher createPullQueryPublisher( } private ConfiguredStatement createStatement(final String queryString, - final Map properties) { + final Map properties, final Map sessionVariables) { final List statements = ksqlEngine.parse(queryString); if ((statements.size() != 1)) { throw new KsqlStatementException( @@ -207,7 +209,12 @@ private ConfiguredStatement createStatement(final String queryString, .format("Expected exactly one KSQL statement; found %d instead", statements.size()), queryString); } - final PreparedStatement ps = ksqlEngine.prepare(statements.get(0)); + final PreparedStatement ps = ksqlEngine.prepare( + statements.get(0), + sessionVariables.entrySet() + .stream() + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue().toString())) + ); final Statement statement = ps.getStatement(); if (!(statement instanceof Query)) { throw new KsqlStatementException("Not a query", queryString); diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/QueryStreamHandler.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/QueryStreamHandler.java index 5b4bcd1c2d48..ea57a3f0ae9a 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/QueryStreamHandler.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/QueryStreamHandler.java @@ -80,8 +80,8 @@ public void handle(final RoutingContext routingContext) { final MetricsCallbackHolder metricsCallbackHolder = new MetricsCallbackHolder(); final long startTimeNanos = Time.SYSTEM.nanoseconds(); endpoints.createQueryPublisher(queryStreamArgs.get().sql, queryStreamArgs.get().properties, - context, server.getWorkerExecutor(), DefaultApiSecurityContext.create(routingContext), - metricsCallbackHolder) + queryStreamArgs.get().sessionVariables, context, server.getWorkerExecutor(), + DefaultApiSecurityContext.create(routingContext), metricsCallbackHolder) .thenAccept(queryPublisher -> { final QueryResponseMetadata metadata; diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/spi/Endpoints.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/spi/Endpoints.java index dee610423a11..0edada5ec894 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/spi/Endpoints.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/spi/Endpoints.java @@ -52,8 +52,8 @@ public interface Endpoints { * @return A CompletableFuture representing the future result of the operation */ CompletableFuture createQueryPublisher(String sql, JsonObject properties, - Context context, WorkerExecutor workerExecutor, ApiSecurityContext apiSecurityContext, - MetricsCallbackHolder metricsCallbackHolder); + JsonObject sessionVariables, Context context, WorkerExecutor workerExecutor, + ApiSecurityContext apiSecurityContext, MetricsCallbackHolder metricsCallbackHolder); /** * Create a subscriber which will receive a stream of inserts from the API server and process diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlServerEndpoints.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlServerEndpoints.java index b0d909fcf8ff..4a43705cb86d 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlServerEndpoints.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlServerEndpoints.java @@ -141,6 +141,7 @@ public KsqlServerEndpoints( @Override public CompletableFuture createQueryPublisher(final String sql, final JsonObject properties, + final JsonObject sessionVariables, final Context context, final WorkerExecutor workerExecutor, final ApiSecurityContext apiSecurityContext, @@ -155,6 +156,7 @@ public CompletableFuture createQueryPublisher(final String sql, .createQueryPublisher( sql, properties, + sessionVariables, context, workerExecutor, ksqlSecurityContext.getServiceContext(), diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/KsqlResource.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/KsqlResource.java index 5126cd4a495f..1dd1d736872a 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/KsqlResource.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/KsqlResource.java @@ -296,7 +296,8 @@ public EndpointResponse handleKsqlStatements( configProperties, localHost, localUrl, - requestConfig.getBoolean(KsqlRequestConfig.KSQL_REQUEST_INTERNAL_REQUEST) + requestConfig.getBoolean(KsqlRequestConfig.KSQL_REQUEST_INTERNAL_REQUEST), + request.getSessionVariables() ) ); diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/ApiTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/ApiTest.java index fbe87053f2ca..05c24022b5fb 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/ApiTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/ApiTest.java @@ -115,6 +115,33 @@ public void shouldExecutePullQuery() throws Exception { assertThat(queryId, is(nullValue())); } + @Test + public void shouldExecutePullQueryWithVariableSubstitution() throws Exception { + + // Given + JsonObject requestBody = new JsonObject().put("sql", "select * from ${name} where rowkey='1234';"); + JsonObject properties = new JsonObject().put("prop1", "val1").put("prop2", 23); + JsonObject sessionVariables = new JsonObject().put("name", "foo"); + requestBody.put("properties", properties).put("sessionVariables", sessionVariables); + + // When + HttpResponse response = sendPostRequest("/query-stream", requestBody.toBuffer()); + + // Then + assertThat(response.statusCode(), is(200)); + assertThat(response.statusMessage(), is("OK")); + assertThat(testEndpoints.getLastSql(), is("select * from ${name} where rowkey='1234';")); + assertThat(testEndpoints.getLastProperties(), is(properties)); + assertThat(testEndpoints.getLastSessionVariables(), is(sessionVariables)); + QueryResponse queryResponse = new QueryResponse(response.bodyAsString()); + assertThat(queryResponse.responseObject.getJsonArray("columnNames"), is(DEFAULT_COLUMN_NAMES)); + assertThat(queryResponse.responseObject.getJsonArray("columnTypes"), is(DEFAULT_COLUMN_TYPES)); + assertThat(queryResponse.rows, is(DEFAULT_JSON_ROWS)); + assertThat(server.getQueryIDs(), hasSize(0)); + String queryId = queryResponse.responseObject.getString("queryId"); + assertThat(queryId, is(nullValue())); + } + @Test @CoreApiTest public void shouldExecutePushQuery() throws Exception { @@ -134,6 +161,29 @@ public void shouldExecutePushQuery() throws Exception { assertThat(server.getQueryIDs().contains(new PushQueryId(queryId)), is(true)); } + @Test + public void shouldExecutePushQueryWithVariableSubstitution() throws Exception { + + // When + JsonObject requestBody = new JsonObject().put("sql", "select * from ${name} emit changes;"); + JsonObject properties = new JsonObject().put("prop1", "val1").put("prop2", 23); + JsonObject sessionVariables = new JsonObject().put("name", "foo"); + requestBody.put("properties", properties).put("sessionVariables", sessionVariables); + QueryResponse queryResponse = executePushQueryAndWaitForRows(requestBody); + + // Then + assertThat(testEndpoints.getLastSql(), is("select * from ${name} emit changes;")); + assertThat(testEndpoints.getLastProperties(), is(DEFAULT_PUSH_QUERY_REQUEST_PROPERTIES)); + assertThat(testEndpoints.getLastSessionVariables(), is(sessionVariables)); + assertThat(queryResponse.responseObject.getJsonArray("columnNames"), is(DEFAULT_COLUMN_NAMES)); + assertThat(queryResponse.responseObject.getJsonArray("columnTypes"), is(DEFAULT_COLUMN_TYPES)); + assertThat(queryResponse.rows, is(DEFAULT_JSON_ROWS)); + assertThat(server.getQueryIDs(), hasSize(1)); + String queryId = queryResponse.responseObject.getString("queryId"); + assertThat(queryId, is(notNullValue())); + assertThat(server.getQueryIDs().contains(new PushQueryId(queryId)), is(true)); + } + @Test public void shouldExecuteMultiplePushQueries() throws Exception { diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/TestEndpoints.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/TestEndpoints.java index 4b91a31af8c6..e8c234e19329 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/TestEndpoints.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/TestEndpoints.java @@ -54,6 +54,7 @@ public class TestEndpoints implements Endpoints { private List ksqlEndpointResponse; private String lastSql; private JsonObject lastProperties; + private JsonObject lastSessionVariables; private String lastTarget; private Set queryPublishers = new HashSet<>(); private int acksBeforePublisherError = -1; @@ -65,7 +66,7 @@ public class TestEndpoints implements Endpoints { @Override public synchronized CompletableFuture createQueryPublisher(final String sql, - final JsonObject properties, final Context context, final WorkerExecutor workerExecutor, + final JsonObject properties, JsonObject sessionVariables, final Context context, final WorkerExecutor workerExecutor, final ApiSecurityContext apiSecurityContext, final MetricsCallbackHolder metricsCallbackHolder) { CompletableFuture completableFuture = new CompletableFuture<>(); @@ -75,6 +76,7 @@ public synchronized CompletableFuture createQueryPublisher(final } else { this.lastSql = sql; this.lastProperties = properties; + this.lastSessionVariables = sessionVariables; this.lastApiSecurityContext = apiSecurityContext; final boolean push = sql.toLowerCase().contains("emit changes"); final int limit = extractLimit(sql); @@ -240,6 +242,10 @@ public synchronized JsonObject getLastProperties() { return lastProperties; } + public synchronized JsonObject getLastSessionVariables() { + return lastSessionVariables; + } + public synchronized Set getQueryPublishers() { return queryPublishers; } diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/integration/ApiIntegrationTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/integration/ApiIntegrationTest.java index 6a91a9409756..8d0bcacb9150 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/integration/ApiIntegrationTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/integration/ApiIntegrationTest.java @@ -174,6 +174,22 @@ public void shouldExecutePushQueryWithLimit() { assertThat(response.responseObject.getString("queryId"), is(notNullValue())); } + @Test + public void shouldExecutePushQueryWithVariableSubstitution() { + + // Given: + String sql = "SELECT DEC AS ${name} from " + TEST_STREAM + " EMIT CHANGES LIMIT 2;"; + + // When: + QueryResponse response = executeQueryWithVariables(sql, new JsonObject().put("name", "COL")); + + // Then: + assertThat(response.rows, hasSize(2)); + assertThat(response.responseObject.getJsonArray("columnNames"), is(new JsonArray().add("COL"))); + assertThat(response.responseObject.getJsonArray("columnTypes"), is(new JsonArray().add("DECIMAL(4, 2)"))); + assertThat(response.responseObject.getString("queryId"), is(notNullValue())); + } + @Test public void shouldFailPushQueryWithInvalidSql() { @@ -290,6 +306,35 @@ public void shouldExecutePullQuery() { assertThat(response.rows.get(0).getLong(1), is(1L)); // latest_by_offset(long) } + @Test + public void shouldExecutePullQueryWithVariableSubstitution() { + + // Given: + String sql = "SELECT * from ${AGG_TABLE} WHERE K=" + AN_AGG_KEY + ";"; + final JsonObject variables = new JsonObject().put("AGG_TABLE", AGG_TABLE); + + // When: + // Maybe need to retry as populating agg table is async + AtomicReference atomicReference = new AtomicReference<>(); + assertThatEventually(() -> { + QueryResponse queryResponse = executeQueryWithVariables(sql, variables); + atomicReference.set(queryResponse); + return queryResponse.rows; + }, hasSize(1)); + + QueryResponse response = atomicReference.get(); + + // Then: + JsonArray expectedColumnNames = new JsonArray().add("K").add("LONG"); + JsonArray expectedColumnTypes = new JsonArray().add("STRUCT<`F1` ARRAY>").add("BIGINT"); + assertThat(response.rows, hasSize(1)); + assertThat(response.responseObject.getJsonArray("columnNames"), is(expectedColumnNames)); + assertThat(response.responseObject.getJsonArray("columnTypes"), is(expectedColumnTypes)); + assertThat(response.responseObject.getString("queryId"), is(nullValue())); + assertThat(response.rows.get(0).getJsonObject(0).getJsonArray("F1").getString(0), is("a")); // rowkey + assertThat(response.rows.get(0).getLong(1), is(1L)); // latest_by_offset(long) + } + @Test public void shouldFailPullQueryWithInvalidSql() { @@ -613,9 +658,13 @@ private void shouldFailToExecuteQuery(final String sql, final String message) { } private QueryResponse executeQuery(final String sql) { + return executeQueryWithVariables(sql, new JsonObject()); + } + + private QueryResponse executeQueryWithVariables(final String sql, final JsonObject variables) { JsonObject properties = new JsonObject(); JsonObject requestBody = new JsonObject() - .put("sql", sql).put("properties", properties); + .put("sql", sql).put("properties", properties).put("sessionVariables", variables); HttpResponse response = sendRequest("/query-stream", requestBody.toBuffer()); return new QueryResponse(response.bodyAsString()); } diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/InsertsStreamRunner.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/InsertsStreamRunner.java index 2e6d5ba5bd4d..91962fc0bb2d 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/InsertsStreamRunner.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/InsertsStreamRunner.java @@ -161,6 +161,7 @@ private class InsertsStreamEndpoints implements Endpoints { @Override public CompletableFuture createQueryPublisher(final String sql, final JsonObject properties, + final JsonObject sessionVariables, final Context context, final WorkerExecutor workerExecutor, final ApiSecurityContext apiSecurityContext, diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/PullQueryRunner.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/PullQueryRunner.java index 50776c89fb82..863a5786c38a 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/PullQueryRunner.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/PullQueryRunner.java @@ -119,6 +119,7 @@ private static class PullQueryEndpoints implements Endpoints { @Override public synchronized CompletableFuture createQueryPublisher(final String sql, final JsonObject properties, + final JsonObject sessionVariables, final Context context, final WorkerExecutor workerExecutor, final ApiSecurityContext apiSecurityContext, diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/QueryStreamRunner.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/QueryStreamRunner.java index 4fdc408c8e7e..8ff54cc0dc77 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/QueryStreamRunner.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/QueryStreamRunner.java @@ -98,6 +98,7 @@ private class QueryStreamEndpoints implements Endpoints { @Override public synchronized CompletableFuture createQueryPublisher(final String sql, final JsonObject properties, + final JsonObject sessionVariables, final Context context, final WorkerExecutor workerExecutor, final ApiSecurityContext apiSecurityContext, diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java index a7b08b7d203a..88b0f22cb0ac 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java @@ -55,8 +55,10 @@ import io.confluent.ksql.rest.entity.CommandStatus; import io.confluent.ksql.rest.entity.CommandStatus.Status; import io.confluent.ksql.rest.entity.CommandStatuses; +import io.confluent.ksql.rest.entity.KsqlEntity; import io.confluent.ksql.rest.entity.KsqlMediaType; import io.confluent.ksql.rest.entity.KsqlRequest; +import io.confluent.ksql.rest.entity.Queries; import io.confluent.ksql.rest.entity.QueryStreamArgs; import io.confluent.ksql.rest.entity.ServerClusterId; import io.confluent.ksql.rest.entity.ServerInfo; @@ -151,6 +153,11 @@ public class RestApiTest { resource(TOPIC, "X"), ops(ALL) ) + .withAcl( + NORMAL_USER, + resource(TOPIC, "Y"), + ops(ALL) + ) .withAcl( NORMAL_USER, resource(TOPIC, AGG_TABLE), @@ -614,7 +621,7 @@ public void shouldFailToExecutePullQueryOverRestHttp2() { public void shouldExecutePullQueryOverHttp2QueryStream() { QueryStreamArgs queryStreamArgs = new QueryStreamArgs( "SELECT COUNT, USERID from " + AGG_TABLE + " WHERE USERID='" + AN_AGG_KEY + "';", - Collections.emptyMap()); + Collections.emptyMap(), Collections.emptyMap()); QueryResponse[] queryResponse = new QueryResponse[1]; assertThatEventually(() -> { @@ -674,6 +681,23 @@ public void shouldDeleteTopic() { assertThat("Expected topic X to be deleted", !topicExists("X")); } + @Test + public void shouldCreateStreamWithVariableSubstitution() { + // Given: + // When: + makeKsqlRequestWithVariables( + "CREATE STREAM Y AS SELECT * FROM " + PAGE_VIEW_STREAM + " WHERE USERID='${id}';", + ImmutableMap.of("id", "USER_1") + ); + + // Then: + final List query = ((Queries) makeKsqlRequest("SHOW QUERIES;").get(0)) + .getQueries().stream().map(q -> q.getQueryString()) + .filter(q -> q.contains("WHERE (PAGEVIEW_KSTREAM.USERID = 'USER_1')")) + .collect(Collectors.toList()); + assertThat(query.size(), is(1)); + } + @Test public void shouldFailToExecuteQueryUsingRestWithHttp2() { // Given: @@ -699,8 +723,12 @@ private ServiceContext getServiceContext() { return serviceContext; } - private static void makeKsqlRequest(final String sql) { - RestIntegrationTestUtil.makeKsqlRequest(REST_APP, sql); + private static List makeKsqlRequest(final String sql) { + return RestIntegrationTestUtil.makeKsqlRequest(REST_APP, sql); + } + + private static void makeKsqlRequestWithVariables(final String sql, final Map variables) { + RestIntegrationTestUtil.makeKsqlRequestWithVariables(REST_APP, sql, variables); } private static String rawRestQueryRequest(final String sql, final String mediaType) { diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestIntegrationTestUtil.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestIntegrationTestUtil.java index 88135e0106ec..987fc2801dcd 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestIntegrationTestUtil.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestIntegrationTestUtil.java @@ -35,6 +35,7 @@ import io.confluent.ksql.rest.entity.KsqlEntity; import io.confluent.ksql.rest.entity.KsqlEntityList; import io.confluent.ksql.rest.entity.KsqlErrorMessage; +import io.confluent.ksql.rest.entity.KsqlMediaType; import io.confluent.ksql.rest.entity.KsqlRequest; import io.confluent.ksql.rest.entity.ServerClusterId; import io.confluent.ksql.rest.entity.ServerInfo; @@ -91,6 +92,15 @@ public static List makeKsqlRequest(final TestKsqlRestApp restApp, fi return makeKsqlRequest(restApp, sql, Optional.empty()); } + public static String makeKsqlRequestWithVariables( + final TestKsqlRestApp restApp, final String sql, final Map variables) { + final KsqlRequest request = + new KsqlRequest(sql, ImmutableMap.of(), ImmutableMap.of(), variables, null); + + return rawRestRequest(restApp, HTTP_1_1, POST, "/ksql", request, KsqlMediaType.KSQL_V1_JSON.mediaType(), + Optional.empty()).body().toString(); + } + static List makeKsqlRequest( final TestKsqlRestApp restApp, final String sql, diff --git a/ksqldb-rest-client/src/test/java/io/confluent/ksql/rest/client/KsqlClientUtilTest.java b/ksqldb-rest-client/src/test/java/io/confluent/ksql/rest/client/KsqlClientUtilTest.java index bf45d477a8ff..6ab631a1601a 100644 --- a/ksqldb-rest-client/src/test/java/io/confluent/ksql/rest/client/KsqlClientUtilTest.java +++ b/ksqldb-rest-client/src/test/java/io/confluent/ksql/rest/client/KsqlClientUtilTest.java @@ -184,7 +184,7 @@ public void shouldSerialiseDeserialise() { // Then: assertThat(buff, is(notNullValue())); String expectedJson = "{\"ksql\":\"some ksql\",\"streamsProperties\":{\"auto.offset.reset\":\"" - + "latest\"},\"requestProperties\":{},\"commandSequenceNumber\":21345}"; + + "latest\"},\"requestProperties\":{},\"commandSequenceNumber\":21345,\"sessionVariables\":{}}"; assertThat(new JsonObject(buff), is(new JsonObject(expectedJson))); // When: diff --git a/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/SessionProperties.java b/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/SessionProperties.java index e9e95afaaa33..483df83e0cb0 100644 --- a/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/SessionProperties.java +++ b/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/SessionProperties.java @@ -22,6 +22,7 @@ import java.util.Map; import java.util.Objects; import java.util.TreeMap; +import java.util.stream.Collectors; /** * Wraps the incoming {@link io.confluent.ksql.rest.entity.KsqlRequest} streamsProperties @@ -38,22 +39,44 @@ public class SessionProperties { /** * @param mutableScopedProperties The streamsProperties of the incoming request - * @param ksqlHostInfo The ksqlHostInfo of the server that handles the request + * @param ksqlHostInfo The ksqlHostInfo of the server that handles the request * @param localUrl The url of the server that handles the request * @param internalRequest Flag indicating if request is from within the KSQL cluster + * @param sessionVariables Initial session variables */ public SessionProperties( final Map mutableScopedProperties, final KsqlHostInfo ksqlHostInfo, final URL localUrl, - final boolean internalRequest + final boolean internalRequest, + final Map sessionVariables ) { - this.mutableScopedProperties = + this.mutableScopedProperties = new HashMap<>(Objects.requireNonNull(mutableScopedProperties, "mutableScopedProperties")); this.ksqlHostInfo = Objects.requireNonNull(ksqlHostInfo, "ksqlHostInfo"); this.localUrl = Objects.requireNonNull(localUrl, "localUrl"); this.internalRequest = internalRequest; this.sessionVariables = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + this.sessionVariables.putAll( + Objects.requireNonNull(sessionVariables, "sessionVariables") + .entrySet() + .stream() + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue().toString()))); + } + + /** + * @param mutableScopedProperties The streamsProperties of the incoming request + * @param ksqlHostInfo The ksqlHostInfo of the server that handles the request + * @param localUrl The url of the server that handles the request + * @param internalRequest Flag indicating if request is from within the KSQL cluster + */ + public SessionProperties( + final Map mutableScopedProperties, + final KsqlHostInfo ksqlHostInfo, + final URL localUrl, + final boolean internalRequest + ) { + this(mutableScopedProperties, ksqlHostInfo, localUrl, internalRequest, Collections.EMPTY_MAP); } public Map getMutableScopedProperties() { diff --git a/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/KsqlRequest.java b/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/KsqlRequest.java index 4fd580e973c4..509b66c6063e 100644 --- a/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/KsqlRequest.java +++ b/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/KsqlRequest.java @@ -39,13 +39,24 @@ public class KsqlRequest { private final String ksql; private final ImmutableMap configOverrides; private final ImmutableMap requestProperties; + private final ImmutableMap sessionVariables; private final Optional commandSequenceNumber; + public KsqlRequest( + @JsonProperty("ksql") final String ksql, + @JsonProperty("streamsProperties") final Map configOverrides, + @JsonProperty("requestProperties") final Map requestProperties, + @JsonProperty("commandSequenceNumber") final Long commandSequenceNumber + ) { + this(ksql, configOverrides, requestProperties, null, commandSequenceNumber); + } + @JsonCreator public KsqlRequest( @JsonProperty("ksql") final String ksql, @JsonProperty("streamsProperties") final Map configOverrides, @JsonProperty("requestProperties") final Map requestProperties, + @JsonProperty("sessionVariables") final Map sessionVariables, @JsonProperty("commandSequenceNumber") final Long commandSequenceNumber ) { this.ksql = ksql == null ? "" : ksql; @@ -55,6 +66,9 @@ public KsqlRequest( this.requestProperties = requestProperties == null ? ImmutableMap.of() : ImmutableMap.copyOf(serializeClassValues(requestProperties)); + this.sessionVariables = sessionVariables == null + ? ImmutableMap.of() + : ImmutableMap.copyOf(serializeClassValues(sessionVariables)); this.commandSequenceNumber = Optional.ofNullable(commandSequenceNumber); } @@ -71,6 +85,10 @@ public Map getRequestProperties() { return coerceTypes(requestProperties); } + public Map getSessionVariables() { + return sessionVariables; + } + public Optional getCommandSequenceNumber() { return commandSequenceNumber; } @@ -89,12 +107,14 @@ public boolean equals(final Object o) { return Objects.equals(ksql, that.ksql) && Objects.equals(configOverrides, that.configOverrides) && Objects.equals(requestProperties, that.requestProperties) + && Objects.equals(sessionVariables, that.sessionVariables) && Objects.equals(commandSequenceNumber, that.commandSequenceNumber); } @Override public int hashCode() { - return Objects.hash(ksql, configOverrides, requestProperties, commandSequenceNumber); + return Objects.hash(ksql, configOverrides, requestProperties, + sessionVariables, commandSequenceNumber); } @Override @@ -103,6 +123,7 @@ public String toString() { + "ksql='" + ksql + '\'' + ", configOverrides=" + configOverrides + ", requestProperties=" + requestProperties + + ", sessionVariables=" + sessionVariables + ", commandSequenceNumber=" + commandSequenceNumber + '}'; } diff --git a/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/QueryStreamArgs.java b/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/QueryStreamArgs.java index fe8d3eff8b39..2f48242e467d 100644 --- a/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/QueryStreamArgs.java +++ b/ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/entity/QueryStreamArgs.java @@ -29,12 +29,18 @@ public class QueryStreamArgs { public final String sql; public final JsonObject properties; + public final JsonObject sessionVariables; public QueryStreamArgs(final @JsonProperty(value = "sql", required = true) String sql, final @JsonProperty(value = "properties") - Map properties) { + Map properties, + final @JsonProperty(value = "sessionVariables") + Map sessionVariables) { this.sql = Objects.requireNonNull(sql); this.properties = properties == null ? new JsonObject() : new JsonObject(properties); + this.sessionVariables = sessionVariables == null + ? new JsonObject() + : new JsonObject(sessionVariables); } @Override @@ -42,6 +48,7 @@ public String toString() { return "QueryStreamArgs{" + "sql='" + sql + '\'' + ", properties=" + properties + + ", sessionVariables=" + sessionVariables + '}'; } } diff --git a/ksqldb-rest-model/src/test/java/io/confluent/ksql/rest/entity/KsqlRequestTest.java b/ksqldb-rest-model/src/test/java/io/confluent/ksql/rest/entity/KsqlRequestTest.java index c8a25fb287d6..0e365e58fd77 100644 --- a/ksqldb-rest-model/src/test/java/io/confluent/ksql/rest/entity/KsqlRequestTest.java +++ b/ksqldb-rest-model/src/test/java/io/confluent/ksql/rest/entity/KsqlRequestTest.java @@ -65,6 +65,7 @@ public class KsqlRequestTest { + "\"" + KsqlRequestConfig.KSQL_REQUEST_INTERNAL_REQUEST + "\":true," + "\"" + KsqlRequestConfig.KSQL_REQUEST_QUERY_PULL_SKIP_FORWARDING + "\":true" + "}," + + "\"sessionVariables\":{}," + "\"commandSequenceNumber\":2}"; private static final String A_JSON_REQUEST_WITH_NULL_COMMAND_NUMBER = "{" + "\"ksql\":\"sql\"," @@ -77,6 +78,7 @@ public class KsqlRequestTest { + "\"" + KsqlRequestConfig.KSQL_REQUEST_INTERNAL_REQUEST + "\":true," + "\"" + KsqlRequestConfig.KSQL_REQUEST_QUERY_PULL_SKIP_FORWARDING + "\":true" + "}," + + "\"sessionVariables\":{}," + "\"commandSequenceNumber\":null}"; private static final ImmutableMap SOME_PROPS = ImmutableMap.of( @@ -110,6 +112,13 @@ public void shouldHandleNullProps() { is(Collections.emptyMap())); } + @Test + public void shouldHandleNullSessionVariables() { + assertThat( + new KsqlRequest("sql", SOME_PROPS, Collections.emptyMap(), null, SOME_COMMAND_NUMBER).getSessionVariables(), + is(Collections.emptyMap())); + } + @Test public void shouldHandleNullCommandNumber() { assertThat( @@ -166,12 +175,15 @@ public void shouldSerializeToJsonWithCommandNumber() { public void shouldImplementHashCodeAndEqualsCorrectly() { new EqualsTester() .addEqualityGroup(new KsqlRequest("sql", SOME_PROPS, SOME_REQUEST_PROPS, SOME_COMMAND_NUMBER), - new KsqlRequest("sql", SOME_PROPS, SOME_REQUEST_PROPS, SOME_COMMAND_NUMBER)) + new KsqlRequest("sql", SOME_PROPS, SOME_REQUEST_PROPS, SOME_COMMAND_NUMBER), + new KsqlRequest("sql", SOME_PROPS, SOME_REQUEST_PROPS, ImmutableMap.of(), SOME_COMMAND_NUMBER), + new KsqlRequest("sql", SOME_PROPS, SOME_REQUEST_PROPS, null, SOME_COMMAND_NUMBER)) .addEqualityGroup( new KsqlRequest("different-sql", SOME_PROPS, SOME_REQUEST_PROPS, SOME_COMMAND_NUMBER)) .addEqualityGroup( new KsqlRequest("sql", ImmutableMap.of(), SOME_REQUEST_PROPS, SOME_COMMAND_NUMBER)) .addEqualityGroup(new KsqlRequest("sql", SOME_PROPS, SOME_REQUEST_PROPS, null)) + .addEqualityGroup(new KsqlRequest("sql", SOME_PROPS, SOME_REQUEST_PROPS, ImmutableMap.of("", ""), null)) .testEquals(); }