Skip to content

Commit

Permalink
use factories for client utility classes initialization, improve shut…
Browse files Browse the repository at this point in the history
…down behavior #3084
  • Loading branch information
SergeyRyabinin committed Aug 27, 2024
1 parent f52653c commit 9e8e71d
Show file tree
Hide file tree
Showing 27 changed files with 247 additions and 155 deletions.
2 changes: 2 additions & 0 deletions src/aws-cpp-sdk-core/include/aws/core/client/AWSClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ namespace Aws
const Aws::Http::QueryStringParameterCollection& extraParams = Aws::Http::QueryStringParameterCollection(), long long expirationInSeconds = 0,
const std::shared_ptr<Aws::Http::ServiceSpecificParameters> serviceSpecificParameter = {}) const;

const std::shared_ptr<Aws::Http::HttpClient>& GetHttpClient() const { return m_httpClient; }

/**
* Stop all requests immediately.
* In flight requests will likely fail.
Expand Down
28 changes: 19 additions & 9 deletions src/aws-cpp-sdk-core/include/aws/core/client/AWSClientAsyncCRTP.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ namespace Client
{
if (&other != this)
{
ShutdownSdkClient(static_cast<AwsServiceClientT*>(this));
m_operationsProcessed = 0;
m_isInitialized = other.m_isInitialized.load();
}

Expand Down Expand Up @@ -97,7 +99,10 @@ namespace Client
std::unique_lock<std::mutex> lock(pClient->m_shutdownMutex);

pClient->m_isInitialized = false;

if (pClient->GetHttpClient().use_count() == 1)
{
pClient->DisableRequestProcessing();
}

if (timeoutMs == -1)
{
Expand All @@ -107,10 +112,15 @@ namespace Client
std::chrono::milliseconds(timeoutMs),
[&](){ return pClient->m_operationsProcessed.load() == 0; });

pClient->m_endpointProvider.reset();
pClient->m_executor.reset();
if (pClient->m_operationsProcessed.load())
{
AWS_LOGSTREAM_FATAL(AwsServiceClientT::GetAllocationTag(), "Service client "
<< AwsServiceClientT::GetServiceName() << " is shutting down while async tasks are present.");
}

pClient->m_clientConfiguration.executor.reset();
pClient->m_clientConfiguration.retryStrategy.reset();
pClient->m_endpointProvider.reset();
}

/**
Expand All @@ -124,7 +134,7 @@ namespace Client
const std::shared_ptr<const Aws::Client::AsyncCallerContext>& context = nullptr) const
{
const AwsServiceClientT* clientThis = static_cast<const AwsServiceClientT*>(this);
Aws::Client::MakeAsyncOperation(operationFunc, clientThis, request, handler, context, clientThis->m_executor.get());
Aws::Client::MakeAsyncOperation(operationFunc, clientThis, request, handler, context, clientThis->m_clientConfiguration.executor.get());
}

/**
Expand All @@ -139,7 +149,7 @@ namespace Client
const std::shared_ptr<const Aws::Client::AsyncCallerContext>& context = nullptr) const
{
const AwsServiceClientT* clientThis = static_cast<const AwsServiceClientT*>(this);
Aws::Client::MakeAsyncStreamingOperation(operationFunc, clientThis, request, handler, context, clientThis->m_executor.get());
Aws::Client::MakeAsyncStreamingOperation(operationFunc, clientThis, request, handler, context, clientThis->m_clientConfiguration.executor.get());
}

/**
Expand All @@ -152,7 +162,7 @@ namespace Client
const std::shared_ptr<const Aws::Client::AsyncCallerContext>& context = nullptr) const
{
const AwsServiceClientT* clientThis = static_cast<const AwsServiceClientT*>(this);
Aws::Client::MakeAsyncOperation(operationFunc, clientThis, handler, context, clientThis->m_executor.get());
Aws::Client::MakeAsyncOperation(operationFunc, clientThis, handler, context, clientThis->m_clientConfiguration.executor.get());
}

/**
Expand All @@ -165,7 +175,7 @@ namespace Client
-> std::future<decltype((static_cast<const AwsServiceClientT*>(nullptr)->*operationFunc)(request))>
{
const AwsServiceClientT* clientThis = static_cast<const AwsServiceClientT*>(this);
return Aws::Client::MakeCallableOperation(AwsServiceClientT::GetAllocationTag(), operationFunc, clientThis, request, clientThis->m_executor.get());
return Aws::Client::MakeCallableOperation(AwsServiceClientT::GetAllocationTag(), operationFunc, clientThis, request, clientThis->m_clientConfiguration.executor.get());
}

/**
Expand All @@ -178,7 +188,7 @@ namespace Client
-> std::future<decltype((static_cast<const AwsServiceClientT*>(nullptr)->*operationFunc)(request))>
{
const AwsServiceClientT* clientThis = static_cast<const AwsServiceClientT*>(this);
return Aws::Client::MakeCallableStreamingOperation(AwsServiceClientT::GetAllocationTag(), operationFunc, clientThis, request, clientThis->m_executor.get());
return Aws::Client::MakeCallableStreamingOperation(AwsServiceClientT::GetAllocationTag(), operationFunc, clientThis, request, clientThis->m_clientConfiguration.executor.get());
}

/**
Expand All @@ -191,7 +201,7 @@ namespace Client
-> std::future<decltype((static_cast<const AwsServiceClientT*>(nullptr)->*operationFunc)())>
{
const AwsServiceClientT* clientThis = static_cast<const AwsServiceClientT*>(this);
return Aws::Client::MakeCallableOperation(AwsServiceClientT::GetAllocationTag(), operationFunc, clientThis, clientThis->m_executor.get());
return Aws::Client::MakeCallableOperation(AwsServiceClientT::GetAllocationTag(), operationFunc, clientThis, clientThis->m_clientConfiguration.executor.get());
}
protected:
std::atomic<bool> m_isInitialized;
Expand Down
58 changes: 47 additions & 11 deletions src/aws-cpp-sdk-core/include/aws/core/client/ClientConfiguration.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <aws/core/utils/Array.h>
#include <aws/crt/Optional.h>
#include <smithy/tracing/TelemetryProvider.h>
#include <smithy/tracing/NoopTelemetryProvider.h>
#include <memory>

namespace Aws
Expand Down Expand Up @@ -75,6 +74,34 @@ namespace Aws
*/
struct AWS_CORE_API ClientConfiguration
{
struct ProviderFactories
{
/**
* Retry Strategy factory method. Default is DefaultRetryStrategy (i.e. exponential backoff).
*/
std::function<std::shared_ptr<RetryStrategy>()> retryStrategyCreateFn;
/**
* Threading Executor factory method. Default creates a factory that creates DefaultExecutor
* (i.e. spawn a separate thread for each task) for backward compatibility reasons.
* Please switch to a better executor such as PooledThreadExecutor.
*/
std::function<std::shared_ptr<Utils::Threading::Executor>()> executorCreateFn;
/**
* Rate Limiter factory for outgoing bandwidth. Default is wide-open.
*/
std::function<std::shared_ptr<Utils::RateLimits::RateLimiterInterface>()> writeRateLimiterCreateFn;
/**
* Rate Limiter factory for incoming bandwidth. Default is wide-open.
*/
std::function<std::shared_ptr<Utils::RateLimits::RateLimiterInterface>()> readRateLimiterCreateFn;
/**
* TelemetryProvider factory. Defaults to Noop provider.
*/
std::function<std::shared_ptr<smithy::components::tracing::TelemetryProvider>()> telemetryProviderCreateFn;

static ProviderFactories defaultFactories;
};

ClientConfiguration();

/**
Expand Down Expand Up @@ -104,6 +131,11 @@ namespace Aws
*/
virtual ~ClientConfiguration() = default;

/**
* Client configuration factory methods to init client utility classes such as Executor, Retry Strategy
*/
ProviderFactories configFactories = ProviderFactories::defaultFactories;

/**
* User Agent string user for http calls. This is filled in for you in the constructor. Don't override this unless you have a really good reason.
*/
Expand Down Expand Up @@ -165,9 +197,10 @@ namespace Aws
*/
unsigned long lowSpeedLimit = 1;
/**
* Strategy to use in case of failed requests. Default is DefaultRetryStrategy (i.e. exponential backoff)
* Strategy to use in case of failed requests. Default is DefaultRetryStrategy (i.e. exponential backoff).
* Provide retry strategy here or via a factory method.
*/
std::shared_ptr<RetryStrategy> retryStrategy;
std::shared_ptr<RetryStrategy> retryStrategy = nullptr;
/**
* Override the http endpoint used to talk to a service.
*/
Expand Down Expand Up @@ -227,9 +260,10 @@ namespace Aws
*/
Aws::Utils::Array<Aws::String> nonProxyHosts;
/**
* Threading Executor implementation. Default uses std::thread::detach()
*/
std::shared_ptr<Aws::Utils::Threading::Executor> executor;
* Threading Executor implementation. Default uses std::thread::detach()
* Provide executor here or via a factory method.
*/
std::shared_ptr<Aws::Utils::Threading::Executor> executor = nullptr;
/**
* If you need to test and want to get around TLS validation errors, do that here.
* You probably shouldn't use this flag in a production scenario.
Expand Down Expand Up @@ -263,12 +297,14 @@ namespace Aws
Aws::String proxyCaFile;
/**
* Rate Limiter implementation for outgoing bandwidth. Default is wide-open.
* Provide limiter here or via a factory method.
*/
std::shared_ptr<Aws::Utils::RateLimits::RateLimiterInterface> writeRateLimiter;
std::shared_ptr<Aws::Utils::RateLimits::RateLimiterInterface> writeRateLimiter = nullptr;
/**
* Rate Limiter implementation for incoming bandwidth. Default is wide-open.
* Provide limiter here or via a factory method.
*/
std::shared_ptr<Aws::Utils::RateLimits::RateLimiterInterface> readRateLimiter;
std::shared_ptr<Aws::Utils::RateLimits::RateLimiterInterface> readRateLimiter = nullptr;
/**
* Override the http implementation the default factory returns.
*/
Expand Down Expand Up @@ -379,10 +415,10 @@ namespace Aws
const Aws::String& defaultValue);

