Skip to content

Commit

Permalink
#3084 use factories for client utility classes initialization, improv…
Browse files Browse the repository at this point in the history
…e shutdown behavior
  • Loading branch information
SergeyRyabinin committed Aug 26, 2024
1 parent 9514d1a commit 11c124f
Show file tree
Hide file tree
Showing 22 changed files with 195 additions and 122 deletions.
23 changes: 14 additions & 9 deletions src/aws-cpp-sdk-core/include/aws/core/client/AWSClientAsyncCRTP.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ namespace Client
std::unique_lock<std::mutex> lock(pClient->m_shutdownMutex);

pClient->m_isInitialized = false;

pClient->DisableRequestProcessing();

if (timeoutMs == -1)
{
Expand All @@ -107,10 +107,15 @@ namespace Client
std::chrono::milliseconds(timeoutMs),
[&](){ return pClient->m_operationsProcessed.load() == 0; });

pClient->m_endpointProvider.reset();
pClient->m_executor.reset();
if (pClient->m_operationsProcessed.load())
{
AWS_LOGSTREAM_FATAL(AwsServiceClientT::GetAllocationTag(), "Service client "
<< AwsServiceClientT::GetServiceName() << " is shutting down while async tasks are present.");
}

pClient->m_clientConfiguration.executor.reset();
pClient->m_clientConfiguration.retryStrategy.reset();
pClient->m_endpointProvider.reset();
}

/**
Expand All @@ -124,7 +129,7 @@ namespace Client
const std::shared_ptr<const Aws::Client::AsyncCallerContext>& context = nullptr) const
{
const AwsServiceClientT* clientThis = static_cast<const AwsServiceClientT*>(this);
Aws::Client::MakeAsyncOperation(operationFunc, clientThis, request, handler, context, clientThis->m_executor.get());
Aws::Client::MakeAsyncOperation(operationFunc, clientThis, request, handler, context, clientThis->m_clientConfiguration.executor.get());
}

/**
Expand All @@ -139,7 +144,7 @@ namespace Client
const std::shared_ptr<const Aws::Client::AsyncCallerContext>& context = nullptr) const
{
const AwsServiceClientT* clientThis = static_cast<const AwsServiceClientT*>(this);
Aws::Client::MakeAsyncStreamingOperation(operationFunc, clientThis, request, handler, context, clientThis->m_executor.get());
Aws::Client::MakeAsyncStreamingOperation(operationFunc, clientThis, request, handler, context, clientThis->m_clientConfiguration.executor.get());
}

/**
Expand All @@ -152,7 +157,7 @@ namespace Client
const std::shared_ptr<const Aws::Client::AsyncCallerContext>& context = nullptr) const
{
const AwsServiceClientT* clientThis = static_cast<const AwsServiceClientT*>(this);
Aws::Client::MakeAsyncOperation(operationFunc, clientThis, handler, context, clientThis->m_executor.get());
Aws::Client::MakeAsyncOperation(operationFunc, clientThis, handler, context, clientThis->m_clientConfiguration.executor.get());
}

/**
Expand All @@ -165,7 +170,7 @@ namespace Client
-> std::future<decltype((static_cast<const AwsServiceClientT*>(nullptr)->*operationFunc)(request))>
{
const AwsServiceClientT* clientThis = static_cast<const AwsServiceClientT*>(this);
return Aws::Client::MakeCallableOperation(AwsServiceClientT::GetAllocationTag(), operationFunc, clientThis, request, clientThis->m_executor.get());
return Aws::Client::MakeCallableOperation(AwsServiceClientT::GetAllocationTag(), operationFunc, clientThis, request, clientThis->m_clientConfiguration.executor.get());
}

/**
Expand All @@ -178,7 +183,7 @@ namespace Client
-> std::future<decltype((static_cast<const AwsServiceClientT*>(nullptr)->*operationFunc)(request))>
{
const AwsServiceClientT* clientThis = static_cast<const AwsServiceClientT*>(this);
return Aws::Client::MakeCallableStreamingOperation(AwsServiceClientT::GetAllocationTag(), operationFunc, clientThis, request, clientThis->m_executor.get());
return Aws::Client::MakeCallableStreamingOperation(AwsServiceClientT::GetAllocationTag(), operationFunc, clientThis, request, clientThis->m_clientConfiguration.executor.get());
}

/**
Expand All @@ -191,7 +196,7 @@ namespace Client
-> std::future<decltype((static_cast<const AwsServiceClientT*>(nullptr)->*operationFunc)())>
{
const AwsServiceClientT* clientThis = static_cast<const AwsServiceClientT*>(this);
return Aws::Client::MakeCallableOperation(AwsServiceClientT::GetAllocationTag(), operationFunc, clientThis, clientThis->m_executor.get());
return Aws::Client::MakeCallableOperation(AwsServiceClientT::GetAllocationTag(), operationFunc, clientThis, clientThis->m_clientConfiguration.executor.get());
}
protected:
std::atomic<bool> m_isInitialized;
Expand Down
66 changes: 55 additions & 11 deletions src/aws-cpp-sdk-core/include/aws/core/client/ClientConfiguration.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <aws/core/utils/Array.h>
#include <aws/crt/Optional.h>
#include <smithy/tracing/TelemetryProvider.h>
#include <smithy/tracing/NoopTelemetryProvider.h>
#include <memory>

namespace Aws
Expand Down Expand Up @@ -69,12 +68,48 @@ namespace Aws
bool shouldDisableIMDS = false;
};

/**
* Utility struct to hold default factories used in ClientConfiguration.
*/
struct AWS_CORE_API ClientConfigDefaultFactories
{
static std::shared_ptr<Aws::Utils::Threading::Executor> CreateExecutor();
};

