Skip to content

Commit

Permalink
Align RestClient parameters with WebClient
Browse files Browse the repository at this point in the history
  • Loading branch information
sjohnr committed Sep 27, 2024
1 parent 680f29c commit b101d6d
Show file tree
Hide file tree
Showing 6 changed files with 353 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.security.oauth2.client.endpoint;

import java.util.function.Consumer;

import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.converter.FormHttpMessageConverter;
Expand Down Expand Up @@ -75,8 +77,14 @@ public abstract class AbstractRestClientOAuth2AccessTokenResponseClient<T extend

private Converter<T, HttpHeaders> headersConverter = new DefaultOAuth2TokenRequestHeadersConverter<>();

private Consumer<HttpHeaders> headersCustomizer = (headers) -> {
};

private Converter<T, MultiValueMap<String, String>> parametersConverter = this::createParameters;

private Consumer<MultiValueMap<String, String>> parametersCustomizer = (parameters) -> {
};

AbstractRestClientOAuth2AccessTokenResponseClient() {
}

Expand Down Expand Up @@ -124,15 +132,21 @@ private void validateClientAuthenticationMethod(T grantRequest) {
}

private RequestHeadersSpec<?> populateRequest(T grantRequest) {
MultiValueMap<String, String> parameters = this.parametersConverter.convert(grantRequest);
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}
this.parametersCustomizer.accept(parameters);
return this.restClient.post()
.uri(grantRequest.getClientRegistration().getProviderDetails().getTokenUri())
.headers((headers) -> {
HttpHeaders headersToAdd = this.headersConverter.convert(grantRequest);
if (headersToAdd != null) {
headers.addAll(headersToAdd);
}
this.headersCustomizer.accept(headers);
})
.body(this.parametersConverter.convert(grantRequest));
.body(parameters);
}

/**
Expand Down Expand Up @@ -207,6 +221,17 @@ public final void addHeadersConverter(Converter<T, HttpHeaders> headersConverter
this.requestEntityConverter = this::populateRequest;
}

/**
* Sets the {@link Consumer} used for customizing all of the OAuth 2.0 Access Token
* headers, which allows for headers to be added, overwritten or removed.
* @param headersCustomizer the {@link Consumer} to customize the headers
* @since 6.4
*/
public final void setHeadersCustomizer(Consumer<HttpHeaders> headersCustomizer) {
Assert.notNull(headersCustomizer, "headersCustomizer cannot be null");
this.headersCustomizer = headersCustomizer;
}

/**
* Sets the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
Expand All @@ -216,7 +241,18 @@ public final void addHeadersConverter(Converter<T, HttpHeaders> headersConverter
*/
public final void setParametersConverter(Converter<T, MultiValueMap<String, String>> parametersConverter) {
Assert.notNull(parametersConverter, "parametersConverter cannot be null");
this.parametersConverter = parametersConverter;
Converter<T, MultiValueMap<String, String>> defaultParametersConverter = this::createParameters;
this.parametersConverter = (authorizationGrantRequest) -> {
MultiValueMap<String, String> parameters = defaultParametersConverter.convert(authorizationGrantRequest);
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}
MultiValueMap<String, String> parametersToSet = parametersConverter.convert(authorizationGrantRequest);
if (parametersToSet != null) {
parameters.putAll(parametersToSet);
}
return parameters;
};
this.requestEntityConverter = this::populateRequest;
}

Expand Down Expand Up @@ -246,4 +282,15 @@ public final void addParametersConverter(Converter<T, MultiValueMap<String, Stri
this.requestEntityConverter = this::populateRequest;
}

/**
* Sets the {@link Consumer} used for customizing all of the OAuth 2.0 Access Token
* parameters, which allows for parameters to be added, overwritten or removed.
* @param parametersCustomizer the {@link Consumer} to customize the parameters
* @since 6.4
*/
public final void setParametersCustomizer(Consumer<MultiValueMap<String, String>> parametersCustomizer) {
Assert.notNull(parametersCustomizer, "parametersCustomizer cannot be null");
this.parametersCustomizer = parametersCustomizer;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.function.Consumer;

import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
Expand Down Expand Up @@ -54,6 +55,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
Expand Down Expand Up @@ -135,6 +137,15 @@ public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
// @formatter:on
}

@Test
public void setHeadersCustomizerWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.tokenResponseClient.setHeadersCustomizer(null))
.withMessage("headersCustomizer cannot be null");
// @formatter:on
}

