Skip to content

Commit

Permalink
Customize the strategy for resolving the principal
Browse files Browse the repository at this point in the history
  • Loading branch information
sjohnr committed Sep 20, 2024
1 parent c1a303b commit d725e7b
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
Expand Down Expand Up @@ -121,40 +119,24 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque

private final OAuth2AuthorizedClientManager authorizedClientManager;

private final ClientRegistrationIdResolver clientRegistrationIdResolver;
private ClientRegistrationIdResolver clientRegistrationIdResolver = new RequestAttributeClientRegistrationIdResolver();

private PrincipalResolver principalResolver = new SecurityContextHolderPrincipalResolver();

// @formatter:off
private OAuth2AuthorizationFailureHandler authorizationFailureHandler =
(clientRegistrationId, principal, attributes) -> { };
// @formatter:on

private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();

/**
* Constructs a {@code OAuth2ClientHttpRequestInterceptor} using the provided
* parameters.
* @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which
* manages the authorized client(s)
*/
public OAuth2ClientHttpRequestInterceptor(OAuth2AuthorizedClientManager authorizedClientManager) {
this(authorizedClientManager, new RequestAttributeClientRegistrationIdResolver());
}

/**
* Constructs a {@code OAuth2ClientHttpRequestInterceptor} using the provided
* parameters.
* @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which
* manages the authorized client(s)
* @param clientRegistrationIdResolver the strategy for resolving a
* {@code clientRegistrationId} from the intercepted request
*/
public OAuth2ClientHttpRequestInterceptor(OAuth2AuthorizedClientManager authorizedClientManager,
ClientRegistrationIdResolver clientRegistrationIdResolver) {
Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null");
Assert.notNull(clientRegistrationIdResolver, "clientRegistrationIdResolver cannot be null");
this.authorizedClientManager = authorizedClientManager;
this.clientRegistrationIdResolver = clientRegistrationIdResolver;
}

/**
Expand Down Expand Up @@ -238,20 +220,31 @@ public static OAuth2AuthorizationFailureHandler authorizationFailureHandler(
}

/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
* @param securityContextHolderStrategy the {@link SecurityContextHolderStrategy} to
* use
* Sets the strategy for resolving a {@code clientRegistrationId} from an intercepted
* request.
* @param clientRegistrationIdResolver the strategy for resolving a
* {@code clientRegistrationId} from an intercepted request
*/
public void setClientRegistrationIdResolver(ClientRegistrationIdResolver clientRegistrationIdResolver) {
Assert.notNull(clientRegistrationIdResolver, "clientRegistrationIdResolver cannot be null");
this.clientRegistrationIdResolver = clientRegistrationIdResolver;
}

