diff --git a/source/extensions/common/aws/BUILD b/source/extensions/common/aws/BUILD index 31bd0c8fa72a..96382e2095c2 100644 --- a/source/extensions/common/aws/BUILD +++ b/source/extensions/common/aws/BUILD @@ -29,7 +29,7 @@ envoy_cc_library( "//source/common/common:logger_lib", "//source/common/common:utility_lib", "//source/common/crypto:utility_lib", - "//source/common/http:headers_lib", + "//source/common/http:message_lib", "//source/common/singleton:const_singleton", ], ) @@ -47,6 +47,7 @@ envoy_cc_library( external_deps = ["abseil_time"], deps = [ ":credentials_provider_interface", + ":utility_lib", "//envoy/api:api_interface", "//source/common/common:logger_lib", "//source/common/common:thread_lib", @@ -61,6 +62,7 @@ envoy_cc_library( hdrs = ["utility.h"], external_deps = ["curl"], deps = [ + "//envoy/http:message_interface", "//source/common/common:empty_string", "//source/common/common:matchers_lib", "//source/common/common:utility_lib", diff --git a/source/extensions/common/aws/credentials_provider.h b/source/extensions/common/aws/credentials_provider.h index 7e1da55fb182..9de0fe8b7a4d 100644 --- a/source/extensions/common/aws/credentials_provider.h +++ b/source/extensions/common/aws/credentials_provider.h @@ -21,9 +21,10 @@ namespace Aws { */ class Credentials { public: - Credentials(absl::string_view access_key_id = absl::string_view(), - absl::string_view secret_access_key = absl::string_view(), - absl::string_view session_token = absl::string_view()) { + explicit Credentials(absl::string_view access_key_id = absl::string_view(), + absl::string_view secret_access_key = absl::string_view(), + absl::string_view session_token = absl::string_view()) { + // TODO(suniltheta): Move credential expiration date in here if (!access_key_id.empty()) { access_key_id_ = std::string(access_key_id); if (!secret_access_key.empty()) { diff --git a/source/extensions/common/aws/credentials_provider_impl.cc b/source/extensions/common/aws/credentials_provider_impl.cc index 4dd5bfed47ae..8d79ac693e62 100644 --- a/source/extensions/common/aws/credentials_provider_impl.cc +++ b/source/extensions/common/aws/credentials_provider_impl.cc @@ -3,8 +3,10 @@ #include "envoy/common/exception.h" #include "source/common/common/lock_guard.h" +#include "source/common/http/message_impl.h" #include "source/common/http/utility.h" #include "source/common/json/json_loader.h" +#include "source/extensions/common/aws/utility.h" namespace Envoy { namespace Extensions { @@ -71,20 +73,29 @@ void InstanceProfileCredentialsProvider::refresh() { ENVOY_LOG(debug, "Getting AWS credentials from the instance metadata"); // First discover the Role of this instance - const auto instance_role_string = - metadata_fetcher_(EC2_METADATA_HOST, SECURITY_CREDENTIALS_PATH, ""); + Http::RequestMessageImpl message; + message.headers().setMethod(Http::Headers::get().MethodValues.Get); + message.headers().setHost(EC2_METADATA_HOST); + message.headers().setPath(SECURITY_CREDENTIALS_PATH); + const auto instance_role_string = metadata_fetcher_(message); if (!instance_role_string) { ENVOY_LOG(error, "Could not retrieve credentials listing from the instance metadata"); return; } + fetchCredentialFromInstanceRole(instance_role_string.value()); +} - const auto instance_role_list = - StringUtil::splitToken(StringUtil::trim(instance_role_string.value()), "\n"); +void InstanceProfileCredentialsProvider::fetchCredentialFromInstanceRole( + const std::string& instance_role) { + if (instance_role.empty()) { + return; + } + const auto instance_role_list = StringUtil::splitToken(StringUtil::trim(instance_role), "\n"); if (instance_role_list.empty()) { ENVOY_LOG(error, "No AWS credentials were found in the instance metadata"); return; } - ENVOY_LOG(debug, "AWS credentials list:\n{}", instance_role_string.value()); + ENVOY_LOG(debug, "AWS credentials list:\n{}", instance_role); // Only one Role can be associated with an instance: // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html @@ -94,15 +105,27 @@ void InstanceProfileCredentialsProvider::refresh() { ENVOY_LOG(debug, "AWS credentials path: {}", credential_path); // Then fetch and parse the credentials - const auto credential_document = metadata_fetcher_(EC2_METADATA_HOST, credential_path, ""); + Http::RequestMessageImpl message; + message.headers().setMethod(Http::Headers::get().MethodValues.Get); + message.headers().setHost(EC2_METADATA_HOST); + message.headers().setPath(credential_path); + + const auto credential_document = metadata_fetcher_(message); if (!credential_document) { ENVOY_LOG(error, "Could not load AWS credentials document from the instance metadata"); return; } + extractCredentials(credential_document.value()); +} +void InstanceProfileCredentialsProvider::extractCredentials( + const std::string& credential_document_value) { + if (credential_document_value.empty()) { + return; + } Json::ObjectSharedPtr document_json; try { - document_json = Json::Factory::loadFromString(credential_document.value()); + document_json = Json::Factory::loadFromString(credential_document_value); } catch (EnvoyException& e) { ENVOY_LOG(error, "Could not parse AWS credentials document: {}", e.what()); return; @@ -133,17 +156,27 @@ void TaskRoleCredentialsProvider::refresh() { absl::string_view host; absl::string_view path; Http::Utility::extractHostPathFromUri(credential_uri_, host, path); - const auto credential_document = - metadata_fetcher_(std::string(host.data(), host.size()), - std::string(path.data(), path.size()), authorization_token_); + + Http::RequestMessageImpl message; + message.headers().setMethod(Http::Headers::get().MethodValues.Get); + message.headers().setHost(host); + message.headers().setPath(path); + message.headers().setCopy(Http::CustomHeaders::get().Authorization, authorization_token_); + const auto credential_document = metadata_fetcher_(message); if (!credential_document) { ENVOY_LOG(error, "Could not load AWS credentials document from the task role"); return; } + extractCredentials(credential_document.value()); +} +void TaskRoleCredentialsProvider::extractCredentials(const std::string& credential_document_value) { + if (credential_document_value.empty()) { + return; + } Json::ObjectSharedPtr document_json; try { - document_json = Json::Factory::loadFromString(credential_document.value()); + document_json = Json::Factory::loadFromString(credential_document_value); } catch (EnvoyException& e) { ENVOY_LOG(error, "Could not parse AWS credentials document from the task role: {}", e.what()); return; diff --git a/source/extensions/common/aws/credentials_provider_impl.h b/source/extensions/common/aws/credentials_provider_impl.h index b15637282aa8..b26e67694ff2 100644 --- a/source/extensions/common/aws/credentials_provider_impl.h +++ b/source/extensions/common/aws/credentials_provider_impl.h @@ -4,6 +4,7 @@ #include "envoy/api/api.h" #include "envoy/event/timer.h" +#include "envoy/http/message.h" #include "source/common/common/logger.h" #include "source/common/common/thread.h" @@ -31,8 +32,7 @@ class EnvironmentCredentialsProvider : public CredentialsProvider, class MetadataCredentialsProviderBase : public CredentialsProvider, public Logger::Loggable { public: - using MetadataFetcher = std::function( - const std::string& host, const std::string& path, const std::string& auth_token)>; + using MetadataFetcher = std::function(Http::RequestMessage&)>; MetadataCredentialsProviderBase(Api::Api& api, const MetadataFetcher& metadata_fetcher) : api_(api), metadata_fetcher_(metadata_fetcher) {} @@ -68,6 +68,8 @@ class InstanceProfileCredentialsProvider : public MetadataCredentialsProviderBas private: bool needsRefresh() override; void refresh() override; + void extractCredentials(const std::string& credential_document_value); + void fetchCredentialFromInstanceRole(const std::string& instance_role); }; /** @@ -90,6 +92,7 @@ class TaskRoleCredentialsProvider : public MetadataCredentialsProviderBase { bool needsRefresh() override; void refresh() override; + void extractCredentials(const std::string& credential_document_value); }; /** diff --git a/source/extensions/common/aws/signer_impl.h b/source/extensions/common/aws/signer_impl.h index 208133ce2679..b5dd73a703b4 100644 --- a/source/extensions/common/aws/signer_impl.h +++ b/source/extensions/common/aws/signer_impl.h @@ -48,7 +48,7 @@ using AwsSigV4HeaderExclusionVector = std::vector { +class SignerImpl : public Signer, public Logger::Loggable { public: SignerImpl(absl::string_view service_name, absl::string_view region, const CredentialsProviderSharedPtr& credentials_provider, TimeSource& time_source, diff --git a/source/extensions/common/aws/utility.cc b/source/extensions/common/aws/utility.cc index b1929086ad6d..20afddd1c8e1 100644 --- a/source/extensions/common/aws/utility.cc +++ b/source/extensions/common/aws/utility.cc @@ -223,9 +223,7 @@ static size_t curlCallback(char* ptr, size_t, size_t nmemb, void* data) { return nmemb; } -absl::optional Utility::metadataFetcher(const std::string& host, - const std::string& path, - const std::string& auth_token) { +absl::optional Utility::fetchMetadata(Http::RequestMessage& message) { static const size_t MAX_RETRIES = 4; static const std::chrono::milliseconds RETRY_DELAY{1000}; static const std::chrono::seconds TIMEOUT{5}; @@ -235,7 +233,10 @@ absl::optional Utility::metadataFetcher(const std::string& host, return absl::nullopt; }; - const std::string url = fmt::format("http://{}/{}", host, path); + const auto host = message.headers().getHostValue(); + const auto path = message.headers().getPathValue(); + + const std::string url = fmt::format("http://{}{}", host, path); curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_TIMEOUT, TIMEOUT.count()); curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L); @@ -246,9 +247,17 @@ absl::optional Utility::metadataFetcher(const std::string& host, curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, curlCallback); struct curl_slist* headers = nullptr; - if (!auth_token.empty()) { - const std::string auth = fmt::format("Authorization: {}", auth_token); - headers = curl_slist_append(headers, auth.c_str()); + message.headers().iterate([&headers](const Http::HeaderEntry& entry) -> Http::HeaderMap::Iterate { + // Skip pseudo-headers + if (!entry.key().getStringView().empty() && entry.key().getStringView()[0] == ':') { + return Http::HeaderMap::Iterate::Continue; + } + const std::string header = + fmt::format("{}: {}", entry.key().getStringView(), entry.value().getStringView()); + headers = curl_slist_append(headers, header.c_str()); + return Http::HeaderMap::Iterate::Continue; + }); + if (headers != nullptr) { curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); } diff --git a/source/extensions/common/aws/utility.h b/source/extensions/common/aws/utility.h index e7cfdac2abb1..2ec7cae045cd 100644 --- a/source/extensions/common/aws/utility.h +++ b/source/extensions/common/aws/utility.h @@ -1,5 +1,7 @@ #pragma once +#include "envoy/http/message.h" + #include "source/common/common/matchers.h" #include "source/common/http/headers.h" @@ -81,9 +83,7 @@ class Utility { /** * Fetch AWS instance or task metadata. * - * @param host host or ip address of the metadata endpoint. - * @param path path of the metadata document. - * @auth_token authentication token to pass in the request, empty string indicates no auth. + * @param message An HTTP request. * @return Metadata document or nullopt in case if unable to fetch it. * * @note In case of an error, function will log ENVOY_LOG_MISC(debug) message. @@ -91,8 +91,7 @@ class Utility { * @note This is not main loop safe method as it is blocking. It is intended to be used from the * gRPC auth plugins that are able to schedule blocking plugins on a different thread. */ - static absl::optional - metadataFetcher(const std::string& host, const std::string& path, const std::string& auth_token); + static absl::optional fetchMetadata(Http::RequestMessage& message); }; } // namespace Aws diff --git a/source/extensions/filters/http/aws_lambda/config.cc b/source/extensions/filters/http/aws_lambda/config.cc index 41cbbfe9b804..de5a1cfc0a39 100644 --- a/source/extensions/filters/http/aws_lambda/config.cc +++ b/source/extensions/filters/http/aws_lambda/config.cc @@ -37,15 +37,16 @@ Http::FilterFactoryCb AwsLambdaFilterFactory::createFilterFactoryFromProtoTyped( const envoy::extensions::filters::http::aws_lambda::v3::Config& proto_config, const std::string& stat_prefix, Server::Configuration::FactoryContext& context) { - auto credentials_provider = - std::make_shared( - context.api(), Extensions::Common::Aws::Utility::metadataFetcher); - const auto arn = parseArn(proto_config.arn()); if (!arn) { throw EnvoyException(fmt::format("aws_lambda_filter: Invalid ARN: {}", proto_config.arn())); } const std::string region = arn->region(); + + auto credentials_provider = + std::make_shared( + context.api(), Extensions::Common::Aws::Utility::fetchMetadata); + auto signer = std::make_shared( service_name, region, std::move(credentials_provider), context.mainThreadDispatcher().timeSource(), diff --git a/source/extensions/filters/http/aws_request_signing/config.cc b/source/extensions/filters/http/aws_request_signing/config.cc index d9a1a5c74a7b..6743c9473c8b 100644 --- a/source/extensions/filters/http/aws_request_signing/config.cc +++ b/source/extensions/filters/http/aws_request_signing/config.cc @@ -20,7 +20,7 @@ Http::FilterFactoryCb AwsRequestSigningFilterFactory::createFilterFactoryFromPro auto credentials_provider = std::make_shared( - context.api(), Extensions::Common::Aws::Utility::metadataFetcher); + context.api(), Extensions::Common::Aws::Utility::fetchMetadata); const auto matcher_config = Extensions::Common::Aws::AwsSigV4HeaderExclusionVector( config.match_excluded_headers().begin(), config.match_excluded_headers().end()); auto signer = std::make_unique( diff --git a/source/extensions/grpc_credentials/aws_iam/config.cc b/source/extensions/grpc_credentials/aws_iam/config.cc index 6253e6cc5cbc..0dbba166d326 100644 --- a/source/extensions/grpc_credentials/aws_iam/config.cc +++ b/source/extensions/grpc_credentials/aws_iam/config.cc @@ -45,7 +45,7 @@ std::shared_ptr AwsIamGrpcCredentialsFactory::getChann const envoy::config::grpc_credential::v3::AwsIamConfig&>( *config_message, ProtobufMessage::getNullValidationVisitor()); auto credentials_provider = std::make_shared( - api, Common::Aws::Utility::metadataFetcher); + api, Common::Aws::Utility::fetchMetadata); auto signer = std::make_unique( config.service_name(), getRegion(config), credentials_provider, api.timeSource(), // TODO: extend API to allow specifying header exclusion. ref: diff --git a/test/extensions/common/aws/aws_metadata_fetcher_integration_test.cc b/test/extensions/common/aws/aws_metadata_fetcher_integration_test.cc index be22190b31f4..3f4a09820257 100644 --- a/test/extensions/common/aws/aws_metadata_fetcher_integration_test.cc +++ b/test/extensions/common/aws/aws_metadata_fetcher_integration_test.cc @@ -79,9 +79,12 @@ class AwsMetadataIntegrationTestSuccess : public AwsMetadataIntegrationTestBase }; TEST_F(AwsMetadataIntegrationTestSuccess, Success) { - const auto endpoint = fmt::format("{}:{}", Network::Test::getLoopbackAddressUrlString(version_), - lookupPort("listener_0")); - const auto response = Utility::metadataFetcher(endpoint, "", ""); + const auto authority = fmt::format("{}:{}", Network::Test::getLoopbackAddressUrlString(version_), + lookupPort("listener_0")); + auto headers = Http::RequestHeaderMapPtr{new Http::TestRequestHeaderMapImpl{ + {":path", "/"}, {":authority", authority}, {":method", "GET"}}}; + Http::RequestMessageImpl message(std::move(headers)); + const auto response = Utility::fetchMetadata(message); ASSERT_TRUE(response.has_value()); EXPECT_EQ("METADATA_VALUE", *response); @@ -91,9 +94,15 @@ TEST_F(AwsMetadataIntegrationTestSuccess, Success) { } TEST_F(AwsMetadataIntegrationTestSuccess, AuthToken) { - const auto endpoint = fmt::format("{}:{}", Network::Test::getLoopbackAddressUrlString(version_), - lookupPort("listener_0")); - const auto response = Utility::metadataFetcher(endpoint, "", "AUTH_TOKEN"); + const auto authority = fmt::format("{}:{}", Network::Test::getLoopbackAddressUrlString(version_), + lookupPort("listener_0")); + auto headers = Http::RequestHeaderMapPtr{ + new Http::TestRequestHeaderMapImpl{{":path", "/"}, + {":authority", authority}, + {":method", "GET"}, + {"authorization", "AUTH_TOKEN"}}}; + Http::RequestMessageImpl message(std::move(headers)); + const auto response = Utility::fetchMetadata(message); ASSERT_TRUE(response.has_value()); EXPECT_EQ("METADATA_VALUE_WITH_AUTH", *response); @@ -103,9 +112,15 @@ TEST_F(AwsMetadataIntegrationTestSuccess, AuthToken) { } TEST_F(AwsMetadataIntegrationTestSuccess, Redirect) { - const auto endpoint = fmt::format("{}:{}", Network::Test::getLoopbackAddressUrlString(version_), - lookupPort("listener_0")); - const auto response = Utility::metadataFetcher(endpoint, "redirect", "AUTH_TOKEN"); + const auto authority = fmt::format("{}:{}", Network::Test::getLoopbackAddressUrlString(version_), + lookupPort("listener_0")); + auto headers = Http::RequestHeaderMapPtr{ + new Http::TestRequestHeaderMapImpl{{":path", "/redirect"}, + {":authority", authority}, + {":method", "GET"}, + {"authorization", "AUTH_TOKEN"}}}; + Http::RequestMessageImpl message(std::move(headers)); + const auto response = Utility::fetchMetadata(message); ASSERT_TRUE(response.has_value()); EXPECT_EQ("METADATA_VALUE_WITH_AUTH", *response); @@ -124,11 +139,17 @@ class AwsMetadataIntegrationTestFailure : public AwsMetadataIntegrationTestBase }; TEST_F(AwsMetadataIntegrationTestFailure, Failure) { - const auto endpoint = fmt::format("{}:{}", Network::Test::getLoopbackAddressUrlString(version_), - lookupPort("listener_0")); - + const auto authority = fmt::format("{}:{}", Network::Test::getLoopbackAddressUrlString(version_), + lookupPort("listener_0")); + auto headers = Http::RequestHeaderMapPtr{ + new Http::TestRequestHeaderMapImpl{{":path", "/"}, + {":authority", authority}, + {":method", "GET"}, + {"authorization", "AUTH_TOKEN"}}}; + + Http::RequestMessageImpl message(std::move(headers)); const auto start_time = timeSystem().monotonicTime(); - const auto response = Utility::metadataFetcher(endpoint, "", ""); + const auto response = Utility::fetchMetadata(message); const auto end_time = timeSystem().monotonicTime(); EXPECT_FALSE(response.has_value()); @@ -148,16 +169,19 @@ class AwsMetadataIntegrationTestTimeout : public AwsMetadataIntegrationTestBase }; TEST_F(AwsMetadataIntegrationTestTimeout, Timeout) { - const auto endpoint = fmt::format("{}:{}", Network::Test::getLoopbackAddressUrlString(version_), - lookupPort("listener_0")); + const auto authority = fmt::format("{}:{}", Network::Test::getLoopbackAddressUrlString(version_), + lookupPort("listener_0")); + auto headers = Http::RequestHeaderMapPtr{new Http::TestRequestHeaderMapImpl{ + {":path", "/"}, {":authority", authority}, {":method", "GET"}}}; + Http::RequestMessageImpl message(std::move(headers)); const auto start_time = timeSystem().monotonicTime(); - const auto response = Utility::metadataFetcher(endpoint, "", ""); + const auto response = Utility::fetchMetadata(message); const auto end_time = timeSystem().monotonicTime(); EXPECT_FALSE(response.has_value()); - // We do now check http.metadata_test.downstream_rq_completed value here because it's + // We do not check http.metadata_test.downstream_rq_completed value here because it's // behavior is different between Linux and Mac when Curl disconnects on timeout. On Mac it is // incremented, while on Linux it is not. diff --git a/test/extensions/common/aws/credentials_provider_impl_test.cc b/test/extensions/common/aws/credentials_provider_impl_test.cc index 6e4cc0819c9c..035f42e33959 100644 --- a/test/extensions/common/aws/credentials_provider_impl_test.cc +++ b/test/extensions/common/aws/credentials_provider_impl_test.cc @@ -62,35 +62,70 @@ TEST_F(EvironmentCredentialsProviderTest, NoSessionToken) { EXPECT_FALSE(credentials.sessionToken().has_value()); } +class MessageMatcher : public testing::MatcherInterface { +public: + explicit MessageMatcher(const Http::TestRequestHeaderMapImpl& expected_headers) + : expected_headers_(expected_headers) {} + + bool MatchAndExplain(Http::RequestMessage& message, + testing::MatchResultListener* result_listener) const override { + const bool equal = TestUtility::headerMapEqualIgnoreOrder(message.headers(), expected_headers_); + if (!equal) { + *result_listener << "\n" + << TestUtility::addLeftAndRightPadding("Expected header map:") << "\n" + << expected_headers_ + << TestUtility::addLeftAndRightPadding("is not equal to actual header map:") + << "\n" + << message.headers() + << TestUtility::addLeftAndRightPadding("") // line full of padding + << "\n"; + } + return equal; + } + + void DescribeTo(::std::ostream* os) const override { *os << "Message matches"; } + + void DescribeNegationTo(::std::ostream* os) const override { *os << "Message does not match"; } + +private: + const Http::TestRequestHeaderMapImpl expected_headers_; +}; + +testing::Matcher +messageMatches(const Http::TestRequestHeaderMapImpl& expected_headers) { + return testing::MakeMatcher(new MessageMatcher(expected_headers)); +} + class InstanceProfileCredentialsProviderTest : public testing::Test { public: InstanceProfileCredentialsProviderTest() : api_(Api::createApiForTest(time_system_)), - provider_(*api_, - [this](const std::string& host, const std::string& path, - const std::string& auth_token) -> absl::optional { - return this->fetcher_.fetch(host, path, auth_token); - }) {} + provider_(*api_, [this](Http::RequestMessage& message) -> absl::optional { + return this->fetch_metadata_.fetch(message); + }) {} void expectCredentialListing(const absl::optional& listing) { - EXPECT_CALL(fetcher_, - fetch("169.254.169.254:80", "/latest/meta-data/iam/security-credentials", _)) - .WillOnce(Return(listing)); + Http::TestRequestHeaderMapImpl headers{{":path", "/latest/meta-data/iam/security-credentials"}, + {":authority", "169.254.169.254:80"}, + {":method", "GET"}}; + EXPECT_CALL(fetch_metadata_, fetch(messageMatches(headers))).WillOnce(Return(listing)); } void expectDocument(const absl::optional& document) { - EXPECT_CALL(fetcher_, - fetch("169.254.169.254:80", "/latest/meta-data/iam/security-credentials/doc1", _)) - .WillOnce(Return(document)); + Http::TestRequestHeaderMapImpl headers{ + {":path", "/latest/meta-data/iam/security-credentials/doc1"}, + {":authority", "169.254.169.254:80"}, + {":method", "GET"}}; + EXPECT_CALL(fetch_metadata_, fetch(messageMatches(headers))).WillOnce(Return(document)); } Event::SimulatedTimeSystem time_system_; Api::ApiPtr api_; - NiceMock fetcher_; + NiceMock fetch_metadata_; InstanceProfileCredentialsProvider provider_; }; -TEST_F(InstanceProfileCredentialsProviderTest, FailedCredentailListing) { +TEST_F(InstanceProfileCredentialsProviderTest, FailedCredentialListing) { expectCredentialListing(absl::optional()); const auto credentials = provider_.getCredentials(); EXPECT_FALSE(credentials.accessKeyId().has_value()); @@ -195,9 +230,8 @@ class TaskRoleCredentialsProviderTest : public testing::Test { : api_(Api::createApiForTest(time_system_)), provider_( *api_, - [this](const std::string& host, const std::string& path, - const absl::optional& auth_token) -> absl::optional { - return this->fetcher_.fetch(host, path, auth_token); + [this](Http::RequestMessage& message) -> absl::optional { + return this->fetch_metadata_.fetch(message); }, "169.254.170.2:80/path/to/doc", "auth_token") { // Tue Jan 2 03:04:05 UTC 2018 @@ -205,12 +239,16 @@ class TaskRoleCredentialsProviderTest : public testing::Test { } void expectDocument(const absl::optional& document) { - EXPECT_CALL(fetcher_, fetch("169.254.170.2:80", "/path/to/doc", _)).WillOnce(Return(document)); + Http::TestRequestHeaderMapImpl headers{{":path", "/path/to/doc"}, + {":authority", "169.254.170.2:80"}, + {":method", "GET"}, + {"authorization", "auth_token"}}; + EXPECT_CALL(fetch_metadata_, fetch(messageMatches(headers))).WillOnce(Return(document)); } Event::SimulatedTimeSystem time_system_; Api::ApiPtr api_; - NiceMock fetcher_; + NiceMock fetch_metadata_; TaskRoleCredentialsProvider provider_; }; diff --git a/test/extensions/common/aws/mocks.h b/test/extensions/common/aws/mocks.h index d990409b4327..d316f66a30d0 100644 --- a/test/extensions/common/aws/mocks.h +++ b/test/extensions/common/aws/mocks.h @@ -29,21 +29,16 @@ class MockSigner : public Signer { MOCK_METHOD(void, signUnsignedPayload, (Http::RequestHeaderMap&)); }; -class MockMetadataFetcher { +class MockFetchMetadata { public: - virtual ~MockMetadataFetcher() = default; + virtual ~MockFetchMetadata() = default; - MOCK_METHOD(absl::optional, fetch, - (const std::string&, const std::string&, const absl::optional&), - (const)); + MOCK_METHOD(absl::optional, fetch, (Http::RequestMessage&), (const)); }; class DummyMetadataFetcher { public: - absl::optional operator()(const std::string&, const std::string&, - const absl::optional&) { - return absl::nullopt; - } + absl::optional operator()(Http::RequestMessage&) { return absl::nullopt; } }; } // namespace Aws