Skip to content

Commit b101d6d

Browse files
committed
Align RestClient parameters with WebClient
Issue spring-projectsgh-11298
1 parent 680f29c commit b101d6d

File tree

6 files changed

+353
-18
lines changed

6 files changed

+353
-18
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractRestClientOAuth2AccessTokenResponseClient.java

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.security.oauth2.client.endpoint;
1818

19+
import java.util.function.Consumer;
20+
1921
import org.springframework.core.convert.converter.Converter;
2022
import org.springframework.http.HttpHeaders;
2123
import org.springframework.http.converter.FormHttpMessageConverter;
@@ -75,8 +77,14 @@ public abstract class AbstractRestClientOAuth2AccessTokenResponseClient<T extend
7577

7678
private Converter<T, HttpHeaders> headersConverter = new DefaultOAuth2TokenRequestHeadersConverter<>();
7779

80+
private Consumer<HttpHeaders> headersCustomizer = (headers) -> {
81+
};
82+
7883
private Converter<T, MultiValueMap<String, String>> parametersConverter = this::createParameters;
7984

85+
private Consumer<MultiValueMap<String, String>> parametersCustomizer = (parameters) -> {
86+
};
87+
8088
AbstractRestClientOAuth2AccessTokenResponseClient() {
8189
}
8290

@@ -124,15 +132,21 @@ private void validateClientAuthenticationMethod(T grantRequest) {
124132
}
125133

126134
private RequestHeadersSpec<?> populateRequest(T grantRequest) {
135+
MultiValueMap<String, String> parameters = this.parametersConverter.convert(grantRequest);
136+
if (parameters == null) {
137+
parameters = new LinkedMultiValueMap<>();
138+
}
139+
this.parametersCustomizer.accept(parameters);
127140
return this.restClient.post()
128141
.uri(grantRequest.getClientRegistration().getProviderDetails().getTokenUri())
129142
.headers((headers) -> {
130143
HttpHeaders headersToAdd = this.headersConverter.convert(grantRequest);
131144
if (headersToAdd != null) {
132145
headers.addAll(headersToAdd);
133146
}
147+
this.headersCustomizer.accept(headers);
134148
})
135-
.body(this.parametersConverter.convert(grantRequest));
149+
.body(parameters);
136150
}
137151

138152
/**
@@ -207,6 +221,17 @@ public final void addHeadersConverter(Converter<T, HttpHeaders> headersConverter
207221
this.requestEntityConverter = this::populateRequest;
208222
}
209223

224+
/**
225+
* Sets the {@link Consumer} used for customizing all of the OAuth 2.0 Access Token
226+
* headers, which allows for headers to be added, overwritten or removed.
227+
* @param headersCustomizer the {@link Consumer} to customize the headers
228+
* @since 6.4
229+
*/
230+
public final void setHeadersCustomizer(Consumer<HttpHeaders> headersCustomizer) {
231+
Assert.notNull(headersCustomizer, "headersCustomizer cannot be null");
232+
this.headersCustomizer = headersCustomizer;
233+
}
234+
210235
/**
211236
* Sets the {@link Converter} used for converting the
212237
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
@@ -216,7 +241,18 @@ public final void addHeadersConverter(Converter<T, HttpHeaders> headersConverter
216241
*/
217242
public final void setParametersConverter(Converter<T, MultiValueMap<String, String>> parametersConverter) {
218243
Assert.notNull(parametersConverter, "parametersConverter cannot be null");
219-
this.parametersConverter = parametersConverter;
244+
Converter<T, MultiValueMap<String, String>> defaultParametersConverter = this::createParameters;
245+
this.parametersConverter = (authorizationGrantRequest) -> {
246+
MultiValueMap<String, String> parameters = defaultParametersConverter.convert(authorizationGrantRequest);
247+
if (parameters == null) {
248+
parameters = new LinkedMultiValueMap<>();
249+
}
250+
MultiValueMap<String, String> parametersToSet = parametersConverter.convert(authorizationGrantRequest);
251+
if (parametersToSet != null) {
252+
parameters.putAll(parametersToSet);
253+
}
254+
return parameters;
255+
};
220256
this.requestEntityConverter = this::populateRequest;
221257
}
222258

@@ -246,4 +282,15 @@ public final void addParametersConverter(Converter<T, MultiValueMap<String, Stri
246282
this.requestEntityConverter = this::populateRequest;
247283
}
248284

285+
/**
286+
* Sets the {@link Consumer} used for customizing all of the OAuth 2.0 Access Token
287+
* parameters, which allows for parameters to be added, overwritten or removed.
288+
* @param parametersCustomizer the {@link Consumer} to customize the parameters
289+
* @since 6.4
290+
*/
291+
public final void setParametersCustomizer(Consumer<MultiValueMap<String, String>> parametersCustomizer) {
292+
Assert.notNull(parametersCustomizer, "parametersCustomizer cannot be null");
293+
this.parametersCustomizer = parametersCustomizer;
294+
}
295+
249296
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientAuthorizationCodeTokenResponseClientTests.java

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.nio.charset.StandardCharsets;
2222
import java.time.Instant;
2323
import java.util.Collections;
24+
import java.util.function.Consumer;
2425

