Skip to content

Commit

Permalink
fix: Add null handling to functions (#8726)
Browse files Browse the repository at this point in the history
Address #8545 by
adding null to each function which does not have them.
  • Loading branch information
jnh5y authored Feb 8, 2022
1 parent 593f25b commit 6117604
Show file tree
Hide file tree
Showing 53 changed files with 459 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public <T extends Comparable<? super T>> List<T> arraySortWithDirection(@UdfPara
description = "The array to sort") final List<T> input,
@UdfParameter(
description = "Marks the end of the series (inclusive)") final String direction) {
if (input == null) {
if (input == null || direction == null) {
return null;
}
if (SORT_DIRECTION_ASC.contains(direction.toUpperCase())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public Timestamp convertTz(
description = "The toTimeZone in java.util.TimeZone ID format. For example: \"UTC\","
+ " \"America/Los_Angeles\", \"PST\", \"Europe/London\"") final String toTimeZone
) {
if (timestamp == null) {
if (timestamp == null || fromTimeZone == null || toTimeZone == null) {
return null;
}
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ public String dateToString(
@UdfParameter(
description = "The format pattern should be in the format expected by"
+ " java.time.format.DateTimeFormatter.") final String formatPattern) {
if (formatPattern == null) {
return null;
}
try {
final DateTimeFormatter formatter = formatters.get(formatPattern);
return LocalDate.ofEpochDay(epochDays).format(formatter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public String formatDate(
@UdfParameter(
description = "The format pattern should be in the format expected by"
+ " java.time.format.DateTimeFormatter.") final String formatPattern) {
if (date == null) {
if (date == null || formatPattern == null) {
return null;
}
try {
Expand All @@ -66,5 +66,4 @@ public String formatDate(
+ "': " + e.getMessage(), e);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public String formatTime(
@UdfParameter(
description = "The format pattern should be in the format expected by"
+ " java.time.format.DateTimeFormatter.") final String formatPattern) {
if (time == null) {
if (time == null || formatPattern == null) {
return null;
}
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public String formatTimestamp(
@UdfParameter(
description = " timeZone is a java.util.TimeZone ID format, for example: \"UTC\","
+ " \"America/Los_Angeles\", \"PST\", \"Europe/London\"") final String timeZone) {
if (timestamp == null) {
if (timestamp == null || formatPattern == null || timeZone == null) {
return null;
}
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ public Date parseDate(
@UdfParameter(
description = "The format pattern should be in the format expected by"
+ " java.text.SimpleDateFormat.") final String formatPattern) {
if (formattedDate == null || formatPattern == null) {
return null;
}
try {
final long time = formatters.get(formatPattern).parse(formattedDate).getTime();
if (time % MILLIS_IN_DAY != 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ public Time parseTime(
@UdfParameter(
description = "The format pattern should be in the format expected by"
+ " java.time.format.DateTimeFormatter.") final String formatPattern) {
if (formattedTime == null | formatPattern == null) {
return null;
}
try {
final TemporalAccessor ta = formatters.get(formatPattern).parse(formattedTime);
final Optional<ChronoField> dateField = Arrays.stream(ChronoField.values())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ public Timestamp parseTimestamp(
@UdfParameter(
description = " timeZone is a java.util.TimeZone ID format, for example: \"UTC\","
+ " \"America/Los_Angeles\", \"PST\", \"Europe/London\"") final String timeZone) {
if (formattedTimestamp == null || formatPattern == null || timeZone == null) {
return null;
}
try {
final StringToTimestampParser timestampParser = parsers.get(formatPattern);
final ZoneId zoneId = ZoneId.of(timeZone);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ public String timestampToString(
@UdfParameter(
description = "The format pattern should be in the format expected by"
+ " java.time.format.DateTimeFormatter.") final String formatPattern) {
if (formatPattern == null) {
return null;
}
try {
final Timestamp timestamp = new Timestamp(epochMilli);
final DateTimeFormatter formatter = formatters.get(formatPattern);
Expand Down Expand Up @@ -84,6 +87,9 @@ public String timestampToString(
@UdfParameter(
description = " timeZone is a java.util.TimeZone ID format, for example: \"UTC\","
+ " \"America/Los_Angeles\", \"PST\", \"Europe/London\"") final String timeZone) {
if (formatPattern == null || timeZone == null) {
return null;
}
try {
final Timestamp timestamp = new Timestamp(epochMilli);
final DateTimeFormatter formatter = formatters.get(formatPattern);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ public <T> Boolean contains(

parser.skipChildren();
if (TOKEN_COMPAT.getOrDefault(token, foo -> false).test(val)) {
if (token == VALUE_NULL || Objects.equals(parser.readValueAs(val.getClass()), val)) {
if (token == VALUE_NULL
|| (val != null && Objects.equals(parser.readValueAs(val.getClass()), val))) {
return true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
public class JsonConcat {
@Udf
public String concat(@UdfParameter final String... jsonStrings) {
if (jsonStrings == null) {
return null;
}
final List<JsonNode> nodes = new ArrayList<>(jsonStrings.length);
boolean allObjects = true;
for (final String jsonString : jsonStrings) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public <T> List<T> filterArray(
@UdfParameter(description = "The array") final List<T> array,
@UdfParameter(description = "The lambda function") final Function<T, Boolean> function
) {
if (array == null) {
if (array == null || function == null) {
return null;
}
return array.stream().filter(function::apply).collect(Collectors.toList());
Expand All @@ -62,7 +62,7 @@ public <K, V> Map<K, V> filterMap(
@UdfParameter(description = "The map") final Map<K, V> map,
@UdfParameter(description = "The lambda function") final BiFunction<K, V, Boolean> biFunction
) {
if (map == null) {
if (map == null || biFunction == null) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public <T,S> S reduceArray(
@UdfParameter(description = "The initial state.") final S initialState,
@UdfParameter(description = "The reduce function.") final BiFunction<S, T, S> biFunction
) {
if (initialState == null) {
if (initialState == null || biFunction == null) {
return null;
}

Expand All @@ -79,7 +79,7 @@ public <K,V,S> S reduceMap(
@UdfParameter(description = "The initial state.") final S initialState,
@UdfParameter(description = "The reduce function.") final TriFunction<S, K, V, S> triFunction
) {
if (initialState == null) {
if (initialState == null || triFunction == null) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public <T, R> List<R> transformArray(
@UdfParameter(description = "The array") final List<T> array,
@UdfParameter(description = "The lambda function") final Function<T, R> function
) {
if (array == null) {
if (array == null || function == null) {
return null;
}
return array.stream().map(function).collect(Collectors.toList());
Expand All @@ -67,7 +67,7 @@ public <K,V,R,T> Map<R,T> transformMap(
@UdfParameter(description = "The key lambda function") final BiFunction<K, V, R> biFunction1,
@UdfParameter(description = "The value lambda function") final BiFunction<K, V, T> biFunction2
) {
if (map == null) {
if (map == null || biFunction1 == null || biFunction2 == null) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ public class AsMap {
public final <T> Map<String, T> asMap(
@UdfParameter final List<String> keys,
@UdfParameter final List<T> values) {
final Map<String, T> map = new HashMap<>();
if (keys == null || values == null) {
return null;
}
final Map<String, T> map = new HashMap<>(keys.size());
for (int i = 0; i < keys.size(); i++) {
final String key = keys.get(i);
final T value = i >= values.size() ? null : values.get(i);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class Greatest {
@Udf
public Integer greatest(@UdfParameter final Integer val, @UdfParameter final Integer... vals) {

return Stream.concat(Stream.of(val), Arrays.stream(vals))
return (vals == null) ? null : Stream.concat(Stream.of(val), Arrays.stream(vals))
.filter(Objects::nonNull)
.max(Integer::compareTo)
.orElse(null);
Expand All @@ -53,7 +53,7 @@ public Integer greatest(@UdfParameter final Integer val, @UdfParameter final Int
@Udf
public Long greatest(@UdfParameter final Long val, @UdfParameter final Long... vals) {

return Stream.concat(Stream.of(val), Arrays.stream(vals))
return (vals == null) ? null : Stream.concat(Stream.of(val), Arrays.stream(vals))
.filter(Objects::nonNull)
.max(Long::compareTo)
.orElse(null);
Expand All @@ -62,7 +62,7 @@ public Long greatest(@UdfParameter final Long val, @UdfParameter final Long... v
@Udf
public Double greatest(@UdfParameter final Double val, @UdfParameter final Double... vals) {

return Streams.concat(Stream.of(val), Arrays.stream(vals))
return (vals == null) ? null : Streams.concat(Stream.of(val), Arrays.stream(vals))
.filter(Objects::nonNull)
.max(Double::compareTo)
.orElse(null);
Expand All @@ -71,7 +71,7 @@ public Double greatest(@UdfParameter final Double val, @UdfParameter final Doubl
@Udf
public String greatest(@UdfParameter final String val, @UdfParameter final String... vals) {

return Streams.concat(Stream.of(val), Arrays.stream(vals))
return (vals == null) ? null : Streams.concat(Stream.of(val), Arrays.stream(vals))
.filter(Objects::nonNull)
.max(String::compareTo)
.orElse(null);
Expand All @@ -81,7 +81,7 @@ public String greatest(@UdfParameter final String val, @UdfParameter final Strin
public BigDecimal greatest(@UdfParameter final BigDecimal val,
@UdfParameter final BigDecimal... vals) {

return Streams.concat(Stream.of(val), Arrays.stream(vals))
return (vals == null) ? null : Streams.concat(Stream.of(val), Arrays.stream(vals))
.filter(Objects::nonNull)
.max(Comparator.naturalOrder())
.orElse(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public class Least {
@Udf
public Integer least(@UdfParameter final Integer val, @UdfParameter final Integer... vals) {

return Stream.concat(Stream.of(val), Arrays.stream(vals))
return (vals == null) ? null : Stream.concat(Stream.of(val), Arrays.stream(vals))
.filter(Objects::nonNull)
.min(Integer::compareTo)
.orElse(null);
Expand All @@ -52,7 +52,7 @@ public Integer least(@UdfParameter final Integer val, @UdfParameter final Intege
@Udf
public Long least(@UdfParameter final Long val, @UdfParameter final Long... vals) {

return Stream.concat(Stream.of(val), Arrays.stream(vals))
return (vals == null) ? null : Stream.concat(Stream.of(val), Arrays.stream(vals))
.filter(Objects::nonNull)
.min(Long::compareTo)
.orElse(null);
Expand All @@ -61,7 +61,7 @@ public Long least(@UdfParameter final Long val, @UdfParameter final Long... vals
@Udf
public Double least(@UdfParameter final Double val, @UdfParameter final Double... vals) {

return Streams.concat(Stream.of(val), Arrays.stream(vals))
return (vals == null) ? null : Streams.concat(Stream.of(val), Arrays.stream(vals))
.filter(Objects::nonNull)
.min(Double::compareTo)
.orElse(null);
Expand All @@ -70,7 +70,7 @@ public Double least(@UdfParameter final Double val, @UdfParameter final Double..
@Udf
public String least(@UdfParameter final String val, @UdfParameter final String... vals) {

return Streams.concat(Stream.of(val), Arrays.stream(vals))
return (vals == null) ? null : Streams.concat(Stream.of(val), Arrays.stream(vals))
.filter(Objects::nonNull)
.min(String::compareTo)
.orElse(null);
Expand All @@ -80,7 +80,7 @@ public String least(@UdfParameter final String val, @UdfParameter final String..
public BigDecimal least(@UdfParameter final BigDecimal val,
@UdfParameter final BigDecimal... vals) {

return Streams.concat(Stream.of(val), Arrays.stream(vals))
return (vals == null) ? null : Streams.concat(Stream.of(val), Arrays.stream(vals))
.filter(Objects::nonNull)
.min(Comparator.naturalOrder())
.orElse(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public Long round(@UdfParameter final Double val) {

@Udf
public Double round(@UdfParameter final Double val, @UdfParameter final Integer decimalPlaces) {
return val == null
return (val == null || decimalPlaces == null)
? null
: roundBigDecimal(BigDecimal.valueOf(val), decimalPlaces).doubleValue();
}
Expand All @@ -113,10 +113,7 @@ public BigDecimal round(
@UdfParameter final BigDecimal val,
@UdfParameter final Integer decimalPlaces
) {
if (val == null) {
return null;
}
return roundBigDecimal(val, decimalPlaces)
return (val == null || decimalPlaces == null) ? null : roundBigDecimal(val, decimalPlaces)
// Must maintain source scale for now. See https://github.com/confluentinc/ksql/issues/6235.
.setScale(val.scale(), RoundingMode.UNNECESSARY);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ public String elt(
@UdfParameter(description = "the nth element to extract") final int n,
@UdfParameter(description = "the strings of which to extract the nth") final String... args
) {
if (args == null) {
return null;
}
if (n < 1 || n > args.length) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public int field(
@UdfParameter final String str,
@UdfParameter final String... args
) {
if (str == null) {
if (str == null || args == null) {
return 0;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ public class FromBytes {
public String fromBytes(
@UdfParameter(description = "The bytes value to convert.") final ByteBuffer value,
@UdfParameter(description = "The encoding to use on conversion.") final String encoding) {
return (value == null) ? null : BytesUtils.encode(BytesUtils.getByteArray(value),
return (value == null || encoding == null)
? null : BytesUtils.encode(BytesUtils.getByteArray(value),
BytesUtils.Encoding.from(encoding));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class ToBytes {
public ByteBuffer toBytes(
@UdfParameter(description = "The string to convert.") final String value,
@UdfParameter(description = "The type of encoding.") final String encoding) {
return (value == null) ? null : ByteBuffer.wrap(BytesUtils.decode(value,
return (value == null || encoding == null) ? null : ByteBuffer.wrap(BytesUtils.decode(value,
BytesUtils.Encoding.from(encoding)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ public class UrlDecodeParam {
@Udf
public String decodeParam(
@UdfParameter(description = "the value to decode") final String input) {
if (input == null) {
return null;
}
try {
return URLDecoder.decode(input, UTF_8.name());
} catch (final UnsupportedEncodingException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ public class UrlEncodeParam {
@Udf
public String encodeParam(
@UdfParameter(description = "the value to encode") final String input) {
if (input == null) {
return null;
}
final Escaper escaper = UrlEscapers.urlFormParameterEscaper();
return escaper.escape(input);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,20 @@ public void shouldSortIntsAscending() {
assertThat(output, contains(-2, 1, 3));
}

@Test
public void shouldReturnNullWithBadSortDirection() {
final List<Integer> input = Arrays.asList(1, 3, -2);
final List<Integer> output = udf.arraySortWithDirection(input, "ASCDESC");
assertThat(output, is(nullValue()));
}

@Test
public void shouldReturnNullWithNullSortDirection() {
final List<Integer> input = Arrays.asList(1, 3, -2);
final List<Integer> output = udf.arraySortWithDirection(input, null);
assertThat(output, is(nullValue()));
}

@Test
public void shouldSortIntsDescending() {
final List<Integer> input = Arrays.asList(1, 3, -2);
Expand Down
Loading

0 comments on commit 6117604

Please sign in to comment.