Skip to content

Commit

Permalink
feat: Add support for average aggregate function (#3302)
Browse files Browse the repository at this point in the history
  • Loading branch information
vpapavas authored Sep 20, 2019
1 parent bfbdc20 commit 6757d9f
Show file tree
Hide file tree
Showing 401 changed files with 5,574 additions and 3,177 deletions.
3 changes: 3 additions & 0 deletions docs/developer-guide/syntax-reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1880,6 +1880,9 @@ Aggregate functions
| | | | late-arriving record, then the records from the second window in |
| | | | the order they were originally processed. |
+------------------------+---------------------------+------------+---------------------------------------------------------------------+
| AVERAGE | ``AVG(col1)`` | Stream, | Return the average value for a given column. |
| | | Table | Note: rows where ``col1`` is null are ignored. |
+------------------------+---------------------------+------------+---------------------------------------------------------------------+
| MAX | ``MAX(col1)`` | Stream | Return the maximum value for a given column and window. |
| | | | Note: rows where ``col1`` is null will be ignored. |
+------------------------+---------------------------+------------+---------------------------------------------------------------------+
Expand Down
339 changes: 228 additions & 111 deletions docs/developer-guide/udf.rst

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ public AggregateFunctionFactory(final UdfMetadata metadata) {
this.metadata = Objects.requireNonNull(metadata, "metadata can't be null");
}

public abstract KsqlAggregateFunction<?, ?> getProperAggregateFunction(List<Schema> argTypeList);
public abstract KsqlAggregateFunction<?, ?, ?> getProperAggregateFunction(
List<Schema> argTypeList);

protected abstract List<List<Schema>> supportedArgs();

Expand All @@ -69,7 +70,7 @@ public String getVersion() {
return metadata.getVersion();
}

public void eachFunction(final Consumer<KsqlAggregateFunction<?, ?>> consumer) {
public void eachFunction(final Consumer<KsqlAggregateFunction<?, ?, ?>> consumer) {
supportedArgs().forEach(args -> consumer.accept(getProperAggregateFunction(args)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public interface FunctionRegistry {
* @return the function instance.
* @throws KsqlException on unknown UDAF, or on unsupported {@code argumentType}.
*/
KsqlAggregateFunction<?, ?> getAggregate(String functionName, Schema argumentType);
KsqlAggregateFunction<?, ?, ?> getAggregate(String functionName, Schema argumentType);

/**
* @return all UDF factories.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,25 @@

import io.confluent.ksql.schema.ksql.types.SqlType;
import java.util.List;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.streams.kstream.Merger;


public interface KsqlAggregateFunction<V, A> extends IndexedFunction {
public interface KsqlAggregateFunction<I, A, O> extends IndexedFunction {

KsqlAggregateFunction<V, A> getInstance(AggregateFunctionArguments aggregateFunctionArguments);
KsqlAggregateFunction<I, A, O> getInstance(AggregateFunctionArguments aggregateFunctionArguments);

Supplier<A> getInitialValueSupplier();

int getArgIndexInValue();

Schema getAggregateType();

SqlType aggregateType();

Schema getReturnType();

SqlType returnType();
Expand All @@ -41,13 +46,15 @@ public interface KsqlAggregateFunction<V, A> extends IndexedFunction {
* Merges values inside the window.
* @return A - type of return value
*/
A aggregate(V currentValue, A aggregateValue);
A aggregate(I currentValue, A aggregateValue);

/**
* Merges two session windows together with the same merge key.
*/
Merger<Struct, A> getMerger();

Function<A, O> getResultMapper();

String getDescription();

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.streams.kstream.Aggregator;
import org.apache.kafka.streams.kstream.Merger;
import org.apache.kafka.streams.kstream.ValueMapper;

public interface UdafAggregator extends Aggregator<Struct, GenericRow, GenericRow> {
Merger<Struct, GenericRow> getMerger();

ValueMapper<GenericRow, GenericRow> getResultMapper();
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.Struct;

public abstract class BaseAggregateFunction<V, A> implements KsqlAggregateFunction<V, A> {
public abstract class BaseAggregateFunction<I, A, O> implements KsqlAggregateFunction<I, A, O> {

private static final ConnectToSqlTypeConverter CONNECT_TO_SQL_CONVERTER
= SchemaConverters.connectToSqlConverter();
Expand All @@ -36,8 +36,10 @@ public abstract class BaseAggregateFunction<V, A> implements KsqlAggregateFuncti
**/
private final int argIndexInValue;
private final Supplier<A> initialValueSupplier;
private final Schema returnSchema;
private final SqlType returnType;
private final Schema aggregateSchema;
private final SqlType aggregateType;
private final Schema outputSchema;
private final SqlType outputType;
private final List<Schema> arguments;

protected final String functionName;
Expand All @@ -47,7 +49,8 @@ public BaseAggregateFunction(
final String functionName,
final int argIndexInValue,
final Supplier<A> initialValueSupplier,
final Schema returnType,
final Schema aggregateType,
final Schema outputType,
final List<Schema> arguments,
final String description
) {
Expand All @@ -60,13 +63,15 @@ public BaseAggregateFunction(
}
return val;
};
this.returnSchema = Objects.requireNonNull(returnType, "returnType");
this.returnType = CONNECT_TO_SQL_CONVERTER.toSqlType(returnType);
this.aggregateSchema = Objects.requireNonNull(aggregateType, "aggregateType");
this.aggregateType = CONNECT_TO_SQL_CONVERTER.toSqlType(aggregateType);
this.outputSchema = Objects.requireNonNull(outputType, "outputType");
this.outputType = CONNECT_TO_SQL_CONVERTER.toSqlType(outputType);
this.arguments = Objects.requireNonNull(arguments, "arguments");
this.functionName = Objects.requireNonNull(functionName, "functionName");
this.description = Objects.requireNonNull(description, "description");

if (!returnType.isOptional()) {
if (!outputType.isOptional() || !aggregateType.isOptional()) {
throw new IllegalArgumentException("KSQL only supports optional field types");
}
}
Expand All @@ -90,13 +95,22 @@ public int getArgIndexInValue() {
return argIndexInValue;
}

public Schema getAggregateType() {
return aggregateSchema;
}

@Override
public SqlType aggregateType() {
return aggregateType;
}

public Schema getReturnType() {
return returnSchema;
return outputSchema;
}

@Override
public SqlType returnType() {
return returnType;
return outputType;
}

public List<Schema> getArguments() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

import io.confluent.ksql.function.udaf.Udaf;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.metrics.Sensor;
Expand All @@ -32,46 +34,55 @@
import org.apache.kafka.streams.kstream.Merger;

@SuppressWarnings({"unused", "WeakerAccess"}) // used in generated code
public abstract class GeneratedAggregateFunction<V, A> extends BaseAggregateFunction<V, A> {
public abstract class GeneratedAggregateFunction<I, A, O> extends BaseAggregateFunction<I, A, O> {

protected final Sensor aggregateSensor;
protected final Sensor mergeSensor;
protected Udaf<V, A> udaf;
protected Optional<Sensor> aggregateSensor;
protected Optional<Sensor> mapSensor;
protected Optional<Sensor> mergeSensor;
protected Udaf<I, A, O> udaf;

public GeneratedAggregateFunction(
final String functionName,
final Schema returnType,
final Schema aggregateType,
final Schema outputType,
final List<Schema> arguments,
final String description,
final Optional<Metrics> metrics) {
super(functionName, -1, null, returnType, arguments, description);

this(functionName, -1, null, aggregateType, outputType, arguments, description,
Optional.empty(), Optional.empty(), Optional.empty());

final String method = getSourceMethodName();
final String aggSensorName = String.format("aggregate-%s-%s", functionName, method);
final String mapSensorName = String.format("map-%s-%s", functionName, method);
final String mergeSensorName = String.format("merge-%s-%s", functionName, method);

initMetrics(metrics, functionName, method, aggSensorName, mergeSensorName);
this.aggregateSensor = metrics.map(m -> m.getSensor(aggSensorName)).orElse(null);
this.mergeSensor = metrics.map(m -> m.getSensor(mergeSensorName)).orElse(null);
initMetrics(metrics, functionName, method, aggSensorName, mapSensorName, mergeSensorName);

}

protected GeneratedAggregateFunction(
final String functionName,
final int udafIndex,
final Supplier<A> udafSupplier,
final Schema returnType,
final Schema aggregateType,
final Schema outputType,
final List<Schema> arguments,
final String description,
final Sensor aggregateSensor,
final Sensor mergeSensor) {
super(functionName, udafIndex, udafSupplier, returnType, arguments, description);
this.aggregateSensor = aggregateSensor;
this.mergeSensor = mergeSensor;
final Optional<Sensor> aggregateSensor,
final Optional<Sensor> mapSensor,
final Optional<Sensor> mergeSensor) {

super(functionName, udafIndex, udafSupplier, aggregateType, outputType, arguments, description);

this.aggregateSensor = Objects.requireNonNull(aggregateSensor);
this.mapSensor = Objects.requireNonNull(mapSensor);
this.mergeSensor = Objects.requireNonNull(mergeSensor);
}

protected abstract String getSourceMethodName();

protected Udaf<V, A> getUdaf() {
protected Udaf<I, A, O> getUdaf() {
return udaf;
}

Expand All @@ -80,6 +91,7 @@ private void initMetrics(
final String name,
final String method,
final String aggSensorName,
final String mapSensorName,
final String mergeSensorName) {
if (maybeMetrics.isPresent()) {
final String groupName = String.format("ksql-udaf-%s-%s", name, method);
Expand Down Expand Up @@ -108,6 +120,33 @@ private void initMetrics(
String.format("The average number of occurrences of aggregate "
+ "%s %s operation per second udaf", name, method)),
new Rate(TimeUnit.SECONDS, new WindowedCount()));
this.aggregateSensor = Optional.of(sensor);
}

if (metrics.getSensor(mapSensorName) == null) {
final Sensor sensor = metrics.sensor(mapSensorName);
sensor.add(metrics.metricName(
mapSensorName + "-avg",
groupName,
String.format("Average time for a map invocation of %s %s udaf", name, method)),
new Avg());
sensor.add(metrics.metricName(
mapSensorName + "-max",
groupName,
String.format("Max time for a map invocation of %s %s udaf", name, method)),
new Max());
sensor.add(metrics.metricName(
mapSensorName + "-count",
groupName,
String.format("Total number of map invocations of %s %s udaf", name, method)),
new WindowedCount());
sensor.add(metrics.metricName(
mapSensorName + "-rate",
groupName,
String.format("The average number of occurrences of map "
+ "%s %s operation per second udaf", name, method)),
new Rate(TimeUnit.SECONDS, new WindowedCount()));
this.mapSensor = Optional.of(sensor);
}

if (metrics.getSensor(mergeSensorName) == null) {
Expand All @@ -134,40 +173,41 @@ private void initMetrics(
"The average number of occurrences of merge %s %s operation per second udaf",
name, method)),
new Rate(TimeUnit.SECONDS, new WindowedCount()));
this.mergeSensor = Optional.of(sensor);
}
} else {
this.aggregateSensor = Optional.empty();
this.mapSensor = Optional.empty();
this.mergeSensor = Optional.empty();
}

}

@Override
public A aggregate(final V currentValue, final A aggregateValue) {
final long start = Time.SYSTEM.nanoseconds();
try {
//noinspection unchecked
return udaf.aggregate(currentValue,aggregateValue);
} finally {
if (aggregateSensor != null) {
aggregateSensor.record(Time.SYSTEM.nanoseconds() - start);
}
}
public A aggregate(final I currentValue, final A aggregateValue) {
return timed(aggregateSensor, () -> udaf.aggregate(currentValue,aggregateValue));
}

@Override
public Merger<Struct, A> getMerger() {
return (key, v1, v2) -> {
final long start = Time.SYSTEM.nanoseconds();
try {
//noinspection unchecked
return udaf.merge(v1, v2);
} finally {
if (mergeSensor != null) {
mergeSensor.record(Time.SYSTEM.nanoseconds() - start);
}
}
};
return (key, v1, v2) -> timed(mergeSensor, () -> udaf.merge(v1, v2));
}

protected static <V,A> Supplier<A> supplier(final Udaf<V, A> udaf) {
@Override
public Function<A, O> getResultMapper() {
return (v1) -> timed(mapSensor, () -> udaf.map(v1));
}

protected static <I, A, O> Supplier<A> supplier(final Udaf<I, A, O> udaf) {
return udaf::initialize;
}

private static <T> T timed(final Optional<Sensor> maybeSensor, final Supplier<T> task) {
final long start = Time.SYSTEM.nanoseconds();
try {
return task.get();
} finally {
maybeSensor.ifPresent(sensor -> sensor.record(Time.SYSTEM.nanoseconds() - start));
}
}
}
Loading

0 comments on commit 6757d9f

Please sign in to comment.