2526
import okhttp3.mockwebserver.MockResponse;
2627
import okhttp3.mockwebserver.MockWebServer;
@@ -54,6 +55,7 @@
5455
import static org.assertj.core.api.Assertions.assertThat;
5556
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
5657
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
58+
import static org.mockito.ArgumentMatchers.any;
5759
import static org.mockito.BDDMockito.given;
5860
import static org.mockito.Mockito.mock;
5961
import static org.mockito.Mockito.spy;
@@ -135,6 +137,15 @@ public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
135137
// @formatter:on
136138
}
137139

140+
@Test
141+
public void setHeadersCustomizerWhenNullThenThrowIllegalArgumentException() {
142+
// @formatter:off
143+
assertThatIllegalArgumentException()
144+
.isThrownBy(() -> this.tokenResponseClient.setHeadersCustomizer(null))
145+
.withMessage("headersCustomizer cannot be null");
146+
// @formatter:on
147+
}
148+
138149
@Test
139150
public void setParametersConverterWhenNullThenThrowIllegalArgumentException() {
140151
// @formatter:off
@@ -153,6 +164,15 @@ public void addParametersConverterWhenNullThenThrowIllegalArgumentException() {
153164
// @formatter:on
154165
}
155166

167+
@Test
168+
public void setParametersCustomizerWhenNullThenThrowIllegalArgumentException() {
169+
// @formatter:off
170+
assertThatIllegalArgumentException()
171+
.isThrownBy(() -> this.tokenResponseClient.setParametersCustomizer(null))
172+
.withMessage("parametersCustomizer cannot be null");
173+
// @formatter:on
174+
}
175+
156176
@Test
157177
public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() {
158178
// @formatter:off
@@ -419,6 +439,25 @@ public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception
419439
assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value");
420440
}
421441

442+
@Test
443+
public void getTokenResponseWhenHeadersCustomizerSetThenCalled() throws Exception {
444+
// @formatter:off
445+
String accessTokenSuccessResponse = "{\n"
446+
+ " \"access_token\": \"access-token-1234\",\n"
447+
+ " \"token_type\": \"bearer\",\n"
448+
+ " \"expires_in\": \"3600\"\n"
449+
+ "}\n";
450+
// @formatter:on
451+
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
452+
ClientRegistration clientRegistration = this.clientRegistration.build();
453+
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
454+
this.authorizationExchange);
455+
Consumer<HttpHeaders> headersCustomizer = mock(Consumer.class);
456+
this.tokenResponseClient.setHeadersCustomizer(headersCustomizer);
457+
this.tokenResponseClient.getTokenResponse(grantRequest);
458+
verify(headersCustomizer).accept(any(HttpHeaders.class));
459+
}
460+
422461
@Test
423462
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
424463
// @formatter:off
@@ -463,18 +502,18 @@ public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultP
463502
parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom");
464503
parameters.set(OAuth2ParameterNames.CODE, "custom-code");
465504
parameters.set(OAuth2ParameterNames.REDIRECT_URI, "custom-uri");
466-
// The client_id parameter is omitted for testing purposes
467505
this.tokenResponseClient.setParametersConverter((authorizationGrantRequest) -> parameters);
468506
this.tokenResponseClient.getTokenResponse(grantRequest);
469507
RecordedRequest recordedRequest = this.server.takeRequest();
470508
String formParameters = recordedRequest.getBody().readUtf8();
471509
// @formatter:off
472510
assertThat(formParameters).contains(
473511
param(OAuth2ParameterNames.GRANT_TYPE, "custom"),
512+
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
474513
param(OAuth2ParameterNames.CODE, "custom-code"),
475-
param(OAuth2ParameterNames.REDIRECT_URI, "custom-uri"));
514+
param(OAuth2ParameterNames.REDIRECT_URI, "custom-uri")
515+
);
476516
// @formatter:on
477-
assertThat(formParameters).doesNotContain(OAuth2ParameterNames.CLIENT_ID);
478517
}
479518

480519
@Test
@@ -509,6 +548,25 @@ public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exce
509548
// @formatter:on
510549
}
511550

551+
@Test
552+
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
553+
// @formatter:off
554+
String accessTokenSuccessResponse = "{\n"
555+
+ " \"access_token\": \"access-token-1234\",\n"
556+
+ " \"token_type\": \"bearer\",\n"
557+
+ " \"expires_in\": \"3600\"\n"
558+
+ "}\n";
559+
// @formatter:on
560+
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
561+
ClientRegistration clientRegistration = this.clientRegistration.build();
562+
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
563+
this.authorizationExchange);
564+
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock(Consumer.class);
565+
this.tokenResponseClient.setParametersCustomizer(parametersCustomizer);
566+
this.tokenResponseClient.getTokenResponse(grantRequest);
567+
verify(parametersCustomizer).accept(any());
568+
}
569+
512570
@Test
513571
public void getTokenResponseWhenRestClientSetThenCalled() {
514572
// @formatter:off

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientClientCredentialsTokenResponseClientTests.java

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.time.Instant;
2323
import java.util.Collections;
2424
import java.util.Set;
25+
import java.util.function.Consumer;
2526

2627
import okhttp3.mockwebserver.MockResponse;
2728
import okhttp3.mockwebserver.MockWebServer;
@@ -53,6 +54,7 @@
5354
import static org.assertj.core.api.Assertions.assertThat;
5455
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
5556
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
57+
import static org.mockito.ArgumentMatchers.any;
5658
import static org.mockito.BDDMockito.given;
5759
import static org.mockito.Mockito.mock;
5860
import static org.mockito.Mockito.spy;
@@ -119,6 +121,15 @@ public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
119121
// @formatter:on
120122
}
121123

