Skip to content

Commit

Permalink
v2 sdk changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Trianz-Akshay committed Sep 27, 2024
1 parent f4d1090 commit 78fa87f
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@
*/
package com.amazonaws.athena.connectors.elasticsearch;

import com.amazonaws.DefaultRequest;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.Signer;
import com.amazonaws.http.HttpMethodName;
import org.apache.http.Header;
import org.apache.http.HttpEntityEnclosingRequest;
import org.apache.http.HttpException;
Expand All @@ -34,9 +30,15 @@
import org.apache.http.entity.BasicHttpEntity;
import org.apache.http.message.BasicHeader;
import org.apache.http.protocol.HttpContext;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner;
import software.amazon.awssdk.http.auth.spi.signer.SignedRequest;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
Expand All @@ -47,8 +49,8 @@
import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST;

/**
* An {@link HttpRequestInterceptor} that signs requests using any AWS {@link Signer}
* and {@link AWSCredentialsProvider}.
* An {@link HttpRequestInterceptor} that signs requests using any AWS {@link AwsV4HttpSigner}
* and {@link AwsCredentialsProvider}.
*/
public class AWSRequestSigningApacheInterceptor implements HttpRequestInterceptor
{
Expand All @@ -61,34 +63,35 @@ public class AWSRequestSigningApacheInterceptor implements HttpRequestIntercepto
/**
* The particular signer implementation.
*/
private final Signer signer;
private final AwsV4HttpSigner signer;

/**
* The source of AWS credentials for signing.
*/
private final AWSCredentialsProvider awsCredentialsProvider;
private final AwsCredentialsProvider awsCredentialsProvider;
private final String region;

/**
*
* @param service service that we're connecting to
* @param signer particular signer implementation
* @param service service that we're connecting to
* @param signer particular signer implementation
* @param awsCredentialsProvider source of AWS credentials for signing
*/
public AWSRequestSigningApacheInterceptor(final String service,
final Signer signer,
final AWSCredentialsProvider awsCredentialsProvider)
final AwsV4HttpSigner signer,
final AwsCredentialsProvider awsCredentialsProvider,
final String region)
{
this.service = service;
this.signer = signer;
this.awsCredentialsProvider = awsCredentialsProvider;
this.region = region;
}

/**
* {@inheritDoc}
*/
@Override
public void process(final HttpRequest request, final HttpContext context)
throws HttpException, IOException
public void process(final HttpRequest request, final HttpContext context) throws HttpException, IOException
{
URIBuilder uriBuilder;
try {
Expand All @@ -98,55 +101,61 @@ public void process(final HttpRequest request, final HttpContext context)
throw new IOException("Invalid URI", e);
}

// Copy Apache HttpRequest to AWS DefaultRequest
DefaultRequest<?> signableRequest = new DefaultRequest<>(service);

HttpHost host = (HttpHost) context.getAttribute(HTTP_TARGET_HOST);
if (host != null) {
signableRequest.setEndpoint(URI.create(host.toURI()));
}
final HttpMethodName httpMethod =
HttpMethodName.fromValue(request.getRequestLine().getMethod());
signableRequest.setHttpMethod(httpMethod);
// Build the SdkHttpFullRequest
SdkHttpFullRequest.Builder signableRequest = null;
try {
signableRequest.setResourcePath(uriBuilder.build().getRawPath());
signableRequest = SdkHttpFullRequest.builder()
.method(SdkHttpMethod.fromValue(request.getRequestLine().getMethod())) // Set HTTP Method
.encodedPath(uriBuilder.build().getRawPath()) // Set Resource Path
.rawQueryParameters(nvpToMapParams(uriBuilder.getQueryParams())) // Set Query Parameters
.headers(headerArrayToMap(request.getAllHeaders()));
}
catch (URISyntaxException e) {
throw new IOException("Invalid URI", e);
throw new RuntimeException(e);
}

// Set the endpoint (host) if present in the context
HttpHost host = (HttpHost) context.getAttribute(HTTP_TARGET_HOST);
if (host != null) {
signableRequest.uri(URI.create(host.toURI())); // Set the base endpoint URL
}

// Handle content/body if it's an HttpEntityEnclosingRequest
if (request instanceof HttpEntityEnclosingRequest) {
HttpEntityEnclosingRequest httpEntityEnclosingRequest =
(HttpEntityEnclosingRequest) request;
HttpEntityEnclosingRequest httpEntityEnclosingRequest = (HttpEntityEnclosingRequest) request;
if (httpEntityEnclosingRequest.getEntity() != null) {
signableRequest.setContent(httpEntityEnclosingRequest.getEntity().getContent());
InputStream contentStream = httpEntityEnclosingRequest.getEntity().getContent();
signableRequest.contentStreamProvider(() -> contentStream); // Set content provider
}
else {
// This is a workaround from here: https://github.com/aws/aws-sdk-java/issues/2078
signableRequest.setContent(new ByteArrayInputStream(new byte[0]));
// Workaround: provide an empty stream if no entity is present
signableRequest.contentStreamProvider(() -> new ByteArrayInputStream(new byte[0]));
}
}
signableRequest.setParameters(nvpToMapParams(uriBuilder.getQueryParams()));
signableRequest.setHeaders(headerArrayToMap(request.getAllHeaders()));

// Sign it
signer.sign(signableRequest, awsCredentialsProvider.getCredentials());

// Now copy everything back
request.setHeaders(mapToHeaderArray(signableRequest.getHeaders()));
// Sign the request
SdkHttpFullRequest.Builder finalSignableRequest = signableRequest;
SignedRequest signedRequest =
signer.sign(r -> r.identity(awsCredentialsProvider.resolveCredentials())
.request(finalSignableRequest.build())
.payload(finalSignableRequest.contentStreamProvider())
.putProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, service)
.putProperty(AwsV4HttpSigner.REGION_NAME, region)); // Required for S3 only
// Now copy everything back to the original request (including signed headers)
request.setHeaders(mapToHeaderArray(signedRequest.request().headers()));

