Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Generalize the UDAFs earliest_by_offset and latest_by_offset #8878

Merged
merged 5 commits into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ class UdafFactoryInvoker implements FunctionSignature {
private static final Logger LOG = LoggerFactory.getLogger(UdafFactoryInvoker.class);

private final FunctionName functionName;
private final ParamType aggregateArgType;
private final ParamType aggregateReturnType;
private final Optional<Metrics> metrics;
private final List<ParamType> paramTypes;
private final List<ParameterInfo> params;
private final Method method;
private final String description;
private final UdafTypes types;
private final String aggregateSchema;
private final String outputSchema;
private ParamType aggregateReturnType;

UdafFactoryInvoker(
final Method method,
Expand All @@ -70,10 +72,10 @@ class UdafFactoryInvoker implements FunctionSignature {
if (!Modifier.isStatic(method.getModifiers())) {
throw new KsqlException("UDAF factory methods must be static " + method);
}
final UdafTypes types = new UdafTypes(method, functionName, typeParser);
this.types = new UdafTypes(method, functionName, typeParser);
this.functionName = Objects.requireNonNull(functionName);
this.aggregateArgType = Objects.requireNonNull(types.getAggregateSchema(aggregateSchema));
this.aggregateReturnType = Objects.requireNonNull(types.getOutputSchema(outputSchema));
this.aggregateSchema = aggregateSchema; // This can be null if the annotation is not used.
this.outputSchema = outputSchema; // This can be null if the annotation is not used.
this.metrics = Objects.requireNonNull(metrics);
this.params = types.getInputSchema(Objects.requireNonNull(inputSchema));
this.paramTypes = params.stream().map(ParameterInfo::type).collect(Collectors.toList());
Expand All @@ -95,10 +97,14 @@ KsqlAggregateFunction createFunction(final AggregateFunctionInitArguments initAr
}

final SqlType aggregateSqlType = (SqlType) udaf.getAggregateSqlType()
.orElseGet(() -> SchemaConverters.functionToSqlConverter().toSqlType(aggregateArgType));
.orElseGet(() -> SchemaConverters.functionToSqlConverter()
.toSqlType(types.getAggregateSchema(aggregateSchema)));
final SqlType returnSqlType = (SqlType) udaf.getReturnSqlType()
.orElseGet(() ->
SchemaConverters.functionToSqlConverter().toSqlType(aggregateReturnType));
SchemaConverters.functionToSqlConverter()
.toSqlType(types.getOutputSchema(outputSchema)));
this.aggregateReturnType =
SchemaConverters.sqlToFunctionConverter().toFunctionType(returnSqlType);

final KsqlAggregateFunction function;
if (TableUdaf.class.isAssignableFrom(method.getReturnType())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ private static void validateStructAnnotation(
final String msg
) {
if (type.equals(Struct.class) && schema.isEmpty()) {
throw new KsqlException("Must specify '" + msg + "' for STRUCT parameter in @UdafFactory.");
throw new KsqlException("Must specify '" + msg + "' for STRUCT parameter in @UdafFactory or"
+ " implement getAggregateSqlType()/getReturnSqlType().");
}
}

Expand Down
Loading