diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/schema/Operator.java b/ksqldb-common/src/main/java/io/confluent/ksql/schema/Operator.java index b45ec3524129..5e0e9ad2a344 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/schema/Operator.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/schema/Operator.java @@ -28,7 +28,9 @@ public enum Operator { ADD("+", SqlDecimal::add) { @Override - public SqlType resultType(final SqlType left, final SqlType right) { + public SqlType resultType(final SqlType left, final SqlType right) throws KsqlException { + checkForNullTypes(left, right); + if (left.baseType() == SqlBaseType.STRING && right.baseType() == SqlBaseType.STRING) { return SqlTypes.STRING; } @@ -60,7 +62,9 @@ public String getSymbol() { * @param right the right side of the operation. * @return the result schema. */ - public SqlType resultType(final SqlType left, final SqlType right) { + public SqlType resultType(final SqlType left, final SqlType right) throws KsqlException { + checkForNullTypes(left, right); + if (left.baseType().isNumber() && right.baseType().isNumber()) { if (left.baseType().canImplicitlyCast(right.baseType())) { if (right.baseType() != SqlBaseType.DECIMAL) { @@ -82,4 +86,13 @@ public SqlType resultType(final SqlType left, final SqlType right) { throw new KsqlException( "Unsupported arithmetic types. " + left.baseType() + " " + right.baseType()); } + + private static void checkForNullTypes(final SqlType left, final SqlType right) + throws KsqlException { + if (left == null || right == null) { + throw new KsqlException( + String.format("Arithmetic on types %s and %s are not supported.", left, right)); + } + return; + } } diff --git a/ksqldb-common/src/test/java/io/confluent/ksql/schema/OperatorTest.java b/ksqldb-common/src/test/java/io/confluent/ksql/schema/OperatorTest.java index d84ac03578df..5553aa9e2c71 100644 --- a/ksqldb-common/src/test/java/io/confluent/ksql/schema/OperatorTest.java +++ b/ksqldb-common/src/test/java/io/confluent/ksql/schema/OperatorTest.java @@ -29,6 +29,8 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; @@ -92,6 +94,33 @@ public void shouldResolveModulusReturnType() { assertConversionRule(MODULUS, SqlDecimal::modulus); } + @Test + public void shouldThrowExceptionWhenNullType() { + allOperations().forEach(op -> { + for (final SqlBaseType leftBaseType : SqlBaseType.values()) { + // When: + final Throwable exception = assertThrows(KsqlException.class, + () -> op.resultType(getType(leftBaseType), null)); + + // Then: + assertEquals(String.format("Arithmetic on types %s and null are not supported.", + getType(leftBaseType)), exception.getMessage()); + } + }); + + allOperations().forEach(op -> { + for (final SqlBaseType rightBaseType : SqlBaseType.values()) { + // When: + final Throwable exception = assertThrows(KsqlException.class, + () -> op.resultType(null, getType(rightBaseType))); + + // Then: + assertEquals(String.format("Arithmetic on types null and %s are not supported.", + getType(rightBaseType)), exception.getMessage()); + } + }); + } + @Test public void shouldWorkUsingSameRulesAsBaseTypeUpCastRules() { allOperations().forEach(op -> { diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/execution/ExpressionEvaluatorParityTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/execution/ExpressionEvaluatorParityTest.java index b270a459775e..62483acedef4 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/execution/ExpressionEvaluatorParityTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/execution/ExpressionEvaluatorParityTest.java @@ -174,10 +174,35 @@ public void shouldDoArithmetic() throws Exception { @Test public void shouldDoArithmetic_nulls() throws Exception { ordersRow = GenericRow.genericRow(null, null, null, null, null, null, null, null, null); - assertOrdersError("1 + null", compileTime("Unexpected error generating code for Test"), - compileTime("Unexpected error generating code for expression: (1 + null)")); - assertOrdersError("'a' + null", compileTime("Unexpected error generating code for Test"), - compileTime("Unexpected error generating code for expression: ('a' + null)")); + + //The error message coming from the compiler and the interpreter should be the same + assertOrdersError("1 + null", compileTime("Error processing expression: (1 + null). Arithmetic on types INTEGER and null are not supported."), + compileTime("Error processing expression: (1 + null). Arithmetic on types INTEGER and null are not supported.")); + + assertOrdersError("'a' + null", compileTime("Error processing expression: ('a' + null). Arithmetic on types STRING and null are not supported."), + compileTime("Error processing expression: ('a' + null). Arithmetic on types STRING and null are not supported.")); + + assertOrdersError("MAP(1 := 'cat') + null", compileTime("Error processing expression: (MAP(1:='cat') + null). Arithmetic on types MAP and null are not supported."), + compileTime("Error processing expression: (MAP(1:='cat') + null). Arithmetic on types MAP and null are not supported.")); + + assertOrdersError("Array[1,2,3] + null", compileTime("Error processing expression: (ARRAY[1, 2, 3] + null). Arithmetic on types ARRAY and null are not supported."), + compileTime("Error processing expression: (ARRAY[1, 2, 3] + null). Arithmetic on types ARRAY and null are not supported.")); + + assertOrdersError("1 - null", compileTime("Error processing expression: (1 - null). Arithmetic on types INTEGER and null are not supported."), + compileTime("Error processing expression: (1 - null). Arithmetic on types INTEGER and null are not supported.")); + + assertOrdersError("1 * null", compileTime("Error processing expression: (1 * null). Arithmetic on types INTEGER and null are not supported."), + compileTime("Error processing expression: (1 * null). Arithmetic on types INTEGER and null are not supported.")); + + assertOrdersError("1 / null", compileTime("Error processing expression: (1 / null). Arithmetic on types INTEGER and null are not supported."), + compileTime("Error processing expression: (1 / null). Arithmetic on types INTEGER and null are not supported.")); + + assertOrdersError("null + null", compileTime("Error processing expression: (null + null). Arithmetic on types null and null are not supported."), + compileTime("Error processing expression: (null + null). Arithmetic on types null and null are not supported.")); + + assertOrdersError("null / 0", compileTime("Error processing expression: (null / 0). Arithmetic on types null and INTEGER are not supported."), + compileTime("Error processing expression: (null / 0). Arithmetic on types null and INTEGER are not supported.")); + assertOrdersError("1 + ORDERID", evalLogger(null)); } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java index 513ed6ffb7e1..276eea5555be 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java @@ -153,14 +153,20 @@ private class Visitor implements ExpressionVisitor { public Void visitArithmeticBinary( final ArithmeticBinaryExpression node, final Context context - ) { + ) throws KsqlException { process(node.getLeft(), context); final SqlType leftType = context.getSqlType(); process(node.getRight(), context); final SqlType rightType = context.getSqlType(); - final SqlType resultType = node.getOperator().resultType(leftType, rightType); + final SqlType resultType; + try { + resultType = node.getOperator().resultType(leftType, rightType); + } catch (KsqlException e) { + throw new KsqlException(String.format( + "Error processing expression: %s. %s", node.toString(), e.getMessage()), e); + } context.setSqlType(resultType); return null; diff --git a/ksqldb-functional-tests/src/test/resources/query-validation-tests/comparison-expression.json b/ksqldb-functional-tests/src/test/resources/query-validation-tests/comparison-expression.json index 36485c945815..da4d70cfe8c5 100644 --- a/ksqldb-functional-tests/src/test/resources/query-validation-tests/comparison-expression.json +++ b/ksqldb-functional-tests/src/test/resources/query-validation-tests/comparison-expression.json @@ -67,7 +67,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Error in WHERE expression: Unsupported arithmetic types. BOOLEAN DECIMAL" + "message": "Error in WHERE expression: Error processing expression: (true + 1.5). Unsupported arithmetic types. BOOLEAN DECIMAL" } }, { diff --git a/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/push-queries.json b/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/push-queries.json index e4086a1949ca..a4e624a112e3 100644 --- a/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/push-queries.json +++ b/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/push-queries.json @@ -927,6 +927,66 @@ {"finalMessage":"Limit Reached"} ]} ] + }, + { + "name": "NULL Arithmetic Behavior - INTEGER addition", + "statements": [ + "CREATE STREAM INPUT (ID INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "SELECT 1 + NULL FROM INPUT EMIT CHANGES;" + ], + "expectedError": { + "type": "io.confluent.ksql.rest.entity.KsqlErrorMessage", + "message": "Error processing expression: (1 + null). Arithmetic on types INTEGER and null are not supported.", + "status": 400 + } + }, + { + "name": "NULL Arithmetic Behavior - MAP addition", + "statements": [ + "CREATE STREAM INPUT (ID INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "SELECT MAP(1 := 'cat') + NULL FROM INPUT EMIT CHANGES;" + ], + "expectedError": { + "type": "io.confluent.ksql.rest.entity.KsqlErrorMessage", + "message": "Error processing expression: (MAP(1:='cat') + null). Arithmetic on types MAP and null are not supported.", + "status": 400 + } + }, + { + "name": "NULL Arithmetic Behavior - ARRAY addition", + "statements": [ + "CREATE STREAM INPUT (ID INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "SELECT Array[1,2,3] + NULL FROM INPUT EMIT CHANGES;" + ], + "expectedError": { + "type": "io.confluent.ksql.rest.entity.KsqlErrorMessage", + "message": "Error processing expression: (ARRAY[1, 2, 3] + null). Arithmetic on types ARRAY and null are not supported.", + "status": 400 + } + }, + { + "name": "NULL Arithmetic Behavior - DECIMAL division", + "statements": [ + "CREATE STREAM INPUT (ID INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "SELECT 5.0 / NULL FROM INPUT EMIT CHANGES;" + ], + "expectedError": { + "type": "io.confluent.ksql.rest.entity.KsqlErrorMessage", + "message": "Error processing expression: (5.0 / null). Arithmetic on types DECIMAL(2, 1) and null are not supported.", + "status": 400 + } + }, + { + "name": "NULL Arithmetic Behavior - NULL NULL multiplication", + "statements": [ + "CREATE STREAM INPUT (ID INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "SELECT NULL * NULL FROM INPUT EMIT CHANGES;" + ], + "expectedError": { + "type": "io.confluent.ksql.rest.entity.KsqlErrorMessage", + "message": "Error processing expression: (null * null). Arithmetic on types null and null are not supported.", + "status": 400 + } } ] } \ No newline at end of file