124+
@Test
125+
public void setHeadersCustomizerWhenNullThenThrowIllegalArgumentException() {
126+
// @formatter:off
127+
assertThatIllegalArgumentException()
128+
.isThrownBy(() -> this.tokenResponseClient.setHeadersCustomizer(null))
129+
.withMessage("headersCustomizer cannot be null");
130+
// @formatter:on
131+
}
132+
122133
@Test
123134
public void setParametersConverterWhenNullThenThrowIllegalArgumentException() {
124135
// @formatter:off
@@ -137,6 +148,15 @@ public void addParametersConverterWhenNullThenThrowIllegalArgumentException() {
137148
// @formatter:on
138149
}
139150

151+
@Test
152+
public void setParametersCustomizerWhenNullThenThrowIllegalArgumentException() {
153+
// @formatter:off
154+
assertThatIllegalArgumentException()
155+
.isThrownBy(() -> this.tokenResponseClient.setParametersCustomizer(null))
156+
.withMessage("parametersCustomizer cannot be null");
157+
// @formatter:on
158+
}
159+
140160
@Test
141161
public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() {
142162
// @formatter:off
@@ -428,6 +448,24 @@ public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception
428448
assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value");
429449
}
430450

451+
@Test
452+
public void getTokenResponseWhenHeadersCustomizerSetThenCalled() throws Exception {
453+
// @formatter:off
454+
String accessTokenSuccessResponse = "{\n"
455+
+ " \"access_token\": \"access-token-1234\",\n"
456+
+ " \"token_type\": \"bearer\",\n"
457+
+ " \"expires_in\": \"3600\"\n"
458+
+ "}\n";
459+
// @formatter:on
460+
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
461+
ClientRegistration clientRegistration = this.clientRegistration.build();
462+
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
463+
Consumer<HttpHeaders> headersCustomizer = mock(Consumer.class);
464+
this.tokenResponseClient.setHeadersCustomizer(headersCustomizer);
465+
this.tokenResponseClient.getTokenResponse(grantRequest);
466+
verify(headersCustomizer).accept(any(HttpHeaders.class));
467+
}
468+
431469
@Test
432470
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
433471
// @formatter:off
@@ -471,7 +509,6 @@ public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultP
471509
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
472510
parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom");
473511
parameters.set(OAuth2ParameterNames.SCOPE, "one two");
474-
// The client_id parameter is omitted for testing purposes
475512
given(parametersConverter.convert(grantRequest)).willReturn(parameters);
476513
this.tokenResponseClient.setParametersConverter((authorizationGrantRequest) -> parameters);
477514
this.tokenResponseClient.getTokenResponse(grantRequest);
@@ -480,9 +517,10 @@ public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultP
480517
// @formatter:off
481518
assertThat(formParameters).contains(
482519
param(OAuth2ParameterNames.GRANT_TYPE, "custom"),
483-
param(OAuth2ParameterNames.SCOPE, "one two"));
520+
param(OAuth2ParameterNames.CLIENT_ID, "client-1"),
521+
param(OAuth2ParameterNames.SCOPE, "one two")
522+
);
484523
// @formatter:on
485-
assertThat(formParameters).doesNotContain(OAuth2ParameterNames.CLIENT_ID);
486524
}
487525

488526
@Test
@@ -517,6 +555,24 @@ public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exce
517555
// @formatter:on
518556
}
519557

558+
@Test
559+
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
560+
// @formatter:off
561+
String accessTokenSuccessResponse = "{\n"
562+
+ " \"access_token\": \"access-token-1234\",\n"
563+
+ " \"token_type\": \"bearer\",\n"
564+
+ " \"expires_in\": \"3600\"\n"
565+
+ "}\n";
566+
// @formatter:on
567+
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
568+
ClientRegistration clientRegistration = this.clientRegistration.build();
569+
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
570+
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock(Consumer.class);
571+
this.tokenResponseClient.setParametersCustomizer(parametersCustomizer);
572+
this.tokenResponseClient.getTokenResponse(grantRequest);
573+
verify(parametersCustomizer).accept(any());
574+
}
575+
520576
@Test
521577
public void getTokenResponseWhenRestClientSetThenCalled() {
522578
// @formatter:off

0 commit comments

Comments
 (0)