Skip to content

Commit

Permalink
feat: enable variable substitution for /query-stream and /ksql endpoi…
Browse files Browse the repository at this point in the history
…nts (#7271)

* feat: enable variable substitution for query-stream and ksql endpoints

* address review comments

* refactor sessionProperties

* checkstyle

* fix unit tests

* Update KsqlRequest test

* fix unit test

* checkstyle

* fix test
  • Loading branch information
Zara Lim authored Mar 29, 2021
1 parent 71f501b commit f6dd212
Show file tree
Hide file tree
Showing 18 changed files with 239 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,16 @@ public QueryEndpoint(
public QueryPublisher createQueryPublisher(
final String sql,
final JsonObject properties,
final JsonObject sessionVariables,
final Context context,
final WorkerExecutor workerExecutor,
final ServiceContext serviceContext,
final MetricsCallbackHolder metricsCallbackHolder) {
// Must be run on worker as all this stuff is slow
VertxUtils.checkIsWorker();

final ConfiguredStatement<Query> statement = createStatement(sql, properties.getMap());
final ConfiguredStatement<Query> statement = createStatement(
sql, properties.getMap(), sessionVariables.getMap());

if (statement.getStatement().isPullQuery()) {
return createPullQueryPublisher(
Expand Down Expand Up @@ -199,15 +201,20 @@ private QueryPublisher createPullQueryPublisher(
}

private ConfiguredStatement<Query> createStatement(final String queryString,
final Map<String, Object> properties) {
final Map<String, Object> properties, final Map<String, Object> sessionVariables) {
final List<ParsedStatement> statements = ksqlEngine.parse(queryString);
if ((statements.size() != 1)) {
throw new KsqlStatementException(
String
.format("Expected exactly one KSQL statement; found %d instead", statements.size()),
queryString);
}
final PreparedStatement<?> ps = ksqlEngine.prepare(statements.get(0));
final PreparedStatement<?> ps = ksqlEngine.prepare(
statements.get(0),
sessionVariables.entrySet()
.stream()
.collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue().toString()))
);
final Statement statement = ps.getStatement();
if (!(statement instanceof Query)) {
throw new KsqlStatementException("Not a query", queryString);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ public void handle(final RoutingContext routingContext) {
final MetricsCallbackHolder metricsCallbackHolder = new MetricsCallbackHolder();
final long startTimeNanos = Time.SYSTEM.nanoseconds();
endpoints.createQueryPublisher(queryStreamArgs.get().sql, queryStreamArgs.get().properties,
context, server.getWorkerExecutor(), DefaultApiSecurityContext.create(routingContext),
metricsCallbackHolder)
queryStreamArgs.get().sessionVariables, context, server.getWorkerExecutor(),
DefaultApiSecurityContext.create(routingContext), metricsCallbackHolder)
.thenAccept(queryPublisher -> {

final QueryResponseMetadata metadata;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ public interface Endpoints {
* @return A CompletableFuture representing the future result of the operation
*/
CompletableFuture<QueryPublisher> createQueryPublisher(String sql, JsonObject properties,
Context context, WorkerExecutor workerExecutor, ApiSecurityContext apiSecurityContext,
MetricsCallbackHolder metricsCallbackHolder);
JsonObject sessionVariables, Context context, WorkerExecutor workerExecutor,
ApiSecurityContext apiSecurityContext, MetricsCallbackHolder metricsCallbackHolder);

/**
* Create a subscriber which will receive a stream of inserts from the API server and process
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ public KsqlServerEndpoints(
@Override
public CompletableFuture<QueryPublisher> createQueryPublisher(final String sql,
final JsonObject properties,
final JsonObject sessionVariables,
final Context context,
final WorkerExecutor workerExecutor,
final ApiSecurityContext apiSecurityContext,
Expand All @@ -155,6 +156,7 @@ public CompletableFuture<QueryPublisher> createQueryPublisher(final String sql,
.createQueryPublisher(
sql,
properties,
sessionVariables,
context,
workerExecutor,
ksqlSecurityContext.getServiceContext(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ public EndpointResponse handleKsqlStatements(
configProperties,
localHost,
localUrl,
requestConfig.getBoolean(KsqlRequestConfig.KSQL_REQUEST_INTERNAL_REQUEST)
requestConfig.getBoolean(KsqlRequestConfig.KSQL_REQUEST_INTERNAL_REQUEST),
request.getSessionVariables()
)
);

Expand Down
50 changes: 50 additions & 0 deletions ksqldb-rest-app/src/test/java/io/confluent/ksql/api/ApiTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,33 @@ public void shouldExecutePullQuery() throws Exception {
assertThat(queryId, is(nullValue()));
}

@Test
public void shouldExecutePullQueryWithVariableSubstitution() throws Exception {

// Given
JsonObject requestBody = new JsonObject().put("sql", "select * from ${name} where rowkey='1234';");
JsonObject properties = new JsonObject().put("prop1", "val1").put("prop2", 23);
JsonObject sessionVariables = new JsonObject().put("name", "foo");
requestBody.put("properties", properties).put("sessionVariables", sessionVariables);

// When
HttpResponse<Buffer> response = sendPostRequest("/query-stream", requestBody.toBuffer());

// Then
assertThat(response.statusCode(), is(200));
assertThat(response.statusMessage(), is("OK"));
assertThat(testEndpoints.getLastSql(), is("select * from ${name} where rowkey='1234';"));
assertThat(testEndpoints.getLastProperties(), is(properties));
assertThat(testEndpoints.getLastSessionVariables(), is(sessionVariables));
QueryResponse queryResponse = new QueryResponse(response.bodyAsString());
assertThat(queryResponse.responseObject.getJsonArray("columnNames"), is(DEFAULT_COLUMN_NAMES));
assertThat(queryResponse.responseObject.getJsonArray("columnTypes"), is(DEFAULT_COLUMN_TYPES));
assertThat(queryResponse.rows, is(DEFAULT_JSON_ROWS));
assertThat(server.getQueryIDs(), hasSize(0));
String queryId = queryResponse.responseObject.getString("queryId");
assertThat(queryId, is(nullValue()));
}

@Test
@CoreApiTest
public void shouldExecutePushQuery() throws Exception {
Expand All @@ -134,6 +161,29 @@ public void shouldExecutePushQuery() throws Exception {
assertThat(server.getQueryIDs().contains(new PushQueryId(queryId)), is(true));
}

@Test
public void shouldExecutePushQueryWithVariableSubstitution() throws Exception {

// When
JsonObject requestBody = new JsonObject().put("sql", "select * from ${name} emit changes;");
JsonObject properties = new JsonObject().put("prop1", "val1").put("prop2", 23);
JsonObject sessionVariables = new JsonObject().put("name", "foo");
requestBody.put("properties", properties).put("sessionVariables", sessionVariables);
QueryResponse queryResponse = executePushQueryAndWaitForRows(requestBody);

// Then
assertThat(testEndpoints.getLastSql(), is("select * from ${name} emit changes;"));
assertThat(testEndpoints.getLastProperties(), is(DEFAULT_PUSH_QUERY_REQUEST_PROPERTIES));
assertThat(testEndpoints.getLastSessionVariables(), is(sessionVariables));
assertThat(queryResponse.responseObject.getJsonArray("columnNames"), is(DEFAULT_COLUMN_NAMES));
assertThat(queryResponse.responseObject.getJsonArray("columnTypes"), is(DEFAULT_COLUMN_TYPES));
assertThat(queryResponse.rows, is(DEFAULT_JSON_ROWS));
assertThat(server.getQueryIDs(), hasSize(1));
String queryId = queryResponse.responseObject.getString("queryId");
assertThat(queryId, is(notNullValue()));
assertThat(server.getQueryIDs().contains(new PushQueryId(queryId)), is(true));
}

@Test
public void shouldExecuteMultiplePushQueries() throws Exception {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public class TestEndpoints implements Endpoints {
private List<KsqlEntity> ksqlEndpointResponse;
private String lastSql;
private JsonObject lastProperties;
private JsonObject lastSessionVariables;
private String lastTarget;
private Set<TestQueryPublisher> queryPublishers = new HashSet<>();
private int acksBeforePublisherError = -1;
Expand All @@ -65,7 +66,7 @@ public class TestEndpoints implements Endpoints {

@Override
public synchronized CompletableFuture<QueryPublisher> createQueryPublisher(final String sql,
final JsonObject properties, final Context context, final WorkerExecutor workerExecutor,
final JsonObject properties, JsonObject sessionVariables, final Context context, final WorkerExecutor workerExecutor,
final ApiSecurityContext apiSecurityContext,
final MetricsCallbackHolder metricsCallbackHolder) {
CompletableFuture<QueryPublisher> completableFuture = new CompletableFuture<>();
Expand All @@ -75,6 +76,7 @@ public synchronized CompletableFuture<QueryPublisher> createQueryPublisher(final
} else {
this.lastSql = sql;
this.lastProperties = properties;
this.lastSessionVariables = sessionVariables;
this.lastApiSecurityContext = apiSecurityContext;
final boolean push = sql.toLowerCase().contains("emit changes");
final int limit = extractLimit(sql);
Expand Down Expand Up @@ -240,6 +242,10 @@ public synchronized JsonObject getLastProperties() {
return lastProperties;
}

public synchronized JsonObject getLastSessionVariables() {
return lastSessionVariables;
}

public synchronized Set<TestQueryPublisher> getQueryPublishers() {
return queryPublishers;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,22 @@ public void shouldExecutePushQueryWithLimit() {
assertThat(response.responseObject.getString("queryId"), is(notNullValue()));
}

@Test
public void shouldExecutePushQueryWithVariableSubstitution() {

// Given:
String sql = "SELECT DEC AS ${name} from " + TEST_STREAM + " EMIT CHANGES LIMIT 2;";

// When:
QueryResponse response = executeQueryWithVariables(sql, new JsonObject().put("name", "COL"));

// Then:
assertThat(response.rows, hasSize(2));
assertThat(response.responseObject.getJsonArray("columnNames"), is(new JsonArray().add("COL")));
assertThat(response.responseObject.getJsonArray("columnTypes"), is(new JsonArray().add("DECIMAL(4, 2)")));
assertThat(response.responseObject.getString("queryId"), is(notNullValue()));
}

@Test
public void shouldFailPushQueryWithInvalidSql() {

Expand Down Expand Up @@ -290,6 +306,35 @@ public void shouldExecutePullQuery() {
assertThat(response.rows.get(0).getLong(1), is(1L)); // latest_by_offset(long)
}

@Test
public void shouldExecutePullQueryWithVariableSubstitution() {

// Given:
String sql = "SELECT * from ${AGG_TABLE} WHERE K=" + AN_AGG_KEY + ";";
final JsonObject variables = new JsonObject().put("AGG_TABLE", AGG_TABLE);

// When:
// Maybe need to retry as populating agg table is async
AtomicReference<QueryResponse> atomicReference = new AtomicReference<>();
assertThatEventually(() -> {
QueryResponse queryResponse = executeQueryWithVariables(sql, variables);
atomicReference.set(queryResponse);
return queryResponse.rows;
}, hasSize(1));

QueryResponse response = atomicReference.get();

// Then:
JsonArray expectedColumnNames = new JsonArray().add("K").add("LONG");
JsonArray expectedColumnTypes = new JsonArray().add("STRUCT<`F1` ARRAY<STRING>>").add("BIGINT");
assertThat(response.rows, hasSize(1));
assertThat(response.responseObject.getJsonArray("columnNames"), is(expectedColumnNames));
assertThat(response.responseObject.getJsonArray("columnTypes"), is(expectedColumnTypes));
assertThat(response.responseObject.getString("queryId"), is(nullValue()));
assertThat(response.rows.get(0).getJsonObject(0).getJsonArray("F1").getString(0), is("a")); // rowkey
assertThat(response.rows.get(0).getLong(1), is(1L)); // latest_by_offset(long)
}

@Test
public void shouldFailPullQueryWithInvalidSql() {

Expand Down Expand Up @@ -613,9 +658,13 @@ private void shouldFailToExecuteQuery(final String sql, final String message) {
}

private QueryResponse executeQuery(final String sql) {
return executeQueryWithVariables(sql, new JsonObject());
}

private QueryResponse executeQueryWithVariables(final String sql, final JsonObject variables) {
JsonObject properties = new JsonObject();
JsonObject requestBody = new JsonObject()
.put("sql", sql).put("properties", properties);
.put("sql", sql).put("properties", properties).put("sessionVariables", variables);
HttpResponse<Buffer> response = sendRequest("/query-stream", requestBody.toBuffer());
return new QueryResponse(response.bodyAsString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ private class InsertsStreamEndpoints implements Endpoints {
@Override
public CompletableFuture<QueryPublisher> createQueryPublisher(final String sql,
final JsonObject properties,
final JsonObject sessionVariables,
final Context context,
final WorkerExecutor workerExecutor,
final ApiSecurityContext apiSecurityContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ private static class PullQueryEndpoints implements Endpoints {
@Override
public synchronized CompletableFuture<QueryPublisher> createQueryPublisher(final String sql,
final JsonObject properties,
final JsonObject sessionVariables,
final Context context,
final WorkerExecutor workerExecutor,
final ApiSecurityContext apiSecurityContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ private class QueryStreamEndpoints implements Endpoints {
@Override
public synchronized CompletableFuture<QueryPublisher> createQueryPublisher(final String sql,
final JsonObject properties,
final JsonObject sessionVariables,
final Context context,
final WorkerExecutor workerExecutor,
final ApiSecurityContext apiSecurityContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@
import io.confluent.ksql.rest.entity.CommandStatus;
import io.confluent.ksql.rest.entity.CommandStatus.Status;
import io.confluent.ksql.rest.entity.CommandStatuses;
import io.confluent.ksql.rest.entity.KsqlEntity;
import io.confluent.ksql.rest.entity.KsqlMediaType;
import io.confluent.ksql.rest.entity.KsqlRequest;
import io.confluent.ksql.rest.entity.Queries;
import io.confluent.ksql.rest.entity.QueryStreamArgs;
import io.confluent.ksql.rest.entity.ServerClusterId;
import io.confluent.ksql.rest.entity.ServerInfo;
Expand Down Expand Up @@ -151,6 +153,11 @@ public class RestApiTest {
resource(TOPIC, "X"),
ops(ALL)
)
.withAcl(
NORMAL_USER,
resource(TOPIC, "Y"),
ops(ALL)
)
.withAcl(
NORMAL_USER,
resource(TOPIC, AGG_TABLE),
Expand Down Expand Up @@ -614,7 +621,7 @@ public void shouldFailToExecutePullQueryOverRestHttp2() {
public void shouldExecutePullQueryOverHttp2QueryStream() {
QueryStreamArgs queryStreamArgs = new QueryStreamArgs(
"SELECT COUNT, USERID from " + AGG_TABLE + " WHERE USERID='" + AN_AGG_KEY + "';",
Collections.emptyMap());
Collections.emptyMap(), Collections.emptyMap());

QueryResponse[] queryResponse = new QueryResponse[1];
assertThatEventually(() -> {
Expand Down Expand Up @@ -674,6 +681,23 @@ public void shouldDeleteTopic() {
assertThat("Expected topic X to be deleted", !topicExists("X"));
}

@Test
public void shouldCreateStreamWithVariableSubstitution() {
// Given:
// When:
makeKsqlRequestWithVariables(
"CREATE STREAM Y AS SELECT * FROM " + PAGE_VIEW_STREAM + " WHERE USERID='${id}';",
ImmutableMap.of("id", "USER_1")
);

// Then:
final List<String> query = ((Queries) makeKsqlRequest("SHOW QUERIES;").get(0))
.getQueries().stream().map(q -> q.getQueryString())
.filter(q -> q.contains("WHERE (PAGEVIEW_KSTREAM.USERID = 'USER_1')"))
.collect(Collectors.toList());
assertThat(query.size(), is(1));
}

@Test
public void shouldFailToExecuteQueryUsingRestWithHttp2() {
// Given:
Expand All @@ -699,8 +723,12 @@ private ServiceContext getServiceContext() {
return serviceContext;
}

private static void makeKsqlRequest(final String sql) {
RestIntegrationTestUtil.makeKsqlRequest(REST_APP, sql);
private static List<KsqlEntity> makeKsqlRequest(final String sql) {
return RestIntegrationTestUtil.makeKsqlRequest(REST_APP, sql);
}

private static void makeKsqlRequestWithVariables(final String sql, final Map<String, Object> variables) {
RestIntegrationTestUtil.makeKsqlRequestWithVariables(REST_APP, sql, variables);
}

private static String rawRestQueryRequest(final String sql, final String mediaType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import io.confluent.ksql.rest.entity.KsqlEntity;
import io.confluent.ksql.rest.entity.KsqlEntityList;
import io.confluent.ksql.rest.entity.KsqlErrorMessage;
import io.confluent.ksql.rest.entity.KsqlMediaType;
import io.confluent.ksql.rest.entity.KsqlRequest;
import io.confluent.ksql.rest.entity.ServerClusterId;
import io.confluent.ksql.rest.entity.ServerInfo;
Expand Down Expand Up @@ -91,6 +92,15 @@ public static List<KsqlEntity> makeKsqlRequest(final TestKsqlRestApp restApp, fi
return makeKsqlRequest(restApp, sql, Optional.empty());
}

public static String makeKsqlRequestWithVariables(
final TestKsqlRestApp restApp, final String sql, final Map<String, Object> variables) {
final KsqlRequest request =
new KsqlRequest(sql, ImmutableMap.of(), ImmutableMap.of(), variables, null);

return rawRestRequest(restApp, HTTP_1_1, POST, "/ksql", request, KsqlMediaType.KSQL_V1_JSON.mediaType(),
Optional.empty()).body().toString();
}

static List<KsqlEntity> makeKsqlRequest(
final TestKsqlRestApp restApp,
final String sql,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ public void shouldSerialiseDeserialise() {
// Then:
assertThat(buff, is(notNullValue()));
String expectedJson = "{\"ksql\":\"some ksql\",\"streamsProperties\":{\"auto.offset.reset\":\""
+ "latest\"},\"requestProperties\":{},\"commandSequenceNumber\":21345}";
+ "latest\"},\"requestProperties\":{},\"commandSequenceNumber\":21345,\"sessionVariables\":{}}";
assertThat(new JsonObject(buff), is(new JsonObject(expectedJson)));

// When:
Expand Down
Loading

0 comments on commit f6dd212

Please sign in to comment.