diff --git a/ksql-common/src/main/java/io/confluent/ksql/schema/Operator.java b/ksql-common/src/main/java/io/confluent/ksql/schema/Operator.java index f2cb38b1f59a..d6550a64003f 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/schema/Operator.java +++ b/ksql-common/src/main/java/io/confluent/ksql/schema/Operator.java @@ -15,20 +15,84 @@ package io.confluent.ksql.schema; +import static java.util.Objects.requireNonNull; + +import io.confluent.ksql.schema.ksql.SqlBaseType; +import io.confluent.ksql.schema.ksql.types.SqlDecimal; +import io.confluent.ksql.schema.ksql.types.SqlType; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.util.KsqlException; +import java.util.function.BinaryOperator; + public enum Operator { - ADD("+"), - SUBTRACT("-"), - MULTIPLY("*"), - DIVIDE("/"), - MODULUS("%"); + ADD("+", SqlDecimal::add) { + @Override + public SqlType resultType(final SqlType left, final SqlType right) { + if (left.baseType() == SqlBaseType.STRING && right.baseType() == SqlBaseType.STRING) { + return SqlTypes.STRING; + } + + return super.resultType(left, right); + } + }, + SUBTRACT("-", SqlDecimal::subtract), + MULTIPLY("*", SqlDecimal::multiply), + DIVIDE("/", SqlDecimal::divide), + MODULUS("%", SqlDecimal::modulus); private final String symbol; + private final BinaryOperator binaryResolver; - Operator(final String symbol) { - this.symbol = symbol; + Operator(final String symbol, final BinaryOperator binaryResolver) { + this.symbol = requireNonNull(symbol, "symbol"); + this.binaryResolver = requireNonNull(binaryResolver, "binaryResolver"); } public String getSymbol() { return symbol; } + + /** + * Determine the result type for the given parameters. + * + * @param left the left side of the operation. + * @param right the right side of the operation. + * @return the result schema. + */ + public SqlType resultType(final SqlType left, final SqlType right) { + if (left.baseType().isNumber() && right.baseType().isNumber()) { + if (left.baseType().canUpCast(right.baseType())) { + if (right.baseType() != SqlBaseType.DECIMAL) { + return right; + } + + return binaryResolver.apply(toDecimal(left), (SqlDecimal) right); + } + + if (right.baseType().canUpCast(left.baseType())) { + if (left.baseType() != SqlBaseType.DECIMAL) { + return left; + } + + return binaryResolver.apply((SqlDecimal) left, toDecimal(right)); + } + } + + throw new KsqlException( + "Unsupported arithmetic types. " + left.baseType() + " " + right.baseType()); + } + + private static SqlDecimal toDecimal(final SqlType type) { + switch (type.baseType()) { + case DECIMAL: + return (SqlDecimal) type; + case INTEGER: + return SqlTypes.INT_UPCAST_TO_DECIMAL; + case BIGINT: + return SqlTypes.BIGINT_UPCAST_TO_DECIMAL; + default: + throw new KsqlException( + "Cannot convert " + type.baseType() + " to " + SqlBaseType.DECIMAL + "."); + } + } } diff --git a/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/SqlBaseType.java b/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/SqlBaseType.java index c1f679be4459..f7e5bf15db8f 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/SqlBaseType.java +++ b/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/SqlBaseType.java @@ -19,14 +19,28 @@ * The SQL types supported by KSQL. */ public enum SqlBaseType { - BOOLEAN, INTEGER, BIGINT, DOUBLE, DECIMAL, STRING, ARRAY, MAP, STRUCT; + BOOLEAN, INTEGER, BIGINT, DECIMAL, DOUBLE, STRING, ARRAY, MAP, STRUCT; + /** + * @return {@code true} if numeric type. + */ public boolean isNumber() { - // for now, conversions between DECIMAL and other numeric types is not supported - return this == INTEGER || this == BIGINT || this == DOUBLE; + return this == INTEGER || this == BIGINT || this == DECIMAL || this == DOUBLE; } + /** + * Test to see if this type can be up-cast to another. + * + *

This defines if KSQL supports implicitly converting one numeric type to another. + * + *

Types can always be upcast to themselves. Only numeric types can be upcast to different + * numeric types. Note: STRING to DECIMAL handling is not seen as up-casting, it's parsing. + * + * @param to the target type. + * @return true if this type can be upcast to the supplied type. + */ public boolean canUpCast(final SqlBaseType to) { - return isNumber() && to.isNumber() && this.ordinal() <= to.ordinal(); + return this.equals(to) + || (isNumber() && to.isNumber() && this.ordinal() <= to.ordinal()); } } diff --git a/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/types/SqlDecimal.java b/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/types/SqlDecimal.java index 66c4b2b35246..5ef1b6402a4f 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/types/SqlDecimal.java +++ b/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/types/SqlDecimal.java @@ -105,4 +105,75 @@ public String toString() { public String toString(final FormatOptions formatOptions) { return toString(); } + + /** + * Determine the decimal type should two decimals be added together. + * + * @param left the left side decimal. + * @param right the right side decimal. + * @return the resulting decimal type. + */ + public static SqlDecimal add(final SqlDecimal left, final SqlDecimal right) { + final int precision = Math.max(left.scale, right.scale) + + Math.max(left.precision - left.scale, right.precision - right.scale) + + 1; + + final int scale = Math.max(left.scale, right.scale); + return SqlDecimal.of(precision, scale); + } + + /** + * Determine the decimal type should one decimal be subtracted from another. + * + * @param left the left side decimal. + * @param right the right side decimal. + * @return the resulting decimal type. + */ + public static SqlDecimal subtract(final SqlDecimal left, final SqlDecimal right) { + return add(left, right); + } + + /** + * Determine the decimal type should one decimal be multiplied by another. + * + * @param left the left side decimal. + * @param right the right side decimal. + * @return the resulting decimal type. + */ + public static SqlDecimal multiply(final SqlDecimal left, final SqlDecimal right) { + final int precision = left.precision + right.precision + 1; + final int scale = left.scale + right.scale; + return SqlDecimal.of(precision, scale); + } + + /** + * Determine the decimal type should one decimal be divided by another. + * + * @param left the left side decimal. + * @param right the right side decimal. + * @return the resulting decimal type. + */ + public static SqlDecimal divide(final SqlDecimal left, final SqlDecimal right) { + final int precision = left.precision - left.scale + right.scale + + Math.max(6, left.scale + right.precision + 1); + + final int scale = Math.max(6, left.scale + right.precision + 1); + return SqlDecimal.of(precision, scale); + } + + /** + * Determine the decimal result type when calculating the remainder of dividing one decimal by + * another. + * + * @param left the left side decimal. + * @param right the right side decimal. + * @return the resulting decimal type. + */ + public static SqlDecimal modulus(final SqlDecimal left, final SqlDecimal right) { + final int precision = Math.min(left.precision - left.scale, right.precision - right.scale) + + Math.max(left.scale, right.scale); + + final int scale = Math.max(left.scale, right.scale); + return SqlDecimal.of(precision, scale); + } } diff --git a/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/types/SqlTypes.java b/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/types/SqlTypes.java index 9093dca164b8..8beb72a14b62 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/types/SqlTypes.java +++ b/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/types/SqlTypes.java @@ -43,4 +43,14 @@ public static SqlMap map(final SqlType valueType) { public static SqlStruct.Builder struct() { return SqlStruct.builder(); } + + /** + * Schema of an INT up-cast to a DECIMAL + */ + public static final SqlDecimal INT_UPCAST_TO_DECIMAL = SqlDecimal.of(10, 0); + + /** + * Schema of an BIGINT up-cast to a DECIMAL + */ + public static final SqlDecimal BIGINT_UPCAST_TO_DECIMAL = SqlDecimal.of(19, 0); } diff --git a/ksql-common/src/main/java/io/confluent/ksql/util/DecimalUtil.java b/ksql-common/src/main/java/io/confluent/ksql/util/DecimalUtil.java index 624fc1b7e627..97e582f231b1 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/util/DecimalUtil.java +++ b/ksql-common/src/main/java/io/confluent/ksql/util/DecimalUtil.java @@ -16,6 +16,7 @@ package io.confluent.ksql.util; import io.confluent.ksql.schema.ksql.types.SqlDecimal; +import io.confluent.ksql.schema.ksql.types.SqlTypes; import java.math.BigDecimal; import java.math.MathContext; import java.math.RoundingMode; @@ -154,28 +155,6 @@ public static BigDecimal ensureFit(final BigDecimal value, final Schema schema) } } - /** - * Converts a schema to a decimal schema with set precision/scale without losing - * scale or precision. - * - * @param schema the schema - * @return the decimal schema - * @throws KsqlException if the schema cannot safely be converted to decimal - */ - public static Schema toDecimal(final Schema schema) { - switch (schema.type()) { - case BYTES: - requireDecimal(schema); - return schema; - case INT32: - return builder(10, 0).build(); - case INT64: - return builder(19, 0).build(); - default: - throw new KsqlException("Cannot convert schema of type " + schema.type() + " to decimal."); - } - } - /** * Converts a schema to a sql decimal with set precision/scale without losing * scale or precision. @@ -190,9 +169,9 @@ public static SqlDecimal toSqlDecimal(final Schema schema) { requireDecimal(schema); return SqlDecimal.of(precision(schema), scale(schema)); case INT32: - return SqlDecimal.of(10, 0); + return SqlTypes.INT_UPCAST_TO_DECIMAL; case INT64: - return SqlDecimal.of(19, 0); + return SqlTypes.BIGINT_UPCAST_TO_DECIMAL; default: throw new KsqlException("Cannot convert schema of type " + schema.type() + " to decimal."); } diff --git a/ksql-common/src/main/java/io/confluent/ksql/util/SchemaUtil.java b/ksql-common/src/main/java/io/confluent/ksql/util/SchemaUtil.java index c213f0c5c526..36d87885d8ce 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/util/SchemaUtil.java +++ b/ksql-common/src/main/java/io/confluent/ksql/util/SchemaUtil.java @@ -15,20 +15,14 @@ package io.confluent.ksql.util; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.ImmutableSortedMap; -import com.google.common.collect.Ordering; import io.confluent.ksql.function.GenericsUtil; import io.confluent.ksql.name.ColumnName; -import io.confluent.ksql.schema.Operator; import io.confluent.ksql.schema.ksql.SchemaConverters; import java.util.List; import java.util.Map; -import java.util.NavigableMap; import java.util.Objects; -import java.util.Optional; import java.util.Set; import java.util.function.BiPredicate; import org.apache.kafka.connect.data.Field; @@ -44,31 +38,15 @@ public final class SchemaUtil { public static final int ROWKEY_INDEX = 1; - private static final List ARITHMETIC_TYPES_LIST = - ImmutableList.of( - Schema.Type.INT8, - Schema.Type.INT16, - Schema.Type.INT32, - Schema.Type.INT64, - Schema.Type.FLOAT32, - Schema.Type.FLOAT64 - ); - - private static final Set ARITHMETIC_TYPES = - ImmutableSet.copyOf(ARITHMETIC_TYPES_LIST); - - private static final Ordering ARITHMETIC_TYPE_ORDERING = Ordering.explicit( - ARITHMETIC_TYPES_LIST + private static final Set ARITHMETIC_TYPES = ImmutableSet.of( + Type.INT8, + Type.INT16, + Type.INT32, + Type.INT64, + Type.FLOAT32, + Type.FLOAT64 ); - private static final NavigableMap TYPE_TO_SCHEMA = - ImmutableSortedMap.orderedBy(ARITHMETIC_TYPE_ORDERING) - .put(Schema.Type.INT32, Schema.OPTIONAL_INT32_SCHEMA) - .put(Schema.Type.INT64, Schema.OPTIONAL_INT64_SCHEMA) - .put(Schema.Type.FLOAT32, Schema.OPTIONAL_FLOAT64_SCHEMA) - .put(Schema.Type.FLOAT64, Schema.OPTIONAL_FLOAT64_SCHEMA) - .build(); - private static final char FIELD_NAME_DELIMITER = '.'; private static final Map> CUSTOM_SCHEMA_EQ = @@ -121,85 +99,8 @@ public static String getFieldNameWithNoAlias(final String fieldName) { return fieldName.substring(idx + 1); } - public static Optional getFieldNameAlias(final String fieldName) { - final int idx = fieldName.indexOf(FIELD_NAME_DELIMITER); - if (idx < 0) { - return Optional.empty(); - } - - return Optional.of(fieldName.substring(0, idx)); - } - - public static Schema resolveBinaryOperatorResultType( - final Schema left, - final Schema right, - final Operator operator - ) { - if (left.type() == Schema.Type.STRING && right.type() == Schema.Type.STRING) { - return Schema.OPTIONAL_STRING_SCHEMA; - } - - if (DecimalUtil.isDecimal(left) || DecimalUtil.isDecimal(right)) { - if (left.type() != Schema.Type.FLOAT64 && right.type() != Schema.Type.FLOAT64) { - return resolveDecimalOperatorResultType( - DecimalUtil.toDecimal(left), DecimalUtil.toDecimal(right), operator); - } - return Schema.OPTIONAL_FLOAT64_SCHEMA; - } - - if (!TYPE_TO_SCHEMA.containsKey(left.type()) || !TYPE_TO_SCHEMA.containsKey(right.type())) { - throw new KsqlException("Unsupported arithmetic types. " + left.type() + " " + right.type()); - } - - return TYPE_TO_SCHEMA.ceilingEntry( - ARITHMETIC_TYPE_ORDERING.max(left.type(), right.type())).getValue(); - } - - private static Schema resolveDecimalOperatorResultType( - final Schema left, - final Schema right, - final Operator operator - ) { - final int lPrecision = DecimalUtil.precision(left); - final int rPrecision = DecimalUtil.precision(right); - final int lScale = DecimalUtil.scale(left); - final int rScale = DecimalUtil.scale(right); - - final int precision; - final int scale; - switch (operator) { - case ADD: - case SUBTRACT: - precision = Math.max(lScale, rScale) - + Math.max(lPrecision - lScale, rPrecision - rScale) - + 1; - scale = Math.max(lScale, rScale); - break; - case MULTIPLY: - precision = lPrecision + rPrecision + 1; - scale = lScale + rScale; - break; - case DIVIDE: - precision = lPrecision - lScale + rScale + Math.max(6, lScale + rPrecision + 1); - scale = Math.max(6, lScale + rPrecision + 1); - break; - case MODULUS: - precision = Math.min(lPrecision - lScale, rPrecision - rScale) + Math.max(lScale, rScale); - scale = Math.max(lScale, rScale); - break; - default: - throw new KsqlException("Unexpected operator type: " + operator); - } - - return DecimalUtil.builder(precision, scale).build(); - } - - static boolean isNumber(final Schema.Type type) { - return ARITHMETIC_TYPES.contains(type); - } - public static boolean isNumber(final Schema schema) { - return isNumber(schema.type()) || DecimalUtil.isDecimal(schema); + return ARITHMETIC_TYPES.contains(schema.type()) || DecimalUtil.isDecimal(schema); } public static Schema ensureOptional(final Schema schema) { diff --git a/ksql-common/src/test/java/io/confluent/ksql/schema/OperatorTest.java b/ksql-common/src/test/java/io/confluent/ksql/schema/OperatorTest.java new file mode 100644 index 000000000000..abd29c764b3f --- /dev/null +++ b/ksql-common/src/test/java/io/confluent/ksql/schema/OperatorTest.java @@ -0,0 +1,188 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.schema; + +import static io.confluent.ksql.schema.Operator.ADD; +import static io.confluent.ksql.schema.Operator.DIVIDE; +import static io.confluent.ksql.schema.Operator.MODULUS; +import static io.confluent.ksql.schema.Operator.MULTIPLY; +import static io.confluent.ksql.schema.Operator.SUBTRACT; +import static io.confluent.ksql.schema.ksql.types.SqlTypes.BIGINT; +import static io.confluent.ksql.schema.ksql.types.SqlTypes.BOOLEAN; +import static io.confluent.ksql.schema.ksql.types.SqlTypes.DOUBLE; +import static io.confluent.ksql.schema.ksql.types.SqlTypes.INTEGER; +import static io.confluent.ksql.schema.ksql.types.SqlTypes.STRING; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.junit.Assert.fail; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.confluent.ksql.schema.ksql.SqlBaseType; +import io.confluent.ksql.schema.ksql.types.SqlDecimal; +import io.confluent.ksql.schema.ksql.types.SqlType; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.util.KsqlException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.function.BinaryOperator; +import java.util.stream.Collectors; +import org.junit.Test; + +public class OperatorTest { + + private static final SqlDecimal DECIMAL = SqlTypes.decimal(2, 1); + private static final SqlDecimal INT_AS_DECIMAL = SqlTypes.decimal(10, 0); + private static final SqlDecimal BIGINT_AS_DECIMAL = SqlTypes.decimal(19, 0); + + private static final Map TYPES = ImmutableMap + .builder() + .put(SqlBaseType.BOOLEAN, BOOLEAN) + .put(SqlBaseType.INTEGER, INTEGER) + .put(SqlBaseType.BIGINT, BIGINT) + .put(SqlBaseType.DECIMAL, SqlTypes.decimal(2, 1)) + .put(SqlBaseType.DOUBLE, DOUBLE) + .put(SqlBaseType.STRING, STRING) + .put(SqlBaseType.ARRAY, SqlTypes.array(BIGINT)) + .put(SqlBaseType.MAP, SqlTypes.map(INTEGER)) + .put(SqlBaseType.STRUCT, SqlTypes.struct().field("f", INTEGER).build()) + .build(); + + @Test + public void shouldResolveValidAddReturnType() { + assertThat(ADD.resultType(STRING, STRING), is(STRING)); + + assertConversionRule(ADD, SqlDecimal::add); + } + + @Test + public void shouldResolveSubtractReturnType() { + assertConversionRule(SUBTRACT, SqlDecimal::subtract); + } + + @Test + public void shouldResolveMultiplyReturnType() { + assertConversionRule(MULTIPLY, SqlDecimal::multiply); + } + + @Test + public void shouldResolveDivideReturnType() { + assertConversionRule(DIVIDE, SqlDecimal::divide); + } + + @Test + public void shouldResolveModulusReturnType() { + assertConversionRule(MODULUS, SqlDecimal::modulus); + } + + @Test + public void shouldWorkUsingSameRulesAsBaseTypeUpCastRules() { + allOperations().forEach(op -> { + + for (final SqlBaseType leftBaseType : SqlBaseType.values()) { + // Given: + final Map> partitioned = Arrays + .stream(SqlBaseType.values()) + .collect(Collectors.partitioningBy( + rightBaseType -> shouldBeSupported(op, leftBaseType, rightBaseType))); + + final List shouldUpCast = partitioned.getOrDefault(true, ImmutableList.of()); + final List shouldNotUpCast = partitioned + .getOrDefault(false, ImmutableList.of()); + + // Then: + shouldUpCast.forEach(rightBaseType -> + assertThat( + "should support " + op + " on (" + leftBaseType + ", " + rightBaseType + ")", + op.resultType(getType(leftBaseType), getType(rightBaseType)), + is(notNullValue()) + ) + ); + + shouldNotUpCast.forEach(rightBaseType -> { + try { + op.resultType(getType(leftBaseType), getType(rightBaseType)); + fail("should not support " + op + " on (" + leftBaseType + ", " + rightBaseType + ")"); + } catch (final KsqlException e) { + // Expected + } + }); + } + }); + } + + private static void assertConversionRule( + final Operator op, + final BinaryOperator binaryResolver + ) { + assertThat(op.resultType(INTEGER, INTEGER), is(INTEGER)); + assertThat(op.resultType(INTEGER, BIGINT), is(BIGINT)); + assertThat(op.resultType(BIGINT, INTEGER), is(BIGINT)); + assertThat(op.resultType(INTEGER, DECIMAL), is(binaryResolver.apply(INT_AS_DECIMAL, DECIMAL))); + assertThat(op.resultType(DECIMAL, INTEGER), is(binaryResolver.apply(DECIMAL, INT_AS_DECIMAL))); + assertThat(op.resultType(INTEGER, DOUBLE), is(DOUBLE)); + assertThat(op.resultType(DOUBLE, INTEGER), is(DOUBLE)); + + assertThat(op.resultType(BIGINT, BIGINT), is(BIGINT)); + assertThat(op.resultType(BIGINT, DECIMAL), is(binaryResolver.apply(BIGINT_AS_DECIMAL, DECIMAL))); + assertThat(op.resultType(DECIMAL, BIGINT), is(binaryResolver.apply(DECIMAL, BIGINT_AS_DECIMAL))); + assertThat(op.resultType(BIGINT, DOUBLE), is(DOUBLE)); + assertThat(op.resultType(DOUBLE, BIGINT), is(DOUBLE)); + + assertThat(op.resultType(DECIMAL, DECIMAL), is(binaryResolver.apply(DECIMAL, DECIMAL))); + assertThat(op.resultType(DECIMAL, DOUBLE), is(DOUBLE)); + assertThat(op.resultType(DOUBLE, DECIMAL), is(DOUBLE)); + + assertThat(op.resultType(DOUBLE, DOUBLE), is(DOUBLE)); + } + + private static boolean shouldBeSupported( + final Operator op, + final SqlBaseType leftBaseType, + final SqlBaseType rightBaseType + ) { + return (isNumeric(leftBaseType) && isNumeric(rightBaseType)) + || (op == ADD && leftBaseType == SqlBaseType.STRING && rightBaseType == SqlBaseType.STRING); + } + + private static boolean isNumeric(final SqlBaseType baseType) { + switch (baseType) { + case INTEGER: + case BIGINT: + case DECIMAL: + case DOUBLE: + return true; + default: + return false; + } + } + + private static List allOperations() { + return ImmutableList.copyOf(Operator.values()); + } + + private static SqlType getType(final SqlBaseType baseType) { + final SqlType type = TYPES.get(baseType); + assertThat( + "invalid test: need type for base type:" + baseType, + type, + is(notNullValue()) + ); + return type; + } +} \ No newline at end of file diff --git a/ksql-common/src/test/java/io/confluent/ksql/schema/ksql/SqlBaseTypeTest.java b/ksql-common/src/test/java/io/confluent/ksql/schema/ksql/SqlBaseTypeTest.java index b341b1020aa4..2e4464261c20 100644 --- a/ksql-common/src/test/java/io/confluent/ksql/schema/ksql/SqlBaseTypeTest.java +++ b/ksql-common/src/test/java/io/confluent/ksql/schema/ksql/SqlBaseTypeTest.java @@ -19,6 +19,7 @@ import static org.hamcrest.Matchers.is; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; import java.util.Arrays; import java.util.Set; import java.util.stream.Stream; @@ -26,53 +27,89 @@ public class SqlBaseTypeTest { - private static final Set NUMBER_TYPES = - ImmutableSet.of(SqlBaseType.INTEGER, SqlBaseType.BIGINT, SqlBaseType.DOUBLE); + private static final Set NUMBER_TYPES = ImmutableSet.of( + SqlBaseType.INTEGER, + SqlBaseType.BIGINT, + SqlBaseType.DECIMAL, + SqlBaseType.DOUBLE + ); @Test public void shouldNotBeNumber() { - nonNumberTypes().forEach(sqlType -> - assertThat(sqlType + " should not be number", sqlType.isNumber(), is(false))); + nonNumberTypes().forEach(sqlType -> assertThat( + sqlType + " should not be number", + sqlType.isNumber(), + is(false) + )); } @Test public void shouldBeNumber() { - numberTypes().forEach(sqlType -> - assertThat(sqlType + " should be number", sqlType.isNumber(), is(true))); + numberTypes().forEach(sqlType -> assertThat( + sqlType + " should be number", + sqlType.isNumber(), + is(true) + )); } @Test public void shouldNotUpCastIfNotNumber() { - nonNumberTypes().forEach(sqlType -> - assertThat(sqlType + " should not upcast", sqlType.canUpCast(sqlType), is(false))); + nonNumberTypes().forEach(sqlType -> assertThat( + sqlType + " should not upcast", + sqlType.canUpCast(SqlBaseType.DOUBLE), + is(false)) + ); } @Test - public void shouldUpCastToSelfIfNumber() { - numberTypes().forEach(sqlType -> + public void shouldUpCastIfNumber() { + numberTypes().forEach(sqlType -> assertThat( + sqlType + " should upcast", + sqlType.canUpCast(SqlBaseType.DOUBLE), + is(true)) + ); + } + + @Test + public void shouldUpCastToSelf() { + allTypes().forEach(sqlType -> assertThat(sqlType + " should upcast to self", sqlType.canUpCast(sqlType), is(true))); } @Test public void shouldUpCastInt() { assertThat(SqlBaseType.INTEGER.canUpCast(SqlBaseType.BIGINT), is(true)); + assertThat(SqlBaseType.INTEGER.canUpCast(SqlBaseType.DECIMAL), is(true)); assertThat(SqlBaseType.INTEGER.canUpCast(SqlBaseType.DOUBLE), is(true)); } @Test public void shouldUpCastBigInt() { + assertThat(SqlBaseType.BIGINT.canUpCast(SqlBaseType.DECIMAL), is(true)); assertThat(SqlBaseType.BIGINT.canUpCast(SqlBaseType.DOUBLE), is(true)); } + @Test + public void shouldUpCastDecimal() { + assertThat(SqlBaseType.DECIMAL.canUpCast(SqlBaseType.DOUBLE), is(true)); + } + @Test public void shouldNotDownCastBigInt() { assertThat(SqlBaseType.BIGINT.canUpCast(SqlBaseType.INTEGER), is(false)); } + @Test + public void shouldNotDownCastDecimal() { + assertThat(SqlBaseType.DECIMAL.canUpCast(SqlBaseType.INTEGER), is(false)); + assertThat(SqlBaseType.DECIMAL.canUpCast(SqlBaseType.BIGINT), is(false)); + } + @Test public void shouldNotDownCastDouble() { assertThat(SqlBaseType.DOUBLE.canUpCast(SqlBaseType.INTEGER), is(false)); assertThat(SqlBaseType.DOUBLE.canUpCast(SqlBaseType.BIGINT), is(false)); + assertThat(SqlBaseType.DOUBLE.canUpCast(SqlBaseType.DECIMAL), is(false)); } private static Stream numberTypes() { @@ -83,4 +120,9 @@ private static Stream nonNumberTypes() { return Arrays.stream(SqlBaseType.values()) .filter(sqlType -> !NUMBER_TYPES.contains(sqlType)); } + + @SuppressWarnings("UnstableApiUsage") + private static Stream allTypes() { + return Streams.concat(numberTypes(), nonNumberTypes()); + } } \ No newline at end of file diff --git a/ksql-common/src/test/java/io/confluent/ksql/schema/ksql/types/SqlDecimalTest.java b/ksql-common/src/test/java/io/confluent/ksql/schema/ksql/types/SqlDecimalTest.java index 12f7c0f55c55..9d7c7cd2871f 100644 --- a/ksql-common/src/test/java/io/confluent/ksql/schema/ksql/types/SqlDecimalTest.java +++ b/ksql-common/src/test/java/io/confluent/ksql/schema/ksql/types/SqlDecimalTest.java @@ -18,11 +18,14 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; +import com.google.common.collect.ImmutableMap; import com.google.common.testing.EqualsTester; import io.confluent.ksql.schema.ksql.DataException; import io.confluent.ksql.schema.ksql.SqlBaseType; import io.confluent.ksql.util.KsqlException; +import io.confluent.ksql.util.Pair; import java.math.BigDecimal; +import java.util.Map; import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.Rule; @@ -134,4 +137,102 @@ public void shouldValidateValue() { // When: schema.validateValue(new BigDecimal("123.0")); } + + @Test + public void shouldResolveDecimalAddition() { + final Map, SqlDecimal> testCases = + ImmutableMap., SqlDecimal>builder() + .put(Pair.of(SqlTypes.decimal(2, 1), SqlTypes.decimal(2, 1)), SqlTypes.decimal(3, 1)) + .put(Pair.of(SqlTypes.decimal(2, 1), SqlTypes.decimal(2, 2)), SqlTypes.decimal(4, 2)) + .put(Pair.of(SqlTypes.decimal(2, 2), SqlTypes.decimal(2, 1)), SqlTypes.decimal(4, 2)) + .put(Pair.of(SqlTypes.decimal(2, 1), SqlTypes.decimal(3, 2)), SqlTypes.decimal(4, 2)) + .put(Pair.of(SqlTypes.decimal(3, 2), SqlTypes.decimal(2, 1)), SqlTypes.decimal(4, 2)) + .build(); + + testCases.forEach((in, expected) -> { + // When: + final SqlDecimal result = SqlDecimal.add(in.left, in.right); + + // Then: + assertThat(result, is(expected)); + }); + } + + @Test + public void shouldResolveDecimalSubtraction() { + final Map, SqlDecimal> inputToExpected = + ImmutableMap., SqlDecimal>builder() + .put(Pair.of(SqlTypes.decimal(2, 1), SqlTypes.decimal(2, 1)), SqlTypes.decimal(3, 1)) + .put(Pair.of(SqlTypes.decimal(2, 1), SqlTypes.decimal(2, 2)), SqlTypes.decimal(4, 2)) + .put(Pair.of(SqlTypes.decimal(2, 2), SqlTypes.decimal(2, 1)), SqlTypes.decimal(4, 2)) + .put(Pair.of(SqlTypes.decimal(2, 1), SqlTypes.decimal(3, 2)), SqlTypes.decimal(4, 2)) + .put(Pair.of(SqlTypes.decimal(3, 2), SqlTypes.decimal(2, 1)), SqlTypes.decimal(4, 2)) + .build(); + + inputToExpected.forEach((in, expected) -> { + // When: + final SqlDecimal result = SqlDecimal.subtract(in.left, in.right); + + // Then: + assertThat(result, is(expected)); + }); + } + + @Test + public void shouldResolveDecimalMultiply() { + final Map, SqlDecimal> inputToExpected = + ImmutableMap., SqlDecimal>builder() + .put(Pair.of(SqlTypes.decimal(2, 1), SqlTypes.decimal(2, 1)), SqlTypes.decimal(5, 2)) + .put(Pair.of(SqlTypes.decimal(2, 1), SqlTypes.decimal(2, 2)), SqlTypes.decimal(5, 3)) + .put(Pair.of(SqlTypes.decimal(2, 2), SqlTypes.decimal(2, 1)), SqlTypes.decimal(5, 3)) + .put(Pair.of(SqlTypes.decimal(3, 2), SqlTypes.decimal(2, 1)), SqlTypes.decimal(6, 3)) + .build(); + + inputToExpected.forEach((in, expected) -> { + // When: + final SqlDecimal result = SqlDecimal.multiply(in.left, in.right); + + // Then: + assertThat(result, is(expected)); + }); + } + + @Test + public void shouldResolveDecimalDivide() { + final Map, SqlDecimal> inputToExpected = + ImmutableMap., SqlDecimal>builder() + .put(Pair.of(SqlTypes.decimal(2, 1), SqlTypes.decimal(2, 1)), SqlTypes.decimal(8, 6)) + .put(Pair.of(SqlTypes.decimal(2, 1), SqlTypes.decimal(2, 2)), SqlTypes.decimal(9, 6)) + .put(Pair.of(SqlTypes.decimal(2, 2), SqlTypes.decimal(2, 1)), SqlTypes.decimal(7, 6)) + .put(Pair.of(SqlTypes.decimal(3, 3), SqlTypes.decimal(3, 3)), SqlTypes.decimal(10, 7)) + .put(Pair.of(SqlTypes.decimal(3, 3), SqlTypes.decimal(3, 2)), SqlTypes.decimal(9, 7)) + .build(); + + inputToExpected.forEach((in, expected) -> { + // When: + final SqlDecimal result = SqlDecimal.divide(in.left, in.right); + + // Then: + assertThat(result, is(expected)); + }); + } + + @Test + public void shouldResolveDecimalMod() { + final Map, SqlDecimal> inputToExpected = + ImmutableMap., SqlDecimal>builder() + .put(Pair.of(SqlTypes.decimal(2, 1), SqlTypes.decimal(2, 1)), SqlTypes.decimal(2, 1)) + .put(Pair.of(SqlTypes.decimal(2, 2), SqlTypes.decimal(2, 1)), SqlTypes.decimal(2, 2)) + .put(Pair.of(SqlTypes.decimal(2, 1), SqlTypes.decimal(2, 2)), SqlTypes.decimal(2, 2)) + .put(Pair.of(SqlTypes.decimal(3, 1), SqlTypes.decimal(2, 2)), SqlTypes.decimal(2, 2)) + .build(); + + inputToExpected.forEach((in, expected) -> { + // When: + final SqlDecimal result = SqlDecimal.modulus(in.left, in.right); + + // Then: + assertThat(result, is(expected)); + }); + } } \ No newline at end of file diff --git a/ksql-common/src/test/java/io/confluent/ksql/schema/ksql/types/SqlTypesTest.java b/ksql-common/src/test/java/io/confluent/ksql/schema/ksql/types/SqlTypesTest.java index 07e5012ddf5b..f0a4149a31ca 100644 --- a/ksql-common/src/test/java/io/confluent/ksql/schema/ksql/types/SqlTypesTest.java +++ b/ksql-common/src/test/java/io/confluent/ksql/schema/ksql/types/SqlTypesTest.java @@ -42,6 +42,7 @@ public class SqlTypesTest { private static final ImmutableMap, Object> DEFAULTS = ImmutableMap ., Object>builder() + .put(SqlDecimal.class, SqlDecimal.of(1, 0)) .build(); private final Class modelClass; diff --git a/ksql-common/src/test/java/io/confluent/ksql/util/DecimalUtilTest.java b/ksql-common/src/test/java/io/confluent/ksql/util/DecimalUtilTest.java index 1f56cccb0b17..9b2caab63752 100644 --- a/ksql-common/src/test/java/io/confluent/ksql/util/DecimalUtilTest.java +++ b/ksql-common/src/test/java/io/confluent/ksql/util/DecimalUtilTest.java @@ -236,46 +236,6 @@ public void shouldCastStringRoundUp() { assertThat(decimal, is(new BigDecimal("1.2"))); } - @Test - public void shouldConvertInteger() { - // When: - final Schema decimal = DecimalUtil.toDecimal(Schema.OPTIONAL_INT32_SCHEMA); - - // Then: - assertThat(decimal, is(DecimalUtil.builder(10, 0).build())); - } - - @Test - public void shouldConvertLong() { - // When: - final Schema decimal = DecimalUtil.toDecimal(Schema.OPTIONAL_INT64_SCHEMA); - - // Then: - assertThat(decimal, is(DecimalUtil.builder(19, 0).build())); - } - - @Test - public void shouldConvertDecimal() { - // Given: - final Schema given = DecimalUtil.builder(2, 2); - - // When: - final Schema decimal = DecimalUtil.toDecimal(given); - - // Then: - assertThat(decimal, sameInstance(given)); - } - - @Test - public void shouldThrowIfConvertString() { - // Expect: - expectedException.expect(KsqlException.class); - expectedException.expectMessage("Cannot convert schema of type STRING to decimal"); - - // When: - DecimalUtil.toDecimal(Schema.OPTIONAL_STRING_SCHEMA); - } - @Test public void shouldConvertIntegerToSqlDecimal() { // When: @@ -306,16 +266,6 @@ public void shouldConvertDecimalToSqlDecimal() { assertThat(decimal, is(SqlTypes.decimal(2, 2))); } - @Test - public void shouldThrowIfConvertStringToSqlDecimal() { - // Expect: - expectedException.expect(KsqlException.class); - expectedException.expectMessage("Cannot convert schema of type STRING to decimal"); - - // When: - DecimalUtil.toDecimal(Schema.OPTIONAL_STRING_SCHEMA); - } - @Test public void shouldEnsureFitIfExactMatch() { // No Exception When: diff --git a/ksql-common/src/test/java/io/confluent/ksql/util/SchemaUtilTest.java b/ksql-common/src/test/java/io/confluent/ksql/util/SchemaUtilTest.java index 7d2369f1b87f..db286c1a68df 100644 --- a/ksql-common/src/test/java/io/confluent/ksql/util/SchemaUtilTest.java +++ b/ksql-common/src/test/java/io/confluent/ksql/util/SchemaUtilTest.java @@ -19,14 +19,11 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; -import com.google.common.collect.ImmutableMap; import io.confluent.ksql.function.GenericsUtil; -import io.confluent.ksql.schema.Operator; import io.confluent.ksql.schema.ksql.PersistenceSchema; import java.math.BigDecimal; import java.util.List; import java.util.Map; -import java.util.Optional; import org.apache.kafka.connect.data.ConnectSchema; import org.apache.kafka.connect.data.Schema; import org.apache.kafka.connect.data.SchemaBuilder; @@ -226,259 +223,20 @@ public void shouldReturnFieldNameWithoutAliasAsIs() { assertThat(result, is("some-field-name")); } - @Test - public void shouldReturnNoAlias() { - assertThat(SchemaUtil.getFieldNameAlias("not-aliased"), is(Optional.empty())); - } - - @Test - public void shouldReturnAlias() { - assertThat(SchemaUtil.getFieldNameAlias("is.aliased"), is(Optional.of("is"))); - } - - @Test - public void shouldResolveIntAndLongSchemaToLong() { - assertThat( - SchemaUtil.resolveBinaryOperatorResultType(Schema.INT64_SCHEMA, Schema.INT32_SCHEMA, Operator.ADD).type(), - equalTo(Schema.Type.INT64)); - } - - @Test - public void shouldResolveIntAndIntSchemaToInt() { - assertThat( - SchemaUtil.resolveBinaryOperatorResultType(Schema.INT32_SCHEMA, Schema.INT32_SCHEMA, Operator.ADD).type(), - equalTo(Schema.Type.INT32)); - } - - @Test - public void shouldResolveFloat64AndAnyNumberTypeToFloat() { - assertThat( - SchemaUtil.resolveBinaryOperatorResultType(Schema.INT32_SCHEMA, Schema.FLOAT64_SCHEMA, Operator.ADD).type(), - equalTo(Schema.Type.FLOAT64)); - assertThat( - SchemaUtil.resolveBinaryOperatorResultType(Schema.FLOAT64_SCHEMA, Schema.OPTIONAL_INT64_SCHEMA, Operator.ADD).type(), - equalTo(Schema.Type.FLOAT64)); - assertThat( - SchemaUtil.resolveBinaryOperatorResultType(Schema.FLOAT32_SCHEMA, Schema.FLOAT64_SCHEMA, Operator.ADD).type(), - equalTo(Schema.Type.FLOAT64)); - } - - @Test - public void shouldResolveDecimalAddition() { - final Map, PrecisionScale> inputToExpected = - ImmutableMap., PrecisionScale>builder() - .put(Pair.of(PrecisionScale.of(2, 1), PrecisionScale.of(2, 1)), PrecisionScale.of(3, 1)) - .put(Pair.of(PrecisionScale.of(2, 1), PrecisionScale.of(2, 2)), PrecisionScale.of(4, 2)) - .put(Pair.of(PrecisionScale.of(2, 2), PrecisionScale.of(2, 1)), PrecisionScale.of(4, 2)) - .put(Pair.of(PrecisionScale.of(2, 1), PrecisionScale.of(3, 2)), PrecisionScale.of(4, 2)) - .put(Pair.of(PrecisionScale.of(3, 2), PrecisionScale.of(2, 1)), PrecisionScale.of(4, 2)) - .build(); - - inputToExpected.forEach((in, out) -> { - // Given: - final Schema d1 = DecimalUtil.builder(in.left.precision, in.left.scale).build(); - final Schema d2 = DecimalUtil.builder(in.right.precision, in.right.scale).build(); - - // When: - final Schema result = SchemaUtil.resolveBinaryOperatorResultType(d1, d2, Operator.ADD); - - // Then: - assertThat(String.format("precision: %s", in), DecimalUtil.precision(result), is(out.precision)); - assertThat(String.format("scale: %s", in), DecimalUtil.scale(result), is(out.scale)); - }); - } - - @Test - public void shouldResolveDecimalSubtraction() { - final Map, PrecisionScale> inputToExpected = - ImmutableMap., PrecisionScale>builder() - .put(Pair.of(PrecisionScale.of(2, 1), PrecisionScale.of(2, 1)), PrecisionScale.of(3, 1)) - .put(Pair.of(PrecisionScale.of(2, 1), PrecisionScale.of(2, 2)), PrecisionScale.of(4, 2)) - .put(Pair.of(PrecisionScale.of(2, 2), PrecisionScale.of(2, 1)), PrecisionScale.of(4, 2)) - .put(Pair.of(PrecisionScale.of(2, 1), PrecisionScale.of(3, 2)), PrecisionScale.of(4, 2)) - .put(Pair.of(PrecisionScale.of(3, 2), PrecisionScale.of(2, 1)), PrecisionScale.of(4, 2)) - .build(); - - inputToExpected.forEach((in, out) -> { - // Given: - final Schema d1 = DecimalUtil.builder(in.left.precision, in.left.scale).build(); - final Schema d2 = DecimalUtil.builder(in.right.precision, in.right.scale).build(); - - // When: - final Schema result = SchemaUtil.resolveBinaryOperatorResultType(d1, d2, Operator.SUBTRACT); - - // Then: - assertThat(String.format("precision: %s", in), DecimalUtil.precision(result), is(out.precision)); - assertThat(String.format("scale: %s", in), DecimalUtil.scale(result), is(out.scale)); - }); - } - - @Test - public void shouldResolveDecimalMultiply() { - final Map, PrecisionScale> inputToExpected = - ImmutableMap., PrecisionScale>builder() - .put(Pair.of(PrecisionScale.of(2, 1), PrecisionScale.of(2, 1)), PrecisionScale.of(5, 2)) - .put(Pair.of(PrecisionScale.of(2, 1), PrecisionScale.of(2, 2)), PrecisionScale.of(5, 3)) - .put(Pair.of(PrecisionScale.of(2, 2), PrecisionScale.of(2, 1)), PrecisionScale.of(5, 3)) - .put(Pair.of(PrecisionScale.of(3, 2), PrecisionScale.of(2, 1)), PrecisionScale.of(6, 3)) - .build(); - - inputToExpected.forEach((in, out) -> { - // Given: - final Schema d1 = DecimalUtil.builder(in.left.precision, in.left.scale).build(); - final Schema d2 = DecimalUtil.builder(in.right.precision, in.right.scale).build(); - - // When: - final Schema result = SchemaUtil.resolveBinaryOperatorResultType(d1, d2, Operator.MULTIPLY); - - // Then: - assertThat(String.format("precision: %s", in), DecimalUtil.precision(result), is(out.precision)); - assertThat(String.format("scale: %s", in), DecimalUtil.scale(result), is(out.scale)); - }); - } - - @Test - public void shouldResolveDecimalDivide() { - final Map, PrecisionScale> inputToExpected = - ImmutableMap., PrecisionScale>builder() - .put(Pair.of(PrecisionScale.of(2, 1), PrecisionScale.of(2, 1)), PrecisionScale.of(8, 6)) - .put(Pair.of(PrecisionScale.of(2, 1), PrecisionScale.of(2, 2)), PrecisionScale.of(9, 6)) - .put(Pair.of(PrecisionScale.of(2, 2), PrecisionScale.of(2, 1)), PrecisionScale.of(7, 6)) - .put(Pair.of(PrecisionScale.of(3, 3), PrecisionScale.of(3, 3)), PrecisionScale.of(10, 7)) - .put(Pair.of(PrecisionScale.of(3, 3), PrecisionScale.of(3, 2)), PrecisionScale.of(9, 7)) - .build(); - - inputToExpected.forEach((in, out) -> { - // Given: - final Schema d1 = DecimalUtil.builder(in.left.precision, in.left.scale).build(); - final Schema d2 = DecimalUtil.builder(in.right.precision, in.right.scale).build(); - - // When: - final Schema result = SchemaUtil.resolveBinaryOperatorResultType(d1, d2, Operator.DIVIDE); - - // Then: - assertThat(String.format("precision: %s", in), DecimalUtil.precision(result), is(out.precision)); - assertThat(String.format("scale: %s", in), DecimalUtil.scale(result), is(out.scale)); - }); - } - - @Test - public void shouldResolveDecimalMod() { - final Map, PrecisionScale> inputToExpected = - ImmutableMap., PrecisionScale>builder() - .put(Pair.of(PrecisionScale.of(2, 1), PrecisionScale.of(2, 1)), PrecisionScale.of(2, 1)) - .put(Pair.of(PrecisionScale.of(2, 2), PrecisionScale.of(2, 1)), PrecisionScale.of(2, 2)) - .put(Pair.of(PrecisionScale.of(2, 1), PrecisionScale.of(2, 2)), PrecisionScale.of(2, 2)) - .put(Pair.of(PrecisionScale.of(3, 1), PrecisionScale.of(2, 2)), PrecisionScale.of(2, 2)) - .build(); - - inputToExpected.forEach((in, out) -> { - // Given: - final Schema d1 = DecimalUtil.builder(in.left.precision, in.left.scale).build(); - final Schema d2 = DecimalUtil.builder(in.right.precision, in.right.scale).build(); - - // When: - final Schema result = SchemaUtil.resolveBinaryOperatorResultType(d1, d2, Operator.MODULUS); - - // Then: - assertThat(String.format("precision: %s", in), DecimalUtil.precision(result), is(out.precision)); - assertThat(String.format("scale: %s", in), DecimalUtil.scale(result), is(out.scale)); - }); - } - - @Test - public void shouldResolveDecimalLongAdd() { - final Map inputToExpected = - ImmutableMap.builder() - .put(PrecisionScale.of(2, 1), PrecisionScale.of(21, 1)) - .put(PrecisionScale.of(3, 3), PrecisionScale.of(23, 3)) - .put(PrecisionScale.of(23, 0), PrecisionScale.of(24, 0)) - .build(); - - inputToExpected.forEach((in, out) -> { - // Given: - final Schema d1 = DecimalUtil.builder(in.precision, in.scale).build(); - final Schema d2 = Schema.OPTIONAL_INT64_SCHEMA; - - // When: - final Schema result = SchemaUtil.resolveBinaryOperatorResultType(d1, d2, Operator.ADD); - - // Then: - assertThat(String.format("precision: %s", in), DecimalUtil.precision(result), is(out.precision)); - assertThat(String.format("scale: %s", in), DecimalUtil.scale(result), is(out.scale)); - }); - } - - @Test - public void shouldResolveDecimalDoubleMath() { - // Given: - final Schema d1 = DecimalUtil.builder(15, 10).build(); - final Schema d2 = Schema.OPTIONAL_FLOAT64_SCHEMA; - - // When: - final Schema result = SchemaUtil.resolveBinaryOperatorResultType(d1, d2, Operator.ADD); - - // Then: - assertThat(result, is(Schema.OPTIONAL_FLOAT64_SCHEMA)); - } - - private static class PrecisionScale { - final int precision; - final int scale; - - static PrecisionScale of(final int precision, final int scale) { - return new PrecisionScale(precision, scale); - } - - private PrecisionScale(final int precision, final int scale) { - this.precision = precision; - this.scale = scale; - } - - @Override - public String toString() { - return "PrecisionScale{" - + "precision=" + precision - + ", scale=" + scale - + '}'; - } - } - - @Test - public void shouldResolveStringAndStringToString() { - assertThat( - SchemaUtil.resolveBinaryOperatorResultType(Schema.STRING_SCHEMA, Schema.STRING_SCHEMA, Operator.ADD).type(), - equalTo(Schema.Type.STRING)); - } - - @Test(expected = KsqlException.class) - public void shouldThrowExceptionWhenResolvingStringWithAnythingElse() { - SchemaUtil.resolveBinaryOperatorResultType(Schema.STRING_SCHEMA, Schema.FLOAT64_SCHEMA, Operator.ADD); - } - - @Test(expected = KsqlException.class) - public void shouldThrowExceptionWhenResolvingUnkonwnType() { - SchemaUtil.resolveBinaryOperatorResultType(Schema.BOOLEAN_SCHEMA, Schema.FLOAT64_SCHEMA, Operator.ADD); - } - - @Test public void shouldPassIsNumberForInt() { - assertThat(SchemaUtil.isNumber(Schema.Type.INT32), is(true)); assertThat(SchemaUtil.isNumber(Schema.OPTIONAL_INT32_SCHEMA), is(true)); assertThat(SchemaUtil.isNumber(Schema.INT32_SCHEMA), is(true)); } @Test public void shouldPassIsNumberForBigint() { - assertThat(SchemaUtil.isNumber(Schema.Type.INT64), is(true)); assertThat(SchemaUtil.isNumber(Schema.OPTIONAL_INT64_SCHEMA), is(true)); assertThat(SchemaUtil.isNumber(Schema.INT64_SCHEMA), is(true)); } @Test public void shouldPassIsNumberForDouble() { - assertThat(SchemaUtil.isNumber(Schema.Type.FLOAT64), is(true)); assertThat(SchemaUtil.isNumber(Schema.OPTIONAL_FLOAT64_SCHEMA), is(true)); assertThat(SchemaUtil.isNumber(Schema.FLOAT64_SCHEMA), is(true)); } @@ -490,14 +248,12 @@ public void shouldPassIsNumberForDecimal() { @Test public void shouldFailIsNumberForBoolean() { - assertThat(SchemaUtil.isNumber(Schema.Type.BOOLEAN), is(false)); assertThat(SchemaUtil.isNumber(Schema.OPTIONAL_BOOLEAN_SCHEMA), is(false)); assertThat(SchemaUtil.isNumber(Schema.BOOLEAN_SCHEMA), is(false)); } @Test public void shouldFailIsNumberForString() { - assertThat(SchemaUtil.isNumber(Schema.Type.STRING), is(false)); assertThat(SchemaUtil.isNumber(Schema.OPTIONAL_STRING_SCHEMA), is(false)); assertThat(SchemaUtil.isNumber(Schema.STRING_SCHEMA), is(false)); } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/engine/EngineExecutor.java b/ksql-engine/src/main/java/io/confluent/ksql/engine/EngineExecutor.java index 53f68d5e78d7..8a69f95293c9 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/engine/EngineExecutor.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/engine/EngineExecutor.java @@ -106,7 +106,7 @@ ExecuteResult execute(final ConfiguredStatement statement) { } ExecuteResult execute(final KsqlPlan plan) { - final Optional ddlResult = plan.getDdlCommand().map(ddl -> executeDDL(plan)); + final Optional ddlResult = plan.getDdlCommand().map(ddl -> executeDdl(plan)); final Optional queryMetadata = plan.getQueryPlan().map(qp -> executePersistentQuery(plan)); return queryMetadata.map(ExecuteResult::of).orElseGet(() -> ExecuteResult.of(ddlResult.get())); @@ -154,7 +154,7 @@ KsqlPlan plan(final ConfiguredStatement statement) { ); final KsqlStructuredDataOutputNode outputNode = (KsqlStructuredDataOutputNode) plans.logicalPlan.getNode().get(); - final Optional ddlCommand = maybeCreateSinkDDL( + final Optional ddlCommand = maybeCreateSinkDdl( statement.getStatementText(), outputNode, plans.physicalPlan.getKeyField()); @@ -213,7 +213,7 @@ private ExecutorPlans( } } - private Optional maybeCreateSinkDDL( + private Optional maybeCreateSinkDdl( final String sql, final KsqlStructuredDataOutputNode outputNode, final KeyField keyField) { @@ -348,7 +348,7 @@ private static Set getSourceNames(final PlanNode outputNode) { return visitor.getSourceNames(); } - private String executeDDL(final KsqlPlan ksqlPlan) { + private String executeDdl(final KsqlPlan ksqlPlan) { final DdlCommand ddlCommand = ksqlPlan.getDdlCommand().get(); final Optional keyField = ksqlPlan.getQueryPlan() .map(QueryPlan::getPhysicalPlan) diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java index f13302d8923d..021d288dc2b2 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java @@ -582,6 +582,7 @@ private Pair visitArithmeticPlus( } } + @SuppressWarnings("deprecation") @Override public Pair visitArithmeticBinary( final ArithmeticBinaryExpression node, @@ -590,9 +591,7 @@ public Pair visitArithmeticBinary( final Pair left = process(node.getLeft(), context); final Pair right = process(node.getRight(), context); - final Schema schema = - SchemaUtil.resolveBinaryOperatorResultType( - left.getRight(), right.getRight(), node.getOperator()); + final Schema schema = expressionTypeManager.getExpressionSchema(node); if (DecimalUtil.isDecimal(schema)) { final String leftExpr = CastVisitor.getCast( @@ -633,6 +632,7 @@ public Pair visitArithmeticBinary( } } + @SuppressWarnings("deprecation") @Override public Pair visitSearchedCaseExpression( final SearchedCaseExpression node, @@ -646,7 +646,8 @@ public Pair visitSearchedCaseExpression( process(whenClause.getResult(), context) )) .collect(Collectors.toList()); - final Schema resultSchema = whenClauses.get(0).thenProcessResult.getRight(); + + final Schema resultSchema = expressionTypeManager.getExpressionSchema(node); final String resultSchemaString = SchemaUtil.getJavaType(resultSchema).getCanonicalName(); final List lazyWhenClause = whenClauses diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java index f28b4328e95a..60a8195862d0 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java @@ -52,18 +52,14 @@ import io.confluent.ksql.function.KsqlAggregateFunction; import io.confluent.ksql.function.KsqlFunctionException; import io.confluent.ksql.function.UdfFactory; -import io.confluent.ksql.schema.Operator; import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.SchemaConverters; -import io.confluent.ksql.schema.ksql.SchemaConverters.ConnectToSqlTypeConverter; -import io.confluent.ksql.schema.ksql.SchemaConverters.SqlToConnectTypeConverter; import io.confluent.ksql.schema.ksql.types.SqlArray; import io.confluent.ksql.schema.ksql.types.SqlMap; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.util.KsqlException; -import io.confluent.ksql.util.SchemaUtil; import io.confluent.ksql.util.VisitorUtil; import java.util.ArrayList; import java.util.List; @@ -74,12 +70,6 @@ @SuppressWarnings("deprecation") // Need to migrate away from Connect Schema use. public class ExpressionTypeManager { - private static final SqlToConnectTypeConverter SQL_TO_CONNECT_SCHEMA_CONVERTER = - SchemaConverters.sqlToConnectConverter(); - - private static final ConnectToSqlTypeConverter CONNECT_TO_SQL_SCHEMA_CONVERTER = - SchemaConverters.connectToSqlConverter(); - private final LogicalSchema schema; private final FunctionRegistry functionRegistry; @@ -108,25 +98,26 @@ public SqlType getExpressionSqlType(final Expression expression) { return expressionTypeContext.getSqlType(); } - static class ExpressionTypeContext { + private static class ExpressionTypeContext { - private Schema schema; private SqlType sqlType; - public Schema getSchema() { - return schema; - } - SqlType getSqlType() { return sqlType; } - void setSchema( - final SqlType sqlType, - final Schema schema - ) { + Schema getSchema() { + return sqlType == null + ? null + : SchemaConverters.sqlToConnectConverter().toConnectSchema(sqlType); + } + + void setSqlType(final SqlType sqlType) { this.sqlType = sqlType; - this.schema = schema; + } + + void setSchema(final Schema schema) { + this.sqlType = SchemaConverters.connectToSqlConverter().toSqlType(schema); } } @@ -138,14 +129,14 @@ public Void visitArithmeticBinary( final ExpressionTypeContext expressionTypeContext ) { process(node.getLeft(), expressionTypeContext); - final Schema leftType = expressionTypeContext.getSchema(); + final SqlType leftType = expressionTypeContext.getSqlType(); + process(node.getRight(), expressionTypeContext); - final Schema rightType = expressionTypeContext.getSchema(); + final SqlType rightType = expressionTypeContext.getSqlType(); - final Schema schema = resolveArithmeticType(leftType, rightType, node.getOperator()); - final SqlType sqlType = CONNECT_TO_SQL_SCHEMA_CONVERTER.toSqlType(schema); + final SqlType resultType = node.getOperator().resultType(leftType, rightType); - expressionTypeContext.setSchema(sqlType, schema); + expressionTypeContext.setSqlType(resultType); return null; } @@ -163,7 +154,7 @@ public Void visitNotExpression( final NotExpression node, final ExpressionTypeContext expressionTypeContext ) { - expressionTypeContext.setSchema(SqlTypes.BOOLEAN, Schema.OPTIONAL_BOOLEAN_SCHEMA); + expressionTypeContext.setSqlType(SqlTypes.BOOLEAN); return null; } @@ -178,9 +169,7 @@ public Void visitCast( + "are supported: " + sqlType); } - final Schema castType = SQL_TO_CONNECT_SCHEMA_CONVERTER.toConnectSchema(sqlType); - - expressionTypeContext.setSchema(sqlType, castType); + expressionTypeContext.setSqlType(sqlType); return null; } @@ -194,7 +183,7 @@ public Void visitComparisonExpression( process(node.getRight(), expressionTypeContext); final Schema rightSchema = expressionTypeContext.getSchema(); ComparisonUtil.isValidComparison(leftSchema, node.getType(), rightSchema); - expressionTypeContext.setSchema(SqlTypes.BOOLEAN, Schema.OPTIONAL_BOOLEAN_SCHEMA); + expressionTypeContext.setSqlType(SqlTypes.BOOLEAN); return null; } @@ -203,7 +192,7 @@ public Void visitBetweenPredicate( final BetweenPredicate node, final ExpressionTypeContext context ) { - context.setSchema(SqlTypes.BOOLEAN, Schema.OPTIONAL_BOOLEAN_SCHEMA); + context.setSqlType(SqlTypes.BOOLEAN); return null; } @@ -216,8 +205,7 @@ public Void visitColumnReference( .orElseThrow(() -> new KsqlException(String.format("Invalid Expression %s.", node.toString()))); - final Schema schema = SQL_TO_CONNECT_SCHEMA_CONVERTER.toConnectSchema(schemaColumn.type()); - expressionTypeContext.setSchema(schemaColumn.type(), schema); + expressionTypeContext.setSqlType(schemaColumn.type()); return null; } @@ -235,7 +223,7 @@ public Void visitStringLiteral( final StringLiteral node, final ExpressionTypeContext expressionTypeContext ) { - expressionTypeContext.setSchema(SqlTypes.STRING, Schema.OPTIONAL_STRING_SCHEMA); + expressionTypeContext.setSqlType(SqlTypes.STRING); return null; } @@ -244,7 +232,7 @@ public Void visitBooleanLiteral( final BooleanLiteral node, final ExpressionTypeContext expressionTypeContext ) { - expressionTypeContext.setSchema(SqlTypes.BOOLEAN, Schema.OPTIONAL_BOOLEAN_SCHEMA); + expressionTypeContext.setSqlType(SqlTypes.BOOLEAN); return null; } @@ -253,7 +241,7 @@ public Void visitLongLiteral( final LongLiteral node, final ExpressionTypeContext expressionTypeContext ) { - expressionTypeContext.setSchema(SqlTypes.BIGINT, Schema.OPTIONAL_INT64_SCHEMA); + expressionTypeContext.setSqlType(SqlTypes.BIGINT); return null; } @@ -262,7 +250,7 @@ public Void visitIntegerLiteral( final IntegerLiteral node, final ExpressionTypeContext expressionTypeContext ) { - expressionTypeContext.setSchema(SqlTypes.INTEGER, Schema.OPTIONAL_INT32_SCHEMA); + expressionTypeContext.setSqlType(SqlTypes.INTEGER); return null; } @@ -271,7 +259,7 @@ public Void visitDoubleLiteral( final DoubleLiteral node, final ExpressionTypeContext expressionTypeContext ) { - expressionTypeContext.setSchema(SqlTypes.DOUBLE, Schema.OPTIONAL_FLOAT64_SCHEMA); + expressionTypeContext.setSqlType(SqlTypes.DOUBLE); return null; } @@ -280,7 +268,7 @@ public Void visitNullLiteral( final NullLiteral node, final ExpressionTypeContext context ) { - context.setSchema(null, null); + context.setSqlType(null); return null; } @@ -289,7 +277,7 @@ public Void visitLikePredicate( final LikePredicate node, final ExpressionTypeContext expressionTypeContext ) { - expressionTypeContext.setSchema(SqlTypes.BOOLEAN, Schema.OPTIONAL_BOOLEAN_SCHEMA); + expressionTypeContext.setSqlType(SqlTypes.BOOLEAN); return null; } @@ -298,7 +286,7 @@ public Void visitIsNotNullPredicate( final IsNotNullPredicate node, final ExpressionTypeContext expressionTypeContext ) { - expressionTypeContext.setSchema(SqlTypes.BOOLEAN, Schema.OPTIONAL_BOOLEAN_SCHEMA); + expressionTypeContext.setSqlType(SqlTypes.BOOLEAN); return null; } @@ -307,7 +295,7 @@ public Void visitIsNullPredicate( final IsNullPredicate node, final ExpressionTypeContext expressionTypeContext ) { - expressionTypeContext.setSchema(SqlTypes.BOOLEAN, Schema.OPTIONAL_BOOLEAN_SCHEMA); + expressionTypeContext.setSqlType(SqlTypes.BOOLEAN); return null; } @@ -327,7 +315,6 @@ public Void visitSubscriptExpression( final ExpressionTypeContext expressionTypeContext ) { process(node.getBase(), expressionTypeContext); - final Schema arrayMapSchema = expressionTypeContext.getSchema(); final SqlType arrayMapType = expressionTypeContext.getSqlType(); final SqlType valueType; @@ -339,7 +326,7 @@ public Void visitSubscriptExpression( throw new UnsupportedOperationException("Unsupported container type: " + arrayMapType); } - expressionTypeContext.setSchema(valueType, arrayMapSchema.valueSchema()); + expressionTypeContext.setSqlType(valueType); return null; } @@ -361,11 +348,7 @@ public Void visitFunctionCall( final KsqlAggregateFunction aggFunc = functionRegistry .getAggregateFunction(node.getName().name(), schema, args); - final Schema returnSchema = aggFunc.getReturnType(); - - final SqlType returnType = CONNECT_TO_SQL_SCHEMA_CONVERTER.toSqlType(returnSchema); - - expressionTypeContext.setSchema(returnType, returnSchema); + expressionTypeContext.setSchema(aggFunc.getReturnType()); return null; } @@ -379,8 +362,7 @@ public Void visitFunctionCall( node.getArguments().get(0).toString())); } final Schema returnSchema = firstArgSchema.field(fieldName).schema(); - final SqlType returnType = CONNECT_TO_SQL_SCHEMA_CONVERTER.toSqlType(returnSchema); - expressionTypeContext.setSchema(returnType, returnSchema); + expressionTypeContext.setSchema(returnSchema); } else { final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName().name()); final List argTypes = new ArrayList<>(); @@ -389,8 +371,7 @@ public Void visitFunctionCall( argTypes.add(expressionTypeContext.getSchema()); } final Schema returnSchema = udfFactory.getFunction(argTypes).getReturnType(argTypes); - final SqlType returnType = CONNECT_TO_SQL_SCHEMA_CONVERTER.toSqlType(returnSchema); - expressionTypeContext.setSchema(returnType, returnSchema); + expressionTypeContext.setSchema(returnSchema); } return null; } @@ -468,13 +449,6 @@ public Void visitWhenClause( throw VisitorUtil.illegalState(this, whenClause); } - private Schema resolveArithmeticType( - final Schema leftSchema, - final Schema rightSchema, - final Operator operator) { - return SchemaUtil.resolveBinaryOperatorResultType(leftSchema, rightSchema, operator); - } - private void validateSearchedCaseExpression( final SearchedCaseExpression searchedCaseExpression) { final Schema firstResultSchema = getExpressionSchema( diff --git a/ksql-functional-tests/src/main/java/io/confluent/ksql/test/loader/JsonTestLoader.java b/ksql-functional-tests/src/main/java/io/confluent/ksql/test/loader/JsonTestLoader.java index c327a19bfc94..dc24b220140a 100644 --- a/ksql-functional-tests/src/main/java/io/confluent/ksql/test/loader/JsonTestLoader.java +++ b/ksql-functional-tests/src/main/java/io/confluent/ksql/test/loader/JsonTestLoader.java @@ -18,6 +18,7 @@ import static java.nio.charset.StandardCharsets.UTF_8; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; import io.confluent.ksql.test.TestFrameworkException; import io.confluent.ksql.test.tools.Test; import java.io.BufferedReader; @@ -46,6 +47,10 @@ public final class JsonTestLoader implements TestLoader { private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + static { + OBJECT_MAPPER.registerModule(new Jdk8Module()); + } + private final Path testDir; private final Class> testFileType; diff --git a/ksql-functional-tests/src/main/java/io/confluent/ksql/test/model/RecordNode.java b/ksql-functional-tests/src/main/java/io/confluent/ksql/test/model/RecordNode.java index 48735f0dca82..15841a577577 100644 --- a/ksql-functional-tests/src/main/java/io/confluent/ksql/test/model/RecordNode.java +++ b/ksql-functional-tests/src/main/java/io/confluent/ksql/test/model/RecordNode.java @@ -17,40 +17,45 @@ import static java.util.Objects.requireNonNull; -import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import io.confluent.ksql.test.serde.string.StringSerdeSupplier; import io.confluent.ksql.test.tools.Record; import io.confluent.ksql.test.tools.Topic; import io.confluent.ksql.test.tools.exceptions.InvalidFieldException; import io.confluent.ksql.test.tools.exceptions.MissingFieldException; +import io.confluent.ksql.test.utils.JsonParsingUtil; import java.io.IOException; import java.util.Map; import java.util.Optional; +@JsonDeserialize(using = RecordNode.Deserializer.class) public final class RecordNode { private static final ObjectMapper objectMapper = new ObjectMapper(); private final String topicName; - private final String key; + private final Optional key; private final JsonNode value; private final Optional timestamp; private final Optional window; - RecordNode( - @JsonProperty("topic") final String topicName, - @JsonProperty("key") final String key, - @JsonProperty("value") final JsonNode value, - @JsonProperty("timestamp") final Long timestamp, - @JsonProperty("window") final WindowData window + private RecordNode( + final String topicName, + final Optional key, + final JsonNode value, + final Optional timestamp, + final Optional window ) { this.topicName = topicName == null ? "" : topicName; - this.key = key == null ? "" : key; + this.key = requireNonNull(key, "key"); this.value = requireNonNull(value, "value"); - this.timestamp = Optional.ofNullable(timestamp); - this.window = Optional.ofNullable(window); + this.timestamp = requireNonNull(timestamp, "timestamp"); + this.window = requireNonNull(window, "window"); if (this.topicName.isEmpty()) { throw new MissingFieldException("topic"); @@ -64,12 +69,12 @@ public String topicName() { public Record build(final Map topics) { final Topic topic = topics.get(topicName); - final Object topicValue = buildValue(topic); + final Object recordValue = buildValue(topic); return new Record( topic, - key, - topicValue, + key.orElse(null), + recordValue, timestamp, window.orElse(null) ); @@ -94,4 +99,31 @@ private Object buildValue(final Topic topic) { throw new InvalidFieldException("value", "failed to parse", e); } } + + public static class Deserializer extends JsonDeserializer { + + @Override + public RecordNode deserialize( + final JsonParser jp, + final DeserializationContext ctxt + ) throws IOException { + final JsonNode node = jp.getCodec().readTree(jp); + + final String topic = JsonParsingUtil.getRequired("topic", node, jp, String.class); + + final Optional key = node.has("key") + ? JsonParsingUtil.getOptional("key", node, jp, String.class) + : Optional.of(""); + + final JsonNode value = JsonParsingUtil.getRequired("value", node, jp, JsonNode.class); + + final Optional timestamp = JsonParsingUtil + .getOptional("timestamp", node, jp, Long.class); + + final Optional window = JsonParsingUtil + .getOptional("window", node, jp, WindowData.class); + + return new RecordNode(topic, key, value, timestamp, window); + } + } } \ No newline at end of file diff --git a/ksql-functional-tests/src/main/java/io/confluent/ksql/test/tools/Record.java b/ksql-functional-tests/src/main/java/io/confluent/ksql/test/tools/Record.java index 93c783a4ad2d..8bf30b0c34d1 100644 --- a/ksql-functional-tests/src/main/java/io/confluent/ksql/test/tools/Record.java +++ b/ksql-functional-tests/src/main/java/io/confluent/ksql/test/tools/Record.java @@ -33,20 +33,10 @@ public class Record { final Topic topic; private final String key; - final Object value; - final Optional timestamp; + private final Object value; + private final Optional timestamp; private final WindowData window; - public Record( - final Topic topic, - final String key, - final Object value, - final long timestamp, - final WindowData window - ) { - this(topic, key, value, Optional.of(timestamp), window); - } - public Record( final Topic topic, final String key, diff --git a/ksql-functional-tests/src/main/java/io/confluent/ksql/test/tools/stubs/StubKafkaRecord.java b/ksql-functional-tests/src/main/java/io/confluent/ksql/test/tools/stubs/StubKafkaRecord.java index 9edd64edc5ca..77b68f0da7ed 100644 --- a/ksql-functional-tests/src/main/java/io/confluent/ksql/test/tools/stubs/StubKafkaRecord.java +++ b/ksql-functional-tests/src/main/java/io/confluent/ksql/test/tools/stubs/StubKafkaRecord.java @@ -52,7 +52,7 @@ public static StubKafkaRecord of( final SerdeSupplier serdeSupplier = topic.getValueSerdeSupplier(); final Record testRecord = new Record( topic, - producerRecord.key().toString(), + Objects.toString(producerRecord.key()), serdeSupplier instanceof AvroSerdeSupplier ? ((ValueSpec)producerRecord.value()).getSpec() : producerRecord.value(), diff --git a/ksql-functional-tests/src/main/java/io/confluent/ksql/test/utils/JsonParsingUtil.java b/ksql-functional-tests/src/main/java/io/confluent/ksql/test/utils/JsonParsingUtil.java new file mode 100644 index 000000000000..598d8fa6d06d --- /dev/null +++ b/ksql-functional-tests/src/main/java/io/confluent/ksql/test/utils/JsonParsingUtil.java @@ -0,0 +1,66 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.test.utils; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.JsonNode; +import io.confluent.ksql.test.tools.exceptions.MissingFieldException; +import java.io.IOException; +import java.util.Optional; + +public final class JsonParsingUtil { + + private JsonParsingUtil() { + } + + public static T getRequired( + final String name, + final JsonNode node, + final JsonParser jp, + final Class type + ) throws IOException { + if (!node.has(name)) { + throw new MissingFieldException(name); + } + + return getNode(name, node, jp, type); + } + + public static Optional getOptional( + final String name, + final JsonNode node, + final JsonParser jp, + final Class type + ) throws IOException { + if (!node.has(name)) { + return Optional.empty(); + } + + return Optional.ofNullable(getNode(name, node, jp, type)); + } + + private static T getNode( + final String name, + final JsonNode node, + final JsonParser jp, + final Class type + ) throws IOException { + return node + .get(name) + .traverse(jp.getCodec()) + .readValueAs(type); + } +} diff --git a/ksql-functional-tests/src/test/java/io/confluent/ksql/test/SchemaTranslationTest.java b/ksql-functional-tests/src/test/java/io/confluent/ksql/test/SchemaTranslationTest.java index cf0dccb83803..00b367b8b27f 100644 --- a/ksql-functional-tests/src/test/java/io/confluent/ksql/test/SchemaTranslationTest.java +++ b/ksql-functional-tests/src/test/java/io/confluent/ksql/test/SchemaTranslationTest.java @@ -82,7 +82,7 @@ private static List generateInputRecords( topic, "test-key", avroToValueSpec(generator.generate(), avroSchema, true), - 0, + Optional.of(0L), null ) ).collect(Collectors.toList()); @@ -97,7 +97,7 @@ private static List getOutputRecords( topic, "test-key", r.value(), - 0, + Optional.of(0L), null )) .collect(Collectors.toList()); diff --git a/ksql-functional-tests/src/test/java/io/confluent/ksql/test/tools/RecordTest.java b/ksql-functional-tests/src/test/java/io/confluent/ksql/test/tools/RecordTest.java index 2039f337ce6e..ee9c520d680f 100644 --- a/ksql-functional-tests/src/test/java/io/confluent/ksql/test/tools/RecordTest.java +++ b/ksql-functional-tests/src/test/java/io/confluent/ksql/test/tools/RecordTest.java @@ -20,6 +20,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import io.confluent.ksql.test.model.WindowData; +import java.util.Optional; import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.Serdes; import org.apache.kafka.common.serialization.Serializer; @@ -45,7 +46,13 @@ public class RecordTest { @Test public void shouldGetCorrectStringKeySerializer() { // Given: - final Record record = new Record(topic, "foo", "bar", 1000L, null); + final Record record = new Record( + topic, + "foo", + "bar", + Optional.of(1000L), + null + ); // When: final Serializer serializer = record.keySerializer(); @@ -60,7 +67,7 @@ public void shouldGetCorrectTimeWondowedKeySerializer() { final Record record = new Record(topic, "foo", "bar", - 1000L, + Optional.of(1000L), new WindowData(100L, 1000L, "TIME")); // When: @@ -76,7 +83,7 @@ public void shouldGetCorrectSessionWindowedKeySerializer() { final Record record = new Record(topic, "foo", "bar", - 1000L, + Optional.of(1000L), new WindowData(100L, 1000L, "SESSION")); // When: @@ -92,7 +99,7 @@ public void shouldGetCorrectStringKeyDeserializer() { final Record record = new Record(topic, "foo", "bar", - 1000L, + Optional.of(1000L), null); // When: @@ -109,7 +116,7 @@ public void shouldGetCorrectTimedWindowKeyDeserializer() { final Record record = new Record(topic, "foo", "bar", - 1000L, + Optional.of(1000L), new WindowData(100L, 1000L, "TIME")); // When: @@ -126,7 +133,7 @@ public void shouldGetCorrectSessionedWindowKeyDeserializer() { final Record record = new Record(topic, "foo", "bar", - 1000L, + Optional.of(1000L), new WindowData(100L, 1000L, "SESSION")); // When: @@ -143,7 +150,7 @@ public void shouldGetStringKey() { final Record record = new Record(topic, "foo", "bar", - 1000L, + Optional.of(1000L), null); // When: @@ -160,7 +167,7 @@ public void shouldGetTimeWindowKey() { final Record record = new Record(topic, "foo", "bar", - 1000L, + Optional.of(1000L), new WindowData(100L, 1000L, "TIME")); // When: @@ -180,7 +187,7 @@ public void shouldGetSessionWindowKey() { final Record record = new Record(topic, "foo", "bar", - 1000L, + Optional.of(1000L), new WindowData(100L, 1000L, "SESSION")); // When: diff --git a/ksql-functional-tests/src/test/java/io/confluent/ksql/test/tools/TestExecutorTest.java b/ksql-functional-tests/src/test/java/io/confluent/ksql/test/tools/TestExecutorTest.java index ef0f576e964b..77eba4c12ca5 100644 --- a/ksql-functional-tests/src/test/java/io/confluent/ksql/test/tools/TestExecutorTest.java +++ b/ksql-functional-tests/src/test/java/io/confluent/ksql/test/tools/TestExecutorTest.java @@ -193,8 +193,8 @@ public void shouldFailOnTwoLittleOutput() { final StubKafkaRecord actual_0 = kafkaRecord(sinkTopic, 123456719L, "k1", "v1"); when(kafkaService.readRecords("sink_topic")).thenReturn(ImmutableList.of(actual_0)); - final Record expected_0 = new Record(sinkTopic, "k1", "v1", 1L, null); - final Record expected_1 = new Record(sinkTopic, "k1", "v1", 1L, null); + final Record expected_0 = new Record(sinkTopic, "k1", "v1", Optional.of(1L), null); + final Record expected_1 = new Record(sinkTopic, "k1", "v1", Optional.of(1L), null); when(testCase.getOutputRecords()).thenReturn(ImmutableList.of(expected_0, expected_1)); // Expect @@ -214,7 +214,7 @@ public void shouldFailOnTwoMuchOutput() { final StubKafkaRecord actual_1 = kafkaRecord(sinkTopic, 123456789L, "k2", "v2"); when(kafkaService.readRecords("sink_topic")).thenReturn(ImmutableList.of(actual_0, actual_1)); - final Record expected_0 = new Record(sinkTopic, "k1", "v1", 1L, null); + final Record expected_0 = new Record(sinkTopic, "k1", "v1", Optional.of(1L), null); when(testCase.getOutputRecords()).thenReturn(ImmutableList.of(expected_0)); // Expect @@ -235,8 +235,8 @@ public void shouldFailOnUnexpectedOutput() { final StubKafkaRecord actual_1 = kafkaRecord(sinkTopic, 123456789L, "k2", "v2"); when(kafkaService.readRecords("sink_topic")).thenReturn(ImmutableList.of(actual_0, actual_1)); - final Record expected_0 = new Record(sinkTopic, "k1", "v1", 123456719L, null); - final Record expected_1 = new Record(sinkTopic, "k2", "different", 123456789L, null); + final Record expected_0 = new Record(sinkTopic, "k1", "v1", Optional.of(123456719L), null); + final Record expected_1 = new Record(sinkTopic, "k2", "different", Optional.of(123456789L), null); when(testCase.getOutputRecords()).thenReturn(ImmutableList.of(expected_0, expected_1)); // Expect @@ -255,8 +255,8 @@ public void shouldPassOnExpectedOutput() { final StubKafkaRecord actual_1 = kafkaRecord(sinkTopic, 123456789L, "k2", "v2"); when(kafkaService.readRecords("sink_topic")).thenReturn(ImmutableList.of(actual_0, actual_1)); - final Record expected_0 = new Record(sinkTopic, "k1", "v1", 123456719L, null); - final Record expected_1 = new Record(sinkTopic, "k2", "v2", 123456789L, null); + final Record expected_0 = new Record(sinkTopic, "k1", "v1", Optional.of(123456719L), null); + final Record expected_1 = new Record(sinkTopic, "k2", "v2", Optional.of(123456789L), null); when(testCase.getOutputRecords()).thenReturn(ImmutableList.of(expected_0, expected_1)); // When: diff --git a/ksql-functional-tests/src/test/resources/query-validation-tests/key-schemas.json b/ksql-functional-tests/src/test/resources/query-validation-tests/key-schemas.json index fa9e61eac692..99abe2a303ab 100644 --- a/ksql-functional-tests/src/test/resources/query-validation-tests/key-schemas.json +++ b/ksql-functional-tests/src/test/resources/query-validation-tests/key-schemas.json @@ -16,7 +16,7 @@ {"topic": "OUTPUT", "key": "1", "value": {"ID": 1, "KEY": "1"}}, {"topic": "OUTPUT", "key": "1", "value": {"ID": 2, "KEY": "1"}}, {"topic": "OUTPUT", "key": "", "value": {"ID": 3, "KEY": ""}}, - {"topic": "OUTPUT", "key": null, "value": {"ID": 4, "KEY": ""}} + {"topic": "OUTPUT", "key": null, "value": {"ID": 4, "KEY": null}} ] }, { @@ -28,14 +28,12 @@ "inputs": [ {"topic": "input", "key": 1, "value": {"id": 1}}, {"topic": "input", "key": "1", "value": {"id": 2}}, - {"topic": "input", "key": "", "value": {"id": 3}}, - {"topic": "input", "key": null, "value": {"id": 4}} + {"topic": "input", "key": "", "value": {"id": 3}} ], "outputs": [ {"topic": "OUTPUT", "key": "1", "value": {"ID": 1, "KEY": "1"}}, {"topic": "OUTPUT", "key": "1", "value": {"ID": 2, "KEY": "1"}}, - {"topic": "OUTPUT", "key": "", "value": {"ID": 3, "KEY": ""}}, - {"topic": "OUTPUT", "key": null, "value": {"ID": 4, "KEY": ""}} + {"topic": "OUTPUT", "key": "", "value": {"ID": 3, "KEY": ""}} ] }, { @@ -54,7 +52,7 @@ {"topic": "OUTPUT", "key": "1", "value": {"ID": 1, "KEY": "1"}}, {"topic": "OUTPUT", "key": "1", "value": {"ID": 2, "KEY": "1"}}, {"topic": "OUTPUT", "key": "", "value": {"ID": 3, "KEY": ""}}, - {"topic": "OUTPUT", "key": null, "value": {"ID": 4, "KEY": ""}} + {"topic": "OUTPUT", "key": null, "value": {"ID": 4, "KEY": null}} ] }, { @@ -66,14 +64,12 @@ "inputs": [ {"topic": "input", "key": 1, "value": {"id": 1}}, {"topic": "input", "key": "1", "value": {"id": 2}}, - {"topic": "input", "key": "", "value": {"id": 3}}, - {"topic": "input", "key": null, "value": {"id": 4}} + {"topic": "input", "key": "", "value": {"id": 3}} ], "outputs": [ {"topic": "OUTPUT", "key": "1", "value": {"ID": 1, "KEY": "1"}}, {"topic": "OUTPUT", "key": "1", "value": {"ID": 2, "KEY": "1"}}, - {"topic": "OUTPUT", "key": "", "value": {"ID": 3, "KEY": ""}}, - {"topic": "OUTPUT", "key": null, "value": {"ID": 4, "KEY": ""}} + {"topic": "OUTPUT", "key": "", "value": {"ID": 3, "KEY": ""}} ] }, { diff --git a/ksql-functional-tests/src/test/resources/rest-query-validation-tests/insert-values.json b/ksql-functional-tests/src/test/resources/rest-query-validation-tests/insert-values.json index e65f7ec74adf..8f2df55d6e58 100644 --- a/ksql-functional-tests/src/test/resources/rest-query-validation-tests/insert-values.json +++ b/ksql-functional-tests/src/test/resources/rest-query-validation-tests/insert-values.json @@ -66,6 +66,18 @@ {"topic": "test_topic", "key": "10", "value": {"ID": null}} ] }, + { + "name": "should insert null key", + "statements": [ + "CREATE STREAM TEST (ID INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "INSERT INTO TEST (ID) VALUES (10);" + ], + "inputs": [ + ], + "outputs": [ + {"topic": "test_topic", "key": null, "value": {"ID": 10}} + ] + }, { "name": "rowkey should be set when stream has int key and only key specified in insert", "statements": [ @@ -221,6 +233,18 @@ "message": "Failed to insert values into 'TEST'. Expected ROWKEY and ID to match but got 10 and 5 respectively.", "status": 400 } + }, + { + "name": "should coerce numbers", + "statements": [ + "CREATE STREAM TEST (I INT, BI BIGINT, D DOUBLE) WITH (kafka_topic='test_topic', value_format='JSON');", + "INSERT INTO TEST (I, BI, D) VALUES (1, 2, 3);" + ], + "inputs": [ + ], + "outputs": [ + {"topic": "test_topic", "key": null, "value": {"I": 1, "BI": 2, "D": 3.0}} + ] } ] } \ No newline at end of file diff --git a/ksql-parser/src/main/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercer.java b/ksql-parser/src/main/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercer.java index e066a2bd7a42..f3905dbcb9d8 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercer.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercer.java @@ -75,7 +75,7 @@ private static Optional coerceDecimal(final Object value, final SqlDecima } } - if (value instanceof Number) { + if (value instanceof Number && !(value instanceof Double)) { return optional( new BigDecimal( ((Number) value).doubleValue(), diff --git a/ksql-parser/src/test/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercerTest.java b/ksql-parser/src/test/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercerTest.java index 6a44ea54d6bc..ab02b3335341 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercerTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercerTest.java @@ -17,14 +17,22 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.util.KsqlException; import java.math.BigDecimal; +import java.util.Arrays; +import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; import org.apache.kafka.connect.data.SchemaBuilder; import org.apache.kafka.connect.data.Struct; import org.junit.Before; @@ -34,6 +42,32 @@ public class DefaultSqlValueCoercerTest { + private static final Set UNSUPPORTED = ImmutableSet.of( + SqlBaseType.ARRAY, + SqlBaseType.MAP, + SqlBaseType.STRUCT + ); + + private static final Map TYPES = ImmutableMap + .builder() + .put(SqlBaseType.BOOLEAN, SqlTypes.BOOLEAN) + .put(SqlBaseType.INTEGER, SqlTypes.INTEGER) + .put(SqlBaseType.BIGINT, SqlTypes.BIGINT) + .put(SqlBaseType.DECIMAL, SqlTypes.decimal(2, 1)) + .put(SqlBaseType.DOUBLE, SqlTypes.DOUBLE) + .put(SqlBaseType.STRING, SqlTypes.STRING) + .build(); + + private static final Map INSTANCES = ImmutableMap + .builder() + .put(SqlBaseType.BOOLEAN, false) + .put(SqlBaseType.INTEGER, 1) + .put(SqlBaseType.BIGINT, 2L) + .put(SqlBaseType.DECIMAL, BigDecimal.ONE) + .put(SqlBaseType.DOUBLE, 3.0D) + .put(SqlBaseType.STRING, "4.1") + .build(); + private DefaultSqlValueCoercer coercer; @Rule @@ -69,7 +103,7 @@ public void shouldCoerceToBoolean() { public void shouldNotCoerceToBoolean() { assertThat(coercer.coerce("true", SqlTypes.BOOLEAN), is(Optional.empty())); assertThat(coercer.coerce(1, SqlTypes.BOOLEAN), is(Optional.empty())); - assertThat(coercer.coerce(1l, SqlTypes.BOOLEAN), is(Optional.empty())); + assertThat(coercer.coerce(1L, SqlTypes.BOOLEAN), is(Optional.empty())); assertThat(coercer.coerce(1.0d, SqlTypes.BOOLEAN), is(Optional.empty())); assertThat(coercer.coerce(new BigDecimal(123), SqlTypes.BOOLEAN), is(Optional.empty())); } @@ -82,7 +116,7 @@ public void shouldCoerceToInteger() { @Test public void shouldNotCoerceToInteger() { assertThat(coercer.coerce(true, SqlTypes.INTEGER), is(Optional.empty())); - assertThat(coercer.coerce(1l, SqlTypes.INTEGER), is(Optional.empty())); + assertThat(coercer.coerce(1L, SqlTypes.INTEGER), is(Optional.empty())); assertThat(coercer.coerce(1.0d, SqlTypes.INTEGER), is(Optional.empty())); assertThat(coercer.coerce("1", SqlTypes.INTEGER), is(Optional.empty())); assertThat(coercer.coerce(new BigDecimal(123), SqlTypes.INTEGER), is(Optional.empty())); @@ -90,8 +124,8 @@ public void shouldNotCoerceToInteger() { @Test public void shouldCoerceToBigInt() { - assertThat(coercer.coerce(1, SqlTypes.BIGINT), is(Optional.of(1l))); - assertThat(coercer.coerce(1l, SqlTypes.BIGINT), is(Optional.of(1l))); + assertThat(coercer.coerce(1, SqlTypes.BIGINT), is(Optional.of(1L))); + assertThat(coercer.coerce(1L, SqlTypes.BIGINT), is(Optional.of(1L))); } @Test @@ -102,10 +136,28 @@ public void shouldNotCoerceToBigInt() { assertThat(coercer.coerce(new BigDecimal(123), SqlTypes.BIGINT), is(Optional.empty())); } + @Test + public void shouldCoerceToDecimal() { + final SqlType decimalType = SqlTypes.decimal(2, 1); + assertThat(coercer.coerce(1, decimalType), is(Optional.of(new BigDecimal("1.0")))); + assertThat(coercer.coerce(1L, decimalType), is(Optional.of(new BigDecimal("1.0")))); + assertThat(coercer.coerce("1.0", decimalType), is(Optional.of(new BigDecimal("1.0")))); + assertThat(coercer.coerce(new BigDecimal("1.0"), decimalType), + is(Optional.of(new BigDecimal("1.0")))); + } + + @Test + public void shouldNotCoerceToDecimal() { + final SqlType decimalType = SqlTypes.decimal(2, 1); + assertThat(coercer.coerce(true, decimalType), is(Optional.empty())); + assertThat(coercer.coerce(1.0d, decimalType), is(Optional.empty())); + } + @Test public void shouldCoerceToDouble() { assertThat(coercer.coerce(1, SqlTypes.DOUBLE), is(Optional.of(1.0d))); - assertThat(coercer.coerce(1l, SqlTypes.DOUBLE), is(Optional.of(1.0d))); + assertThat(coercer.coerce(1L, SqlTypes.DOUBLE), is(Optional.of(1.0d))); + assertThat(coercer.coerce(new BigDecimal(123), SqlTypes.DOUBLE), is(Optional.of(123.0d))); assertThat(coercer.coerce(1.0d, SqlTypes.DOUBLE), is(Optional.of(1.0d))); } @@ -113,7 +165,6 @@ public void shouldCoerceToDouble() { public void shouldNotCoerceToDouble() { assertThat(coercer.coerce(true, SqlTypes.DOUBLE), is(Optional.empty())); assertThat(coercer.coerce("1", SqlTypes.DOUBLE), is(Optional.empty())); - assertThat(coercer.coerce(new BigDecimal(123), SqlTypes.DOUBLE), is(Optional.empty())); } @Test @@ -125,31 +176,11 @@ public void shouldCoerceToString() { public void shouldNotCoerceToString() { assertThat(coercer.coerce(true, SqlTypes.STRING), is(Optional.empty())); assertThat(coercer.coerce(1, SqlTypes.STRING), is(Optional.empty())); - assertThat(coercer.coerce(1l, SqlTypes.STRING), is(Optional.empty())); + assertThat(coercer.coerce(1L, SqlTypes.STRING), is(Optional.empty())); assertThat(coercer.coerce(1.0d, SqlTypes.STRING), is(Optional.empty())); assertThat(coercer.coerce(new BigDecimal(123), SqlTypes.STRING), is(Optional.empty())); } - - @Test - public void shouldCoerceToDecimal() { - SqlType decimalType = SqlTypes.decimal(2, 1); - assertThat(coercer.coerce(1, decimalType), is(Optional.of(new BigDecimal("1.0")))); - assertThat(coercer.coerce(1l, decimalType), is(Optional.of(new BigDecimal("1.0")))); - assertThat(coercer.coerce(1.0d, decimalType), - is(Optional.of(new BigDecimal("1.0")))); - assertThat(coercer.coerce("1.0", decimalType), - is(Optional.of(new BigDecimal("1.0")))); - assertThat(coercer.coerce(new BigDecimal("1.0"), decimalType), - is(Optional.of(new BigDecimal("1.0")))); - } - - @Test - public void shouldNotCoerceToDecimal() { - assertThat(coercer.coerce(true, SqlTypes.decimal(2, 1)), - is(Optional.empty())); - } - @Test public void shouldThrowIfInvalidCoercionString() { // Given: @@ -162,4 +193,67 @@ public void shouldThrowIfInvalidCoercionString() { // When: coercer.coerce(val, SqlTypes.decimal(2, 1)); } + + @Test + public void shouldCoerceUsingSameRulesAsBaseTypeUpCastRules() { + for (final SqlBaseType fromBaseType : supportedTypes()) { + // Given: + final Map> partitioned = supportedTypes().stream() + .collect(Collectors + .partitioningBy(toBaseType -> coercionShouldBeSupported(fromBaseType, toBaseType))); + + final List shouldUpCast = partitioned.getOrDefault(true, ImmutableList.of()); + final List shouldNotUpCast = partitioned.getOrDefault(false, ImmutableList.of()); + + // Then: + shouldUpCast.forEach(toBaseType -> assertThat( + "should coerce " + fromBaseType + " to " + toBaseType, + coercer.coerce(getInstance(fromBaseType), getType(toBaseType)), + is(not(Optional.empty())) + )); + + shouldNotUpCast.forEach(toBaseType -> assertThat( + "should not coerce " + fromBaseType + " to " + toBaseType, + coercer.coerce(getInstance(fromBaseType), getType(toBaseType)), + is(Optional.empty()) + )); + } + } + + private static boolean coercionShouldBeSupported( + final SqlBaseType fromBaseType, + final SqlBaseType toBaseType + ) { + if (fromBaseType == SqlBaseType.STRING && toBaseType == SqlBaseType.DECIMAL) { + // Handled by parsing the string to a decimal: + return true; + } + return fromBaseType.canUpCast(toBaseType); + } + + private static List supportedTypes() { + return Arrays.stream(SqlBaseType.values()) + .filter(baseType -> !UNSUPPORTED.contains(baseType)) + .collect(Collectors.toList()); + } + + private static Object getInstance(final SqlBaseType baseType) { + final Object instance = INSTANCES.get(baseType); + assertThat( + "invalid test: need instance for base type:" + baseType, + instance, + is(notNullValue()) + ); + return instance; + } + + private static SqlType getType(final SqlBaseType baseType) { + final SqlType type = TYPES.get(baseType); + assertThat( + "invalid test: need type for base type:" + baseType, + type, + is(notNullValue()) + ); + return type; + } } \ No newline at end of file