Skip to content

Commit

Permalink
[FLINK-34974][state] Support getOrCreateKeyedState for AsyncKeyedStat…
Browse files Browse the repository at this point in the history
…eBackend (#25745)
  • Loading branch information
fredia authored Dec 11, 2024
1 parent 6951686 commit 6a85f80
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,22 @@ public interface AsyncKeyedStateBackend<K>
void setup(@Nonnull StateRequestHandler stateRequestHandler);

/**
* Creates and returns a new state.
* Creates or retrieves a keyed state backed by this state backend.
*
* @param <N> the type of namespace for partitioning.
* @param <S> The type of the public API state.
* @param <SV> The type of the stored state value.
* @param defaultNamespace the default namespace for this state.
* @param namespaceSerializer the serializer for namespace.
* @param stateDesc The {@code StateDescriptor} that contains the name of the state.
* @throws Exception Exceptions may occur during initialization of the state.
* @return A new key/value state backed by this backend.
* @throws Exception Exceptions may occur during initialization of the state and should be
* forwarded.
*/
@Nonnull
<N, S extends State, SV> S createState(
@Nonnull N defaultNamespace,
@Nonnull TypeSerializer<N> namespaceSerializer,
@Nonnull StateDescriptor<SV> stateDesc)
<N, S extends State, SV> S getOrCreateKeyedState(
N defaultNamespace,
TypeSerializer<N> namespaceSerializer,
StateDescriptor<SV> stateDesc)
throws Exception;

/**
Expand Down Expand Up @@ -122,6 +123,20 @@ default boolean requiresLegacySynchronousTimerSnapshots(SnapshotType checkpointT
return true;
}

/**
* Whether it's safe to reuse key-values from the state-backend, e.g for the purpose of
* optimization.
*
* <p>NOTE: this method should not be used to check for {@link InternalPriorityQueue}, as the
* priority queue could be stored on different locations, e.g ForSt state-backend could store
* that on JVM heap if configuring HEAP as the time-service factory.
*
* @return returns ture if safe to reuse the key-values from the state-backend.
*/
default boolean isSafeToReuseKVState() {
return false;
}

@Override
void dispose();
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public DefaultKeyedStateStore(@Nonnull AsyncKeyedStateBackend asyncKeyedStateBac
public <T> ValueState<T> getValueState(@Nonnull ValueStateDescriptor<T> stateProperties) {
Preconditions.checkNotNull(stateProperties, "The state properties must not be null");
try {
return asyncKeyedStateBackend.createState(
return asyncKeyedStateBackend.getOrCreateKeyedState(
VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateProperties);
} catch (Exception e) {
throw new RuntimeException("Error while getting state", e);
Expand All @@ -54,7 +54,7 @@ public <T> ValueState<T> getValueState(@Nonnull ValueStateDescriptor<T> statePro
public <T> ListState<T> getListState(@Nonnull ListStateDescriptor<T> stateProperties) {
Preconditions.checkNotNull(stateProperties, "The state properties must not be null");
try {
return asyncKeyedStateBackend.createState(
return asyncKeyedStateBackend.getOrCreateKeyedState(
VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateProperties);
} catch (Exception e) {
throw new RuntimeException("Error while getting state", e);
Expand All @@ -66,7 +66,7 @@ public <UK, UV> MapState<UK, UV> getMapState(
@Nonnull MapStateDescriptor<UK, UV> stateProperties) {
Preconditions.checkNotNull(stateProperties, "The state properties must not be null");
try {
return asyncKeyedStateBackend.createState(
return asyncKeyedStateBackend.getOrCreateKeyedState(
VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateProperties);
} catch (Exception e) {
throw new RuntimeException("Error while getting state", e);
Expand All @@ -78,7 +78,7 @@ public <T> ReducingState<T> getReducingState(
@Nonnull ReducingStateDescriptor<T> stateProperties) {
Preconditions.checkNotNull(stateProperties, "The state properties must not be null");
try {
return asyncKeyedStateBackend.createState(
return asyncKeyedStateBackend.getOrCreateKeyedState(
VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateProperties);
} catch (Exception e) {
throw new RuntimeException("Error while getting state", e);
Expand All @@ -90,7 +90,7 @@ public <IN, ACC, OUT> AggregatingState<IN, OUT> getAggregatingState(
@Nonnull AggregatingStateDescriptor<IN, ACC, OUT> stateProperties) {
Preconditions.checkNotNull(stateProperties, "The state properties must not be null");
try {
return asyncKeyedStateBackend.createState(
return asyncKeyedStateBackend.getOrCreateKeyedState(
VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateProperties);
} catch (Exception e) {
throw new RuntimeException("Error while getting state", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,11 @@ public AsyncKeyedStateBackendAdaptor(CheckpointableKeyedStateBackend<K> keyedSta
@Override
public void setup(@Nonnull StateRequestHandler stateRequestHandler) {}

@Nonnull
@Override
@SuppressWarnings({"rawtypes", "unchecked"})
public <N, S extends State, SV> S createState(
@Nonnull N defaultNamespace,
@Nonnull TypeSerializer<N> namespaceSerializer,
@Nonnull StateDescriptor<SV> stateDesc)
public <N, S extends State, SV> S getOrCreateKeyedState(
N defaultNamespace,
TypeSerializer<N> namespaceSerializer,
StateDescriptor<SV> stateDesc)
throws Exception {
return createStateInternal(defaultNamespace, namespaceSerializer, stateDesc);
}
Expand Down Expand Up @@ -191,4 +189,9 @@ public boolean requiresLegacySynchronousTimerSnapshots(SnapshotType checkpointTy
}
return false;
}

@Override
public boolean isSafeToReuseKVState() {
return keyedStateBackend.isSafeToReuseKVState();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ public <N, S extends org.apache.flink.api.common.state.v2.State, T> S getOrCreat
throws Exception {

if (asyncKeyedStateBackend != null) {
return asyncKeyedStateBackend.createState(
return asyncKeyedStateBackend.getOrCreateKeyedState(
defaultNamespace, namespaceSerializer, stateDescriptor);
} else {
throw new IllegalStateException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ void setup(

try {
valueState =
asyncKeyedStateBackend.createState(
asyncKeyedStateBackend.getOrCreateKeyedState(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
stateDescriptor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ public void setup(@Nonnull StateRequestHandler stateRequestHandler) {
// do nothing
}

@Nonnull
@Override
@SuppressWarnings("unchecked")
public <N, S extends org.apache.flink.api.common.state.v2.State, SV> S createState(
@Nonnull N defaultNamespace,
@Nonnull TypeSerializer<N> namespaceSerializer,
@Nonnull org.apache.flink.runtime.state.v2.StateDescriptor<SV> stateDesc) {
public <N, S extends org.apache.flink.api.common.state.v2.State, SV>
S getOrCreateKeyedState(
N defaultNamespace,
TypeSerializer<N> namespaceSerializer,
org.apache.flink.runtime.state.v2.StateDescriptor<SV> stateDesc)
throws Exception {
return (S) innerStateSupplier.get();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,11 @@ public void close() throws IOException {}
@Override
public void setup(@Nonnull StateRequestHandler stateRequestHandler) {}

@Nonnull
@Override
public <N, S extends State, SV> S createState(
@Nonnull N defaultNamespace,
@Nonnull TypeSerializer<N> namespaceSerializer,
@Nonnull StateDescriptor<SV> stateDesc)
public <N, S extends State, SV> S getOrCreateKeyedState(
N defaultNamespace,
TypeSerializer<N> namespaceSerializer,
StateDescriptor<SV> stateDesc)
throws Exception {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public void testValueStateAdaptor() throws Exception {
new ValueStateDescriptor<>("testState", BasicTypeInfo.INT_TYPE_INFO);

org.apache.flink.api.common.state.v2.ValueState<Integer> valueState =
adaptor.createState(
adaptor.getOrCreateKeyedState(
VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descriptor);

// test synchronous interfaces.
Expand Down Expand Up @@ -102,7 +102,7 @@ public void testListStateAdaptor() throws Exception {
new ListStateDescriptor<>("testState", BasicTypeInfo.INT_TYPE_INFO);

org.apache.flink.api.common.state.v2.ListState<Integer> listState =
adaptor.createState(
adaptor.getOrCreateKeyedState(
VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descriptor);

// test synchronous interfaces.
Expand Down Expand Up @@ -154,7 +154,7 @@ public void testMapStateAdaptor() throws Exception {
"testState", BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO);

org.apache.flink.api.common.state.v2.MapState<Integer, Integer> mapState =
adaptor.createState(
adaptor.getOrCreateKeyedState(
VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descriptor);

final HashMap<Integer, Integer> groundTruth =
Expand Down Expand Up @@ -247,7 +247,7 @@ public void testReducingStateAdaptor() throws Exception {
"testState", Integer::sum, BasicTypeInfo.INT_TYPE_INFO);

InternalReducingState<String, Long, Integer> reducingState =
adaptor.createState(0L, LongSerializer.INSTANCE, descriptor);
adaptor.getOrCreateKeyedState(0L, LongSerializer.INSTANCE, descriptor);

// test synchronous interfaces.
reducingState.clear();
Expand Down Expand Up @@ -353,7 +353,7 @@ public Integer merge(Integer a, Integer b) {
BasicTypeInfo.INT_TYPE_INFO);

InternalAggregatingState<String, Long, Integer, Integer, String> aggState =
adaptor.createState(0L, LongSerializer.INSTANCE, descriptor);
adaptor.getOrCreateKeyedState(0L, LongSerializer.INSTANCE, descriptor);

// test synchronous interfaces.
aggState.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ void testAsyncKeyedStateBackendSnapshot() throws Exception {
new ValueStateDescriptor<>("test", BasicTypeInfo.INT_TYPE_INFO);

ValueState<Integer> valueState =
backend.createState(
backend.getOrCreateKeyedState(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
stateDescriptor);
Expand Down Expand Up @@ -318,7 +318,7 @@ void testAsyncKeyedStateBackendSnapshot() throws Exception {
new ValueStateDescriptor<>("test", BasicTypeInfo.INT_TYPE_INFO);

ValueState<Integer> valueState =
backend.createState(
backend.getOrCreateKeyedState(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
stateDescriptor);
Expand Down Expand Up @@ -432,7 +432,7 @@ void testValueStateWorkWithTtl() throws Exception {
kvId.enableTimeToLive(StateTtlConfig.newBuilder(Duration.ofSeconds(1)).build());

ValueState<Long> state =
backend.createState(
backend.getOrCreateKeyedState(
VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
RecordContext recordContext = aec.buildContext("record-1", 1L);
recordContext.retain();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ public <T> TypeSerializer<T> createSerializer(
return null;
})
.when(asyncKeyedStateBackend)
.createState(
.getOrCreateKeyedState(
any(),
any(TypeSerializer.class),
any(org.apache.flink.runtime.state.v2.StateDescriptor.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import javax.annotation.concurrent.GuardedBy;

import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
Expand All @@ -81,6 +82,7 @@
import java.util.function.Supplier;

import static org.apache.flink.runtime.state.SnapshotExecutionType.ASYNCHRONOUS;
import static org.apache.flink.util.Preconditions.checkNotNull;

/**
* A KeyedStateBackend that stores its state in {@code ForSt}. This state backend can store very
Expand Down Expand Up @@ -158,6 +160,9 @@ public class ForStKeyedStateBackend<K> implements AsyncKeyedStateBackend<K> {
*/
private final LinkedHashMap<String, ForStOperationUtils.ForStKvStateInfo> kvStateInformation;

/** So that we can give out state when the user uses the same key. */
private final HashMap<String, InternalKeyedState<K, ?, ?>> keyValueStatesByName;

/** Lock guarding the {@code managedStateExecutors} and {@code disposed}. */
private final Object lock = new Object();

Expand Down Expand Up @@ -201,6 +206,7 @@ public ForStKeyedStateBackend(
this.valueDeserializerView = valueDeserializerView;
this.db = db;
this.kvStateInformation = kvStateInformation;
this.keyValueStatesByName = new HashMap<>();
this.columnFamilyOptionsFactory = columnFamilyOptionsFactory;
this.defaultColumnFamily = defaultColumnFamilyHandle;
this.snapshotStrategy = snapshotStrategy;
Expand All @@ -227,10 +233,24 @@ public void setup(@Nonnull StateRequestHandler stateRequestHandler) {
this.stateRequestHandler = stateRequestHandler;
}

@Nonnull
@Override
public <N, S extends State, SV> S getOrCreateKeyedState(
N defaultNamespace,
TypeSerializer<N> namespaceSerializer,
StateDescriptor<SV> stateDesc)
throws Exception {
checkNotNull(namespaceSerializer, "Namespace serializer");
InternalKeyedState<K, ?, ?> kvState = keyValueStatesByName.get(stateDesc.getStateId());
if (kvState == null) {
kvState = createState(defaultNamespace, namespaceSerializer, stateDesc);
keyValueStatesByName.put(stateDesc.getStateId(), kvState);
}
return (S) kvState;
}

@Nonnull
@SuppressWarnings("unchecked")
public <N, S extends State, SV> S createState(
protected <N, S extends State, SV> S createState(
@Nonnull N defaultNamespace,
@Nonnull TypeSerializer<N> namespaceSerializer,
@Nonnull StateDescriptor<SV> stateDesc)
Expand Down Expand Up @@ -523,6 +543,11 @@ public void dispose() {
}
}

@Override
public boolean isSafeToReuseKVState() {
return true;
}

@VisibleForTesting
Path getLocalBasePath() {
return optionsContainer.getLocalBasePath();
Expand Down

0 comments on commit 6a85f80

Please sign in to comment.