/**
* This mutable structure is used to configure any of the AWS clients.
* Default values can only be overwritten prior to passing to the client constructors.
*/
struct AWS_CORE_API ClientConfiguration
{
struct ProviderFactories
{
/**
* Retry Strategy factory method. Default is DefaultRetryStrategy (i.e. exponential backoff).
*/
std::function<std::shared_ptr<RetryStrategy>()> retryStrategyCreateFn;
/**
* Threading Executor factory method. Default creates a factory that creates DefaultExecutor
* (i.e. spawn a separate thread for each task) for backward compatibility reasons.
* Please switch to a better executor such as PooledThreadExecutor.
*/
std::function<std::shared_ptr<Utils::Threading::Executor>()> executorCreateFn;
/**
* Rate Limiter factory for outgoing bandwidth. Default is wide-open.
*/
std::function<std::shared_ptr<Utils::RateLimits::RateLimiterInterface>()> writeRateLimiterCreateFn;
/**
* Rate Limiter factory for incoming bandwidth. Default is wide-open.
*/
std::function<std::shared_ptr<Utils::RateLimits::RateLimiterInterface>()> readRateLimiterCreateFn;
/**
* TelemetryProvider factory. Defaults to Noop provider.
*/
std::function<std::shared_ptr<smithy::components::tracing::TelemetryProvider>()> telemetryProviderCreateFn;

static ProviderFactories defaultFactories;
};

ClientConfiguration();

/**
Expand Down Expand Up @@ -104,6 +139,11 @@ namespace Aws
*/
virtual ~ClientConfiguration() = default;

/**
* Client configuration factory methods to init client utility classes such as Executor, Retry Strategy
*/
ProviderFactories configFactories = ProviderFactories::defaultFactories;

/**
* User Agent string user for http calls. This is filled in for you in the constructor. Don't override this unless you have a really good reason.
*/
Expand Down Expand Up @@ -165,9 +205,10 @@ namespace Aws
*/
unsigned long lowSpeedLimit = 1;
/**
* Strategy to use in case of failed requests. Default is DefaultRetryStrategy (i.e. exponential backoff)
* Strategy to use in case of failed requests. Default is DefaultRetryStrategy (i.e. exponential backoff).
* Provide retry strategy here or via a factory method.
*/
std::shared_ptr<RetryStrategy> retryStrategy;
std::shared_ptr<RetryStrategy> retryStrategy = nullptr;
/**
* Override the http endpoint used to talk to a service.
*/
Expand Down Expand Up @@ -227,9 +268,10 @@ namespace Aws
*/
Aws::Utils::Array<Aws::String> nonProxyHosts;
/**
* Threading Executor implementation. Default uses std::thread::detach()
*/
std::shared_ptr<Aws::Utils::Threading::Executor> executor;
* Threading Executor implementation. Default uses std::thread::detach()
* Provide executor here or via a factory method.
*/
std::shared_ptr<Aws::Utils::Threading::Executor> executor = nullptr;
/**
* If you need to test and want to get around TLS validation errors, do that here.
* You probably shouldn't use this flag in a production scenario.
Expand Down Expand Up @@ -263,12 +305,14 @@ namespace Aws
Aws::String proxyCaFile;
/**
* Rate Limiter implementation for outgoing bandwidth. Default is wide-open.
* Provide limiter here or via a factory method.
*/
std::shared_ptr<Aws::Utils::RateLimits::RateLimiterInterface> writeRateLimiter;
std::shared_ptr<Aws::Utils::RateLimits::RateLimiterInterface> writeRateLimiter = nullptr;
/**
* Rate Limiter implementation for incoming bandwidth. Default is wide-open.
* Provide limiter here or via a factory method.
*/
std::shared_ptr<Aws::Utils::RateLimits::RateLimiterInterface> readRateLimiter;
std::shared_ptr<Aws::Utils::RateLimits::RateLimiterInterface> readRateLimiter = nullptr;
/**
* Override the http implementation the default factory returns.
*/
Expand Down Expand Up @@ -374,10 +418,10 @@ namespace Aws
const Aws::String& defaultValue);