/**
* Sets the strategy for resolving a {@link Authentication principal} from an
* intercepted request.
* @param principalResolver the strategy for resolving a {@link Authentication
* principal}
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
public void setPrincipalResolver(PrincipalResolver principalResolver) {
Assert.notNull(principalResolver, "principalResolver cannot be null");
this.principalResolver = principalResolver;
}

@Override
public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution)
throws IOException {
Authentication principal = this.securityContextHolderStrategy.getContext().getAuthentication();
Authentication principal = this.principalResolver.resolve(request);
if (principal == null) {
principal = ANONYMOUS_AUTHENTICATION;
}
Expand Down Expand Up @@ -378,4 +371,24 @@ public interface ClientRegistrationIdResolver {

}

/**
* A strategy for resolving a {@link Authentication principal} from an intercepted
* request.
*/
@FunctionalInterface
public interface PrincipalResolver {

/**
* Resolve the {@link Authentication principal} from the current request, which is
* used to obtain an {@link OAuth2AuthorizedClient}.
* @param request the intercepted request, containing HTTP method, URI, headers,
* and request attributes
* @return the {@link Authentication principal} to be used for resolving an
* {@link OAuth2AuthorizedClient}.
*/
@Nullable
Authentication resolve(HttpRequest request);

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client.web.client;

import java.util.Collections;
import java.util.Map;
import java.util.function.Consumer;

import org.springframework.http.HttpRequest;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.util.Assert;

/**
* A strategy for resolving a {@link Authentication principal} from an intercepted request
* using {@link ClientHttpRequest#getAttributes() attributes}.
*
* @author Steve Riesenberg
* @since 6.4
*/
public class RequestAttributePrincipalResolver implements OAuth2ClientHttpRequestInterceptor.PrincipalResolver {

private static final String PRINCIPAL_ATTR_NAME = RequestAttributePrincipalResolver.class.getName()
.concat(".principal");

@Override
public Authentication resolve(HttpRequest request) {
return (Authentication) request.getAttributes().get(PRINCIPAL_ATTR_NAME);
}

/**
* Modifies the {@link ClientHttpRequest#getAttributes() attributes} to include the
* {@link Authentication principal} to be used to look up the
* {@link OAuth2AuthorizedClient}.
* @param principal the {@link Authentication principal} to be used to look up the
* {@link OAuth2AuthorizedClient}
* @return the {@link Consumer} to populate the attributes
*/
public static Consumer<Map<String, Object>> principal(Authentication principal) {
Assert.notNull(principal, "principal cannot be null");
return (attributes) -> attributes.put(PRINCIPAL_ATTR_NAME, principal);
}

/**
* Modifies the {@link ClientHttpRequest#getAttributes() attributes} to include the
* {@link Authentication principal} to be used to look up the
* {@link OAuth2AuthorizedClient}.
* @param principalName the {@code principalName} to be used to look up the
* {@link OAuth2AuthorizedClient}
* @return the {@link Consumer} to populate the attributes
*/
public static Consumer<Map<String, Object>> principal(String principalName) {
Assert.hasText(principalName, "principalName cannot be empty");
Authentication principal = createAuthentication(principalName);
return (attributes) -> attributes.put(PRINCIPAL_ATTR_NAME, principal);
}

private static Authentication createAuthentication(String principalName) {
return new AbstractAuthenticationToken(Collections.emptySet()) {
@Override
public Object getPrincipal() {
return principalName;
}

@Override
public Object getCredentials() {
return null;
}
};
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client.web.client;

import org.springframework.http.HttpRequest;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;

/**
* A strategy for resolving a {@link Authentication principal} from an intercepted request
* using the {@link SecurityContextHolder}.
*
* @author Steve Riesenberg
* @since 6.4
*/
public class SecurityContextHolderPrincipalResolver implements OAuth2ClientHttpRequestInterceptor.PrincipalResolver {

private final SecurityContextHolderStrategy securityContextHolderStrategy;

/**
* Constructs a {@code SecurityContextHolderPrincipalResolver}.
*/
public SecurityContextHolderPrincipalResolver() {
this(SecurityContextHolder.getContextHolderStrategy());
}

/**
* Constructs a {@code SecurityContextHolderPrincipalResolver} using the provided
* parameters.
* @param securityContextHolderStrategy the {@link SecurityContextHolderStrategy} to
* use for resolving the {@link Authentication principal}
*/
public SecurityContextHolderPrincipalResolver(SecurityContextHolderStrategy securityContextHolderStrategy) {
this.securityContextHolderStrategy = securityContextHolderStrategy;
}

@Override
public Authentication resolve(HttpRequest request) {
return this.securityContextHolderStrategy.getContext().getAuthentication();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
Expand Down Expand Up @@ -110,15 +109,15 @@ public class OAuth2ClientHttpRequestInterceptorTests {
@Mock
private OAuth2AuthorizedClientRepository authorizedClientRepository;

@Mock
private SecurityContextHolderStrategy securityContextHolderStrategy;

@Mock
private OAuth2AuthorizedClientService authorizedClientService;

@Mock
private OAuth2ClientHttpRequestInterceptor.ClientRegistrationIdResolver clientRegistrationIdResolver;

@Mock
private OAuth2ClientHttpRequestInterceptor.PrincipalResolver principalResolver;

@Captor
private ArgumentCaptor<OAuth2AuthorizeRequest> authorizeRequestCaptor;

Expand Down Expand Up @@ -167,13 +166,6 @@ public void constructorWhenAuthorizedClientManagerIsNullThenThrowsIllegalArgumen
.withMessage("authorizedClientManager cannot be null");
}

@Test
public void constructorWhenClientRegistrationIdResolverIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> new OAuth2ClientHttpRequestInterceptor(this.authorizedClientManager, null))
.withMessage("clientRegistrationIdResolver cannot be null");
}

@Test
public void setAuthorizationFailureHandlerWhenNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException()
Expand All @@ -198,10 +190,16 @@ public void authorizationFailureHandlerWhenAuthorizedClientServiceIsNullThenThro
}

@Test
public void setSecurityContextHolderStrategyWhenNullThenThrowsIllegalArgumentException() {
public void setClientRegistrationIdResolverWhenNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.requestInterceptor.setSecurityContextHolderStrategy(null))
.withMessage("securityContextHolderStrategy cannot be null");
.isThrownBy(() -> this.requestInterceptor.setClientRegistrationIdResolver(null))
.withMessage("clientRegistrationIdResolver cannot be null");
}

@Test
public void setPrincipalResolverWhenNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.requestInterceptor.setPrincipalResolver(null))
.withMessage("principalResolver cannot be null");
}

