Skip to content

Commit

Permalink
fix: error message when UDAF has STRUCT type with no schema (#3407)
Browse files Browse the repository at this point in the history
Error message was missing the name of the annotation property they needed to set.
  • Loading branch information
big-andy-coates authored Sep 25, 2019
1 parent 366b771 commit 49f456e
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public class UdfCompiler {
private final SqlTypeParser typeParser;

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public UdfCompiler(final Optional<Metrics> metrics) {
UdfCompiler(final Optional<Metrics> metrics) {
this.metrics = Objects.requireNonNull(metrics, "metrics can't be null");
this.typeParser = SqlTypeParser.create(TypeRegistry.EMPTY);
}
Expand Down Expand Up @@ -195,13 +195,13 @@ private class UdafTypes {
}

private void validateTypes(final Type t) {
if (!isTypeSupported((Class<?>)getRawType(t), SUPPORTED_TYPES)) {
if (isUnsupportedType((Class<?>) getRawType(t))) {
throw new KsqlException(String.format(invalidClassErrorMsg, t));
}
}

Schema getInputSchema(final String inSchema) {
validateStructAnnotation(inputType, inSchema, "");
validateStructAnnotation(inputType, inSchema, "paramSchema");
final Schema inputSchema = getSchemaFromType(inputType, inSchema);
//Currently, aggregate functions cannot have reified types as input parameters.
if (!GenericsUtil.constituentGenerics(inputSchema).isEmpty()) {
Expand All @@ -212,19 +212,18 @@ Schema getInputSchema(final String inSchema) {
}

Schema getAggregateSchema(final String aggSchema) {
validateStructAnnotation(aggregateType, aggSchema, "");
validateStructAnnotation(aggregateType, aggSchema, "aggregateSchema");
return getSchemaFromType(aggregateType, aggSchema);
}

Schema getOutputSchema(final String outSchema) {
validateStructAnnotation(outputType, outSchema, "");
validateStructAnnotation(outputType, outSchema, "returnSchema");
return getSchemaFromType(outputType, outSchema);
}

private void validateStructAnnotation(final Type type, final String schema, final String msg) {
if (type.equals(Struct.class) && schema.isEmpty()) {
throw new KsqlException(String.format("Must specify '%s' for STRUCT parameter in "
+ "@UdafFactory.", msg));
throw new KsqlException("Must specify '" + msg + "' for STRUCT parameter in @UdafFactory.");
}
}

Expand Down Expand Up @@ -283,7 +282,7 @@ private static String generateUdafClass(
) {
validateMethodSignature(method);
Arrays.stream(method.getParameterTypes())
.filter(type -> !UdfCompiler.isTypeSupported(type, SUPPORTED_TYPES))
.filter(UdfCompiler::isUnsupportedType)
.findFirst()
.ifPresent(type -> {
throw new KsqlException(
Expand Down Expand Up @@ -319,7 +318,7 @@ private static String generateCode(final Method method) {
continue;
}

if (!UdfCompiler.isTypeSupported(type, SUPPORTED_TYPES)) {
if (isUnsupportedType(type)) {
throw new KsqlException(
String.format(
"Type %s is not supported by UDF methods. "
Expand Down Expand Up @@ -347,11 +346,10 @@ private static void validateMethodSignature(final Method method) {
}
}

@SuppressWarnings("BooleanMethodIsAlwaysInverted")
private static boolean isTypeSupported(final Class<?> type, final Set<Class<?>> supportedTypes) {
return supportedTypes.contains(type)
|| type.isArray() && supportedTypes.contains(type.getComponentType())
|| supportedTypes.stream().anyMatch(supported -> supported.isAssignableFrom(type));
private static boolean isUnsupportedType(final Class<?> type) {
return !SUPPORTED_TYPES.contains(type)
&& (!type.isArray() || !SUPPORTED_TYPES.contains(type.getComponentType()))
&& SUPPORTED_TYPES.stream().noneMatch(supported -> supported.isAssignableFrom(type));
}

private static IScriptEvaluator createScriptEvaluator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.junit.Test;
import org.junit.rules.ExpectedException;

@SuppressWarnings({"MethodMayBeStatic", "WeakerAccess", "unused"}) // UDFs not static / private
public class UdfCompilerTest {

private static final Schema STRUCT_SCHEMA =
Expand Down Expand Up @@ -272,37 +273,59 @@ public void shouldThrowIfUnsupportedInputType() throws Exception {
"");
}

@Test(expected = KsqlException.class)
@Test
public void shouldThrowIfMissingInputTypeSchema() throws Exception {
udfCompiler.compileAggregate(UdfCompilerTest.class.getMethod("missingInputSchemaAnnotationUdaf"),
classLoader,
"test",
"desc",
"",
"",
"");
// Then:
expectedException.expect(KsqlException.class);
expectedException.expectMessage(
"Must specify 'paramSchema' for STRUCT parameter in @UdafFactory.");

// When:
udfCompiler.compileAggregate(
UdfCompilerTest.class.getMethod("missingInputSchemaAnnotationUdaf"),
classLoader,
"test",
"desc",
"",
"",
"");
}

@Test(expected = KsqlException.class)
@Test
public void shouldThrowIfMissingAggregateTypeSchema() throws Exception {
udfCompiler.compileAggregate(UdfCompilerTest.class.getMethod("missingAggregateSchemaAnnotationUdaf"),
classLoader,
"test",
"desc",
"",
"",
"");
// Then:
expectedException.expect(KsqlException.class);
expectedException.expectMessage(
"Must specify 'aggregateSchema' for STRUCT parameter in @UdafFactory.");

// When:
udfCompiler.compileAggregate(
UdfCompilerTest.class.getMethod("missingAggregateSchemaAnnotationUdaf"),
classLoader,
"test",
"desc",
"",
"",
"");
}

@Test(expected = KsqlException.class)
@Test
public void shouldThrowIfMissingOutputTypeSchema() throws Exception {
udfCompiler.compileAggregate(UdfCompilerTest.class.getMethod("missingOutputSchemaAnnotationUdaf"),
classLoader,
"test",
"desc",
"",
"",
"");
// Then:
expectedException.expect(KsqlException.class);
expectedException.expectMessage(
"Must specify 'returnSchema' for STRUCT parameter in @UdafFactory.");

// When:
udfCompiler.compileAggregate(
UdfCompilerTest.class.getMethod("missingOutputSchemaAnnotationUdaf"),
classLoader,
"test",
"desc",
"",
"",
""
);
}

@Test
Expand Down
30 changes: 20 additions & 10 deletions ksql-udf/src/main/java/io/confluent/ksql/function/udaf/Udaf.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,21 @@
/**
* {@code Udaf} represents a custom UDAF (User Defined Aggregate Function)
* that can be used to perform aggregations on KSQL Streams.
* Type support is presently limited to: int, Integer, long, Long, boolean, Boolean, double,
*
* <p>Type support is presently limited to: int, Integer, long, Long, boolean, Boolean, double,
* Double, String, Map, and List.
*
* @param <I> value type
* @param <A> aggregate type
* <p>Sequence of calls is:
* <ol>
* <li>{@code initialize()}: to get the initial value for the aggregate</li>
* <li>{@code aggregate(value, aggregate)}: adds {@code value} to the {@code aggregate}.</li>
* <li>{@code merge(agg1, agg2)}: merges to aggregates together, e.g. on session merges.</li>
* <li>{@code map(agg)}: reduces the intermediate state to the final output type.</li>
* </ol>
*
* @param <I> the input type
* @param <A> the intermediate aggregate type
* @param <O> the final output type
*/
public interface Udaf<I, A, O> {
/**
Expand All @@ -39,18 +49,18 @@ public interface Udaf<I, A, O> {
*/
A aggregate(I current, A aggregate);

/**
* Map the intermediate aggregate value into the actual returned value.
* @param agg aggregate value of current record
* @return new value of current record
*/
O map(A agg);

/**
* Merge two aggregates
* @param aggOne first aggregate
* @param aggTwo second aggregate
* @return new aggregate
*/
A merge(A aggOne, A aggTwo);

/**
* Map the intermediate aggregate value into the actual returned value.
* @param agg aggregate value of current record
* @return new value of current record
*/
O map(A agg);
}

0 comments on commit 49f456e

Please sign in to comment.