/**
* A wrapper for interfacing with telemetry functionality.
* A wrapper for interfacing with telemetry functionality. Defaults to Noop provider.
* Provide TelemetryProvider here or via a factory method.
*/
std::shared_ptr<smithy::components::tracing::TelemetryProvider> telemetryProvider =
smithy::components::tracing::NoopTelemetryProvider::CreateProvider();
std::shared_ptr<smithy::components::tracing::TelemetryProvider> telemetryProvider;
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,10 @@ if(!m_isInitialized) \
} \
Aws::Utils::RAIICounter(this->m_operationsProcessed, &this->m_shutdownSignal)

#define AWS_ASYNC_OPERATION_GUARD(OPERATION) \
if(!m_isInitialized) \
{ \
AWS_LOGSTREAM_ERROR(#OPERATION, "Unable to call " #OPERATION ": client is not initialized (or already terminated)"); \
return handler(this, request, Aws::Client::AWSError<CoreErrors>(CoreErrors::NOT_INITIALIZED, "NOT_INITIALIZED", "Client is not initialized or already terminated", false), handlerContext); \
} \
Aws::Utils::RAIICounter(this->m_operationsProcessed, &this->m_shutdownSignal)
32 changes: 22 additions & 10 deletions src/aws-cpp-sdk-core/source/client/AWSClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,19 @@ AWSClient::AWSClient(const Aws::Client::ClientConfiguration& configuration,
const std::shared_ptr<Aws::Client::AWSAuthSigner>& signer,
const std::shared_ptr<AWSErrorMarshaller>& errorMarshaller) :
m_region(configuration.region),
m_telemetryProvider(configuration.telemetryProvider),
m_telemetryProvider(configuration.telemetryProvider ? configuration.telemetryProvider : configuration.configFactories.telemetryProviderCreateFn()),
m_signerProvider(Aws::MakeUnique<Aws::Auth::DefaultAuthSignerProvider>(AWS_CLIENT_LOG_TAG, signer)),
m_httpClient(CreateHttpClient(configuration)),
m_httpClient(CreateHttpClient(
[&configuration, this]()
{
ClientConfiguration tempConfig(configuration);
tempConfig.telemetryProvider = m_telemetryProvider;
return tempConfig;
}())),
m_errorMarshaller(errorMarshaller),
m_retryStrategy(configuration.retryStrategy),
m_writeRateLimiter(configuration.writeRateLimiter),
m_readRateLimiter(configuration.readRateLimiter),
m_retryStrategy(configuration.retryStrategy ? configuration.retryStrategy : configuration.configFactories.retryStrategyCreateFn()),
m_writeRateLimiter(configuration.writeRateLimiter ? configuration.writeRateLimiter : configuration.configFactories.writeRateLimiterCreateFn()),
m_readRateLimiter(configuration.readRateLimiter ? configuration.readRateLimiter : configuration.configFactories.readRateLimiterCreateFn()),
m_userAgent(Aws::Client::ComputeUserAgentString(&configuration)),
m_hash(Aws::Utils::Crypto::CreateMD5Implementation()),
m_requestTimeoutMs(configuration.requestTimeoutMs),
Expand All @@ -138,13 +144,19 @@ AWSClient::AWSClient(const Aws::Client::ClientConfiguration& configuration,
const std::shared_ptr<Aws::Auth::AWSAuthSignerProvider>& signerProvider,
const std::shared_ptr<AWSErrorMarshaller>& errorMarshaller) :
m_region(configuration.region),
m_telemetryProvider(configuration.telemetryProvider),
m_telemetryProvider(configuration.telemetryProvider ? configuration.telemetryProvider : configuration.configFactories.telemetryProviderCreateFn()),
m_signerProvider(signerProvider),
m_httpClient(CreateHttpClient(configuration)),
m_httpClient(CreateHttpClient(
[&configuration, this]()
{
ClientConfiguration tempConfig(configuration);
tempConfig.telemetryProvider = m_telemetryProvider;
return tempConfig;
}())),
m_errorMarshaller(errorMarshaller),
m_retryStrategy(configuration.retryStrategy),
m_writeRateLimiter(configuration.writeRateLimiter),
m_readRateLimiter(configuration.readRateLimiter),
m_retryStrategy(configuration.retryStrategy ? configuration.retryStrategy : configuration.configFactories.retryStrategyCreateFn()),
m_writeRateLimiter(configuration.writeRateLimiter ? configuration.writeRateLimiter : configuration.configFactories.writeRateLimiterCreateFn()),
m_readRateLimiter(configuration.readRateLimiter ? configuration.readRateLimiter : configuration.configFactories.readRateLimiterCreateFn()),
m_userAgent(Aws::Client::ComputeUserAgentString(&configuration)),
m_hash(Aws::Utils::Crypto::CreateMD5Implementation()),
m_requestTimeoutMs(configuration.requestTimeoutMs),
Expand Down
15 changes: 14 additions & 1 deletion src/aws-cpp-sdk-core/source/client/ClientConfiguration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <aws/core/Version.h>
#include <aws/core/config/AWSProfileConfigLoader.h>
#include <aws/core/utils/logging/LogMacros.h>
#include <smithy/tracing/NoopTelemetryProvider.h>

#include <aws/crt/Config.h>

Expand All @@ -39,6 +40,19 @@ static const char* AWS_EXECUTION_ENV = "AWS_EXECUTION_ENV";
static const char* DISABLE_IMDSV1_CONFIG_VAR = "AWS_EC2_METADATA_V1_DISABLED";
static const char* DISABLE_IMDSV1_ENV_VAR = "ec2_metadata_v1_disabled";

ClientConfiguration::ProviderFactories ClientConfiguration::ProviderFactories::defaultFactories = []()
{
ProviderFactories factories;

factories.retryStrategyCreateFn = [](){return InitRetryStrategy();};
factories.executorCreateFn = [](){return Aws::MakeShared<Aws::Utils::Threading::DefaultExecutor>(CLIENT_CONFIG_TAG);};
factories.writeRateLimiterCreateFn = [](){return nullptr;};
factories.readRateLimiterCreateFn = [](){return nullptr;};
factories.telemetryProviderCreateFn = [](){return smithy::components::tracing::NoopTelemetryProvider::CreateProvider();};

return factories;
}();

Aws::String FilterUserAgentToken(char const * const source)
{
// Tokens are short textual identifiers that do not include whitespace or delimiters.
Expand Down Expand Up @@ -131,7 +145,6 @@ void setLegacyClientConfigurationParameters(ClientConfiguration& clientConfig)
clientConfig.lowSpeedLimit = 1;
clientConfig.proxyScheme = Aws::Http::Scheme::HTTP;
clientConfig.proxyPort = 0;
clientConfig.executor = Aws::MakeShared<Aws::Utils::Threading::DefaultExecutor>(CLIENT_CONFIG_TAG);
clientConfig.verifySSL = true;
clientConfig.writeRateLimiter = nullptr;
clientConfig.readRateLimiter = nullptr;
Expand Down
17 changes: 13 additions & 4 deletions src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ namespace Aws
*/
struct TransferManagerConfiguration
{
TransferManagerConfiguration(Aws::Utils::Threading::Executor* executor) : s3Client(nullptr), transferExecutor(executor), computeContentMD5(false), transferBufferMaxHeapSize(10 * MB5), bufferSize(MB5)
{
}

TransferManagerConfiguration(Aws::Utils::Threading::Executor* executor);
/**
* S3 Client to use for transfers. You are responsible for setting this.
*/
Expand All @@ -53,6 +50,18 @@ namespace Aws
* It is not a bug to use the same executor, but at least be aware that this is how the manager will be used.
*/
Aws::Utils::Threading::Executor* transferExecutor = nullptr;

/**
* Threading Executor shared pointer.
* Created and owned by Transfer manager if no raw pointer `transferExecutor` is provided.
*/
std::shared_ptr<Aws::Utils::Threading::Executor> spExecutor = nullptr;

/**
* Threading Executor factory method. Default creates a factory that creates DefaultExecutor
*/
std::function<std::shared_ptr<Utils::Threading::Executor>()> executorCreateFn;

/**
* When true, TransferManager will calculate the MD5 digest of the content being uploaded.
* The digest is sent to S3 via an HTTP header enabling the service to perform integrity checks.
Expand Down
21 changes: 21 additions & 0 deletions src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ namespace Aws
{
namespace Transfer
{
TransferManagerConfiguration::TransferManagerConfiguration(Aws::Utils::Threading::Executor* executor)
: s3Client(nullptr),
transferExecutor(executor),
computeContentMD5(false),
transferBufferMaxHeapSize(10 * MB5),
bufferSize(MB5)
{
}

static inline bool IsS3KeyPrefix(const Aws::String& path)
{
return (path.find_last_of('/') == path.size() - 1 || path.find_last_of('\\') == path.size() - 1);
Expand Down Expand Up @@ -57,6 +66,18 @@ namespace Aws
TransferManager::TransferManager(const TransferManagerConfiguration& configuration) : m_transferConfig(configuration)
{
assert(m_transferConfig.s3Client);
if (!m_transferConfig.transferExecutor)
{
if(!m_transferConfig.spExecutor && m_transferConfig.executorCreateFn)
{
m_transferConfig.spExecutor = m_transferConfig.executorCreateFn();
}
m_transferConfig.transferExecutor = m_transferConfig.spExecutor.get();
}
if (!m_transferConfig.transferExecutor)
{
AWS_LOGSTREAM_FATAL(CLASS_TAG, "Failed to init TransferManager: transferExecutor is null");
}
assert(m_transferConfig.transferExecutor);
m_transferConfig.s3Client->AppendToUserAgent("ft/s3-transfer");
for (uint64_t i = 0; i < m_transferConfig.transferBufferMaxHeapSize; i += m_transferConfig.bufferSize)
Expand Down
Loading

0 comments on commit 11c124f

Please sign in to comment.