@Test
public void setParametersConverterWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
Expand All @@ -153,6 +164,15 @@ public void addParametersConverterWhenNullThenThrowIllegalArgumentException() {
// @formatter:on
}

@Test
public void setParametersCustomizerWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.tokenResponseClient.setParametersCustomizer(null))
.withMessage("parametersCustomizer cannot be null");
// @formatter:on
}

@Test
public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() {
// @formatter:off
Expand Down Expand Up @@ -419,6 +439,25 @@ public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception
assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value");
}

@Test
public void getTokenResponseWhenHeadersCustomizerSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
Consumer<HttpHeaders> headersCustomizer = mock(Consumer.class);
this.tokenResponseClient.setHeadersCustomizer(headersCustomizer);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(headersCustomizer).accept(any(HttpHeaders.class));
}

@Test
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
// @formatter:off
Expand Down Expand Up @@ -463,18 +502,18 @@ public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultP
parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom");
parameters.set(OAuth2ParameterNames.CODE, "custom-code");
parameters.set(OAuth2ParameterNames.REDIRECT_URI, "custom-uri");
// The client_id parameter is omitted for testing purposes
this.tokenResponseClient.setParametersConverter((authorizationGrantRequest) -> parameters);
this.tokenResponseClient.getTokenResponse(grantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.GRANT_TYPE, "custom"),
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
param(OAuth2ParameterNames.CODE, "custom-code"),
param(OAuth2ParameterNames.REDIRECT_URI, "custom-uri"));
param(OAuth2ParameterNames.REDIRECT_URI, "custom-uri")
);
// @formatter:on
assertThat(formParameters).doesNotContain(OAuth2ParameterNames.CLIENT_ID);
}

@Test
Expand Down Expand Up @@ -509,6 +548,25 @@ public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exce
// @formatter:on
}

@Test
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock(Consumer.class);
this.tokenResponseClient.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(parametersCustomizer).accept(any());
}

@Test
public void getTokenResponseWhenRestClientSetThenCalled() {
// @formatter:off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.time.Instant;
import java.util.Collections;
import java.util.Set;
import java.util.function.Consumer;

import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
Expand Down Expand Up @@ -53,6 +54,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
Expand Down Expand Up @@ -119,6 +121,15 @@ public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
// @formatter:on
}

@Test
public void setHeadersCustomizerWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.tokenResponseClient.setHeadersCustomizer(null))
.withMessage("headersCustomizer cannot be null");
// @formatter:on
}

@Test
public void setParametersConverterWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
Expand All @@ -137,6 +148,15 @@ public void addParametersConverterWhenNullThenThrowIllegalArgumentException() {
// @formatter:on
}

@Test
public void setParametersCustomizerWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.tokenResponseClient.setParametersCustomizer(null))
.withMessage("parametersCustomizer cannot be null");
// @formatter:on
}

@Test
public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() {
// @formatter:off
Expand Down Expand Up @@ -428,6 +448,24 @@ public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception
assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value");
}

@Test
public void getTokenResponseWhenHeadersCustomizerSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
Consumer<HttpHeaders> headersCustomizer = mock(Consumer.class);
this.tokenResponseClient.setHeadersCustomizer(headersCustomizer);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(headersCustomizer).accept(any(HttpHeaders.class));
}

@Test
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
// @formatter:off
Expand Down Expand Up @@ -471,7 +509,6 @@ public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultP
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom");
parameters.set(OAuth2ParameterNames.SCOPE, "one two");
// The client_id parameter is omitted for testing purposes
given(parametersConverter.convert(grantRequest)).willReturn(parameters);
this.tokenResponseClient.setParametersConverter((authorizationGrantRequest) -> parameters);
this.tokenResponseClient.getTokenResponse(grantRequest);
Expand All @@ -480,9 +517,10 @@ public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultP
// @formatter:off
assertThat(formParameters).contains(
param(OAuth2ParameterNames.GRANT_TYPE, "custom"),
param(OAuth2ParameterNames.SCOPE, "one two"));
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
param(OAuth2ParameterNames.SCOPE, "one two")
);
// @formatter:on
assertThat(formParameters).doesNotContain(OAuth2ParameterNames.CLIENT_ID);
}

@Test
Expand Down Expand Up @@ -517,6 +555,24 @@ public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exce
// @formatter:on
}

@Test
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock(Consumer.class);
this.tokenResponseClient.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(parametersCustomizer).accept(any());
}

@Test
public void getTokenResponseWhenRestClientSetThenCalled() {
// @formatter:off
Expand Down
Loading

0 comments on commit b101d6d

Please sign in to comment.