diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/topkdistinct/TopkDistinctAggFunctionFactory.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/topkdistinct/TopkDistinctAggFunctionFactory.java index c34dbac3b951..0d73ae2696af 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/topkdistinct/TopkDistinctAggFunctionFactory.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/topkdistinct/TopkDistinctAggFunctionFactory.java @@ -15,13 +15,11 @@ package io.confluent.ksql.function.udaf.topkdistinct; -import com.google.common.collect.ImmutableList; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import io.confluent.ksql.function.AggregateFunctionFactory; import io.confluent.ksql.function.AggregateFunctionInitArguments; import io.confluent.ksql.function.KsqlAggregateFunction; import io.confluent.ksql.function.types.ParamType; -import io.confluent.ksql.function.types.ParamTypes; import io.confluent.ksql.schema.ksql.SchemaConverters; import io.confluent.ksql.schema.ksql.SqlArgument; import io.confluent.ksql.schema.ksql.types.SqlType; @@ -32,15 +30,6 @@ public class TopkDistinctAggFunctionFactory extends AggregateFunctionFactory { private static final String NAME = "TOPKDISTINCT"; - private static final ImmutableList> SUPPORTED_TYPES = ImmutableList - .>builder() - .add(ImmutableList.of(ParamTypes.INTEGER)) - .add(ImmutableList.of(ParamTypes.LONG)) - .add(ImmutableList.of(ParamTypes.DOUBLE)) - .add(ImmutableList.of(ParamTypes.STRING)) - .add(ImmutableList.of(ParamTypes.DECIMAL)) - .build(); - public TopkDistinctAggFunctionFactory() { super(NAME); } @@ -48,7 +37,7 @@ public TopkDistinctAggFunctionFactory() { private static final AggregateFunctionInitArguments DEFAULT_INIT_ARGS = new AggregateFunctionInitArguments(0, 1); - @SuppressWarnings("unchecked") + @SuppressWarnings({"unchecked", "CyclomaticComplexity"}) @Override public KsqlAggregateFunction createAggregateFunction( final List argTypeList, @@ -64,6 +53,10 @@ public KsqlAggregateFunction createAggregateFunction( case DOUBLE: case STRING: case DECIMAL: + case BYTES: + case DATE: + case TIME: + case TIMESTAMP: return new TopkDistinctKudaf( NAME, initArgs.udafIndex(), tkValFromArg, argSchema, SchemaConverters.sqlToFunctionConverter().toFunctionType(argSchema), @@ -80,9 +73,9 @@ public KsqlAggregateFunction createAggregateFunction( } @Override - @SuppressFBWarnings(value = "EI_EXPOSE_REP", justification = "SUPPORTED_TYPES is ImmutableList") + @SuppressFBWarnings(value = "EI_EXPOSE_REP", justification = "COMPARABLE_ARGS is ImmutableList") public List> supportedArgs() { - return SUPPORTED_TYPES; + return COMPARABLE_ARGS; } @Override diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Greatest.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Greatest.java index 54761a1ed37f..fdb2bc4c3783 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Greatest.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Greatest.java @@ -26,6 +26,10 @@ import io.confluent.ksql.util.DecimalUtil; import io.confluent.ksql.util.KsqlConstants; import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.sql.Date; +import java.sql.Time; +import java.sql.Timestamp; import java.util.Arrays; import java.util.Comparator; import java.util.List; @@ -70,13 +74,50 @@ public Double greatest(@UdfParameter final Double val, @UdfParameter final Doubl @Udf public String greatest(@UdfParameter final String val, @UdfParameter final String... vals) { - return (vals == null) ? null : Streams.concat(Stream.of(val), Arrays.stream(vals)) .filter(Objects::nonNull) .max(String::compareTo) .orElse(null); } + @Udf + public ByteBuffer greatest(@UdfParameter final ByteBuffer val, + @UdfParameter final ByteBuffer... vals) { + + return (vals == null) ? null : Streams.concat(Stream.of(val), Arrays.stream(vals)) + .filter(Objects::nonNull) + .max(ByteBuffer::compareTo) + .orElse(null); + } + + @Udf + public Date greatest(@UdfParameter final Date val, @UdfParameter final Date... vals) { + + return (vals == null) ? null : Streams.concat(Stream.of(val), Arrays.stream(vals)) + .filter(Objects::nonNull) + .max(Date::compareTo) + .orElse(null); + } + + @Udf + public Time greatest(@UdfParameter final Time val, @UdfParameter final Time... vals) { + + return (vals == null) ? null : Streams.concat(Stream.of(val), Arrays.stream(vals)) + .filter(Objects::nonNull) + .max(Time::compareTo) + .orElse(null); + } + + @Udf + public Timestamp greatest(@UdfParameter final Timestamp val, + @UdfParameter final Timestamp... vals) { + + return (vals == null) ? null : Streams.concat(Stream.of(val), Arrays.stream(vals)) + .filter(Objects::nonNull) + .max(Timestamp::compareTo) + .orElse(null); + } + @Udf(schemaProvider = "greatestDecimalProvider") public BigDecimal greatest(@UdfParameter final BigDecimal val, @UdfParameter final BigDecimal... vals) { diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Least.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Least.java index e8a3743862fa..2f17b80db600 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Least.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Least.java @@ -26,6 +26,10 @@ import io.confluent.ksql.util.DecimalUtil; import io.confluent.ksql.util.KsqlConstants; import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.sql.Date; +import java.sql.Time; +import java.sql.Timestamp; import java.util.Arrays; import java.util.Comparator; import java.util.List; @@ -76,6 +80,43 @@ public String least(@UdfParameter final String val, @UdfParameter final String.. .orElse(null); } + @Udf + public ByteBuffer least(@UdfParameter final ByteBuffer val, + @UdfParameter final ByteBuffer... vals) { + + return (vals == null) ? null : Streams.concat(Stream.of(val), Arrays.stream(vals)) + .filter(Objects::nonNull) + .min(ByteBuffer::compareTo) + .orElse(null); + } + + @Udf + public Date least(@UdfParameter final Date val, @UdfParameter final Date... vals) { + + return (vals == null) ? null : Streams.concat(Stream.of(val), Arrays.stream(vals)) + .filter(Objects::nonNull) + .min(Date::compareTo) + .orElse(null); + } + + @Udf + public Time least(@UdfParameter final Time val, @UdfParameter final Time... vals) { + + return (vals == null) ? null : Streams.concat(Stream.of(val), Arrays.stream(vals)) + .filter(Objects::nonNull) + .min(Time::compareTo) + .orElse(null); + } + + @Udf + public Timestamp least(@UdfParameter final Timestamp val, @UdfParameter final Timestamp... vals) { + + return (vals == null) ? null : Streams.concat(Stream.of(val), Arrays.stream(vals)) + .filter(Objects::nonNull) + .min(Timestamp::compareTo) + .orElse(null); + } + @Udf(schemaProvider = "leastDecimalProvider") public BigDecimal least(@UdfParameter final BigDecimal val, @UdfParameter final BigDecimal... vals) { @@ -90,10 +131,10 @@ public BigDecimal least(@UdfParameter final BigDecimal val, public SqlType leastDecimalProvider(final List params) { return params.stream() - .filter(s -> s.getSqlType().isPresent()) - .map(SqlArgument::getSqlTypeOrThrow) - .reduce(DecimalUtil::widen) - .orElse(null); + .filter(s -> s.getSqlType().isPresent()) + .map(SqlArgument::getSqlTypeOrThrow) + .reduce(DecimalUtil::widen) + .orElse(null); } } \ No newline at end of file diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/function/udaf/topkdistinct/BytesTopKDistinctKudafTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udaf/topkdistinct/BytesTopKDistinctKudafTest.java new file mode 100644 index 000000000000..6b7d51bb583d --- /dev/null +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udaf/topkdistinct/BytesTopKDistinctKudafTest.java @@ -0,0 +1,106 @@ +/* + * Copyright 2022 Confluent Inc. + * + * Licensed under the Confluent Community License; you may not use this file + * except in compliance with the License. You may obtain a copy of the License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.function.udaf.topkdistinct; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.function.udf.string.ToBytes; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.util.BytesUtils; +import org.junit.Before; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; + +public class BytesTopKDistinctKudafTest { + private final List valuesArray = ImmutableList.of("A", "D", "F", "A", "G", "H", "B", "H", + "I", "E", "C", "H", "I"); + private final TopkDistinctKudaf bytesTopkDistinctKudaf + = TopKDistinctTestUtils.getTopKDistinctKudaf(3, SqlTypes.BYTES); + private ToBytes toBytesUDF; + + @Before + public void setUp() { + toBytesUDF = new ToBytes(); + } + + @Test + public void shouldAggregateTopK() { + List currentVal = new ArrayList<>(); + for (final String d : valuesArray) { + currentVal = bytesTopkDistinctKudaf.aggregate(toBytes(d), currentVal); + } + + List expected = toBytes(ImmutableList.of("I", "H", "G")); + assertThat("Invalid results.", currentVal, equalTo(expected)); + } + + @Test + public void shouldAggregateTopKWithLessThanKValues() { + List currentVal = new ArrayList<>(); + currentVal = bytesTopkDistinctKudaf.aggregate(toBytes("I"), currentVal); + + assertThat("Invalid results.", currentVal, equalTo(toBytes(ImmutableList.of("I")))); + } + + @Test + public void shouldMergeTopK() { + final List array1 = toBytes(ImmutableList.of("D", "B", "A")); + final List array2 = toBytes(ImmutableList.of("E", "D", "C")); + + assertThat("Invalid results.", bytesTopkDistinctKudaf.getMerger().apply(null, array1, array2), + equalTo(toBytes(ImmutableList.of("E", "D", "C")))); + } + + @Test + public void shouldMergeTopKWithNulls() { + final List array1 = toBytes(ImmutableList.of("B", "A")); + final List array2 = toBytes(ImmutableList.of("C")); + + assertThat("Invalid results.", bytesTopkDistinctKudaf.getMerger().apply(null, array1, array2), + equalTo(toBytes(ImmutableList.of("C", "B", "A")))); + } + + @Test + public void shouldMergeTopKWithNullsDuplicates() { + final List array1 = toBytes(ImmutableList.of("B", "A")); + final List array2 = toBytes(ImmutableList.of("C", "B")); + + assertThat("Invalid results.", bytesTopkDistinctKudaf.getMerger().apply(null, array1, array2), + equalTo(toBytes(ImmutableList.of("C", "B", "A")))); + } + + @Test + public void shouldMergeTopKWithMoreNulls() { + final List array1 = toBytes(ImmutableList.of("A")); + final List array2 = toBytes(ImmutableList.of("A")); + + assertThat("Invalid results.", bytesTopkDistinctKudaf.getMerger().apply(null, array1, array2), + equalTo(toBytes(ImmutableList.of("A")))); + } + + private ByteBuffer toBytes(final String val) { + return toBytesUDF.toBytes(val, BytesUtils.Encoding.ASCII.toString()); + } + + private List toBytes(final List vals) { + return vals.stream().map(this::toBytes).collect(Collectors.toList()); + } +} diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/function/udaf/topkdistinct/DateTopKDistinctKudafTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udaf/topkdistinct/DateTopKDistinctKudafTest.java new file mode 100644 index 000000000000..d7682024887d --- /dev/null +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udaf/topkdistinct/DateTopKDistinctKudafTest.java @@ -0,0 +1,90 @@ +/* + * Copyright 2022 Confluent Inc. + * + * Licensed under the Confluent Community License; you may not use this file + * except in compliance with the License. You may obtain a copy of the License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.function.udaf.topkdistinct; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import org.junit.Test; + +import java.sql.Date; +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; + +public class DateTopKDistinctKudafTest { + + private final List valuesArray = ImmutableList.of(new Date(10), new Date(30), new Date(45), + new Date(10), new Date(50), new Date(60), new Date(20), new Date(60), new Date(80), + new Date(35), new Date(25), new Date(60), new Date(80)); + private final TopkDistinctKudaf dateTopkDistinctKudaf + = TopKDistinctTestUtils.getTopKDistinctKudaf(3, SqlTypes.DATE); + + @Test + public void shouldAggregateTopK() { + List currentVal = new ArrayList<>(); + for (final Date d : valuesArray) { + currentVal = dateTopkDistinctKudaf.aggregate(d, currentVal); + } + + assertThat("Invalid results.", currentVal, + equalTo(ImmutableList.of(new Date(80), new Date(60), new Date(50)))); + } + + @Test + public void shouldAggregateTopKWithLessThanKValues() { + List currentVal = new ArrayList<>(); + currentVal = dateTopkDistinctKudaf.aggregate(new Date(80), currentVal); + + assertThat("Invalid results.", currentVal, equalTo(ImmutableList.of(new Date(80)))); + } + + @Test + public void shouldMergeTopK() { + final List array1 = ImmutableList.of(new Date(50), new Date(45), new Date(25)); + final List array2 = ImmutableList.of(new Date(60), new Date(50), new Date(48)); + + assertThat("Invalid results.", dateTopkDistinctKudaf.getMerger().apply(null, array1, array2), + equalTo(ImmutableList.of(new Date(60), new Date(50), new Date(48)))); + } + + @Test + public void shouldMergeTopKWithNulls() { + final List array1 = ImmutableList.of(new Date(50), new Date(45)); + final List array2 = ImmutableList.of(new Date(60)); + + assertThat("Invalid results.", dateTopkDistinctKudaf.getMerger().apply(null, array1, array2), + equalTo(ImmutableList.of(new Date(60), new Date(50), new Date(45)))); + } + + @Test + public void shouldMergeTopKWithNullsDuplicates() { + final List array1 = ImmutableList.of(new Date(50), new Date(45)); + final List array2 = ImmutableList.of(new Date(60), new Date(50)); + + assertThat("Invalid results.", dateTopkDistinctKudaf.getMerger().apply(null, array1, array2), + equalTo(ImmutableList.of(new Date(60), new Date(50), new Date(45)))); + } + + @Test + public void shouldMergeTopKWithMoreNulls() { + final List array1 = ImmutableList.of(new Date(60)); + final List array2 = ImmutableList.of(new Date(60)); + + assertThat("Invalid results.", dateTopkDistinctKudaf.getMerger().apply(null, array1, array2), + equalTo(ImmutableList.of(new Date(60)))); + } +} diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/function/udaf/topkdistinct/TimeTopKDistinctKudafTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udaf/topkdistinct/TimeTopKDistinctKudafTest.java new file mode 100644 index 000000000000..1464d2f719a9 --- /dev/null +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udaf/topkdistinct/TimeTopKDistinctKudafTest.java @@ -0,0 +1,89 @@ +/* + * Copyright 2022 Confluent Inc. + * + * Licensed under the Confluent Community License; you may not use this file + * except in compliance with the License. You may obtain a copy of the License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.function.udaf.topkdistinct; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import org.junit.Test; + +import java.sql.Time; +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; + +public class TimeTopKDistinctKudafTest { + private final List