Skip to content

Commit

Permalink
SEE OTHER redirect handling fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ok2c committed Nov 19, 2023
1 parent 180d90c commit e2cff33
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ private static class State {
volatile URI redirectURI;
volatile int maxRedirects;
volatile int redirectCount;
volatile HttpRequest originalRequest;
volatile HttpRequest currentRequest;
volatile AsyncEntityProducer currentEntityProducer;
volatile RedirectLocations redirectLocations;
Expand Down Expand Up @@ -150,23 +151,24 @@ public AsyncDataConsumer handleResponse(
redirectBuilder = BasicRequestBuilder.get();
state.currentEntityProducer = null;
} else {
redirectBuilder = BasicRequestBuilder.copy(scope.originalRequest);
redirectBuilder = BasicRequestBuilder.copy(state.originalRequest);
}
break;
case HttpStatus.SC_SEE_OTHER:
if (!Method.GET.isSame(request.getMethod()) && !Method.HEAD.isSame(request.getMethod())) {
redirectBuilder = BasicRequestBuilder.get();
state.currentEntityProducer = null;
} else {
redirectBuilder = BasicRequestBuilder.copy(scope.originalRequest);
redirectBuilder = BasicRequestBuilder.copy(state.originalRequest);
}
break;
default:
redirectBuilder = BasicRequestBuilder.copy(scope.originalRequest);
redirectBuilder = BasicRequestBuilder.copy(state.originalRequest);
}
redirectBuilder.setUri(redirectUri);
state.reroute = false;
state.redirectURI = redirectUri;
state.originalRequest = redirectBuilder.build();
state.currentRequest = redirectBuilder.build();

if (!Objects.equals(currentRoute.getTargetHost(), newTarget)) {
Expand Down Expand Up @@ -270,6 +272,7 @@ public void execute(
final State state = new State();
state.maxRedirects = config.getMaxRedirects() > 0 ? config.getMaxRedirects() : 50;
state.redirectCount = 0;
state.originalRequest = scope.originalRequest;
state.currentRequest = request;
state.currentEntityProducer = entityProducer;
state.redirectLocations = redirectLocations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public ClassicHttpResponse execute(

final RequestConfig config = context.getRequestConfig();
final int maxRedirects = config.getMaxRedirects() > 0 ? config.getMaxRedirects() : 50;
ClassicHttpRequest originalRequest = scope.originalRequest;
ClassicHttpRequest currentRequest = request;
ExecChain.Scope currentScope = scope;
for (int redirectCount = 0;;) {
Expand Down Expand Up @@ -153,18 +154,18 @@ public ClassicHttpResponse execute(
if (Method.POST.isSame(request.getMethod())) {
redirectBuilder = ClassicRequestBuilder.get();
} else {
redirectBuilder = ClassicRequestBuilder.copy(scope.originalRequest);
redirectBuilder = ClassicRequestBuilder.copy(originalRequest);
}
break;
case HttpStatus.SC_SEE_OTHER:
if (!Method.GET.isSame(request.getMethod()) && !Method.HEAD.isSame(request.getMethod())) {
redirectBuilder = ClassicRequestBuilder.get();
} else {
redirectBuilder = ClassicRequestBuilder.copy(scope.originalRequest);
redirectBuilder = ClassicRequestBuilder.copy(originalRequest);
}
break;
default:
redirectBuilder = ClassicRequestBuilder.copy(scope.originalRequest);
redirectBuilder = ClassicRequestBuilder.copy(originalRequest);
}
redirectBuilder.setUri(redirectUri);

Expand Down Expand Up @@ -201,6 +202,7 @@ public ClassicHttpResponse execute(
if (LOG.isDebugEnabled()) {
LOG.debug("{} redirecting to '{}' via {}", exchangeId, redirectUri, currentRoute);
}
originalRequest = redirectBuilder.build();
currentRequest = redirectBuilder.build();
RequestEntityProxy.enhance(currentRequest);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
import org.apache.hc.core5.http.HttpHost;
import org.apache.hc.core5.http.HttpStatus;
import org.apache.hc.core5.http.ProtocolException;
import org.apache.hc.core5.http.io.support.ClassicRequestBuilder;
import org.apache.hc.core5.http.io.support.ClassicResponseBuilder;
import org.apache.hc.core5.http.message.BasicClassicHttpResponse;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
Expand Down Expand Up @@ -357,6 +359,56 @@ public void testRedirectProtocolException() throws Exception {
Mockito.verify(response1).close();
}

@Test
public void testPutSeeOtherRedirect() throws Exception {
final HttpRoute route = new HttpRoute(target);
final URI targetUri = new URI("http://localhost:80/stuff");
final ClassicHttpRequest request = ClassicRequestBuilder.put()
.setUri(targetUri)
.setEntity("stuff")
.build();
final HttpClientContext context = HttpClientContext.create();

final URI redirect1 = new URI("http://localhost:80/see-something-else");
final ClassicHttpResponse response1 = ClassicResponseBuilder.create(HttpStatus.SC_SEE_OTHER)
.addHeader(HttpHeaders.LOCATION, redirect1.toASCIIString())
.build();
final URI redirect2 = new URI("http://localhost:80/other-stuff");
final ClassicHttpResponse response2 = ClassicResponseBuilder.create(HttpStatus.SC_MOVED_PERMANENTLY)
.addHeader(HttpHeaders.LOCATION, redirect2.toASCIIString())
.build();
final ClassicHttpResponse response3 = ClassicResponseBuilder.create(HttpStatus.SC_OK)
.build();

Mockito.when(chain.proceed(
HttpRequestMatcher.matchesRequestUri(targetUri),
ArgumentMatchers.any())).thenReturn(response1);
Mockito.when(chain.proceed(
HttpRequestMatcher.matchesRequestUri(redirect1),
ArgumentMatchers.any())).thenReturn(response2);
Mockito.when(chain.proceed(
HttpRequestMatcher.matchesRequestUri(redirect2),
ArgumentMatchers.any())).thenReturn(response3);

final ExecChain.Scope scope = new ExecChain.Scope("test", route, request, endpoint, context);
final ClassicHttpResponse finalResponse = redirectExec.execute(request, scope, chain);
Assertions.assertEquals(200, finalResponse.getCode());

final ArgumentCaptor<ClassicHttpRequest> reqCaptor = ArgumentCaptor.forClass(ClassicHttpRequest.class);
Mockito.verify(chain, Mockito.times(3)).proceed(reqCaptor.capture(), ArgumentMatchers.same(scope));

final List<ClassicHttpRequest> allValues = reqCaptor.getAllValues();
Assertions.assertNotNull(allValues);
Assertions.assertEquals(3, allValues.size());
final ClassicHttpRequest request1 = allValues.get(0);
final ClassicHttpRequest request2 = allValues.get(1);
final ClassicHttpRequest request3 = allValues.get(2);
Assertions.assertSame(request, request1);
Assertions.assertEquals(request1.getMethod(), "PUT");
Assertions.assertEquals(request2.getMethod(), "GET");
Assertions.assertEquals(request3.getMethod(), "GET");
}

private static class HttpRequestMatcher implements ArgumentMatcher<ClassicHttpRequest> {

private final URI expectedRequestUri;
Expand Down

0 comments on commit e2cff33

Please sign in to comment.