@Test
Expand Down Expand Up @@ -605,8 +603,7 @@ public void interceptWhenUnauthorizedAndAuthorizationFailureHandlerSetWithAuthor

@Test
public void interceptWhenCustomClientRegistrationIdResolverSetThenUsed() {
this.requestInterceptor = new OAuth2ClientHttpRequestInterceptor(this.authorizedClientManager,
this.clientRegistrationIdResolver);
this.requestInterceptor.setClientRegistrationIdResolver(this.clientRegistrationIdResolver);
this.requestInterceptor.setAuthorizationFailureHandler(this.authorizationFailureHandler);
given(this.authorizedClientManager.authorize(any(OAuth2AuthorizeRequest.class)))
.willReturn(this.authorizedClient);
Expand All @@ -625,31 +622,29 @@ public void interceptWhenCustomClientRegistrationIdResolverSetThenUsed() {
this.server.verify();
verify(this.authorizedClientManager).authorize(this.authorizeRequestCaptor.capture());
verify(this.clientRegistrationIdResolver).resolve(any(HttpRequest.class));
verifyNoMoreInteractions(this.clientRegistrationIdResolver, this.authorizedClientManager);
verifyNoMoreInteractions(this.authorizedClientManager, this.clientRegistrationIdResolver);
verifyNoInteractions(this.authorizationFailureHandler);
OAuth2AuthorizeRequest authorizeRequest = this.authorizeRequestCaptor.getValue();
assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(clientRegistrationId);
assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal);
}

@Test
public void interceptWhenCustomSecurityContextHolderStrategySetThenUsed() {
this.requestInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
public void interceptWhenCustomPrincipalResolverSetThenUsed() {
this.requestInterceptor.setPrincipalResolver(this.principalResolver);
given(this.authorizedClientManager.authorize(any(OAuth2AuthorizeRequest.class)))
.willReturn(this.authorizedClient);

bindToRestClient(withRequestInterceptor());
this.server.expect(requestTo(REQUEST_URI))
.andExpect(hasAuthorizationHeader(this.authorizedClient.getAccessToken()))
.andRespond(withApplicationJson());
SecurityContext securityContext = new SecurityContextImpl();
securityContext.setAuthentication(this.principal);
given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext);
given(this.principalResolver.resolve(any(HttpRequest.class))).willReturn(this.principal);
performRequest(withClientRegistrationId());
this.server.verify();
verify(this.authorizedClientManager).authorize(this.authorizeRequestCaptor.capture());
verify(this.securityContextHolderStrategy).getContext();
verifyNoMoreInteractions(this.authorizedClientManager, this.securityContextHolderStrategy);
verify(this.principalResolver).resolve(any(HttpRequest.class));
verifyNoMoreInteractions(this.authorizedClientManager, this.principalResolver);
OAuth2AuthorizeRequest authorizeRequest = this.authorizeRequestCaptor.getValue();
assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId());
assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal);
Expand Down

0 comments on commit d725e7b

Please sign in to comment.