|
17 | 17 | import static org.pac4j.core.util.CommonHelper.assertNotNull;
|
18 | 18 | import static org.pac4j.core.util.CommonHelper.isNotEmpty;
|
19 | 19 |
|
| 20 | +import com.fasterxml.jackson.core.type.TypeReference; |
20 | 21 | import com.google.common.collect.ImmutableMap;
|
21 | 22 | import com.google.common.collect.ImmutableMap.Builder;
|
22 | 23 | import com.nimbusds.jose.JOSEException;
|
|
29 | 30 | import com.nimbusds.oauth2.sdk.auth.PrivateKeyJWT;
|
30 | 31 | import com.nimbusds.oauth2.sdk.auth.Secret;
|
31 | 32 | import com.nimbusds.oauth2.sdk.id.ClientID;
|
| 33 | +import com.nimbusds.oauth2.sdk.token.BearerAccessToken; |
| 34 | +import java.io.BufferedWriter; |
32 | 35 | import java.io.IOException;
|
| 36 | +import java.io.OutputStreamWriter; |
| 37 | +import java.net.HttpURLConnection; |
| 38 | +import java.net.URL; |
| 39 | +import java.nio.charset.StandardCharsets; |
33 | 40 | import java.security.Principal;
|
34 | 41 | import java.security.PrivateKey;
|
35 | 42 | import java.text.ParseException;
|
36 | 43 | import java.time.Instant;
|
37 | 44 | import java.util.Arrays;
|
38 | 45 | import java.util.Collection;
|
39 | 46 | import java.util.Date;
|
| 47 | +import java.util.HashMap; |
40 | 48 | import java.util.List;
|
41 | 49 | import java.util.Map;
|
42 | 50 | import java.util.Optional;
|
|
51 | 59 | import org.openmetadata.common.utils.CommonUtil;
|
52 | 60 | import org.openmetadata.schema.security.client.OidcClientConfig;
|
53 | 61 | import org.openmetadata.service.OpenMetadataApplicationConfig;
|
| 62 | +import org.openmetadata.service.util.JsonUtils; |
| 63 | +import org.pac4j.core.context.HttpConstants; |
54 | 64 | import org.pac4j.core.exception.TechnicalException;
|
55 | 65 | import org.pac4j.core.util.CommonHelper;
|
| 66 | +import org.pac4j.core.util.HttpUtils; |
56 | 67 | import org.pac4j.oidc.client.AzureAd2Client;
|
57 | 68 | import org.pac4j.oidc.client.GoogleOidcClient;
|
58 | 69 | import org.pac4j.oidc.client.OidcClient;
|
@@ -371,11 +382,49 @@ private static void removeOrRenewOidcCredentials(
|
371 | 382 | if (SecurityUtil.isCredentialsExpired(credentials)) {
|
372 | 383 | LOG.debug("Expired credentials found, trying to renew.");
|
373 | 384 | profilesUpdated = true;
|
374 |
| - OidcAuthenticator authenticator = new OidcAuthenticator(client.getConfiguration(), client); |
375 |
| - authenticator.refresh(credentials); |
| 385 | + if (client.getConfiguration() |
| 386 | + instanceof AzureAd2OidcConfiguration azureAd2OidcConfiguration) { |
| 387 | + refreshAccessTokenAzureAd2Token(azureAd2OidcConfiguration, credentials); |
| 388 | + } else { |
| 389 | + OidcAuthenticator authenticator = new OidcAuthenticator(client.getConfiguration(), client); |
| 390 | + authenticator.refresh(credentials); |
| 391 | + } |
376 | 392 | }
|
377 | 393 | if (profilesUpdated) {
|
378 | 394 | request.getSession().setAttribute(OIDC_CREDENTIAL_PROFILE, credentials);
|
379 | 395 | }
|
380 | 396 | }
|
| 397 | + |
| 398 | + private static void refreshAccessTokenAzureAd2Token( |
| 399 | + AzureAd2OidcConfiguration azureConfig, OidcCredentials azureAdProfile) { |
| 400 | + HttpURLConnection connection = null; |
| 401 | + try { |
| 402 | + Map<String, String> headers = new HashMap<>(); |
| 403 | + headers.put( |
| 404 | + HttpConstants.CONTENT_TYPE_HEADER, HttpConstants.APPLICATION_FORM_ENCODED_HEADER_VALUE); |
| 405 | + headers.put(HttpConstants.ACCEPT_HEADER, HttpConstants.APPLICATION_JSON); |
| 406 | + // get the token endpoint from discovery URI |
| 407 | + URL tokenEndpointURL = azureConfig.findProviderMetadata().getTokenEndpointURI().toURL(); |
| 408 | + connection = HttpUtils.openPostConnection(tokenEndpointURL, headers); |
| 409 | + |
| 410 | + BufferedWriter out = |
| 411 | + new BufferedWriter( |
| 412 | + new OutputStreamWriter(connection.getOutputStream(), StandardCharsets.UTF_8)); |
| 413 | + out.write(azureConfig.makeOauth2TokenRequest(azureAdProfile.getRefreshToken().getValue())); |
| 414 | + out.close(); |
| 415 | + |
| 416 | + int responseCode = connection.getResponseCode(); |
| 417 | + if (responseCode != 200) { |
| 418 | + throw new TechnicalException( |
| 419 | + "request for access token failed: " + HttpUtils.buildHttpErrorMessage(connection)); |
| 420 | + } |
| 421 | + var body = HttpUtils.readBody(connection); |
| 422 | + Map<String, Object> res = JsonUtils.readValue(body, new TypeReference<>() {}); |
| 423 | + azureAdProfile.setAccessToken(new BearerAccessToken((String) res.get("access_token"))); |
| 424 | + } catch (final IOException e) { |
| 425 | + throw new TechnicalException(e); |
| 426 | + } finally { |
| 427 | + HttpUtils.closeConnection(connection); |
| 428 | + } |
| 429 | + } |
381 | 430 | }
|
0 commit comments