/**
* A wrapper for interfacing with telemetry functionality.
* A wrapper for interfacing with telemetry functionality. Defaults to Noop provider.
* Provide TelemetryProvider here or via a factory method.
*/
std::shared_ptr<smithy::components::tracing::TelemetryProvider> telemetryProvider =
smithy::components::tracing::NoopTelemetryProvider::CreateProvider();
std::shared_ptr<smithy::components::tracing::TelemetryProvider> telemetryProvider;
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,10 @@ if(!m_isInitialized) \
} \
Aws::Utils::RAIICounter(this->m_operationsProcessed, &this->m_shutdownSignal)

#define AWS_ASYNC_OPERATION_GUARD(OPERATION) \
if(!m_isInitialized) \
{ \
AWS_LOGSTREAM_ERROR(#OPERATION, "Unable to call " #OPERATION ": client is not initialized (or already terminated)"); \
return handler(this, request, Aws::Client::AWSError<CoreErrors>(CoreErrors::NOT_INITIALIZED, "NOT_INITIALIZED", "Client is not initialized or already terminated", false), handlerContext); \
} \
Aws::Utils::RAIICounter(this->m_operationsProcessed, &this->m_shutdownSignal)
6 changes: 3 additions & 3 deletions src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ namespace client
const std::shared_ptr<EndpointProviderT> endpointProvider,
const std::shared_ptr<ServiceAuthSchemeResolverT>& authSchemeResolver,
const Aws::UnorderedMap<Aws::String, AuthSchemesVariantT>& authSchemes)
: AwsSmithyClientBase(clientConfig, serviceName, httpClient, errorMarshaller),
m_clientConfig(clientConfig),
: AwsSmithyClientBase(Aws::MakeUnique<ServiceClientConfigurationT>(ServiceNameT, clientConfig), serviceName, httpClient, errorMarshaller),
m_clientConfig(*AwsSmithyClientBase::m_clientConfig.get()),
m_endpointProvider(endpointProvider),
m_authSchemeResolver(authSchemeResolver),
m_authSchemes(authSchemes)
Expand Down Expand Up @@ -119,7 +119,7 @@ namespace client
}

