Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flink: switch to use SortKey for data statistics #9212

Merged
merged 6 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package org.apache.iceberg.flink.sink.shuffle;

import org.apache.flink.annotation.Internal;
import org.apache.flink.table.data.RowData;
import org.apache.iceberg.SortKey;

/**
* DataStatistics defines the interface to collect data distribution information.
Expand All @@ -29,7 +29,7 @@
* (sketching) can be used.
*/
@Internal
interface DataStatistics<D extends DataStatistics, S> {
interface DataStatistics<D extends DataStatistics<D, S>, S> {

/**
* Check if data statistics contains any statistics information.
Expand All @@ -38,12 +38,8 @@ interface DataStatistics<D extends DataStatistics, S> {
*/
boolean isEmpty();

/**
* Add data key to data statistics.
*
* @param key generate from data by applying key selector
*/
void add(RowData key);
/** Add row sortKey to data statistics. */
void add(SortKey sortKey);

/**
* Merge current statistics with other statistics.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.runtime.operators.coordination.OperatorEvent;
import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
import org.apache.flink.runtime.operators.coordination.OperatorEventHandler;
Expand All @@ -32,6 +31,12 @@
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.data.RowData;
import org.apache.iceberg.Schema;
import org.apache.iceberg.SortKey;
import org.apache.iceberg.SortOrder;
import org.apache.iceberg.StructLike;
import org.apache.iceberg.flink.FlinkSchemaUtil;
import org.apache.iceberg.flink.RowDataWrapper;
import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;

Expand All @@ -45,11 +50,12 @@
class DataStatisticsOperator<D extends DataStatistics<D, S>, S>
extends AbstractStreamOperator<DataStatisticsOrRecord<D, S>>
implements OneInputStreamOperator<RowData, DataStatisticsOrRecord<D, S>>, OperatorEventHandler {

private static final long serialVersionUID = 1L;

private final String operatorName;
// keySelector will be used to generate key from data for collecting data statistics
private final KeySelector<RowData, RowData> keySelector;
private final RowDataWrapper rowDataWrapper;
private final SortKey sortKey;
private final OperatorEventGateway operatorEventGateway;
private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer;
private transient volatile DataStatistics<D, S> localStatistics;
Expand All @@ -58,11 +64,13 @@ class DataStatisticsOperator<D extends DataStatistics<D, S>, S>

DataStatisticsOperator(
String operatorName,
KeySelector<RowData, RowData> keySelector,
Schema schema,
SortOrder sortOrder,
OperatorEventGateway operatorEventGateway,
TypeSerializer<DataStatistics<D, S>> statisticsSerializer) {
this.operatorName = operatorName;
this.keySelector = keySelector;
this.rowDataWrapper = new RowDataWrapper(FlinkSchemaUtil.convert(schema), schema.asStruct());
this.sortKey = new SortKey(schema, sortOrder);
this.operatorEventGateway = operatorEventGateway;
this.statisticsSerializer = statisticsSerializer;
}
Expand Down Expand Up @@ -126,10 +134,11 @@ public void handleOperatorEvent(OperatorEvent event) {
}

@Override
public void processElement(StreamRecord<RowData> streamRecord) throws Exception {
public void processElement(StreamRecord<RowData> streamRecord) {
RowData record = streamRecord.getValue();
RowData key = keySelector.getKey(record);
localStatistics.add(key);
StructLike struct = rowDataWrapper.wrap(record);
sortKey.wrap(struct);
localStatistics.add(sortKey);
output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromRecord(record)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ static <D extends DataStatistics<D, S>, S> byte[] serializeAggregatedStatistics(
return bytes.toByteArray();
}

@SuppressWarnings("unchecked")
static <D extends DataStatistics<D, S>, S>
AggregatedStatistics<D, S> deserializeAggregatedStatistics(
byte[] bytes, TypeSerializer<DataStatistics<D, S>> statisticsSerializer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@

import java.util.Map;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.data.RowData;
import org.apache.iceberg.SortKey;
import org.apache.iceberg.relocated.com.google.common.base.MoreObjects;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;

/** MapDataStatistics uses map to count key frequency */
@Internal
class MapDataStatistics implements DataStatistics<MapDataStatistics, Map<RowData, Long>> {
private final Map<RowData, Long> statistics;
class MapDataStatistics implements DataStatistics<MapDataStatistics, Map<SortKey, Long>> {
private final Map<SortKey, Long> statistics;

MapDataStatistics() {
this.statistics = Maps.newHashMap();
}

MapDataStatistics(Map<RowData, Long> statistics) {
MapDataStatistics(Map<SortKey, Long> statistics) {
this.statistics = statistics;
}

Expand All @@ -43,9 +43,14 @@ public boolean isEmpty() {
}

@Override
public void add(RowData key) {
// increase count of occurrence by one in the dataStatistics map
statistics.merge(key, 1L, Long::sum);
public void add(SortKey sortKey) {
if (statistics.containsKey(sortKey)) {
statistics.merge(sortKey, 1L, Long::sum);
} else {
// clone the sort key before adding to map because input sortKey object can be reused
SortKey copiedKey = sortKey.copy();
statistics.put(copiedKey, 1L);
}
}

@Override
Expand All @@ -54,7 +59,7 @@ public void merge(MapDataStatistics otherStatistics) {
}

@Override
public Map<RowData, Long> statistics() {
public Map<SortKey, Long> statistics() {
return statistics;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,22 @@
import org.apache.flink.api.common.typeutils.base.MapSerializer;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.table.data.RowData;
import org.apache.flink.util.Preconditions;
import org.apache.iceberg.SortKey;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;

@Internal
class MapDataStatisticsSerializer
extends TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>> {
private final MapSerializer<RowData, Long> mapSerializer;
extends TypeSerializer<DataStatistics<MapDataStatistics, Map<SortKey, Long>>> {
private final MapSerializer<SortKey, Long> mapSerializer;

static TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>> fromKeySerializer(
TypeSerializer<RowData> keySerializer) {
static MapDataStatisticsSerializer fromSortKeySerializer(
TypeSerializer<SortKey> sortKeySerializer) {
return new MapDataStatisticsSerializer(
new MapSerializer<>(keySerializer, LongSerializer.INSTANCE));
new MapSerializer<>(sortKeySerializer, LongSerializer.INSTANCE));
}

MapDataStatisticsSerializer(MapSerializer<RowData, Long> mapSerializer) {
MapDataStatisticsSerializer(MapSerializer<SortKey, Long> mapSerializer) {
this.mapSerializer = mapSerializer;
}

Expand All @@ -55,28 +55,28 @@ public boolean isImmutableType() {

@SuppressWarnings("ReferenceEquality")
@Override
public TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>> duplicate() {
MapSerializer<RowData, Long> duplicateMapSerializer =
(MapSerializer<RowData, Long>) mapSerializer.duplicate();
public TypeSerializer<DataStatistics<MapDataStatistics, Map<SortKey, Long>>> duplicate() {
MapSerializer<SortKey, Long> duplicateMapSerializer =
(MapSerializer<SortKey, Long>) mapSerializer.duplicate();
return (duplicateMapSerializer == mapSerializer)
? this
: new MapDataStatisticsSerializer(duplicateMapSerializer);
}

@Override
public DataStatistics<MapDataStatistics, Map<RowData, Long>> createInstance() {
public MapDataStatistics createInstance() {
return new MapDataStatistics();
}

@Override
public DataStatistics<MapDataStatistics, Map<RowData, Long>> copy(DataStatistics obj) {
public MapDataStatistics copy(DataStatistics<MapDataStatistics, Map<SortKey, Long>> obj) {
Preconditions.checkArgument(
obj instanceof MapDataStatistics, "Invalid data statistics type: " + obj.getClass());
MapDataStatistics from = (MapDataStatistics) obj;
TypeSerializer<RowData> keySerializer = mapSerializer.getKeySerializer();
Map<RowData, Long> newMap = Maps.newHashMapWithExpectedSize(from.statistics().size());
for (Map.Entry<RowData, Long> entry : from.statistics().entrySet()) {
RowData newKey = keySerializer.copy(entry.getKey());
TypeSerializer<SortKey> keySerializer = mapSerializer.getKeySerializer();
Map<SortKey, Long> newMap = Maps.newHashMapWithExpectedSize(from.statistics().size());
for (Map.Entry<SortKey, Long> entry : from.statistics().entrySet()) {
SortKey newKey = keySerializer.copy(entry.getKey());
// no need to copy value since it is just a Long
newMap.put(newKey, entry.getValue());
}
Expand All @@ -85,8 +85,9 @@ public DataStatistics<MapDataStatistics, Map<RowData, Long>> copy(DataStatistics
}

@Override
public DataStatistics<MapDataStatistics, Map<RowData, Long>> copy(
DataStatistics from, DataStatistics reuse) {
public DataStatistics<MapDataStatistics, Map<SortKey, Long>> copy(
DataStatistics<MapDataStatistics, Map<SortKey, Long>> from,
DataStatistics<MapDataStatistics, Map<SortKey, Long>> reuse) {
// not much benefit to reuse
return copy(from);
}
Expand All @@ -97,22 +98,25 @@ public int getLength() {
}

@Override
public void serialize(DataStatistics obj, DataOutputView target) throws IOException {
public void serialize(
DataStatistics<MapDataStatistics, Map<SortKey, Long>> obj, DataOutputView target)
throws IOException {
Preconditions.checkArgument(
obj instanceof MapDataStatistics, "Invalid data statistics type: " + obj.getClass());
MapDataStatistics mapStatistics = (MapDataStatistics) obj;
mapSerializer.serialize(mapStatistics.statistics(), target);
}

@Override
public DataStatistics<MapDataStatistics, Map<RowData, Long>> deserialize(DataInputView source)
public DataStatistics<MapDataStatistics, Map<SortKey, Long>> deserialize(DataInputView source)
throws IOException {
return new MapDataStatistics(mapSerializer.deserialize(source));
}

@Override
public DataStatistics<MapDataStatistics, Map<RowData, Long>> deserialize(
DataStatistics reuse, DataInputView source) throws IOException {
public DataStatistics<MapDataStatistics, Map<SortKey, Long>> deserialize(
DataStatistics<MapDataStatistics, Map<SortKey, Long>> reuse, DataInputView source)
throws IOException {
// not much benefit to reuse
return deserialize(source);
}
Expand All @@ -138,14 +142,14 @@ public int hashCode() {
}

@Override
public TypeSerializerSnapshot<DataStatistics<MapDataStatistics, Map<RowData, Long>>>
public TypeSerializerSnapshot<DataStatistics<MapDataStatistics, Map<SortKey, Long>>>
snapshotConfiguration() {
return new MapDataStatisticsSerializerSnapshot(this);
}

public static class MapDataStatisticsSerializerSnapshot
extends CompositeTypeSerializerSnapshot<
DataStatistics<MapDataStatistics, Map<RowData, Long>>, MapDataStatisticsSerializer> {
DataStatistics<MapDataStatistics, Map<SortKey, Long>>, MapDataStatisticsSerializer> {
private static final int CURRENT_VERSION = 1;

// constructors need to public. Otherwise, Flink state restore would complain
Expand Down Expand Up @@ -175,8 +179,8 @@ protected TypeSerializer<?>[] getNestedSerializers(
protected MapDataStatisticsSerializer createOuterSerializerWithNestedSerializers(
TypeSerializer<?>[] nestedSerializers) {
@SuppressWarnings("unchecked")
MapSerializer<RowData, Long> mapSerializer =
(MapSerializer<RowData, Long>) nestedSerializers[0];
MapSerializer<SortKey, Long> mapSerializer =
(MapSerializer<SortKey, Long>) nestedSerializers[0];
return new MapDataStatisticsSerializer(mapSerializer);
}
}
Expand Down
Loading