From 570f3d2d8986dee77c8e237fb27cfbb162dc4cf0 Mon Sep 17 00:00:00 2001 From: Burak Ozakinci Date: Mon, 16 Sep 2024 15:35:55 +0100 Subject: [PATCH] [FLINK-31922][Connectors/AWS] Port over Kinesis Client configurations for retry and backoff --- .../aws/config/AWSConfigOptions.java | 189 ++++++++++++++++++ .../kinesis/source/KinesisStreamsSource.java | 58 +++++- .../source/KinesisStreamsSourceBuilder.java | 52 ++++- ...s.java => KinesisSourceConfigOptions.java} | 28 ++- .../KinesisStreamsSourceConfigUtil.java | 6 +- .../KinesisStreamsSourceEnumerator.java | 8 +- .../source/proxy/KinesisStreamProxy.java | 15 +- .../kinesis/source/proxy/StreamProxy.java | 6 +- .../PollingKinesisShardSplitReader.java | 12 +- .../KinesisStreamsSourceConfigUtilTest.java | 4 +- .../KinesisStreamsSourceEnumeratorTest.java | 6 +- .../source/proxy/KinesisStreamProxyTest.java | 36 +++- .../KinesisStreamsSourceReaderTest.java | 6 +- .../PollingKinesisShardSplitReaderTest.java | 31 ++- .../util/KinesisStreamProxyProvider.java | 9 +- 15 files changed, 425 insertions(+), 41 deletions(-) create mode 100644 flink-connector-aws-base/src/main/java/org/apache/flink/connector/aws/config/AWSConfigOptions.java rename flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/config/{KinesisStreamsSourceConfigConstants.java => KinesisSourceConfigOptions.java} (70%) diff --git a/flink-connector-aws-base/src/main/java/org/apache/flink/connector/aws/config/AWSConfigOptions.java b/flink-connector-aws-base/src/main/java/org/apache/flink/connector/aws/config/AWSConfigOptions.java new file mode 100644 index 00000000..fc6b92de --- /dev/null +++ b/flink-connector-aws-base/src/main/java/org/apache/flink/connector/aws/config/AWSConfigOptions.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.aws.config; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.configuration.ConfigOption; +import org.apache.flink.configuration.ConfigOptions; + +import java.time.Duration; + +import static org.apache.flink.connector.aws.config.AWSConfigConstants.AWS_CREDENTIALS_PROVIDER; +import static org.apache.flink.connector.aws.config.AWSConfigConstants.accessKeyId; +import static org.apache.flink.connector.aws.config.AWSConfigConstants.customCredentialsProviderClass; +import static org.apache.flink.connector.aws.config.AWSConfigConstants.externalId; +import static org.apache.flink.connector.aws.config.AWSConfigConstants.profileName; +import static org.apache.flink.connector.aws.config.AWSConfigConstants.profilePath; +import static org.apache.flink.connector.aws.config.AWSConfigConstants.roleArn; +import static org.apache.flink.connector.aws.config.AWSConfigConstants.roleSessionName; +import static org.apache.flink.connector.aws.config.AWSConfigConstants.roleStsEndpoint; +import static org.apache.flink.connector.aws.config.AWSConfigConstants.secretKey; +import static org.apache.flink.connector.aws.config.AWSConfigConstants.webIdentityTokenFile; + +/** Configuration options for AWS service usage. */ +@PublicEvolving +public class AWSConfigOptions { + public static final ConfigOption AWS_REGION_OPTION = + ConfigOptions.key(AWSConfigConstants.AWS_REGION) + .stringType() + .noDefaultValue() + .withDescription( + "The AWS region of the service (\"us-east-1\" is used if not set)."); + + public static final ConfigOption + AWS_CREDENTIALS_PROVIDER_OPTION = + ConfigOptions.key(AWS_CREDENTIALS_PROVIDER) + .enumType(AWSConfigConstants.CredentialProvider.class) + .defaultValue(AWSConfigConstants.CredentialProvider.BASIC) + .withDescription( + "The credential provider type to use when AWS credentials are required (BASIC is used if not set"); + + public static final ConfigOption AWS_ACCESS_KEY_ID_OPTION = + ConfigOptions.key(accessKeyId(AWS_CREDENTIALS_PROVIDER)) + .stringType() + .noDefaultValue() + .withDescription( + "The AWS access key ID to use when setting credentials provider type to BASIC."); + + public static final ConfigOption AWS_SECRET_ACCESS_KEY_OPTION = + ConfigOptions.key(secretKey(AWSConfigConstants.AWS_CREDENTIALS_PROVIDER)) + .stringType() + .noDefaultValue() + .withDescription( + "The AWS secret key to use when setting credentials provider type to BASIC."); + + public static final ConfigOption AWS_PROFILE_PATH_OPTION = + ConfigOptions.key(profilePath(AWS_CREDENTIALS_PROVIDER)) + .stringType() + .noDefaultValue() + .withDescription( + "Optional configuration for profile path if credential provider type is set to be PROFILE."); + + public static final ConfigOption AWS_PROFILE_NAME_OPTION = + ConfigOptions.key(profileName(AWS_CREDENTIALS_PROVIDER)) + .stringType() + .noDefaultValue() + .withDescription( + "Optional configuration for profile name if credential provider type is set to be PROFILE."); + + public static final ConfigOption AWS_ROLE_STS_ENDPOINT_OPTION = + ConfigOptions.key(roleStsEndpoint(AWS_CREDENTIALS_PROVIDER)) + .stringType() + .noDefaultValue() + .withDescription( + "The AWS endpoint for the STS (derived from the AWS region setting if not set) " + + "to use if credential provider type is set to be ASSUME_ROLE."); + + public static final ConfigOption CUSTOM_CREDENTIALS_PROVIDER_CLASS_OPTION = + ConfigOptions.key(customCredentialsProviderClass(AWS_CREDENTIALS_PROVIDER)) + .stringType() + .noDefaultValue() + .withDescription( + "The full path (e.g. org.user_company.auth.CustomAwsCredentialsProvider) to the user provided" + + "class to use if credential provider type is set to be CUSTOM."); + + public static final ConfigOption AWS_ROLE_ARN_OPTION = + ConfigOptions.key(roleArn(AWS_CREDENTIALS_PROVIDER)) + .stringType() + .noDefaultValue() + .withDescription( + "The role ARN to use when credential provider type is set to ASSUME_ROLE or" + + "WEB_IDENTITY_TOKEN"); + + public static final ConfigOption AWS_ROLE_SESSION_NAME = + ConfigOptions.key(roleSessionName(AWS_CREDENTIALS_PROVIDER)) + .stringType() + .noDefaultValue() + .withDescription( + "The role session name to use when credential provider type is set to ASSUME_ROLE or" + + "WEB_IDENTITY_TOKEN"); + + public static final ConfigOption AWS_ROLE_EXTERNAL_ID_OPTION = + ConfigOptions.key(externalId(AWS_CREDENTIALS_PROVIDER)) + .stringType() + .noDefaultValue() + .withDescription( + "The external ID to use when credential provider type is set to ASSUME_ROLE."); + + public static final ConfigOption AWS_WEB_IDENTITY_TOKEN_FILE = + ConfigOptions.key(webIdentityTokenFile(AWS_CREDENTIALS_PROVIDER)) + .stringType() + .noDefaultValue() + .withDescription( + "The absolute path to the web identity token file that should be used if provider" + + " type is set to WEB_IDENTITY_TOKEN."); + + public static final ConfigOption AWS_ROLE_CREDENTIALS_PROVIDER_OPTION = + ConfigOptions.key(webIdentityTokenFile(AWS_CREDENTIALS_PROVIDER)) + .stringType() + .noDefaultValue() + .withDescription( + "The credentials provider that provides credentials for assuming the role when" + + " credential provider type is set to ASSUME_ROLE. Roles can be nested, so" + + " AWS_ROLE_CREDENTIALS_PROVIDER can again be set to ASSUME_ROLE"); + + public static final ConfigOption AWS_ENDPOINT_OPTION = + ConfigOptions.key(AWSConfigConstants.AWS_ENDPOINT) + .stringType() + .noDefaultValue() + .withDescription( + "The AWS endpoint for the service (derived from the AWS region setting if not set)."); + + public static final ConfigOption TRUST_ALL_CERTIFICATES_OPTION = + ConfigOptions.key(AWSConfigConstants.TRUST_ALL_CERTIFICATES) + .stringType() + .noDefaultValue() + .withDescription("Whether to trust all SSL certificates."); + + public static final ConfigOption HTTP_PROTOCOL_VERSION_OPTION = + ConfigOptions.key(AWSConfigConstants.HTTP_PROTOCOL_VERSION) + .stringType() + .noDefaultValue() + .withDescription("The HTTP protocol version to use."); + + public static final ConfigOption HTTP_CLIENT_MAX_CONCURRENCY_OPTION = + ConfigOptions.key(AWSConfigConstants.HTTP_CLIENT_MAX_CONCURRENCY) + .stringType() + .noDefaultValue() + .withDescription("Maximum request concurrency for SdkAsyncHttpClient."); + + public static final ConfigOption HTTP_CLIENT_READ_TIMEOUT_MILLIS_OPTION = + ConfigOptions.key(AWSConfigConstants.HTTP_CLIENT_READ_TIMEOUT_MILLIS) + .stringType() + .noDefaultValue() + .withDescription("Read Request timeout for SdkAsyncHttpClient."); + + public static final ConfigOption RETRY_STRATEGY_MIN_DELAY_OPTION = + ConfigOptions.key("retry-strategy.delay.min") + .durationType() + .defaultValue(Duration.ofMillis(300)) + .withDescription("Base delay for the exponential backoff retry strategy"); + + public static final ConfigOption RETRY_STRATEGY_MAX_DELAY_OPTION = + ConfigOptions.key("retry-strategy.delay.max") + .durationType() + .defaultValue(Duration.ofMillis(1000)) + .withDescription("Max delay for the exponential backoff retry strategy"); + + public static final ConfigOption RETRY_STRATEGY_MAX_ATTEMPTS_OPTION = + ConfigOptions.key("retry-strategy.attempts.max") + .intType() + .defaultValue(50) + .withDescription( + "Maximum number of attempts for the exponential backoff retry strategy"); +} diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/KinesisStreamsSource.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/KinesisStreamsSource.java index b795f7b0..2977005f 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/KinesisStreamsSource.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/KinesisStreamsSource.java @@ -28,6 +28,7 @@ import org.apache.flink.api.connector.source.SplitEnumeratorContext; import org.apache.flink.configuration.Configuration; import org.apache.flink.connector.aws.config.AWSConfigConstants; +import org.apache.flink.connector.aws.config.AWSConfigOptions; import org.apache.flink.connector.aws.util.AWSClientUtil; import org.apache.flink.connector.aws.util.AWSGeneralUtil; import org.apache.flink.connector.base.source.reader.fetcher.SingleThreadFetcherManager; @@ -49,11 +50,18 @@ import org.apache.flink.util.Preconditions; import org.apache.flink.util.UserCodeClassLoader; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.awscore.internal.AwsErrorCode; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.http.SdkHttpClient; import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.retries.StandardRetryStrategy; +import software.amazon.awssdk.retries.api.BackoffStrategy; +import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.kinesis.KinesisClient; import software.amazon.awssdk.utils.AttributeMap; +import java.time.Duration; import java.util.Map; import java.util.Properties; import java.util.concurrent.ConcurrentHashMap; @@ -139,7 +147,9 @@ public SourceReader createReader(SourceReaderContext reade Supplier splitReaderSupplier = () -> new PollingKinesisShardSplitReader( - createKinesisStreamProxy(sourceConfig), shardMetricGroupMap); + createKinesisStreamProxy(sourceConfig), + shardMetricGroupMap, + sourceConfig); KinesisStreamsRecordEmitter recordEmitter = new KinesisStreamsRecordEmitter<>(deserializationSchema); @@ -199,12 +209,25 @@ private KinesisStreamProxy createKinesisStreamProxy(Configuration consumerConfig consumerConfig.addAllToProperties(kinesisClientProperties); kinesisClientProperties.put(AWSConfigConstants.AWS_REGION, region); + final ClientOverrideConfiguration.Builder overrideBuilder = + ClientOverrideConfiguration.builder() + .retryStrategy( + createExpBackoffRetryStrategy( + sourceConfig.get( + AWSConfigOptions.RETRY_STRATEGY_MIN_DELAY_OPTION), + sourceConfig.get( + AWSConfigOptions.RETRY_STRATEGY_MAX_DELAY_OPTION), + sourceConfig.get( + AWSConfigOptions + .RETRY_STRATEGY_MAX_ATTEMPTS_OPTION))); + AWSGeneralUtil.validateAwsCredentials(kinesisClientProperties); KinesisClient kinesisClient = AWSClientUtil.createAwsSyncClient( kinesisClientProperties, httpClient, KinesisClient.builder(), + overrideBuilder, KinesisStreamsConfigConstants.BASE_KINESIS_USER_AGENT_PREFIX_FORMAT, KinesisStreamsConfigConstants.KINESIS_CLIENT_USER_AGENT_PREFIX); return new KinesisStreamProxy(kinesisClient, httpClient); @@ -225,4 +248,37 @@ public UserCodeClassLoader getUserCodeClassLoader() { } }); } + + private RetryStrategy createExpBackoffRetryStrategy( + Duration initialDelay, Duration maxDelay, int maxAttempts) { + final BackoffStrategy backoffStrategy = + BackoffStrategy.exponentialDelayHalfJitter(initialDelay, maxDelay); + + return StandardRetryStrategy.builder() + .backoffStrategy(backoffStrategy) + .throttlingBackoffStrategy(backoffStrategy) + .maxAttempts(maxAttempts) + .retryOnException( + throwable -> { + if (throwable instanceof AwsServiceException) { + AwsServiceException exception = (AwsServiceException) throwable; + return (AwsErrorCode.RETRYABLE_ERROR_CODES.contains( + exception.awsErrorDetails().errorCode())) + || (AwsErrorCode.THROTTLING_ERROR_CODES.contains( + exception.awsErrorDetails().errorCode())); + } + return false; + }) + .treatAsThrottling( + throwable -> { + if (throwable instanceof AwsServiceException) { + AwsServiceException exception = (AwsServiceException) throwable; + return AwsErrorCode.THROTTLING_ERROR_CODES.contains( + exception.awsErrorDetails().errorCode()); + } + return false; + }) + .circuitBreakerEnabled(false) + .build(); + } } diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/KinesisStreamsSourceBuilder.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/KinesisStreamsSourceBuilder.java index 2e74d052..05a8d77b 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/KinesisStreamsSourceBuilder.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/KinesisStreamsSourceBuilder.java @@ -20,12 +20,16 @@ import org.apache.flink.annotation.Experimental; import org.apache.flink.api.common.serialization.DeserializationSchema; +import org.apache.flink.configuration.ConfigOption; import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions; import org.apache.flink.connector.kinesis.source.enumerator.KinesisShardAssigner; import org.apache.flink.connector.kinesis.source.enumerator.assigner.ShardAssignerFactory; import org.apache.flink.connector.kinesis.source.enumerator.assigner.UniformShardAssigner; import org.apache.flink.connector.kinesis.source.serialization.KinesisDeserializationSchema; +import java.time.Duration; + /** * Builder to construct the {@link KinesisStreamsSource}. * @@ -52,10 +56,17 @@ @Experimental public class KinesisStreamsSourceBuilder { private String streamArn; - private Configuration sourceConfig; private KinesisDeserializationSchema deserializationSchema; private KinesisShardAssigner kinesisShardAssigner = ShardAssignerFactory.uniformShardAssigner(); private boolean preserveShardOrder = true; + private Duration retryStrategyMinDelay; + private Duration retryStrategyMaxDelay; + private Integer retryStrategyMaxAttempts; + private final Configuration configuration; + + public KinesisStreamsSourceBuilder() { + this.configuration = new Configuration(); + } public KinesisStreamsSourceBuilder setStreamArn(String streamArn) { this.streamArn = streamArn; @@ -63,7 +74,7 @@ public KinesisStreamsSourceBuilder setStreamArn(String streamArn) { } public KinesisStreamsSourceBuilder setSourceConfig(Configuration sourceConfig) { - this.sourceConfig = sourceConfig; + this.configuration.addAll(sourceConfig); return this; } @@ -90,12 +101,47 @@ public KinesisStreamsSourceBuilder setPreserveShardOrder(boolean preserveShar return this; } + public KinesisStreamsSourceBuilder setRetryStrategyMinDelay(Duration retryStrategyMinDelay) { + this.retryStrategyMinDelay = retryStrategyMinDelay; + return this; + } + + public KinesisStreamsSourceBuilder setRetryStrategyMaxDelay(Duration retryStrategyMaxDelay) { + this.retryStrategyMaxDelay = retryStrategyMaxDelay; + return this; + } + + public KinesisStreamsSourceBuilder setRetryStrategyMaxAttempts( + Integer retryStrategyMaxAttempts) { + this.retryStrategyMaxAttempts = retryStrategyMaxAttempts; + return this; + } + public KinesisStreamsSource build() { + setSourceConfigurations(); return new KinesisStreamsSource<>( streamArn, - sourceConfig, + configuration, deserializationSchema, kinesisShardAssigner, preserveShardOrder); } + + private void setSourceConfigurations() { + overrideIfExists( + KinesisSourceConfigOptions.RETRY_STRATEGY_MIN_DELAY_OPTION, + this.retryStrategyMinDelay); + overrideIfExists( + KinesisSourceConfigOptions.RETRY_STRATEGY_MAX_DELAY_OPTION, + this.retryStrategyMaxDelay); + overrideIfExists( + KinesisSourceConfigOptions.RETRY_STRATEGY_MAX_ATTEMPTS_OPTION, + this.retryStrategyMaxAttempts); + } + + private void overrideIfExists(ConfigOption configOption, E value) { + if (value != null) { + this.configuration.set(configOption, value); + } + } } diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/config/KinesisStreamsSourceConfigConstants.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/config/KinesisSourceConfigOptions.java similarity index 70% rename from flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/config/KinesisStreamsSourceConfigConstants.java rename to flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/config/KinesisSourceConfigOptions.java index 76ab546d..241e6d37 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/config/KinesisStreamsSourceConfigConstants.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/config/KinesisSourceConfigOptions.java @@ -19,12 +19,17 @@ package org.apache.flink.connector.kinesis.source.config; import org.apache.flink.annotation.Experimental; +import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.configuration.ConfigOption; import org.apache.flink.configuration.ConfigOptions; +import org.apache.flink.connector.aws.config.AWSConfigOptions; + +import java.time.Duration; /** Constants to be used with the KinesisStreamsSource. */ @Experimental -public class KinesisStreamsSourceConfigConstants { +@PublicEvolving +public class KinesisSourceConfigOptions extends AWSConfigOptions { /** Marks the initial position to use when reading from the Kinesis stream. */ public enum InitialPosition { LATEST, @@ -33,28 +38,35 @@ public enum InitialPosition { } public static final ConfigOption STREAM_INITIAL_POSITION = - ConfigOptions.key("flink.stream.initpos") + ConfigOptions.key("kinesis.stream.init.position") .enumType(InitialPosition.class) .defaultValue(InitialPosition.LATEST) .withDescription("The initial position to start reading Kinesis streams."); public static final ConfigOption STREAM_INITIAL_TIMESTAMP = - ConfigOptions.key("flink.stream.initpos.timestamp") + ConfigOptions.key("kinesis.stream.init.timestamp") .stringType() .noDefaultValue() .withDescription( "The initial timestamp at which to start reading from the Kinesis stream. This is used when AT_TIMESTAMP is configured for the STREAM_INITIAL_POSITION."); public static final ConfigOption STREAM_TIMESTAMP_DATE_FORMAT = - ConfigOptions.key("flink.stream.initpos.timestamp.format") + ConfigOptions.key("kinesis.stream.init.timestamp.format") .stringType() .defaultValue("yyyy-MM-dd'T'HH:mm:ss.SSSXXX") .withDescription( "The date format used to parse the initial timestamp at which to start reading from the Kinesis stream. This is used when AT_TIMESTAMP is configured for the STREAM_INITIAL_POSITION."); - public static final ConfigOption SHARD_DISCOVERY_INTERVAL_MILLIS = - ConfigOptions.key("flink.shard.discovery.intervalmillis") - .longType() - .defaultValue(10000L) + public static final ConfigOption SHARD_DISCOVERY_INTERVAL = + ConfigOptions.key("kinesis.shard.discovery.interval") + .durationType() + .defaultValue(Duration.ofSeconds(10)) .withDescription("The interval between each attempt to discover new shards."); + + public static final ConfigOption SHARD_GET_RECORDS_MAX = + ConfigOptions.key("kinesis.shard.get-records.max-record-count") + .intType() + .defaultValue(10000) + .withDescription( + "The maximum number of records to try to get each time we fetch records from a AWS Kinesis shard"); } diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/config/KinesisStreamsSourceConfigUtil.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/config/KinesisStreamsSourceConfigUtil.java index 98000efc..47d99fcc 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/config/KinesisStreamsSourceConfigUtil.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/config/KinesisStreamsSourceConfigUtil.java @@ -27,10 +27,10 @@ import java.time.Instant; import java.util.Optional; -import static org.apache.flink.connector.kinesis.source.config.KinesisStreamsSourceConfigConstants.STREAM_INITIAL_TIMESTAMP; -import static org.apache.flink.connector.kinesis.source.config.KinesisStreamsSourceConfigConstants.STREAM_TIMESTAMP_DATE_FORMAT; +import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.STREAM_INITIAL_TIMESTAMP; +import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.STREAM_TIMESTAMP_DATE_FORMAT; -/** Utility functions to use with {@link KinesisStreamsSourceConfigConstants}. */ +/** Utility functions to use with {@link KinesisSourceConfigOptions}. */ @Internal public class KinesisStreamsSourceConfigUtil { diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/KinesisStreamsSourceEnumerator.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/KinesisStreamsSourceEnumerator.java index 1f9f707c..245b945a 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/KinesisStreamsSourceEnumerator.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/KinesisStreamsSourceEnumerator.java @@ -26,7 +26,7 @@ import org.apache.flink.api.connector.source.SplitEnumeratorContext; import org.apache.flink.api.connector.source.SplitsAssignment; import org.apache.flink.configuration.Configuration; -import org.apache.flink.connector.kinesis.source.config.KinesisStreamsSourceConfigConstants.InitialPosition; +import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.InitialPosition; import org.apache.flink.connector.kinesis.source.enumerator.tracker.SplitTracker; import org.apache.flink.connector.kinesis.source.event.SplitsFinishedEvent; import org.apache.flink.connector.kinesis.source.exception.KinesisStreamsSourceException; @@ -56,8 +56,8 @@ import java.util.Optional; import java.util.Set; -import static org.apache.flink.connector.kinesis.source.config.KinesisStreamsSourceConfigConstants.SHARD_DISCOVERY_INTERVAL_MILLIS; -import static org.apache.flink.connector.kinesis.source.config.KinesisStreamsSourceConfigConstants.STREAM_INITIAL_POSITION; +import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.SHARD_DISCOVERY_INTERVAL; +import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.STREAM_INITIAL_POSITION; import static org.apache.flink.connector.kinesis.source.config.KinesisStreamsSourceConfigUtil.parseStreamTimestampStartingPosition; /** @@ -111,7 +111,7 @@ public void start() { context.callAsync(this::initialDiscoverSplits, this::processDiscoveredSplits); } - final long shardDiscoveryInterval = sourceConfig.get(SHARD_DISCOVERY_INTERVAL_MILLIS); + final long shardDiscoveryInterval = sourceConfig.get(SHARD_DISCOVERY_INTERVAL).toMillis(); context.callAsync( this::periodicallyDiscoverSplits, this::processDiscoveredSplits, diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/proxy/KinesisStreamProxy.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/proxy/KinesisStreamProxy.java index 9e91df5c..afa3878c 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/proxy/KinesisStreamProxy.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/proxy/KinesisStreamProxy.java @@ -93,13 +93,17 @@ public List listShards(String streamArn, ListShardsStartingPosition start @Override public GetRecordsResponse getRecords( - String streamArn, String shardId, StartingPosition startingPosition) { + String streamArn, + String shardId, + StartingPosition startingPosition, + int maxRecordsToGet) { String shardIterator = shardIdToIteratorStore.computeIfAbsent( shardId, (s) -> getShardIterator(streamArn, s, startingPosition)); try { - GetRecordsResponse getRecordsResponse = getRecords(streamArn, shardIterator); + GetRecordsResponse getRecordsResponse = + getRecords(streamArn, shardIterator, maxRecordsToGet); if (getRecordsResponse.nextShardIterator() != null) { shardIdToIteratorStore.put(shardId, getRecordsResponse.nextShardIterator()); } @@ -107,7 +111,8 @@ public GetRecordsResponse getRecords( } catch (ExpiredIteratorException e) { // Eagerly retry getRecords() if the iterator is expired shardIterator = getShardIterator(streamArn, shardId, startingPosition); - GetRecordsResponse getRecordsResponse = getRecords(streamArn, shardIterator); + GetRecordsResponse getRecordsResponse = + getRecords(streamArn, shardIterator, maxRecordsToGet); if (getRecordsResponse.nextShardIterator() != null) { shardIdToIteratorStore.put(shardId, getRecordsResponse.nextShardIterator()); } @@ -152,11 +157,13 @@ private String getShardIterator( return kinesisClient.getShardIterator(requestBuilder.build()).shardIterator(); } - private GetRecordsResponse getRecords(String streamArn, String shardIterator) { + private GetRecordsResponse getRecords( + String streamArn, String shardIterator, int maxRecordsToGet) { return kinesisClient.getRecords( GetRecordsRequest.builder() .streamARN(streamArn) .shardIterator(shardIterator) + .limit(maxRecordsToGet) .build()); } diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/proxy/StreamProxy.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/proxy/StreamProxy.java index 86678d87..bb449db7 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/proxy/StreamProxy.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/proxy/StreamProxy.java @@ -55,9 +55,13 @@ public interface StreamProxy extends Closeable { * @param streamArn the ARN of the stream * @param shardId the shard to subscribe from * @param startingPosition the starting position to read from + * @param maxRecordsToGet the maximum amount of records to retrieve for this batch * @return the response with records. Includes both the returned records and the subsequent * shard iterator to use. */ GetRecordsResponse getRecords( - String streamArn, String shardId, StartingPosition startingPosition); + String streamArn, + String shardId, + StartingPosition startingPosition, + int maxRecordsToGet); } diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/PollingKinesisShardSplitReader.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/PollingKinesisShardSplitReader.java index d93e318c..8b1a144b 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/PollingKinesisShardSplitReader.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/PollingKinesisShardSplitReader.java @@ -19,9 +19,11 @@ package org.apache.flink.connector.kinesis.source.reader; import org.apache.flink.annotation.Internal; +import org.apache.flink.configuration.Configuration; import org.apache.flink.connector.base.source.reader.RecordsWithSplitIds; import org.apache.flink.connector.base.source.reader.splitreader.SplitReader; import org.apache.flink.connector.base.source.reader.splitreader.SplitsChange; +import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions; import org.apache.flink.connector.kinesis.source.metrics.KinesisShardMetrics; import org.apache.flink.connector.kinesis.source.proxy.StreamProxy; import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit; @@ -64,16 +66,21 @@ public class PollingKinesisShardSplitReader implements SplitReader pausedSplitIds = new HashSet<>(); private final Map shardMetricGroupMap; + private final Configuration sourceConfig; public PollingKinesisShardSplitReader( - StreamProxy kinesisProxy, Map shardMetricGroupMap) { + StreamProxy kinesisProxy, + Map shardMetricGroupMap, + Configuration sourceConfig) { this.kinesis = kinesisProxy; this.shardMetricGroupMap = shardMetricGroupMap; + this.sourceConfig = sourceConfig; } @Override public RecordsWithSplitIds fetch() throws IOException { KinesisShardSplitState splitState = assignedSplits.poll(); + if (splitState == null) { return INCOMPLETE_SHARD_EMPTY_RECORDS; } @@ -90,7 +97,8 @@ public RecordsWithSplitIds fetch() throws IOException { kinesis.getRecords( splitState.getStreamArn(), splitState.getShardId(), - splitState.getNextStartingPosition()); + splitState.getNextStartingPosition(), + sourceConfig.get(KinesisSourceConfigOptions.SHARD_GET_RECORDS_MAX)); } catch (ResourceNotFoundException e) { LOG.warn( "Failed to fetch records from shard {}: shard no longer exists. Marking split as complete", diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/config/KinesisStreamsSourceConfigUtilTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/config/KinesisStreamsSourceConfigUtilTest.java index 766da116..f300bddb 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/config/KinesisStreamsSourceConfigUtilTest.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/config/KinesisStreamsSourceConfigUtilTest.java @@ -25,8 +25,8 @@ import java.text.SimpleDateFormat; import java.time.Instant; -import static org.apache.flink.connector.kinesis.source.config.KinesisStreamsSourceConfigConstants.STREAM_INITIAL_TIMESTAMP; -import static org.apache.flink.connector.kinesis.source.config.KinesisStreamsSourceConfigConstants.STREAM_TIMESTAMP_DATE_FORMAT; +import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.STREAM_INITIAL_TIMESTAMP; +import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.STREAM_TIMESTAMP_DATE_FORMAT; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/enumerator/KinesisStreamsSourceEnumeratorTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/enumerator/KinesisStreamsSourceEnumeratorTest.java index 953dbdea..42ff6831 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/enumerator/KinesisStreamsSourceEnumeratorTest.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/enumerator/KinesisStreamsSourceEnumeratorTest.java @@ -22,7 +22,7 @@ import org.apache.flink.api.connector.source.SplitsAssignment; import org.apache.flink.api.connector.source.mocks.MockSplitEnumeratorContext; import org.apache.flink.configuration.Configuration; -import org.apache.flink.connector.kinesis.source.config.KinesisStreamsSourceConfigConstants.InitialPosition; +import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.InitialPosition; import org.apache.flink.connector.kinesis.source.enumerator.assigner.ShardAssignerFactory; import org.apache.flink.connector.kinesis.source.proxy.ListShardsStartingPosition; import org.apache.flink.connector.kinesis.source.proxy.StreamProxy; @@ -45,8 +45,8 @@ import java.util.List; import java.util.stream.Stream; -import static org.apache.flink.connector.kinesis.source.config.KinesisStreamsSourceConfigConstants.STREAM_INITIAL_POSITION; -import static org.apache.flink.connector.kinesis.source.config.KinesisStreamsSourceConfigConstants.STREAM_INITIAL_TIMESTAMP; +import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.STREAM_INITIAL_POSITION; +import static org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions.STREAM_INITIAL_TIMESTAMP; import static org.apache.flink.connector.kinesis.source.util.KinesisStreamProxyProvider.TestKinesisStreamProxy; import static org.apache.flink.connector.kinesis.source.util.KinesisStreamProxyProvider.getTestStreamProxy; import static org.apache.flink.connector.kinesis.source.util.TestUtil.generateShardId; diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/proxy/KinesisStreamProxyTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/proxy/KinesisStreamProxyTest.java index c4bca201..19d8bd67 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/proxy/KinesisStreamProxyTest.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/proxy/KinesisStreamProxyTest.java @@ -66,11 +66,13 @@ class KinesisStreamProxyTest { private TestingKinesisClient testKinesisClient; private KinesisStreamProxy kinesisStreamProxy; + private int maxRecordsToGet; @BeforeEach public void setUp() { testKinesisClient = new TestingKinesisClient(); kinesisStreamProxy = new KinesisStreamProxy(testKinesisClient, HTTP_CLIENT); + maxRecordsToGet = 5000; } @Test @@ -194,9 +196,12 @@ void testGetRecordsInitialReadFromTrimHorizon() { GetRecordsRequest.builder() .streamARN(STREAM_ARN) .shardIterator(expectedShardIterator) + .limit(maxRecordsToGet) .build())); - assertThat(kinesisStreamProxy.getRecords(STREAM_ARN, shardId, startingPosition)) + assertThat( + kinesisStreamProxy.getRecords( + STREAM_ARN, shardId, startingPosition, maxRecordsToGet)) .isEqualTo(expectedGetRecordsResponse); } @@ -228,9 +233,12 @@ void testGetRecordsInitialReadFromTimestamp() { GetRecordsRequest.builder() .streamARN(STREAM_ARN) .shardIterator(expectedShardIterator) + .limit(maxRecordsToGet) .build())); - assertThat(kinesisStreamProxy.getRecords(STREAM_ARN, shardId, startingPosition)) + assertThat( + kinesisStreamProxy.getRecords( + STREAM_ARN, shardId, startingPosition, maxRecordsToGet)) .isEqualTo(expectedGetRecordsResponse); } @@ -263,9 +271,12 @@ void testGetRecordsInitialReadFromSequenceNumber() { GetRecordsRequest.builder() .streamARN(STREAM_ARN) .shardIterator(expectedShardIterator) + .limit(maxRecordsToGet) .build())); - assertThat(kinesisStreamProxy.getRecords(STREAM_ARN, shardId, startingPosition)) + assertThat( + kinesisStreamProxy.getRecords( + STREAM_ARN, shardId, startingPosition, maxRecordsToGet)) .isEqualTo(expectedGetRecordsResponse); } @@ -306,8 +317,11 @@ void testConsecutiveGetRecordsUsesShardIteratorFromResponse() { GetRecordsRequest.builder() .streamARN(streamArn) .shardIterator(firstShardIterator) + .limit(maxRecordsToGet) .build())); - assertThat(kinesisStreamProxy.getRecords(streamArn, shardId, startingPosition)) + assertThat( + kinesisStreamProxy.getRecords( + streamArn, shardId, startingPosition, maxRecordsToGet)) .isEqualTo(firstGetRecordsResponse); // When read for the second time @@ -324,8 +338,11 @@ void testConsecutiveGetRecordsUsesShardIteratorFromResponse() { GetRecordsRequest.builder() .streamARN(streamArn) .shardIterator(secondShardIterator) + .limit(maxRecordsToGet) .build())); - assertThat(kinesisStreamProxy.getRecords(streamArn, shardId, startingPosition)) + assertThat( + kinesisStreamProxy.getRecords( + streamArn, shardId, startingPosition, maxRecordsToGet)) .isEqualTo(secondGetRecordsResponse); } @@ -363,7 +380,9 @@ void testGetRecordsEagerlyRetriesExpiredIterators() { }); // Then getRecords called with second shard iterator - assertThat(kinesisStreamProxy.getRecords(STREAM_ARN, shardId, startingPosition)) + assertThat( + kinesisStreamProxy.getRecords( + STREAM_ARN, shardId, startingPosition, maxRecordsToGet)) .isEqualTo(getRecordsResponse); assertThat(firstGetRecordsCall.get()).isFalse(); } @@ -395,11 +414,14 @@ void testGetRecordsHandlesCompletedShard() { GetRecordsRequest.builder() .streamARN(STREAM_ARN) .shardIterator(expectedShardIterator) + .limit(maxRecordsToGet) .build())); assertThatNoException() .isThrownBy( - () -> kinesisStreamProxy.getRecords(STREAM_ARN, shardId, startingPosition)); + () -> + kinesisStreamProxy.getRecords( + STREAM_ARN, shardId, startingPosition, maxRecordsToGet)); } @Test diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/KinesisStreamsSourceReaderTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/KinesisStreamsSourceReaderTest.java index 9fac4167..b0dcc70a 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/KinesisStreamsSourceReaderTest.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/KinesisStreamsSourceReaderTest.java @@ -52,14 +52,18 @@ class KinesisStreamsSourceReaderTest { private KinesisStreamsSourceReader sourceReader; private MetricListener metricListener; private Map shardMetricGroupMap; + private Configuration sourceConfig; @BeforeEach public void init() { metricListener = new MetricListener(); shardMetricGroupMap = new ConcurrentHashMap<>(); + sourceConfig = new Configuration(); StreamProxy testStreamProxy = getTestStreamProxy(); Supplier splitReaderSupplier = - () -> new PollingKinesisShardSplitReader(testStreamProxy, shardMetricGroupMap); + () -> + new PollingKinesisShardSplitReader( + testStreamProxy, shardMetricGroupMap, sourceConfig); testingReaderContext = KinesisContextProvider.KinesisTestingContext.getKinesisTestingContext( diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/PollingKinesisShardSplitReaderTest.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/PollingKinesisShardSplitReaderTest.java index 13963a9f..e3d231b0 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/PollingKinesisShardSplitReaderTest.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/PollingKinesisShardSplitReaderTest.java @@ -18,8 +18,10 @@ package org.apache.flink.connector.kinesis.source.reader; +import org.apache.flink.configuration.Configuration; import org.apache.flink.connector.base.source.reader.RecordsWithSplitIds; import org.apache.flink.connector.base.source.reader.splitreader.SplitsAddition; +import org.apache.flink.connector.kinesis.source.config.KinesisSourceConfigOptions; import org.apache.flink.connector.kinesis.source.metrics.KinesisShardMetrics; import org.apache.flink.connector.kinesis.source.split.KinesisShardSplit; import org.apache.flink.connector.kinesis.source.util.KinesisStreamProxyProvider.TestKinesisStreamProxy; @@ -54,6 +56,7 @@ class PollingKinesisShardSplitReaderTest { private TestKinesisStreamProxy testStreamProxy; private MetricListener metricListener; private Map shardMetricGroupMap; + private Configuration sourceConfig; private static final String TEST_SHARD_ID = TestUtil.generateShardId(1); @BeforeEach @@ -62,11 +65,16 @@ public void init() { metricListener = new MetricListener(); shardMetricGroupMap = new ConcurrentHashMap<>(); + sourceConfig = new Configuration(); + sourceConfig.set(KinesisSourceConfigOptions.SHARD_GET_RECORDS_MAX, 50); + shardMetricGroupMap.put( TEST_SHARD_ID, new KinesisShardMetrics( TestUtil.getTestSplit(TEST_SHARD_ID), metricListener.getMetricGroup())); - splitReader = new PollingKinesisShardSplitReader(testStreamProxy, shardMetricGroupMap); + splitReader = + new PollingKinesisShardSplitReader( + testStreamProxy, shardMetricGroupMap, sourceConfig); } @Test @@ -362,6 +370,27 @@ void testFetchUpdatesTheMillisBehindLatestMetric() throws IOException { split, TestUtil.MILLIS_BEHIND_LATEST_TEST_VALUE, metricListener); } + @Test + void testMaxRecordsToGetParameterPassed() throws IOException { + int maxRecordsToGet = 2; + sourceConfig.set(KinesisSourceConfigOptions.SHARD_GET_RECORDS_MAX, maxRecordsToGet); + testStreamProxy.addShards(TEST_SHARD_ID); + List sentRecords = + Stream.of(getTestRecord("data-1"), getTestRecord("data-2"), getTestRecord("data-3")) + .collect(Collectors.toList()); + + testStreamProxy.addRecords(TestUtil.STREAM_ARN, TEST_SHARD_ID, sentRecords); + + splitReader.handleSplitsChanges( + new SplitsAddition<>(Collections.singletonList(getTestSplit(TEST_SHARD_ID)))); + + RecordsWithSplitIds retrievedRecords = splitReader.fetch(); + List records = new ArrayList<>(readAllRecords(retrievedRecords)); + + assertThat(sentRecords.size() > maxRecordsToGet).isTrue(); + assertThat(records.size()).isEqualTo(maxRecordsToGet); + } + private List readAllRecords(RecordsWithSplitIds recordsWithSplitIds) { List outputRecords = new ArrayList<>(); Record record; diff --git a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/util/KinesisStreamProxyProvider.java b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/util/KinesisStreamProxyProvider.java index ab12164b..235416a0 100644 --- a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/util/KinesisStreamProxyProvider.java +++ b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/util/KinesisStreamProxyProvider.java @@ -116,7 +116,10 @@ public List listShards( @Override public GetRecordsResponse getRecords( - String streamArn, String shardId, StartingPosition startingPosition) { + String streamArn, + String shardId, + StartingPosition startingPosition, + int maxRecordsToGet) { ShardHandle shardHandle = new ShardHandle(streamArn, shardId); if (getRecordsExceptionSupplier != null) { @@ -126,6 +129,10 @@ public GetRecordsResponse getRecords( List records = null; if (storedRecords.containsKey(shardHandle)) { records = storedRecords.get(shardHandle).poll(); + + if (records != null) { + records = records.stream().limit(maxRecordsToGet).collect(Collectors.toList()); + } } return GetRecordsResponse.builder()