From 467f94f9ecef91b671ebbdc4774f2b690f4fa713 Mon Sep 17 00:00:00 2001 From: jectpro7 Date: Thu, 23 May 2024 20:30:37 +0800 Subject: [PATCH] [FLINK-35355][State] Internal async aggregating state and corresponding state descriptor This closes #24810 --- .../api/common/state/v2/AggregatingState.java | 44 ++++++++++ .../asyncprocessing/StateRequestType.java | 9 +- .../state/v2/AggregatingStateDescriptor.java | 84 +++++++++++++++++++ .../state/v2/DefaultKeyedStateStoreV2.java | 12 +++ .../state/v2/InternalAggregatingState.java | 63 ++++++++++++++ .../runtime/state/v2/KeyedStateStoreV2.java | 16 ++++ .../v2/AggregatingStateDescriptorTest.java | 80 ++++++++++++++++++ .../v2/InternalAggregatingStateTest.java | 72 ++++++++++++++++ .../operators/StreamingRuntimeContext.java | 9 ++ .../StreamingRuntimeContextTest.java | 35 ++++++++ 10 files changed, 423 insertions(+), 1 deletion(-) create mode 100644 flink-core-api/src/main/java/org/apache/flink/api/common/state/v2/AggregatingState.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AggregatingStateDescriptor.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalAggregatingState.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AggregatingStateDescriptorTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalAggregatingStateTest.java diff --git a/flink-core-api/src/main/java/org/apache/flink/api/common/state/v2/AggregatingState.java b/flink-core-api/src/main/java/org/apache/flink/api/common/state/v2/AggregatingState.java new file mode 100644 index 0000000000000..cd72e13afee3c --- /dev/null +++ b/flink-core-api/src/main/java/org/apache/flink/api/common/state/v2/AggregatingState.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.state.v2; + +import org.apache.flink.annotation.Experimental; + +/** + * {@link State} interface for aggregating state, based on an {@link + * org.apache.flink.api.common.functions.AggregateFunction}. Elements that are added to this type of + * state will be eagerly pre-aggregated using a given {@code AggregateFunction}. + * + *

The state holds internally always the accumulator type of the {@code AggregateFunction}. When + * accessing the result of the state, the function's {@link + * org.apache.flink.api.common.functions.AggregateFunction#getResult(Object)} method. + * + *

The state is accessed and modified by user functions, and checkpointed consistently by the + * system as part of the distributed snapshots. + * + *

The state is only accessible by functions applied on a {@code KeyedStream}. The key is + * automatically supplied by the system, so the function always sees the value mapped to the key of + * the current element. That way, the system can handle stream and state partitioning consistently + * together. + * + * @param Type of the value added to the state. + * @param Type of the value extracted from the state. + */ +@Experimental +public interface AggregatingState extends MergingState {} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/StateRequestType.java b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/StateRequestType.java index 39edbb8bda3a0..6382dff695878 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/StateRequestType.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/StateRequestType.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.asyncprocessing; +import org.apache.flink.api.common.state.v2.AggregatingState; import org.apache.flink.api.common.state.v2.ListState; import org.apache.flink.api.common.state.v2.MapState; import org.apache.flink.api.common.state.v2.ReducingState; @@ -105,5 +106,11 @@ public enum StateRequestType { REDUCING_GET, /** Add element into reducing state, {@link ReducingState#asyncAdd(Object)}. */ - REDUCING_ADD + REDUCING_ADD, + + /** Get value from aggregating state by {@link AggregatingState#asyncGet()}. */ + AGGREGATING_GET, + + /** Add element to aggregating state by {@link AggregatingState#asyncAdd(Object)}. */ + AGGREGATING_ADD } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AggregatingStateDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AggregatingStateDescriptor.java new file mode 100644 index 0000000000000..42ff2dd24b1fd --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AggregatingStateDescriptor.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.v2; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.serialization.SerializerConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; + +import javax.annotation.Nonnull; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * A {@link StateDescriptor} for {@link org.apache.flink.api.common.state.v2.AggregatingState}. + * + *

