Skip to content

Commit

Permalink
feat: Add schema resolver method to UDF specification (#3215)
Browse files Browse the repository at this point in the history
* feat: Adding udf schema resolver

* 1. Added jar for testing loading erroneouds udfs that should fail 2. Test Abs udf

* updated documentation, added one extra test case

* Fixed doc comments, added better schema comparison, fixed exception messages

* added tests for SchemaUtil, changed wrong date
  • Loading branch information
vpapavas authored Aug 26, 2019
1 parent 544fc28 commit 08855ad
Show file tree
Hide file tree
Showing 24 changed files with 776 additions and 119 deletions.
18 changes: 16 additions & 2 deletions docs/developer-guide/udf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ Follow these steps to create your custom functions:
For a detailed walkthrough on creating a UDF, see :ref:`implement-a-udf`.

======================
Creating UDF and UDAFs
Creating UDFs and UDAFs
======================

KSQL supports creating User Defined Scalar Functions (UDFs) and User Defined Aggregate Functions (UDAF) via custom jars that are
KSQL supports creating User Defined Scalar Functions (UDFs) and User Defined Aggregate Functions (UDAFs) via custom jars that are
uploaded to the ``ext/`` directory of the KSQL installation.
At start up time KSQL scans the jars in the directory looking for any classes that annotated
with ``@UdfDescription`` (UDF) or ``@UdafDescription`` (UDAF).
Expand Down Expand Up @@ -104,6 +104,20 @@ The KSQL server will check the value being passed to each parameter and report a
log for any null values being passed to a primitive type. The associated column in the output row
will be ``null``.


Dynamic return type
~~~~~~~~~~~~~~~~~~~

UDFs support dynamic return types that are resolved at runtime. This is useful if you want to
implement a UDF with a non-deterministic return type. A UDF which returns ``BigDecimal``,
for example, may vary the precision and scale of the output based on the input schema.

To use this functionality, you need to specify a method with signature
``public Schema <your-method-name>(final List<Schema> params)`` and annotate it with ``@SchemaProvider``.
Also, you need to link it to the corresponding UDF by using the ``schemaProvider=<your-method-name>``
parameter of the ``@Udf`` annotation.


Generics in UDFS
~~~~~~~~~~~~~~~~

Expand Down
114 changes: 71 additions & 43 deletions ksql-common/src/main/java/io/confluent/ksql/function/KsqlFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.confluent.ksql.function.udf.Kudf;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -36,15 +37,52 @@ public final class KsqlFunction implements IndexedFunction {

static final String INTERNAL_PATH = "internal";

private final Schema returnType;
private final Function<List<Schema>,Schema> returnSchemaProvider;
private final Schema javaReturnType;
private final List<Schema> parameters;
private final String functionName;
private final Class<? extends Kudf> kudfClass;
private final Function<KsqlConfig, Kudf> udfFactory;
private final String description;
private final String pathLoadedFrom;
private final boolean isVariadic;
private final boolean hasGenerics;

private KsqlFunction(
final Function<List<Schema>,Schema> returnSchemaProvider,
final Schema javaReturnType,
final List<Schema> arguments,
final String functionName,
final Class<? extends Kudf> kudfClass,
final Function<KsqlConfig, Kudf> udfFactory,
final String description,
final String pathLoadedFrom,
final boolean isVariadic) {

this.returnSchemaProvider = Objects.requireNonNull(returnSchemaProvider, "schemaProvider");
this.javaReturnType = Objects.requireNonNull(javaReturnType, "javaReturnType");
this.parameters = ImmutableList.copyOf(Objects.requireNonNull(arguments, "arguments"));
this.functionName = Objects.requireNonNull(functionName, "functionName");
this.kudfClass = Objects.requireNonNull(kudfClass, "kudfClass");
this.udfFactory = Objects.requireNonNull(udfFactory, "udfFactory");
this.description = Objects.requireNonNull(description, "description");
this.pathLoadedFrom = Objects.requireNonNull(pathLoadedFrom, "pathLoadedFrom");
this.isVariadic = isVariadic;


if (arguments.stream().anyMatch(Objects::isNull)) {
throw new IllegalArgumentException("KSQL Function can't have null argument types");
}
if (isVariadic) {
if (arguments.isEmpty()) {
throw new IllegalArgumentException(
"KSQL variadic functions must have at least one parameter");
}
if (!Iterables.getLast(arguments).type().equals(Type.ARRAY)) {
throw new IllegalArgumentException(
"KSQL variadic functions must have ARRAY type as their last parameter");
}
}
}

/**
* Create built in / legacy function.
Expand All @@ -66,7 +104,8 @@ public static KsqlFunction createLegacyBuiltIn(
};

return create(
returnType, arguments, functionName, kudfClass, udfFactory, "", INTERNAL_PATH, false);
ignored -> returnType, returnType, arguments, functionName, kudfClass, udfFactory, "",
INTERNAL_PATH, false);
}

/**
Expand All @@ -75,7 +114,8 @@ public static KsqlFunction createLegacyBuiltIn(
* <p>Can be either built-in UDF or true user-supplied.
*/
static KsqlFunction create(
final Schema returnType,
final Function<List<Schema>,Schema> schemaProvider,
final Schema javaReturnType,
final List<Schema> arguments,
final String functionName,
final Class<? extends Kudf> kudfClass,
Expand All @@ -85,7 +125,8 @@ static KsqlFunction create(
final boolean isVariadic
) {
return new KsqlFunction(
returnType,
schemaProvider,
javaReturnType,
arguments,
functionName,
kudfClass,
Expand All @@ -95,47 +136,20 @@ static KsqlFunction create(
isVariadic);
}

private KsqlFunction(
final Schema returnType,
final List<Schema> arguments,
final String functionName,
final Class<? extends Kudf> kudfClass,
final Function<KsqlConfig, Kudf> udfFactory,
final String description,
final String pathLoadedFrom,
final boolean isVariadic) {
this.returnType = Objects.requireNonNull(returnType, "returnType");
this.parameters = ImmutableList.copyOf(Objects.requireNonNull(arguments, "arguments"));
this.functionName = Objects.requireNonNull(functionName, "functionName");
this.kudfClass = Objects.requireNonNull(kudfClass, "kudfClass");
this.udfFactory = Objects.requireNonNull(udfFactory, "udfFactory");
this.description = Objects.requireNonNull(description, "description");
this.pathLoadedFrom = Objects.requireNonNull(pathLoadedFrom, "pathLoadedFrom");
this.isVariadic = isVariadic;
this.hasGenerics = GenericsUtil.hasGenerics(returnType);
public Schema getReturnType(final List<Schema> arguments) {

if (arguments.stream().anyMatch(Objects::isNull)) {
throw new IllegalArgumentException("KSQL Function can't have null argument types");
}
if (isVariadic) {
if (arguments.isEmpty()) {
throw new IllegalArgumentException(
"KSQL variadic functions must have at least one parameter");
}
if (!Iterables.getLast(arguments).type().equals(Type.ARRAY)) {
throw new IllegalArgumentException(
"KSQL variadic functions must have ARRAY type as their last parameter");
}
final Schema returnType = returnSchemaProvider.apply(arguments);

if (returnType == null) {
throw new KsqlException(String.format("Return type of UDF %s cannot be null.", functionName));
}

if (!returnType.isOptional()) {
throw new IllegalArgumentException("KSQL only supports optional field types");
}
}


public Schema getReturnType(final List<Schema> arguments) {
if (!hasGenerics) {
if (!GenericsUtil.hasGenerics(returnType)) {
checkMatchingReturnTypes(returnType, javaReturnType);
return returnType;
}

Expand All @@ -152,7 +166,21 @@ public Schema getReturnType(final List<Schema> arguments) {
genericMapping.putAll(GenericsUtil.resolveGenerics(schema, instance));
}

return GenericsUtil.applyResolved(returnType, genericMapping);
final Schema genericSchema = GenericsUtil.applyResolved(returnType, genericMapping);
final Schema genericJavaSchema = GenericsUtil.applyResolved(javaReturnType, genericMapping);
checkMatchingReturnTypes(genericSchema, genericJavaSchema);

return genericSchema;
}

private void checkMatchingReturnTypes(final Schema s1, final Schema s2) {
if (!SchemaUtil.areCompatible(s1, s2)) {
throw new KsqlException(String.format("Return type %s of UDF %s does not match the declared "
+ "return type %s.",
s1.toString(),
functionName,
s2.toString()));
}
}

public List<Schema> getArguments() {
Expand Down Expand Up @@ -188,7 +216,7 @@ public boolean equals(final Object o) {
return false;
}
final KsqlFunction that = (KsqlFunction) o;
return Objects.equals(returnType, that.returnType)
return Objects.equals(javaReturnType, that.javaReturnType)
&& Objects.equals(parameters, that.parameters)
&& Objects.equals(functionName, that.functionName)
&& Objects.equals(kudfClass, that.kudfClass)
Expand All @@ -199,13 +227,13 @@ public boolean equals(final Object o) {
@Override
public int hashCode() {
return Objects.hash(
returnType, parameters, functionName, kudfClass, pathLoadedFrom, isVariadic);
returnSchemaProvider, parameters, functionName, kudfClass, pathLoadedFrom, isVariadic);
}

@Override
public String toString() {
return "KsqlFunction{"
+ "returnType=" + returnType
+ "returnType=" + javaReturnType
+ ", arguments=" + parameters.stream().map(Schema::type).collect(Collectors.toList())
+ ", functionName='" + functionName + '\''
+ ", kudfClass=" + kudfClass
Expand Down
21 changes: 2 additions & 19 deletions ksql-common/src/main/java/io/confluent/ksql/function/UdfIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
package io.confluent.ksql.function;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.confluent.ksql.schema.connect.SqlSchemaFormatter;
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
Expand All @@ -29,11 +29,9 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.function.BiPredicate;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.Schema.Type;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -258,14 +256,6 @@ int compare(final Node other) {
*/
static final class Parameter {

private static final Map<Type, BiPredicate<Schema, Schema>> CUSTOM_SCHEMA_EQ =
ImmutableMap.<Type, BiPredicate<Schema, Schema>>builder()
.put(Type.MAP, Parameter::mapEquals)
.put(Type.ARRAY, Parameter::arrayEquals)
.put(Type.STRUCT, Parameter::structEquals)
.put(Type.BYTES, Parameter::bytesEquals)
.build();

private final Schema schema;
private final boolean isVararg;

Expand Down Expand Up @@ -311,14 +301,7 @@ boolean accepts(final Schema argument, final Map<Schema, Schema> reservedGeneric
return reserveGenerics(schema, argument, reservedGenerics);
}

final Schema.Type type = schema.type();

// we require a custom equals method that ignores certain values (e.g.
// whether or not the schema is optional, and the documentation)
return Objects.equals(type, argument.type())
&& CUSTOM_SCHEMA_EQ.getOrDefault(type, (a, b) -> true).test(schema, argument)
&& Objects.equals(schema.version(), argument.version())
&& Objects.deepEquals(schema.defaultValue(), argument.defaultValue());
return SchemaUtil.areCompatible(schema, argument);
}
// CHECKSTYLE_RULES.ON: BooleanExpressionComplexity

Expand Down
48 changes: 48 additions & 0 deletions ksql-common/src/main/java/io/confluent/ksql/util/SchemaUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@
import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiPredicate;
import org.apache.avro.LogicalTypes;
import org.apache.avro.SchemaBuilder.FieldAssembler;
import org.apache.kafka.connect.data.Field;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.Schema.Type;
import org.apache.kafka.connect.data.SchemaBuilder;
import org.apache.kafka.connect.data.Struct;

Expand Down Expand Up @@ -97,6 +100,15 @@ public final class SchemaUtil {
.put(Schema.Type.BOOLEAN, "(Boolean)")
.build();

private static final Map<Type, BiPredicate<Schema, Schema>> CUSTOM_SCHEMA_EQ =
ImmutableMap.<Type, BiPredicate<Schema, Schema>>builder()
.put(Type.MAP, SchemaUtil::mapEquals)
.put(Type.ARRAY, SchemaUtil::arrayEquals)
.put(Type.STRUCT, SchemaUtil::structEquals)
.put(Type.BYTES, SchemaUtil::bytesEquals)
.build();


private SchemaUtil() {
}

Expand Down Expand Up @@ -363,4 +375,40 @@ public static Schema ensureOptional(final Schema schema) {
.build();
}


public static boolean areCompatible(final Schema arg1, final Schema arg2) {
if (arg2 == null) {
return arg1.isOptional();
}

// we require a custom equals method that ignores certain values (e.g.
// whether or not the schema is optional, and the documentation)
return Objects.equals(arg1.type(), arg2.type())
&& CUSTOM_SCHEMA_EQ.getOrDefault(arg1.type(), (a, b) -> true).test(arg1, arg2)
&& Objects.equals(arg1.version(), arg2.version())
&& Objects.deepEquals(arg1.defaultValue(), arg2.defaultValue());
}

private static boolean mapEquals(final Schema mapA, final Schema mapB) {
return Objects.equals(mapA.keySchema(), mapB.keySchema())
&& Objects.equals(mapA.valueSchema(), mapB.valueSchema());
}

private static boolean arrayEquals(final Schema arrayA, final Schema arrayB) {
return Objects.equals(arrayA.valueSchema(), arrayB.valueSchema());
}

private static boolean structEquals(final Schema structA, final Schema structB) {
return structA.fields().isEmpty()
|| structB.fields().isEmpty()
|| Objects.equals(structA.fields(), structB.fields());
}

private static boolean bytesEquals(final Schema bytesA, final Schema bytesB) {
// from a Java schema perspective, all decimals are the same
// since they can all be cast to BigDecimal - other bytes types
// are not supported in UDFs
return DecimalUtil.isDecimal(bytesA) && DecimalUtil.isDecimal(bytesB);
}

}
Loading

0 comments on commit 08855ad

Please sign in to comment.