Skip to content

Commit

Permalink
Flink: backport PR #9212 to 1.16 for switching to SortKey for data st…
Browse files Browse the repository at this point in the history
…atistics
  • Loading branch information
stevenzwu committed Dec 9, 2023
1 parent 4d0b69b commit 2152269
Show file tree
Hide file tree
Showing 16 changed files with 980 additions and 315 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package org.apache.iceberg.flink.sink.shuffle;

import org.apache.flink.annotation.Internal;
import org.apache.flink.table.data.RowData;
import org.apache.iceberg.SortKey;

/**
* DataStatistics defines the interface to collect data distribution information.
Expand All @@ -29,7 +29,7 @@
* (sketching) can be used.
*/
@Internal
interface DataStatistics<D extends DataStatistics, S> {
interface DataStatistics<D extends DataStatistics<D, S>, S> {

/**
* Check if data statistics contains any statistics information.
Expand All @@ -38,12 +38,8 @@ interface DataStatistics<D extends DataStatistics, S> {
*/
boolean isEmpty();

/**
* Add data key to data statistics.
*
* @param key generate from data by applying key selector
*/
void add(RowData key);
/** Add row sortKey to data statistics. */
void add(SortKey sortKey);

/**
* Merge current statistics with other statistics.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ private void handleDataStatisticRequest(int subtask, DataStatisticsEvent<D, S> e
}
}

@SuppressWarnings("FutureReturnValueIgnored")
private void sendDataStatisticsToSubtasks(
long checkpointId, DataStatistics<D, S> globalDataStatistics) {
callInCoordinatorThread(
Expand Down Expand Up @@ -339,7 +340,7 @@ private void unregisterSubtaskGateway(int subtaskIndex, int attemptNumber) {

private OperatorCoordinator.SubtaskGateway getSubtaskGateway(int subtaskIndex) {
Preconditions.checkState(
gateways[subtaskIndex].size() > 0,
!gateways[subtaskIndex].isEmpty(),
"Coordinator of %s subtask %d is not ready yet to receive events",
operatorName,
subtaskIndex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.runtime.operators.coordination.OperatorEvent;
import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
import org.apache.flink.runtime.operators.coordination.OperatorEventHandler;
Expand All @@ -32,6 +31,12 @@
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.data.RowData;
import org.apache.iceberg.Schema;
import org.apache.iceberg.SortKey;
import org.apache.iceberg.SortOrder;
import org.apache.iceberg.StructLike;
import org.apache.iceberg.flink.FlinkSchemaUtil;
import org.apache.iceberg.flink.RowDataWrapper;
import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;

Expand All @@ -45,11 +50,12 @@
class DataStatisticsOperator<D extends DataStatistics<D, S>, S>
extends AbstractStreamOperator<DataStatisticsOrRecord<D, S>>
implements OneInputStreamOperator<RowData, DataStatisticsOrRecord<D, S>>, OperatorEventHandler {

private static final long serialVersionUID = 1L;

private final String operatorName;
// keySelector will be used to generate key from data for collecting data statistics
private final KeySelector<RowData, RowData> keySelector;
private final RowDataWrapper rowDataWrapper;
private final SortKey sortKey;
private final OperatorEventGateway operatorEventGateway;
private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer;
private transient volatile DataStatistics<D, S> localStatistics;
Expand All @@ -58,11 +64,13 @@ class DataStatisticsOperator<D extends DataStatistics<D, S>, S>

DataStatisticsOperator(
String operatorName,
KeySelector<RowData, RowData> keySelector,
Schema schema,
SortOrder sortOrder,
OperatorEventGateway operatorEventGateway,
TypeSerializer<DataStatistics<D, S>> statisticsSerializer) {
this.operatorName = operatorName;
this.keySelector = keySelector;
this.rowDataWrapper = new RowDataWrapper(FlinkSchemaUtil.convert(schema), schema.asStruct());
this.sortKey = new SortKey(schema, sortOrder);
this.operatorEventGateway = operatorEventGateway;
this.statisticsSerializer = statisticsSerializer;
}
Expand Down Expand Up @@ -126,10 +134,11 @@ public void handleOperatorEvent(OperatorEvent event) {
}

@Override
public void processElement(StreamRecord<RowData> streamRecord) throws Exception {
public void processElement(StreamRecord<RowData> streamRecord) {
RowData record = streamRecord.getValue();
RowData key = keySelector.getKey(record);
localStatistics.add(key);
StructLike struct = rowDataWrapper.wrap(record);
sortKey.wrap(struct);
localStatistics.add(sortKey);
output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromRecord(record)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ static <D extends DataStatistics<D, S>, S> byte[] serializeAggregatedStatistics(
return bytes.toByteArray();
}

@SuppressWarnings("unchecked")
static <D extends DataStatistics<D, S>, S>
AggregatedStatistics<D, S> deserializeAggregatedStatistics(
byte[] bytes, TypeSerializer<DataStatistics<D, S>> statisticsSerializer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@

import java.util.Map;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.data.RowData;
import org.apache.iceberg.SortKey;
import org.apache.iceberg.relocated.com.google.common.base.MoreObjects;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;

/** MapDataStatistics uses map to count key frequency */
@Internal
class MapDataStatistics implements DataStatistics<MapDataStatistics, Map<RowData, Long>> {
private final Map<RowData, Long> statistics;
class MapDataStatistics implements DataStatistics<MapDataStatistics, Map<SortKey, Long>> {
private final Map<SortKey, Long> statistics;

MapDataStatistics() {
this.statistics = Maps.newHashMap();
}

MapDataStatistics(Map<RowData, Long> statistics) {
MapDataStatistics(Map<SortKey, Long> statistics) {
this.statistics = statistics;
}

Expand All @@ -43,9 +43,14 @@ public boolean isEmpty() {
}

@Override
public void add(RowData key) {
// increase count of occurrence by one in the dataStatistics map
statistics.merge(key, 1L, Long::sum);
public void add(SortKey sortKey) {
if (statistics.containsKey(sortKey)) {
statistics.merge(sortKey, 1L, Long::sum);
} else {
// clone the sort key before adding to map because input sortKey object can be reused
SortKey copiedKey = sortKey.copy();
statistics.put(copiedKey, 1L);
}
}

@Override
Expand All @@ -54,7 +59,7 @@ public void merge(MapDataStatistics otherStatistics) {
}

@Override
public Map<RowData, Long> statistics() {
public Map<SortKey, Long> statistics() {
return statistics;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,22 @@
import org.apache.flink.api.common.typeutils.base.MapSerializer;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.table.data.RowData;
import org.apache.flink.util.Preconditions;
import org.apache.iceberg.SortKey;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;

@Internal
class MapDataStatisticsSerializer
extends TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>> {
private final MapSerializer<RowData, Long> mapSerializer;
extends TypeSerializer<DataStatistics<MapDataStatistics, Map<SortKey, Long>>> {
private final MapSerializer<SortKey, Long> mapSerializer;

static TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>> fromKeySerializer(
TypeSerializer<RowData> keySerializer) {
static MapDataStatisticsSerializer fromSortKeySerializer(
TypeSerializer<SortKey> sortKeySerializer) {
return new MapDataStatisticsSerializer(
new MapSerializer<>(keySerializer, LongSerializer.INSTANCE));
new MapSerializer<>(sortKeySerializer, LongSerializer.INSTANCE));
}

MapDataStatisticsSerializer(MapSerializer<RowData, Long> mapSerializer) {
MapDataStatisticsSerializer(MapSerializer<SortKey, Long> mapSerializer) {
this.mapSerializer = mapSerializer;
}

Expand All @@ -55,28 +55,28 @@ public boolean isImmutableType() {

@SuppressWarnings("ReferenceEquality")
@Override
public TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>> duplicate() {
MapSerializer<RowData, Long> duplicateMapSerializer =
(MapSerializer<RowData, Long>) mapSerializer.duplicate();
public TypeSerializer<DataStatistics<MapDataStatistics, Map<SortKey, Long>>> duplicate() {
MapSerializer<SortKey, Long> duplicateMapSerializer =
(MapSerializer<SortKey, Long>) mapSerializer.duplicate();
return (duplicateMapSerializer == mapSerializer)
? this
: new MapDataStatisticsSerializer(duplicateMapSerializer);
}

@Override
public DataStatistics<MapDataStatistics, Map<RowData, Long>> createInstance() {
public MapDataStatistics createInstance() {
return new MapDataStatistics();
}

@Override
public DataStatistics<MapDataStatistics, Map<RowData, Long>> copy(DataStatistics obj) {
public MapDataStatistics copy(DataStatistics<MapDataStatistics, Map<SortKey, Long>> obj) {
Preconditions.checkArgument(
obj instanceof MapDataStatistics, "Invalid data statistics type: " + obj.getClass());
MapDataStatistics from = (MapDataStatistics) obj;
TypeSerializer<RowData> keySerializer = mapSerializer.getKeySerializer();
Map<RowData, Long> newMap = Maps.newHashMapWithExpectedSize(from.statistics().size());
for (Map.Entry<RowData, Long> entry : from.statistics().entrySet()) {
RowData newKey = keySerializer.copy(entry.getKey());
TypeSerializer<SortKey> keySerializer = mapSerializer.getKeySerializer();
Map<SortKey, Long> newMap = Maps.newHashMapWithExpectedSize(from.statistics().size());
for (Map.Entry<SortKey, Long> entry : from.statistics().entrySet()) {
SortKey newKey = keySerializer.copy(entry.getKey());
// no need to copy value since it is just a Long
newMap.put(newKey, entry.getValue());
}
Expand All @@ -85,8 +85,9 @@ public DataStatistics<MapDataStatistics, Map<RowData, Long>> copy(DataStatistics
}

@Override
public DataStatistics<MapDataStatistics, Map<RowData, Long>> copy(
DataStatistics from, DataStatistics reuse) {
public DataStatistics<MapDataStatistics, Map<SortKey, Long>> copy(
DataStatistics<MapDataStatistics, Map<SortKey, Long>> from,
DataStatistics<MapDataStatistics, Map<SortKey, Long>> reuse) {
// not much benefit to reuse
return copy(from);
}
Expand All @@ -97,22 +98,25 @@ public int getLength() {
}

@Override
public void serialize(DataStatistics obj, DataOutputView target) throws IOException {
public void serialize(
DataStatistics<MapDataStatistics, Map<SortKey, Long>> obj, DataOutputView target)
throws IOException {
Preconditions.checkArgument(
obj instanceof MapDataStatistics, "Invalid data statistics type: " + obj.getClass());
MapDataStatistics mapStatistics = (MapDataStatistics) obj;
mapSerializer.serialize(mapStatistics.statistics(), target);
}

@Override
public DataStatistics<MapDataStatistics, Map<RowData, Long>> deserialize(DataInputView source)
public DataStatistics<MapDataStatistics, Map<SortKey, Long>> deserialize(DataInputView source)
throws IOException {
return new MapDataStatistics(mapSerializer.deserialize(source));
}

@Override
public DataStatistics<MapDataStatistics, Map<RowData, Long>> deserialize(
DataStatistics reuse, DataInputView source) throws IOException {
public DataStatistics<MapDataStatistics, Map<SortKey, Long>> deserialize(
DataStatistics<MapDataStatistics, Map<SortKey, Long>> reuse, DataInputView source)
throws IOException {
// not much benefit to reuse
return deserialize(source);
}
Expand All @@ -138,14 +142,14 @@ public int hashCode() {
}

@Override
public TypeSerializerSnapshot<DataStatistics<MapDataStatistics, Map<RowData, Long>>>
public TypeSerializerSnapshot<DataStatistics<MapDataStatistics, Map<SortKey, Long>>>
snapshotConfiguration() {
return new MapDataStatisticsSerializerSnapshot(this);
}

public static class MapDataStatisticsSerializerSnapshot
extends CompositeTypeSerializerSnapshot<
DataStatistics<MapDataStatistics, Map<RowData, Long>>, MapDataStatisticsSerializer> {
DataStatistics<MapDataStatistics, Map<SortKey, Long>>, MapDataStatisticsSerializer> {
private static final int CURRENT_VERSION = 1;

// constructors need to public. Otherwise, Flink state restore would complain
Expand Down Expand Up @@ -175,8 +179,8 @@ protected TypeSerializer<?>[] getNestedSerializers(
protected MapDataStatisticsSerializer createOuterSerializerWithNestedSerializers(
TypeSerializer<?>[] nestedSerializers) {
@SuppressWarnings("unchecked")
MapSerializer<RowData, Long> mapSerializer =
(MapSerializer<RowData, Long>) nestedSerializers[0];
MapSerializer<SortKey, Long> mapSerializer =
(MapSerializer<SortKey, Long>) nestedSerializers[0];
return new MapDataStatisticsSerializer(mapSerializer);
}
}
Expand Down
Loading

0 comments on commit 2152269

Please sign in to comment.