// If the request has an entity (body), copy it back to the original request
if (request instanceof HttpEntityEnclosingRequest) {
HttpEntityEnclosingRequest httpEntityEnclosingRequest =
(HttpEntityEnclosingRequest) request;
HttpEntityEnclosingRequest httpEntityEnclosingRequest = (HttpEntityEnclosingRequest) request;
if (httpEntityEnclosingRequest.getEntity() != null) {
BasicHttpEntity basicHttpEntity = new BasicHttpEntity();
basicHttpEntity.setContent(signableRequest.getContent());
basicHttpEntity.setContent(signableRequest.contentStreamProvider().newStream());
httpEntityEnclosingRequest.setEntity(basicHttpEntity);
}
}
}

/**
*
* @param params list of HTTP query params as NameValuePairs
* @return a multimap of HTTP query params
*/
Expand All @@ -165,12 +174,13 @@ private static Map<String, List<String>> nvpToMapParams(final List<NameValuePair
* @param headers modeled Header objects
* @return a Map of header entries
*/
private static Map<String, String> headerArrayToMap(final Header[] headers)
private static Map<String, List<String>> headerArrayToMap(final Header[] headers)
{
Map<String, String> headersMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
Map<String, List<String>> headersMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
for (Header header : headers) {
if (!skipHeader(header)) {
headersMap.put(header.getName(), header.getValue());
// If the header name already exists, add the new value to the list
headersMap.computeIfAbsent(header.getName(), k -> new ArrayList<>()).add(header.getValue());
}
}
return headersMap;
Expand All @@ -191,12 +201,12 @@ private static boolean skipHeader(final Header header)
* @param mapHeaders Map of header entries
* @return modeled Header objects
*/
private static Header[] mapToHeaderArray(final Map<String, String> mapHeaders)
private static Header[] mapToHeaderArray(final Map<String, List<String>> mapHeaders)
{
Header[] headers = new Header[mapHeaders.size()];
int i = 0;
for (Map.Entry<String, String> headerEntry : mapHeaders.entrySet()) {
headers[i++] = new BasicHeader(headerEntry.getKey(), headerEntry.getValue());
for (Map.Entry<String, List<String>> headerEntry : mapHeaders.entrySet()) {
headers[i++] = new BasicHeader(headerEntry.getKey(), headerEntry.getValue().get(0));
}
return headers;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
*/
package com.amazonaws.athena.connectors.elasticsearch;

import com.amazonaws.auth.AWS4Signer;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.google.common.base.Splitter;
import org.apache.http.HttpHost;
import org.apache.http.HttpRequestInterceptor;
Expand All @@ -46,6 +44,9 @@
import org.elasticsearch.search.SearchHit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner;
import software.amazon.awssdk.services.elasticsearch.ElasticsearchClient;

import java.io.IOException;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -196,7 +197,7 @@ public static class Builder
{
private final String endpoint;
private final RestClientBuilder clientBuilder;
private final AWS4Signer signer;
private final AwsV4HttpSigner signer;
private final Splitter domainSplitter;

/**
Expand All @@ -207,7 +208,7 @@ public Builder(String endpoint)
{
this.endpoint = endpoint;
this.clientBuilder = RestClient.builder(HttpHost.create(this.endpoint));
this.signer = new AWS4Signer();
this.signer = AwsV4HttpSigner.create();
this.domainSplitter = Splitter.on(".");
}

Expand All @@ -216,7 +217,7 @@ public Builder(String endpoint)
* @param credentialsProvider is the AWS credentials provider.
* @return self.
*/
public Builder withCredentials(AWSCredentialsProvider credentialsProvider)
public Builder withCredentials(AwsCredentialsProvider credentialsProvider)
{
/**
* endpoint:
Expand All @@ -231,16 +232,13 @@ public Builder withCredentials(AWSCredentialsProvider credentialsProvider)
*/
List<String> domainSplits = domainSplitter.splitToList(endpoint);

HttpRequestInterceptor interceptor;
if (domainSplits.size() > 1) {
signer.setRegionName(domainSplits.get(1));
signer.setServiceName("es");
}

HttpRequestInterceptor interceptor =
new AWSRequestSigningApacheInterceptor(signer.getServiceName(), signer, credentialsProvider);
interceptor = new AWSRequestSigningApacheInterceptor(ElasticsearchClient.SERVICE_NAME, signer, credentialsProvider, domainSplits.get(1));

clientBuilder.setHttpClientConfigCallback(httpClientBuilder -> httpClientBuilder
.addInterceptorLast(interceptor));
clientBuilder.setHttpClientConfigCallback(httpClientBuilder -> httpClientBuilder
.addInterceptorLast(interceptor));
}

return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
*/
package com.amazonaws.athena.connectors.elasticsearch;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;

import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -100,7 +100,7 @@ private AwsRestHighLevelClient createClient(String endpoint)
{
if (useAwsCredentials) {
return new AwsRestHighLevelClient.Builder(endpoint)
.withCredentials(new DefaultAWSCredentialsProviderChain()).build();
.withCredentials(DefaultCredentialsProvider.create()).build();
}
else {
Matcher credentials = credentialsPattern.matcher(endpoint);
Expand Down

0 comments on commit 78fa87f

Please sign in to comment.