diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandStore.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandStore.java index 8803dccb783e..e7d36c168d4b 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandStore.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/computation/CommandStore.java @@ -17,6 +17,7 @@ package io.confluent.ksql.rest.server.computation; import io.confluent.ksql.parser.tree.Statement; +import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.Pair; import org.apache.kafka.clients.consumer.Consumer; @@ -94,10 +95,16 @@ public CommandId distributeStatement( String statementString, Statement statement, Map streamsProperties - ) throws Exception { + ) throws KsqlException { CommandId commandId = commandIdAssigner.getCommandId(statement); Command command = new Command(statementString, streamsProperties); - commandProducer.send(new ProducerRecord<>(commandTopic, commandId, command)).get(); + try { + commandProducer.send(new ProducerRecord<>(commandTopic, commandId, command)).get(); + } catch (Exception e) { + throw new KsqlException(String.format("Could not write the statement '%s' into the " + + "command topic" + + ".", statementString), e); + } return commandId; } diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/KsqlResource.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/KsqlResource.java index 47ea428f088c..9ee351fd3694 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/KsqlResource.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/KsqlResource.java @@ -86,7 +86,13 @@ import javax.ws.rs.Produces; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; -import java.util.*; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; @@ -171,7 +177,7 @@ private KsqlEntity executeStatement( String statementText, Statement statement, Map streamsProperties - ) throws Exception { + ) throws KsqlException { if (statement instanceof ListTopics) { return listTopics(statementText); } else if (statement instanceof ListRegisteredTopics) { @@ -205,30 +211,42 @@ private KsqlEntity executeStatement( || statement instanceof DropStream || statement instanceof DropTable ) { - // getStatementExecutionPlan(statement, statementText, streamsProperties); + //Sanity check for the statement before distributing it. + validateStatement(statement, statementText, streamsProperties); return distributeStatement(statementText, statement, streamsProperties); } else { if (statement != null) { - throw new Exception(String.format( + throw new KsqlException(String.format( "Cannot handle statement of type '%s'", statement.getClass().getSimpleName() )); - } else if (statementText != null) { - throw new Exception(String.format( + } else { + throw new KsqlException(String.format( "Unable to execute statement '%s'", statementText )); - } else { - throw new Exception("Unable to execute statement"); } } } + /** + * Validate the statement by creating the execution plan for it. + * + * @param statement + * @param statementText + * @param streamsProperties + * @throws Exception + */ + private void validateStatement(Statement statement, String statementText, + Map streamsProperties) throws KsqlException { + getStatementExecutionPlan(statement, statementText, streamsProperties); + } + private CommandStatusEntity distributeStatement( String statementText, Statement statement, Map streamsProperties - ) throws Exception { + ) throws KsqlException { CommandId commandId = commandStore.distributeStatement(statementText, statement, streamsProperties); CommandStatus commandStatus; @@ -239,6 +257,9 @@ private CommandStatusEntity distributeStatement( log.warn("Timeout to get commandStatus, waited {} milliseconds:, statementText:" + statementText, distributedCommandResponseTimeout, exception); commandStatus = statementExecutor.getStatus(commandId).get(); + } catch (Exception e) { + throw new KsqlException(String.format("Could not write the statement '%s' into the command " + + "topic.", statementText), e); } return new CommandStatusEntity(statementText, commandId, commandStatus); } @@ -276,10 +297,10 @@ private Queries showQueries(String statementText) { return new Queries(statementText, runningQueries); } - private TopicDescription describeTopic(String statementText, String name) throws Exception { + private TopicDescription describeTopic(String statementText, String name) throws KsqlException { KsqlTopic ksqlTopic = ksqlEngine.getMetaStore().getTopic(name); if (ksqlTopic == null) { - throw new Exception(String.format("Could not find Topic '%s' in the Metastore", + throw new KsqlException(String.format("Could not find Topic '%s' in the Metastore", name)); } String schemaString = null; @@ -293,11 +314,11 @@ private TopicDescription describeTopic(String statementText, String name) throws return topicDescription; } - private SourceDescription describe(String statementText, String name) throws Exception { + private SourceDescription describe(String statementText, String name) throws KsqlException { StructuredDataSource dataSource = ksqlEngine.getMetaStore().getSource(name); if (dataSource == null) { - throw new Exception(String.format("Could not find STREAM/TABLE '%s' in the Metastore", + throw new KsqlException(String.format("Could not find STREAM/TABLE '%s' in the Metastore", name)); } return new SourceDescription(statementText, dataSource); @@ -316,18 +337,20 @@ private TablesList listTables(String statementText) { } private ExecutionPlan getStatementExecutionPlan(Explain explain, String statementText) - throws Exception { + throws KsqlException { return getStatementExecutionPlan(explain.getStatement(), statementText, Collections.emptyMap()); } private ExecutionPlan getStatementExecutionPlan(Statement statement, String statementText, Map properties) - throws Exception { + throws KsqlException { DDLCommandTask ddlCommandTask = ddlCommandTasks.get(statement.getClass()); if (ddlCommandTask != null) { try { return new ExecutionPlan(ddlCommandTask.execute(statement, statementText, properties)); + } catch (KsqlException ksqlException) { + throw ksqlException; } catch (Throwable t) { throw new KsqlException("Cannot RUN execution plan for this statement, " + statement, t); } diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/mock/MockCommandStore.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/mock/MockCommandStore.java index b192c43faf68..c8a47a8e1851 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/mock/MockCommandStore.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/mock/MockCommandStore.java @@ -36,6 +36,7 @@ import io.confluent.ksql.rest.server.computation.CommandIdAssigner; import io.confluent.ksql.rest.server.computation.CommandStore; import io.confluent.ksql.rest.server.utils.TestUtils; +import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.Pair; public class MockCommandStore extends CommandStore { @@ -86,7 +87,7 @@ public CommandId distributeStatement( String statementString, Statement statement, Map streamsProperties - ) throws Exception { + ) throws KsqlException { CommandId commandId = commandIdAssigner.getCommandId(statement); return commandId; } diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/KsqlResourceTest.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/KsqlResourceTest.java index e4dcd7d68118..84cd6e92d6e7 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/KsqlResourceTest.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/KsqlResourceTest.java @@ -25,8 +25,28 @@ import io.confluent.ksql.metastore.KsqlTable; import io.confluent.ksql.metastore.KsqlTopic; import io.confluent.ksql.metastore.MetaStore; -import io.confluent.ksql.parser.tree.*; -import io.confluent.ksql.rest.entity.*; +import io.confluent.ksql.parser.tree.Expression; +import io.confluent.ksql.parser.tree.ListQueries; +import io.confluent.ksql.parser.tree.ListRegisteredTopics; +import io.confluent.ksql.parser.tree.ListStreams; +import io.confluent.ksql.parser.tree.ListTables; +import io.confluent.ksql.parser.tree.QualifiedName; +import io.confluent.ksql.parser.tree.RegisterTopic; +import io.confluent.ksql.parser.tree.ShowColumns; +import io.confluent.ksql.parser.tree.Statement; +import io.confluent.ksql.parser.tree.StringLiteral; +import io.confluent.ksql.rest.entity.CommandStatus; +import io.confluent.ksql.rest.entity.CommandStatusEntity; +import io.confluent.ksql.rest.entity.ErrorMessageEntity; +import io.confluent.ksql.rest.entity.KsqlEntity; +import io.confluent.ksql.rest.entity.KsqlEntityList; +import io.confluent.ksql.rest.entity.KsqlRequest; +import io.confluent.ksql.rest.entity.KsqlTopicInfo; +import io.confluent.ksql.rest.entity.KsqlTopicsList; +import io.confluent.ksql.rest.entity.Queries; +import io.confluent.ksql.rest.entity.SourceDescription; +import io.confluent.ksql.rest.entity.StreamsList; +import io.confluent.ksql.rest.entity.TablesList; import io.confluent.ksql.rest.server.mock.MockKafkaTopicClient; import io.confluent.ksql.rest.server.KsqlRestConfig; import io.confluent.ksql.rest.server.StatementParser; @@ -45,11 +65,20 @@ import org.apache.kafka.connect.data.SchemaBuilder; import org.junit.Test; -import java.util.*; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; import java.util.concurrent.Future; import java.util.stream.Collectors; +import javax.ws.rs.core.Response; + import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.equalTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -363,4 +392,46 @@ public void testListTablesStatement() throws Exception { assertEquals(expectedTable, testTables.get(0)); } + + @Test + public void shouldFailForIncorrectCSASStatementResultType() throws Exception { + KsqlResource testResource = TestKsqlResourceUtil.get(); + String ksqlString1 = "CREATE STREAM s1 AS SELECT * FROM test_table;"; + + Response response1 = testResource.handleKsqlStatements(new KsqlRequest(ksqlString1, Collections + .emptyMap())); + KsqlEntityList result1 = (KsqlEntityList) response1.getEntity(); + assertTrue("Incorrect response size.", result1.size() == 1); + assertThat(result1.get(0), instanceOf(ErrorMessageEntity.class)); + ErrorMessageEntity errorMessageEntity1 = (ErrorMessageEntity) result1.get(0); + assertThat("", errorMessageEntity1.getErrorMessage().getMessage(), equalTo("Invalid result type. Your SELECT query produces a TABLE. Please use CREATE TABLE AS SELECT statement instead.")); + + String ksqlString2 = "CREATE STREAM s2 AS SELECT S2_F1 , count(S2_F1) FROM test_stream group by " + + "s2_f1;"; + + Response response2 = testResource.handleKsqlStatements(new KsqlRequest(ksqlString2, Collections + .emptyMap())); + KsqlEntityList result2 = (KsqlEntityList) response2.getEntity(); + assertThat("Incorrect response size.", result2.size(), equalTo(1)); + assertThat(result2.get(0), instanceOf(ErrorMessageEntity.class)); + ErrorMessageEntity errorMessageEntity2 = (ErrorMessageEntity) result2.get(0); + assertThat("", errorMessageEntity2.getErrorMessage().getMessage(), equalTo("Invalid " + + "result type. Your SELECT query produces a TABLE. Please use CREATE TABLE AS SELECT statement instead.")); + } + + @Test + public void shouldFailForIncorrectCTASStatementResultType() throws Exception { + KsqlResource testResource = TestKsqlResourceUtil.get(); + final String ksqlString = "CREATE TABLE s1 AS SELECT * FROM test_stream;"; + + Response response = testResource.handleKsqlStatements(new KsqlRequest(ksqlString, Collections + .emptyMap())); + KsqlEntityList result = (KsqlEntityList) response.getEntity(); + assertThat("Incorrect response size.", result.size(), equalTo(1)); + assertThat(result.get(0), instanceOf(ErrorMessageEntity.class)); + ErrorMessageEntity errorMessageEntity = (ErrorMessageEntity) result.get(0); + assertThat(errorMessageEntity.getErrorMessage().getMessage(), equalTo("Invalid result type. Your " + + "SELECT query produces a STREAM. Please use CREATE STREAM AS SELECT statement instead.")); + } + }