protected:
ServiceClientConfigurationT m_clientConfig{};
ServiceClientConfigurationT& m_clientConfig;
std::shared_ptr<EndpointProviderT> m_endpointProvider{};
std::shared_ptr<ServiceAuthSchemeResolverT> m_authSchemeResolver{};
Aws::UnorderedMap<Aws::String, AuthSchemesVariantT> m_authSchemes{};
Expand Down
38 changes: 33 additions & 5 deletions src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,44 @@ namespace client
using SelectAuthSchemeOptionOutcome = Aws::Utils::Outcome<AuthSchemeOption, AWSError>;
using ResolveEndpointOutcome = Aws::Utils::Outcome<Aws::Endpoint::AWSEndpoint, AWSError>;

AwsSmithyClientBase(const Aws::Client::ClientConfiguration& clientConfig,
AwsSmithyClientBase(Aws::UniquePtr<Aws::Client::ClientConfiguration>&& clientConfig,
Aws::String serviceName,
std::shared_ptr<Aws::Http::HttpClient> httpClient,
std::shared_ptr<Aws::Client::AWSErrorMarshaller> errorMarshaller) :
m_clientConfig(clientConfig),
m_clientConfig(std::move(clientConfig)),
m_serviceName(std::move(serviceName)),
m_userAgent(Aws::Client::ComputeUserAgentString(&clientConfig)),
m_userAgent(),
m_httpClient(std::move(httpClient)),
m_errorMarshaller(std::move(errorMarshaller))
{}
{
if (!m_clientConfig->retryStrategy)
{
assert(m_clientConfig->configFactories.retryStrategyCreateFn);
m_clientConfig->retryStrategy = m_clientConfig->configFactories.retryStrategyCreateFn();
}
if (!m_clientConfig->executor)
{
assert(m_clientConfig->configFactories.executorCreateFn);
m_clientConfig->executor = m_clientConfig->configFactories.executorCreateFn();
}
if (!m_clientConfig->writeRateLimiter)
{
assert(m_clientConfig->configFactories.writeRateLimiterCreateFn);
m_clientConfig->writeRateLimiter = m_clientConfig->configFactories.writeRateLimiterCreateFn();
}
if (!m_clientConfig->readRateLimiter)
{
assert(m_clientConfig->configFactories.readRateLimiterCreateFn);
m_clientConfig->readRateLimiter = m_clientConfig->configFactories.readRateLimiterCreateFn();
}
if (!m_clientConfig->telemetryProvider)
{
assert(m_clientConfig->configFactories.telemetryProviderCreateFn);
m_clientConfig->telemetryProvider = m_clientConfig->configFactories.telemetryProviderCreateFn();
}

m_userAgent = Aws::Client::ComputeUserAgentString(m_clientConfig.get());
}

AwsSmithyClientBase(const AwsSmithyClientBase&) = delete;
AwsSmithyClientBase(AwsSmithyClientBase&&) = delete;
Expand Down Expand Up @@ -127,7 +155,7 @@ namespace client
virtual bool AdjustClockSkew(HttpResponseOutcome& outcome, const AuthSchemeOption& authSchemeOption) const = 0;

protected:
Aws::Client::ClientConfiguration m_clientConfig;
Aws::UniquePtr<Aws::Client::ClientConfiguration> m_clientConfig;
Aws::String m_serviceName;
Aws::String m_userAgent;

Expand Down
32 changes: 22 additions & 10 deletions src/aws-cpp-sdk-core/source/client/AWSClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,19 @@ AWSClient::AWSClient(const Aws::Client::ClientConfiguration& configuration,
const std::shared_ptr<Aws::Client::AWSAuthSigner>& signer,
const std::shared_ptr<AWSErrorMarshaller>& errorMarshaller) :
m_region(configuration.region),
m_telemetryProvider(configuration.telemetryProvider),
m_telemetryProvider(configuration.telemetryProvider ? configuration.telemetryProvider : configuration.configFactories.telemetryProviderCreateFn()),
m_signerProvider(Aws::MakeUnique<Aws::Auth::DefaultAuthSignerProvider>(AWS_CLIENT_LOG_TAG, signer)),
m_httpClient(CreateHttpClient(configuration)),
m_httpClient(CreateHttpClient(
[&configuration, this]()
{
ClientConfiguration tempConfig(configuration);
tempConfig.telemetryProvider = m_telemetryProvider;
return tempConfig;
}())),
m_errorMarshaller(errorMarshaller),
m_retryStrategy(configuration.retryStrategy),
m_writeRateLimiter(configuration.writeRateLimiter),
m_readRateLimiter(configuration.readRateLimiter),
m_retryStrategy(configuration.retryStrategy ? configuration.retryStrategy : configuration.configFactories.retryStrategyCreateFn()),
m_writeRateLimiter(configuration.writeRateLimiter ? configuration.writeRateLimiter : configuration.configFactories.writeRateLimiterCreateFn()),
m_readRateLimiter(configuration.readRateLimiter ? configuration.readRateLimiter : configuration.configFactories.readRateLimiterCreateFn()),
m_userAgent(Aws::Client::ComputeUserAgentString(&configuration)),
m_hash(Aws::Utils::Crypto::CreateMD5Implementation()),
m_requestTimeoutMs(configuration.requestTimeoutMs),
Expand All @@ -138,13 +144,19 @@ AWSClient::AWSClient(const Aws::Client::ClientConfiguration& configuration,
const std::shared_ptr<Aws::Auth::AWSAuthSignerProvider>& signerProvider,
const std::shared_ptr<AWSErrorMarshaller>& errorMarshaller) :
m_region(configuration.region),
m_telemetryProvider(configuration.telemetryProvider),
m_telemetryProvider(configuration.telemetryProvider ? configuration.telemetryProvider : configuration.configFactories.telemetryProviderCreateFn()),
m_signerProvider(signerProvider),
m_httpClient(CreateHttpClient(configuration)),
m_httpClient(CreateHttpClient(
[&configuration, this]()
{
ClientConfiguration tempConfig(configuration);
tempConfig.telemetryProvider = m_telemetryProvider;
return tempConfig;
}())),
m_errorMarshaller(errorMarshaller),
m_retryStrategy(configuration.retryStrategy),
m_writeRateLimiter(configuration.writeRateLimiter),
m_readRateLimiter(configuration.readRateLimiter),
m_retryStrategy(configuration.retryStrategy ? configuration.retryStrategy : configuration.configFactories.retryStrategyCreateFn()),
m_writeRateLimiter(configuration.writeRateLimiter ? configuration.writeRateLimiter : configuration.configFactories.writeRateLimiterCreateFn()),
m_readRateLimiter(configuration.readRateLimiter ? configuration.readRateLimiter : configuration.configFactories.readRateLimiterCreateFn()),
m_userAgent(Aws::Client::ComputeUserAgentString(&configuration)),
m_hash(Aws::Utils::Crypto::CreateMD5Implementation()),
m_requestTimeoutMs(configuration.requestTimeoutMs),
Expand Down
Loading

0 comments on commit 9e8e71d

Please sign in to comment.