-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: improve handling of NULLs #5019
Changes from 6 commits
1a1153b
841c804
5d22ed7
c0d8792
2028a47
746e65d
f6a4144
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ | |
import com.google.common.collect.Multiset; | ||
import io.confluent.ksql.execution.codegen.helpers.ArrayAccess; | ||
import io.confluent.ksql.execution.codegen.helpers.ArrayBuilder; | ||
import io.confluent.ksql.execution.codegen.helpers.MapBuilder; | ||
import io.confluent.ksql.execution.codegen.helpers.SearchedCaseFunction; | ||
import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression; | ||
import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression; | ||
|
@@ -124,7 +125,8 @@ public class SqlToJavaVisitor { | |
RoundingMode.class.getCanonicalName(), | ||
SchemaBuilder.class.getCanonicalName(), | ||
Struct.class.getCanonicalName(), | ||
ArrayBuilder.class.getCanonicalName() | ||
ArrayBuilder.class.getCanonicalName(), | ||
MapBuilder.class.getCanonicalName() | ||
); | ||
|
||
private static final Map<Operator, String> DECIMAL_OPERATOR_NAME = ImmutableMap | ||
|
@@ -834,7 +836,9 @@ public Pair<String, SqlType> visitCreateMapExpression( | |
final CreateMapExpression exp, | ||
final Void context | ||
) { | ||
final StringBuilder map = new StringBuilder("ImmutableMap.builder()"); | ||
final StringBuilder map = new StringBuilder("new MapBuilder("); | ||
map.append(exp.getMap().size()); | ||
map.append((')')); | ||
Comment on lines
+818
to
+820
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Switch to a new builder type that won't throw on null keys or values. |
||
|
||
for (Entry<Expression, Expression> entry: exp.getMap().entrySet()) { | ||
map.append(".put("); | ||
|
@@ -925,27 +929,25 @@ private CastVisitor() { | |
} | ||
|
||
static Pair<String, SqlType> getCast(final Pair<String, SqlType> expr, final SqlType sqlType) { | ||
if (!sqlType.supportsCast()) { | ||
throw new KsqlFunctionException( | ||
"Only casts to primitive types and decimal are supported: " + sqlType); | ||
} | ||
|
||
final SqlType rightSchema = expr.getRight(); | ||
if (sqlType.equals(rightSchema) || rightSchema == null) { | ||
final SqlType sourceType = expr.getRight(); | ||
if (sourceType == null || sqlType.equals(sourceType)) { | ||
// sourceType is null if source is SQL NULL | ||
return new Pair<>(expr.getLeft(), sqlType); | ||
} | ||
|
||
return CASTERS.getOrDefault( | ||
sqlType.baseType(), | ||
(e, t, r) -> { | ||
throw new KsqlException("Invalid cast operation: " + t); | ||
} | ||
) | ||
.cast(expr, sqlType, sqlType); | ||
return CASTERS.getOrDefault(sqlType.baseType(), CastVisitor::unsupportedCast) | ||
.cast(expr, sqlType); | ||
} | ||
Comment on lines
+907
to
+915
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. getCast no longer calls sqlType.supportsCast as it is this the getCast method that determines which casts are supported and which are not. Hence supportsCast is superfluous and likely to get out of date with what this method supports. |
||
|
||
private static Pair<String, SqlType> unsupportedCast( | ||
final Pair<String, SqlType> expr, final SqlType returnType | ||
) { | ||
throw new KsqlFunctionException("Cast of " + expr.getRight() | ||
+ " to " + returnType + " is not supported"); | ||
} | ||
|
||
private static Pair<String, SqlType> castString( | ||
final Pair<String, SqlType> expr, final SqlType sqltype, final SqlType returnType | ||
final Pair<String, SqlType> expr, final SqlType returnType | ||
) { | ||
final SqlType schema = expr.getRight(); | ||
final String exprStr; | ||
|
@@ -961,13 +963,13 @@ private static Pair<String, SqlType> castString( | |
} | ||
|
||
private static Pair<String, SqlType> castBoolean( | ||
final Pair<String, SqlType> expr, final SqlType sqltype, final SqlType returnType | ||
final Pair<String, SqlType> expr, final SqlType returnType | ||
) { | ||
return new Pair<>(getCastToBooleanString(expr.getRight(), expr.getLeft()), returnType); | ||
} | ||
|
||
private static Pair<String, SqlType> castInteger( | ||
final Pair<String, SqlType> expr, final SqlType sqltype, final SqlType returnType | ||
final Pair<String, SqlType> expr, final SqlType returnType | ||
) { | ||
final String exprStr = getCastString( | ||
expr.getRight(), | ||
|
@@ -979,7 +981,7 @@ private static Pair<String, SqlType> castInteger( | |
} | ||
|
||
private static Pair<String, SqlType> castLong( | ||
final Pair<String, SqlType> expr, final SqlType sqltype, final SqlType returnType | ||
final Pair<String, SqlType> expr, final SqlType returnType | ||
) { | ||
final String exprStr = getCastString( | ||
expr.getRight(), | ||
|
@@ -991,7 +993,7 @@ private static Pair<String, SqlType> castLong( | |
} | ||
|
||
private static Pair<String, SqlType> castDouble( | ||
final Pair<String, SqlType> expr, final SqlType sqltype, final SqlType returnType | ||
final Pair<String, SqlType> expr, final SqlType returnType | ||
) { | ||
final String exprStr = getCastString( | ||
expr.getRight(), | ||
|
@@ -1003,13 +1005,13 @@ private static Pair<String, SqlType> castDouble( | |
} | ||
|
||
private static Pair<String, SqlType> castDecimal( | ||
final Pair<String, SqlType> expr, final SqlType sqltype, final SqlType returnType | ||
final Pair<String, SqlType> expr, final SqlType returnType | ||
) { | ||
if (!(sqltype instanceof SqlDecimal)) { | ||
throw new KsqlException("Expected decimal type: " + sqltype); | ||
if (!(returnType instanceof SqlDecimal)) { | ||
throw new KsqlException("Expected decimal type: " + returnType); | ||
Comment on lines
+983
to
+986
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note: these methods were always called with the same param for |
||
} | ||
|
||
final SqlDecimal sqlDecimal = (SqlDecimal) sqltype; | ||
final SqlDecimal sqlDecimal = (SqlDecimal) returnType; | ||
|
||
if (expr.getRight().baseType() == SqlBaseType.DECIMAL && expr.right.equals(sqlDecimal)) { | ||
return expr; | ||
|
@@ -1049,7 +1051,6 @@ private static String getCastString( | |
return "(new Double(" + exprStr + ")." + javaTypeMethod + ")"; | ||
case STRING: | ||
return javaStringParserMethod + "(" + exprStr + ")"; | ||
|
||
default: | ||
throw new KsqlFunctionException( | ||
"Invalid cast operation: Cannot cast " | ||
|
@@ -1086,7 +1087,6 @@ private interface CastFunction { | |
|
||
Pair<String, SqlType> cast( | ||
Pair<String, SqlType> expr, | ||
SqlType sqltype, | ||
SqlType returnType | ||
); | ||
} | ||
|
@@ -1104,5 +1104,4 @@ private CaseWhenProcessed( | |
this.thenProcessResult = thenProcessResult; | ||
} | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
/* | ||
* Copyright 2020 Confluent Inc. | ||
* | ||
* Licensed under the Confluent Community License (the "License"; you may not use | ||
* this file except in compliance with the License. You may obtain a copy of the | ||
* License at | ||
* | ||
* http://www.confluent.io/confluent-community-license | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
* WARRANTIES OF ANY KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations under the License. | ||
*/ | ||
|
||
package io.confluent.ksql.execution.codegen.helpers; | ||
|
||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
/** | ||
* Used to construct maps using the builder pattern. Note that we cannot use {@link | ||
* com.google.common.collect.ImmutableMap} because it does not accept null values. | ||
*/ | ||
public class MapBuilder { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Builder of maps with nulls |
||
|
||
private final HashMap<Object, Object> map; | ||
|
||
public MapBuilder(final int size) { | ||
map = new HashMap<>(size); | ||
} | ||
|
||
public MapBuilder put(final Object key, final Object value) { | ||
map.put(key, value); | ||
return this; | ||
} | ||
|
||
public Map<Object, Object> build() { | ||
return Collections.unmodifiableMap(map); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,7 +55,6 @@ | |
import io.confluent.ksql.function.AggregateFunctionInitArguments; | ||
import io.confluent.ksql.function.FunctionRegistry; | ||
import io.confluent.ksql.function.KsqlAggregateFunction; | ||
import io.confluent.ksql.function.KsqlFunctionException; | ||
import io.confluent.ksql.function.KsqlTableFunction; | ||
import io.confluent.ksql.function.UdfFactory; | ||
import io.confluent.ksql.schema.ksql.Column; | ||
|
@@ -146,13 +145,7 @@ public Void visitNotExpression( | |
|
||
@Override | ||
public Void visitCast(final Cast node, final ExpressionTypeContext expressionTypeContext) { | ||
final SqlType sqlType = node.getType().getSqlType(); | ||
if (!sqlType.supportsCast()) { | ||
throw new KsqlFunctionException("Only casts to primitive types or decimals " | ||
+ "are supported: " + sqlType); | ||
} | ||
|
||
expressionTypeContext.setSqlType(sqlType); | ||
expressionTypeContext.setSqlType(node.getType().getSqlType()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed this as its not needed. It's |
||
return null; | ||
} | ||
|
||
|
@@ -404,7 +397,11 @@ public Void visitCreateMapExpression( | |
.collect(Collectors.toList()); | ||
|
||
if (keyTypes.stream().anyMatch(type -> !SqlTypes.STRING.equals(type))) { | ||
throw new KsqlException("Only STRING keys are supported in maps but got: " + keyTypes); | ||
final String types = keyTypes.stream() | ||
.map(type -> type == null ? "NULL" : type.toString()) | ||
.collect(Collectors.joining(", ", "[", "]")); | ||
|
||
throw new KsqlException("Only STRING keys are supported in maps but got: " + types); | ||
} | ||
|
||
final List<SqlType> valueTypes = exp.getMap() | ||
|
@@ -414,9 +411,16 @@ public Void visitCreateMapExpression( | |
process(val, context); | ||
return context.getSqlType(); | ||
}) | ||
.filter(Objects::nonNull) | ||
.distinct() | ||
.collect(Collectors.toList()); | ||
|
||
if (valueTypes.size() == 0) { | ||
throw new KsqlException("Cannot construct a map with all NULL values " | ||
+ "(see https://github.com/confluentinc/ksql/issues/4239). As a workaround, you may " | ||
+ "cast a NULL value to the desired type."); | ||
} | ||
|
||
if (valueTypes.size() != 1) { | ||
throw new KsqlException( | ||
String.format( | ||
|
@@ -425,11 +429,6 @@ public Void visitCreateMapExpression( | |
exp)); | ||
} | ||
|
||
if (valueTypes.get(0) == null) { | ||
throw new KsqlException("Cannot construct MAP with NULL values. As a workaround, you " | ||
+ "may cast a NULL value to the desired type."); | ||
} | ||
|
||
context.setSqlType(SqlMap.of(valueTypes.get(0))); | ||
return null; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This avoids any compiling of code etc to handle a NULL literal.