The type internally stored in the state is the type of the {@code Accumulator} of the {@code + * AggregateFunction}. + * + * @param The type of the values that are added to the state. + * @param The type of the accumulator (intermediate aggregation state). + * @param The type of the values that are returned from the state. + */ +public class AggregatingStateDescriptor extends StateDescriptor { + + private final AggregateFunction aggregateFunction; + + /** + * Create a new state descriptor with the given name, function, and type. + * + * @param stateId The (unique) name for the state. + * @param aggregateFunction The {@code AggregateFunction} used to aggregate the state. + * @param typeInfo The type of the accumulator. The accumulator is stored in the state. + */ + public AggregatingStateDescriptor( + @Nonnull String stateId, + @Nonnull AggregateFunction aggregateFunction, + @Nonnull TypeInformation typeInfo) { + super(stateId, typeInfo); + this.aggregateFunction = checkNotNull(aggregateFunction); + } + + /** + * Create a new state descriptor with the given name, function, and type. + * + * @param stateId The (unique) name for the state. + * @param aggregateFunction The {@code AggregateFunction} used to aggregate the state. + * @param typeInfo The type of the accumulator. The accumulator is stored in the state. + * @param serializerConfig The serializer related config used to generate TypeSerializer. + */ + public AggregatingStateDescriptor( + @Nonnull String stateId, + @Nonnull AggregateFunction aggregateFunction, + @Nonnull TypeInformation typeInfo, + SerializerConfig serializerConfig) { + super(stateId, typeInfo, serializerConfig); + this.aggregateFunction = checkNotNull(aggregateFunction); + } + + /** Returns the Aggregate function for this state. */ + public AggregateFunction getAggregateFunction() { + return aggregateFunction; + } + + @Override + public Type getType() { + return Type.AGGREGATING; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStoreV2.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStoreV2.java index 083678a244a16..002d596378657 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStoreV2.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStoreV2.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.state.v2; +import org.apache.flink.api.common.state.v2.AggregatingState; import org.apache.flink.api.common.state.v2.ListState; import org.apache.flink.api.common.state.v2.MapState; import org.apache.flink.api.common.state.v2.ReducingState; @@ -77,4 +78,15 @@ public ReducingState getReducingState( throw new RuntimeException("Error while getting state", e); } } + + @Override + public AggregatingState getAggregatingState( + @Nonnull AggregatingStateDescriptor stateProperties) { + Preconditions.checkNotNull(stateProperties, "The state properties must not be null"); + try { + return asyncKeyedStateBackend.createState(stateProperties); + } catch (Exception e) { + throw new RuntimeException("Error while getting state", e); + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalAggregatingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalAggregatingState.java new file mode 100644 index 0000000000000..7f05cdbd4d204 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalAggregatingState.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.v2; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.state.v2.AggregatingState; +import org.apache.flink.api.common.state.v2.StateFuture; +import org.apache.flink.runtime.asyncprocessing.StateRequestHandler; +import org.apache.flink.runtime.asyncprocessing.StateRequestType; + +/** + * The default implementation of {@link AggregatingState}, which delegates all async requests to + * {@link StateRequestHandler}. + * + * @param The type of key the state is associated to. + * @param The type of the values that are added into the state. + * @param TThe type of the accumulator (intermediate aggregation state). + * @param The type of the values that are returned from the state. + */ +public class InternalAggregatingState extends InternalKeyedState + implements AggregatingState { + + protected final AggregateFunction aggregateFunction; + + /** + * Creates a new InternalKeyedState with the given asyncExecutionController and stateDescriptor. + * + * @param stateRequestHandler The async request handler for handling all requests. + * @param stateDescriptor The properties of the state. + */ + public InternalAggregatingState( + StateRequestHandler stateRequestHandler, + AggregatingStateDescriptor stateDescriptor) { + super(stateRequestHandler, stateDescriptor); + this.aggregateFunction = stateDescriptor.getAggregateFunction(); + } + + @Override + public StateFuture asyncGet() { + return handleRequest(StateRequestType.AGGREGATING_GET, null); + } + + @Override + public StateFuture asyncAdd(IN value) { + return handleRequest(StateRequestType.AGGREGATING_ADD, value); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/KeyedStateStoreV2.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/KeyedStateStoreV2.java index a8dc7a0750038..4fda283b30a1c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/KeyedStateStoreV2.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/KeyedStateStoreV2.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.state.v2; import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.state.v2.AggregatingState; import org.apache.flink.api.common.state.v2.ListState; import org.apache.flink.api.common.state.v2.MapState; import org.apache.flink.api.common.state.v2.ReducingState; @@ -86,4 +87,19 @@ public interface KeyedStateStoreV2 { * function (function is not part of a KeyedStream). */ ReducingState getReducingState(@Nonnull ReducingStateDescriptor stateProperties); + + /** + * Gets a handle to the system's key/value aggregating state. This state is only accessible if + * the function is executed on a KeyedStream. + * + * @param stateProperties The descriptor defining the properties of the stats. + * @param The type of the values that are added to the state. + * @param The type of the accumulator (intermediate aggregation state). + * @param The type of the values that are returned from the state. + * @return The partitioned state object. + * @throws UnsupportedOperationException Thrown, if no partitioned state is available for the + * function (function is not part of a KeyedStream). + */ + AggregatingState getAggregatingState( + @Nonnull AggregatingStateDescriptor stateProperties); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AggregatingStateDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AggregatingStateDescriptorTest.java new file mode 100644 index 0000000000000..930a0c69124b3 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AggregatingStateDescriptorTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.v2; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.core.testutils.CommonTestUtils; + +import org.junit.jupiter.api.Test; + +import java.io.Serializable; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link AggregatingStateDescriptor}. */ +class AggregatingStateDescriptorTest implements Serializable { + + @Test + void testHashCodeAndEquals() throws Exception { + final String name = "testName"; + AggregateFunction aggregator = + new AggregateFunction() { + @Override + public Integer createAccumulator() { + return 0; + } + + @Override + public Integer add(Integer value, Integer accumulator) { + return accumulator + value; + } + + @Override + public Integer getResult(Integer accumulator) { + return accumulator; + } + + @Override + public Integer merge(Integer a, Integer b) { + return a + b; + } + }; + + AggregatingStateDescriptor original = + new AggregatingStateDescriptor<>(name, aggregator, BasicTypeInfo.INT_TYPE_INFO); + AggregatingStateDescriptor same = + new AggregatingStateDescriptor<>(name, aggregator, BasicTypeInfo.INT_TYPE_INFO); + AggregatingStateDescriptor sameBySerializer = + new AggregatingStateDescriptor<>(name, aggregator, BasicTypeInfo.INT_TYPE_INFO); + + // test that hashCode() works on state descriptors with initialized and uninitialized + // serializers + assertThat(same).hasSameHashCodeAs(original); + assertThat(sameBySerializer).hasSameHashCodeAs(original); + + assertThat(same).isEqualTo(original); + assertThat(sameBySerializer).isEqualTo(original); + + // equality with a clone + AggregatingStateDescriptor clone = + CommonTestUtils.createCopySerializable(original); + assertThat(clone).isEqualTo(original); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalAggregatingStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalAggregatingStateTest.java new file mode 100644 index 0000000000000..6e9724a61551d --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalAggregatingStateTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.v2; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.runtime.asyncprocessing.StateRequestType; + +import org.junit.jupiter.api.Test; + +/** Tests for {@link InternalAggregatingState}. */ +class InternalAggregatingStateTest extends InternalKeyedStateTestBase { + + @Test + @SuppressWarnings({"unchecked"}) + public void testAggregating() { + AggregateFunction aggregator = + new AggregateFunction() { + @Override + public Integer createAccumulator() { + return 0; + } + + @Override + public Integer add(Integer value, Integer accumulator) { + return accumulator + value; + } + + @Override + public Integer getResult(Integer accumulator) { + return accumulator; + } + + @Override + public Integer merge(Integer a, Integer b) { + return a + b; + } + }; + AggregatingStateDescriptor descriptor = + new AggregatingStateDescriptor<>( + "testAggState", aggregator, BasicTypeInfo.INT_TYPE_INFO); + InternalAggregatingState state = + new InternalAggregatingState<>(aec, descriptor); + + aec.setCurrentContext(aec.buildContext("test", "test")); + + state.asyncClear(); + validateRequestRun(state, StateRequestType.CLEAR, null); + + state.asyncGet(); + validateRequestRun(state, StateRequestType.AGGREGATING_GET, null); + + state.asyncAdd(1); + validateRequestRun(state, StateRequestType.AGGREGATING_ADD, 1); + } +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java index 4ad10d3e9bd2a..a3db48cbd7188 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java @@ -282,6 +282,15 @@ public org.apache.flink.api.common.state.v2.ReducingState getReducingStat return keyedStateStoreV2.getReducingState(stateProperties); } + public + org.apache.flink.api.common.state.v2.AggregatingState getAggregatingState( + org.apache.flink.runtime.state.v2.AggregatingStateDescriptor + stateProperties) { + KeyedStateStoreV2 keyedStateStoreV2 = + checkPreconditionsAndGetKeyedStateStoreV2(stateProperties); + return keyedStateStoreV2.getAggregatingState(stateProperties); + } + private KeyedStateStoreV2 checkPreconditionsAndGetKeyedStateStoreV2( org.apache.flink.runtime.state.v2.StateDescriptor stateDescriptor) { checkNotNull(stateDescriptor, "The state properties must not be null"); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java index f18fb87e42cce..144a22a21f569 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java @@ -355,6 +355,41 @@ void testV2ReducingStateInstantiation() throws Exception { .isPositive(); } + @Test + void testV2AggregatingStateInstantiation() throws Exception { + final ExecutionConfig config = new ExecutionConfig(); + SerializerConfig serializerConfig = config.getSerializerConfig(); + serializerConfig.registerKryoType(Path.class); + + final AtomicReference descriptorCapture = new AtomicReference<>(); + + StreamingRuntimeContext context = createRuntimeContext(descriptorCapture, config); + + @SuppressWarnings("unchecked") + AggregateFunction aggregate = + (AggregateFunction) mock(AggregateFunction.class); + + org.apache.flink.runtime.state.v2.AggregatingStateDescriptor + descr = + new org.apache.flink.runtime.state.v2.AggregatingStateDescriptor<>( + "name", + aggregate, + TypeInformation.of(TaskInfo.class), + serializerConfig); + + context.getAggregatingState(descr); + + org.apache.flink.runtime.state.v2.AggregatingStateDescriptor descrIntercepted = + (org.apache.flink.runtime.state.v2.AggregatingStateDescriptor) + descriptorCapture.get(); + TypeSerializer serializer = descrIntercepted.getSerializer(); + + // check that the Path class is really registered, i.e., the execution config was applied + assertThat(serializer).isInstanceOf(KryoSerializer.class); + assertThat(((KryoSerializer) serializer).getKryo().getRegistration(Path.class).getId()) + .isPositive(); + } + // ------------------------------------------------------------------------ // // ------------------------------------------------------------------------