Skip to content

Commit

Permalink
fix: scale of ROUND() return value (#6236)
Browse files Browse the repository at this point in the history
fixes: #6233

Fixes the scale of the decimals values returned by the variant of the `ROUND` method that takes a decimal and a required number of decimal places to round to.

This fixes a bug where the output could not be serialized to Avro.

Co-authored-by: Andy Coates <[email protected]>
  • Loading branch information
big-andy-coates and big-andy-coates authored Sep 18, 2020
1 parent 0bf0dca commit 42ab721
Show file tree
Hide file tree
Showing 12 changed files with 836 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ in order to provide compatibility with the previous ROUND() implementation which
we need to use different rounding modes on BigDecimal depending on whether the value
is +ve or -ve to get consistent behaviour.
*/
@SuppressWarnings("MethodMayBeStatic")
@UdfDescription(
name = "Round",
category = FunctionCategory.MATHEMATICAL,
Expand Down Expand Up @@ -100,19 +101,28 @@ public Double round(@UdfParameter final Double val, @UdfParameter final Integer

@Udf(schemaProvider = "provideDecimalSchema")
public BigDecimal round(@UdfParameter final BigDecimal val) {
return round(val, 0);
if (val == null) {
return null;
}
return roundBigDecimal(val, 0);
}

@Udf(schemaProvider = "provideDecimalSchemaWithDecimalPlaces")
public BigDecimal round(
@UdfParameter final BigDecimal val,
@UdfParameter final Integer decimalPlaces
) {
return val == null ? null : roundBigDecimal(val, decimalPlaces);
if (val == null) {
return null;
}
return roundBigDecimal(val, decimalPlaces)
// Must maintain source scale for now. See https://github.com/confluentinc/ksql/issues/6235.
.setScale(val.scale(), RoundingMode.UNNECESSARY);
}

@SuppressWarnings("unused") // Invoked via reflection
@UdfSchemaProvider
public SqlType provideDecimalSchemaWithDecimalPlaces(final List<SqlType> params) {
public static SqlType provideDecimalSchemaWithDecimalPlaces(final List<SqlType> params) {
final SqlType s0 = params.get(0);
if (s0.baseType() != SqlBaseType.DECIMAL) {
throw new KsqlException("The schema provider method for round expects a BigDecimal parameter"
Expand All @@ -123,11 +133,15 @@ public SqlType provideDecimalSchemaWithDecimalPlaces(final List<SqlType> params)
throw new KsqlException("The schema provider method for round expects an Integer parameter"
+ "type as second parameter.");
}

// While the user requested a certain number of decimal places, this can't be used to change
// the scale of the return type. See https://github.com/confluentinc/ksql/issues/6235.
return s0;
}

@SuppressWarnings("unused") // Invoked via reflection
@UdfSchemaProvider
public SqlType provideDecimalSchema(final List<SqlType> params) {
public static SqlType provideDecimalSchema(final List<SqlType> params) {
final SqlType s0 = params.get(0);
if (s0.baseType() != SqlBaseType.DECIMAL) {
throw new KsqlException("The schema provider method for round expects a BigDecimal parameter"
Expand All @@ -137,7 +151,10 @@ public SqlType provideDecimalSchema(final List<SqlType> params) {
return SqlDecimal.of(param.getPrecision() - param.getScale(), 0);
}

private BigDecimal roundBigDecimal(final BigDecimal val, final int decimalPlaces) {
private static BigDecimal roundBigDecimal(
final BigDecimal val,
final int decimalPlaces
) {
final RoundingMode roundingMode = val.compareTo(BigDecimal.ZERO) > 0
? RoundingMode.HALF_UP : RoundingMode.HALF_DOWN;
return val.setScale(decimalPlaces, roundingMode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,67 +152,66 @@ public void shouldRoundDoubleWithDecimalPlacesNegative() {
@Test
public void shouldRoundBigDecimalWithDecimalPlacesPositive() {
assertThat(udf.round(new BigDecimal("0"), 0), is(new BigDecimal("0")));
assertThat(udf.round(new BigDecimal("1.0"), 0), is(new BigDecimal("1")));
assertThat(udf.round(new BigDecimal("1.1"), 0), is(new BigDecimal("1")));
assertThat(udf.round(new BigDecimal("1.5"), 0), is(new BigDecimal("2")));
assertThat(udf.round(new BigDecimal("1.75"), 0), is(new BigDecimal("2")));
assertThat(udf.round(new BigDecimal("100.1"), 0),is(new BigDecimal("100")));
assertThat(udf.round(new BigDecimal("100.5"), 0), is(new BigDecimal("101")));
assertThat(udf.round(new BigDecimal("100.75"), 0), is(new BigDecimal("101")));
assertThat(udf.round(new BigDecimal("100.10"), 1), is(new BigDecimal("100.1")));
assertThat(udf.round(new BigDecimal("100.11"), 1), is(new BigDecimal("100.1")));
assertThat(udf.round(new BigDecimal("100.15"), 1), is(new BigDecimal("100.2")));
assertThat(udf.round(new BigDecimal("100.17"), 1), is(new BigDecimal("100.2")));
assertThat(udf.round(new BigDecimal("100.110"), 2), is(new BigDecimal("100.11")));
assertThat(udf.round(new BigDecimal("100.111"), 2), is(new BigDecimal("100.11")));
assertThat(udf.round(new BigDecimal("100.115"), 2), is(new BigDecimal("100.12")));
assertThat(udf.round(new BigDecimal("100.117"), 2), is(new BigDecimal("100.12")));
assertThat(udf.round(new BigDecimal("100.1110"), 3), is(new BigDecimal("100.111")));
assertThat(udf.round(new BigDecimal("100.1111"), 3), is(new BigDecimal("100.111")));
assertThat(udf.round(new BigDecimal("100.1115"), 3), is(new BigDecimal("100.112")));
assertThat(udf.round(new BigDecimal("100.1117"), 3), is(new BigDecimal("100.112")));
assertThat(udf.round(new BigDecimal("12345.67"), -1), is(new BigDecimal("1.235E4")));
assertThat(udf.round(new BigDecimal("12345.67"), -2), is(new BigDecimal("1.23E4")));
assertThat(udf.round(new BigDecimal("12345.67"), -3), is(new BigDecimal("1.2E4")));
assertThat(udf.round(new BigDecimal("12345.67"), -4), is(new BigDecimal("1E4")));
assertThat(udf.round(new BigDecimal("12345.67"), -5), is(new BigDecimal("0E5")));
assertThat(udf.round(new BigDecimal("1.0"), 0), is(new BigDecimal("1.0")));
assertThat(udf.round(new BigDecimal("1.1"), 0), is(new BigDecimal("1.0")));
assertThat(udf.round(new BigDecimal("1.5"), 0), is(new BigDecimal("2.0")));
assertThat(udf.round(new BigDecimal("1.75"), 0), is(new BigDecimal("2.00")));
assertThat(udf.round(new BigDecimal("100.1"), 0),is(new BigDecimal("100.0")));
assertThat(udf.round(new BigDecimal("100.5"), 0), is(new BigDecimal("101.0")));
assertThat(udf.round(new BigDecimal("100.75"), 0), is(new BigDecimal("101.00")));
assertThat(udf.round(new BigDecimal("100.10"), 1), is(new BigDecimal("100.10")));
assertThat(udf.round(new BigDecimal("100.11"), 1), is(new BigDecimal("100.10")));
assertThat(udf.round(new BigDecimal("100.15"), 1), is(new BigDecimal("100.20")));
assertThat(udf.round(new BigDecimal("100.17"), 1), is(new BigDecimal("100.20")));
assertThat(udf.round(new BigDecimal("100.110"), 2), is(new BigDecimal("100.110")));
assertThat(udf.round(new BigDecimal("100.111"), 2), is(new BigDecimal("100.110")));
assertThat(udf.round(new BigDecimal("100.115"), 2), is(new BigDecimal("100.120")));
assertThat(udf.round(new BigDecimal("100.117"), 2), is(new BigDecimal("100.120")));
assertThat(udf.round(new BigDecimal("100.1110"), 3), is(new BigDecimal("100.1110")));
assertThat(udf.round(new BigDecimal("100.1111"), 3), is(new BigDecimal("100.1110")));
assertThat(udf.round(new BigDecimal("100.1115"), 3), is(new BigDecimal("100.1120")));
assertThat(udf.round(new BigDecimal("100.1117"), 3), is(new BigDecimal("100.1120")));
assertThat(udf.round(new BigDecimal("12345.67"), -1), is(new BigDecimal("12350.00")));
assertThat(udf.round(new BigDecimal("12345.67"), -2), is(new BigDecimal("12300.00")));
assertThat(udf.round(new BigDecimal("12345.67"), -3), is(new BigDecimal("12000.00")));
assertThat(udf.round(new BigDecimal("12345.67"), -4), is(new BigDecimal("10000.00")));
assertThat(udf.round(new BigDecimal("12345.67"), -5), is(new BigDecimal("0.00")));
}

@Test
public void shouldRoundBigDecimalWithDecimalPlacesNegative() {
assertThat(udf.round(new BigDecimal("-1.0"), 0), is(new BigDecimal("-1")));
assertThat(udf.round(new BigDecimal("-1.1"), 0), is(new BigDecimal("-1")));
assertThat(udf.round(new BigDecimal("-1.5"), 0), is(new BigDecimal("-1")));
assertThat(udf.round(new BigDecimal("-1.75"), 0), is(new BigDecimal("-2")));
assertThat(udf.round(new BigDecimal("-100.1"), 0), is(new BigDecimal("-100")));
assertThat(udf.round(new BigDecimal("-100.5"), 0), is(new BigDecimal("-100")));
assertThat(udf.round(new BigDecimal("-100.75"), 0), is(new BigDecimal("-101")));
assertThat(udf.round(new BigDecimal("-100.10"), 1), is(new BigDecimal("-100.1")));
assertThat(udf.round(new BigDecimal("-100.11"), 1), is(new BigDecimal("-100.1")));
assertThat(udf.round(new BigDecimal("-100.15"), 1), is(new BigDecimal("-100.1")));
assertThat(udf.round(new BigDecimal("-100.17"), 1), is(new BigDecimal("-100.2")));
assertThat(udf.round(new BigDecimal("-100.110"), 2), is(new BigDecimal("-100.11")));
assertThat(udf.round(new BigDecimal("-100.111"), 2), is(new BigDecimal("-100.11")));
assertThat(udf.round(new BigDecimal("-100.115"), 2), is(new BigDecimal("-100.11")));
assertThat(udf.round(new BigDecimal("-100.117"), 2), is(new BigDecimal("-100.12")));
assertThat(udf.round(new BigDecimal("-100.1110"), 3), is(new BigDecimal("-100.111")));
assertThat(udf.round(new BigDecimal("-100.1111"), 3), is(new BigDecimal("-100.111")));
assertThat(udf.round(new BigDecimal("-100.1115"), 3), is(new BigDecimal("-100.111")));
assertThat(udf.round(new BigDecimal("-100.1117"), 3), is(new BigDecimal("-100.112")));
assertThat(udf.round(new BigDecimal("-12345.67"), -1), is(new BigDecimal("-1.235E4")));
assertThat(udf.round(new BigDecimal("-12345.67"), -2), is(new BigDecimal("-1.23E4")));
assertThat(udf.round(new BigDecimal("-12345.67"), -3), is(new BigDecimal("-1.2E4")));
assertThat(udf.round(new BigDecimal("-12345.67"), -4), is(new BigDecimal("-1E4")));
assertThat(udf.round(new BigDecimal("-12345.67"), -5), is(new BigDecimal("-0E5")));
assertThat(udf.round(new BigDecimal("-1.0"), 0), is(new BigDecimal("-1.0")));
assertThat(udf.round(new BigDecimal("-1.1"), 0), is(new BigDecimal("-1.0")));
assertThat(udf.round(new BigDecimal("-1.5"), 0), is(new BigDecimal("-1.0")));
assertThat(udf.round(new BigDecimal("-1.75"), 0), is(new BigDecimal("-2.00")));
assertThat(udf.round(new BigDecimal("-100.1"), 0), is(new BigDecimal("-100.0")));
assertThat(udf.round(new BigDecimal("-100.5"), 0), is(new BigDecimal("-100.0")));
assertThat(udf.round(new BigDecimal("-100.75"), 0), is(new BigDecimal("-101.00")));
assertThat(udf.round(new BigDecimal("-100.10"), 1), is(new BigDecimal("-100.10")));
assertThat(udf.round(new BigDecimal("-100.11"), 1), is(new BigDecimal("-100.10")));
assertThat(udf.round(new BigDecimal("-100.15"), 1), is(new BigDecimal("-100.10")));
assertThat(udf.round(new BigDecimal("-100.17"), 1), is(new BigDecimal("-100.20")));
assertThat(udf.round(new BigDecimal("-100.110"), 2), is(new BigDecimal("-100.110")));
assertThat(udf.round(new BigDecimal("-100.111"), 2), is(new BigDecimal("-100.110")));
assertThat(udf.round(new BigDecimal("-100.115"), 2), is(new BigDecimal("-100.110")));
assertThat(udf.round(new BigDecimal("-100.117"), 2), is(new BigDecimal("-100.120")));
assertThat(udf.round(new BigDecimal("-100.1110"), 3), is(new BigDecimal("-100.1110")));
assertThat(udf.round(new BigDecimal("-100.1111"), 3), is(new BigDecimal("-100.1110")));
assertThat(udf.round(new BigDecimal("-100.1115"), 3), is(new BigDecimal("-100.1110")));
assertThat(udf.round(new BigDecimal("-100.1117"), 3), is(new BigDecimal("-100.1120")));
assertThat(udf.round(new BigDecimal("-12345.67"), -2), is(new BigDecimal("-12300.00")));
assertThat(udf.round(new BigDecimal("-12345.67"), -3), is(new BigDecimal("-12000.00")));
assertThat(udf.round(new BigDecimal("-12345.67"), -4), is(new BigDecimal("-10000.00")));
assertThat(udf.round(new BigDecimal("-12345.67"), -5), is(new BigDecimal("0.00")));
}

@Test
public void shouldHandleDoubleLiteralsEndingWith5ThatCannotBeRepresentedExactylyAsDoubles() {
assertThat(udf.round(new BigDecimal("265.335"), 2), is(new BigDecimal("265.34")));
assertThat(udf.round(new BigDecimal("-265.335"), 2), is(new BigDecimal("-265.33")));
public void shouldHandleDoubleLiteralsEndingWith5ThatCannotBeRepresentedExactlyAsDoubles() {
assertThat(udf.round(new BigDecimal("265.335"), 2), is(new BigDecimal("265.340")));
assertThat(udf.round(new BigDecimal("-265.335"), 2), is(new BigDecimal("-265.330")));

assertThat(udf.round(new BigDecimal("265.365"), 2), is(new BigDecimal("265.37")));
assertThat(udf.round(new BigDecimal("-265.365"), 2), is(new BigDecimal("-265.36")));
assertThat(udf.round(new BigDecimal("265.365"), 2), is(new BigDecimal("265.370")));
assertThat(udf.round(new BigDecimal("-265.365"), 2), is(new BigDecimal("-265.360")));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
{
"plan" : [ {
"@type" : "ksqlPlanV1",
"statementText" : "CREATE STREAM TEST (ID STRING KEY, V DECIMAL(33, 16)) WITH (KAFKA_TOPIC='test_topic', VALUE_FORMAT='AVRO');",
"ddlCommand" : {
"@type" : "createStreamV1",
"sourceName" : "TEST",
"schema" : "`ID` STRING KEY, `V` DECIMAL(33, 16)",
"topicName" : "test_topic",
"formats" : {
"keyFormat" : {
"format" : "KAFKA"
},
"valueFormat" : {
"format" : "AVRO"
}
},
"orReplace" : false
}
}, {
"@type" : "ksqlPlanV1",
"statementText" : "CREATE STREAM OUTPUT AS SELECT\n TEST.ID ID,\n ROUND(TEST.V) R0,\n ROUND(TEST.V, 0) R00\nFROM TEST TEST\nEMIT CHANGES",
"ddlCommand" : {
"@type" : "createStreamV1",
"sourceName" : "OUTPUT",
"schema" : "`ID` STRING KEY, `R0` DECIMAL(17, 0), `R00` DECIMAL(33, 16)",
"topicName" : "OUTPUT",
"formats" : {
"keyFormat" : {
"format" : "KAFKA"
},
"valueFormat" : {
"format" : "AVRO"
}
},
"orReplace" : false
},
"queryPlan" : {
"sources" : [ "TEST" ],
"sink" : "OUTPUT",
"physicalPlan" : {
"@type" : "streamSinkV1",
"properties" : {
"queryContext" : "OUTPUT"
},
"source" : {
"@type" : "streamSelectV1",
"properties" : {
"queryContext" : "Project"
},
"source" : {
"@type" : "streamSourceV1",
"properties" : {
"queryContext" : "KsqlTopic/Source"
},
"topicName" : "test_topic",
"formats" : {
"keyFormat" : {
"format" : "KAFKA"
},
"valueFormat" : {
"format" : "AVRO"
}
},
"sourceSchema" : "`ID` STRING KEY, `V` DECIMAL(33, 16)"
},
"keyColumnNames" : [ "ID" ],
"selectExpressions" : [ "ROUND(V) AS R0", "ROUND(V, 0) AS R00" ]
},
"formats" : {
"keyFormat" : {
"format" : "KAFKA"
},
"valueFormat" : {
"format" : "AVRO"
}
},
"topicName" : "OUTPUT"
},
"queryId" : "CSAS_OUTPUT_0"
}
} ],
"configs" : {
"ksql.extension.dir" : "ext",
"ksql.streams.cache.max.bytes.buffering" : "0",
"ksql.security.extension.class" : null,
"metric.reporters" : "",
"ksql.transient.prefix" : "transient_",
"ksql.query.status.running.threshold.seconds" : "300",
"ksql.streams.default.deserialization.exception.handler" : "io.confluent.ksql.errors.LogMetricAndContinueExceptionHandler",
"ksql.output.topic.name.prefix" : "",
"ksql.query.pull.enable.standby.reads" : "false",
"ksql.streams.max.task.idle.ms" : "0",
"ksql.query.error.max.queue.size" : "10",
"ksql.internal.topic.min.insync.replicas" : "1",
"ksql.streams.shutdown.timeout.ms" : "300000",
"ksql.internal.topic.replicas" : "1",
"ksql.insert.into.values.enabled" : "true",
"ksql.query.pull.max.allowed.offset.lag" : "9223372036854775807",
"ksql.query.pull.max.qps" : "2147483647",
"ksql.access.validator.enable" : "auto",
"ksql.streams.bootstrap.servers" : "localhost:0",
"ksql.query.pull.metrics.enabled" : "false",
"ksql.create.or.replace.enabled" : "true",
"ksql.metrics.extension" : null,
"ksql.hidden.topics" : "_confluent.*,__confluent.*,_schemas,__consumer_offsets,__transaction_state,connect-configs,connect-offsets,connect-status,connect-statuses",
"ksql.cast.strings.preserve.nulls" : "true",
"ksql.authorization.cache.max.entries" : "10000",
"ksql.pull.queries.enable" : "true",
"ksql.suppress.enabled" : "false",
"ksql.sink.window.change.log.additional.retention" : "1000000",
"ksql.readonly.topics" : "_confluent.*,__confluent.*,_schemas,__consumer_offsets,__transaction_state,connect-configs,connect-offsets,connect-status,connect-statuses",
"ksql.query.persistent.active.limit" : "2147483647",
"ksql.persistence.wrap.single.values" : null,
"ksql.authorization.cache.expiry.time.secs" : "30",
"ksql.query.retry.backoff.initial.ms" : "15000",
"ksql.schema.registry.url" : "",
"ksql.properties.overrides.denylist" : "",
"ksql.streams.auto.offset.reset" : "earliest",
"ksql.connect.url" : "http://localhost:8083",
"ksql.service.id" : "some.ksql.service.id",
"ksql.streams.default.production.exception.handler" : "io.confluent.ksql.errors.ProductionExceptionHandlerUtil$LogAndFailProductionExceptionHandler",
"ksql.enable.metastore.backup" : "false",
"ksql.streams.commit.interval.ms" : "2000",
"ksql.streams.auto.commit.interval.ms" : "0",
"ksql.streams.topology.optimization" : "all",
"ksql.query.retry.backoff.max.ms" : "900000",
"ksql.streams.num.stream.threads" : "4",
"ksql.timestamp.throw.on.invalid" : "false",
"ksql.metrics.tags.custom" : "",
"ksql.udfs.enabled" : "true",
"ksql.udf.enable.security.manager" : "true",
"ksql.connect.worker.config" : "",
"ksql.udf.collect.metrics" : "false",
"ksql.persistent.prefix" : "query_",
"ksql.metastore.backup.location" : "",
"ksql.error.classifier.regex" : "",
"ksql.suppress.buffer.size.bytes" : "-1"
}
}
Loading

0 comments on commit 42ab721

Please sign in to comment.