diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java index 6ed5b83d3..4ec58480e 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java @@ -1,5 +1,6 @@ package com.databricks.sdk.core; +import com.databricks.sdk.core.oauth.CachedTokenSource; import com.databricks.sdk.core.oauth.OAuthHeaderFactory; import com.databricks.sdk.core.oauth.Token; import com.databricks.sdk.core.utils.AzureUtils; @@ -19,7 +20,7 @@ public String authType() { return AZURE_CLI; } - public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource) { + public CachedTokenSource tokenSourceFor(DatabricksConfig config, String resource) { String azPath = Optional.ofNullable(config.getEnv()).map(env -> env.get("AZ_PATH")).orElse("az"); @@ -35,7 +36,7 @@ public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource) { List extendedCmd = new ArrayList<>(cmd); extendedCmd.addAll(Arrays.asList("--subscription", subscription.get())); try { - return getToken(config, extendedCmd); + return getTokenSource(config, extendedCmd); } catch (DatabricksException ex) { LOG.warn("Failed to get token for subscription. Using resource only token."); } @@ -45,14 +46,15 @@ public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource) { + "It is recommended to specify this field in the Databricks configuration to avoid authentication errors."); } - return getToken(config, cmd); + return getTokenSource(config, cmd); } - protected CliTokenSource getToken(DatabricksConfig config, List cmd) { - CliTokenSource token = + protected CachedTokenSource getTokenSource(DatabricksConfig config, List cmd) { + CliTokenSource tokenSource = new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config.getEnv()); - token.getToken(); // We need this to check if the CLI is installed and to validate the config. - return token; + CachedTokenSource cachedTokenSource = new CachedTokenSource.Builder(tokenSource).build(); + cachedTokenSource.getToken(); // Check if the CLI is installed and to validate the config. + return cachedTokenSource; } private Optional getSubscription(DatabricksConfig config) { @@ -77,8 +79,8 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { try { AzureUtils.ensureHostPresent(config, mapper, this::tokenSourceFor); String resource = config.getEffectiveAzureLoginAppId(); - CliTokenSource tokenSource = tokenSourceFor(config, resource); - CliTokenSource mgmtTokenSource; + CachedTokenSource tokenSource = tokenSourceFor(config, resource); + CachedTokenSource mgmtTokenSource; try { mgmtTokenSource = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); @@ -86,7 +88,7 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { LOG.debug("Not including service management token in headers", e); mgmtTokenSource = null; } - CliTokenSource finalMgmtTokenSource = mgmtTokenSource; + CachedTokenSource finalMgmtTokenSource = mgmtTokenSource; return OAuthHeaderFactory.fromSuppliers( tokenSource::getToken, () -> { diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index 17e409d93..d3fa200f4 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -1,7 +1,7 @@ package com.databricks.sdk.core; -import com.databricks.sdk.core.oauth.RefreshableTokenSource; import com.databricks.sdk.core.oauth.Token; +import com.databricks.sdk.core.oauth.TokenSource; import com.databricks.sdk.core.utils.Environment; import com.databricks.sdk.core.utils.OSUtils; import com.fasterxml.jackson.databind.JsonNode; @@ -18,7 +18,7 @@ import java.util.List; import org.apache.commons.io.IOUtils; -public class CliTokenSource extends RefreshableTokenSource { +public class CliTokenSource implements TokenSource { private List cmd; private String tokenTypeField; private String accessTokenField; @@ -86,7 +86,7 @@ private String getProcessStream(InputStream stream) throws IOException { } @Override - protected Token refresh() { + public Token getToken() { try { ProcessBuilder processBuilder = new ProcessBuilder(cmd); processBuilder.environment().putAll(env.getEnv()); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java index 655d0b599..6d5a2eb9f 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java @@ -1,5 +1,6 @@ package com.databricks.sdk.core; +import com.databricks.sdk.core.oauth.CachedTokenSource; import com.databricks.sdk.core.oauth.OAuthHeaderFactory; import com.databricks.sdk.core.utils.OSUtils; import java.util.*; @@ -47,8 +48,11 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { if (tokenSource == null) { return null; } - tokenSource.getToken(); // We need this for checking if databricks CLI is installed. - return OAuthHeaderFactory.fromTokenSource(tokenSource); + + CachedTokenSource cachedTokenSource = new CachedTokenSource.Builder(tokenSource).build(); + cachedTokenSource.getToken(); // We need this for checking if databricks CLI is installed. + + return OAuthHeaderFactory.fromTokenSource(cachedTokenSource); } catch (DatabricksException e) { String stderr = e.getMessage(); if (stderr.contains("not found")) { diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProvider.java index b29b5aa0e..281d2d9a4 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProvider.java @@ -46,8 +46,8 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { config.getEffectiveAzureLoginAppId(), idToken.get(), "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); - - return OAuthHeaderFactory.fromTokenSource(tokenSource); + CachedTokenSource cachedTokenSource = new CachedTokenSource.Builder(tokenSource).build(); + return OAuthHeaderFactory.fromTokenSource(cachedTokenSource); } /** diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java index c7c7bb672..a7a041f41 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java @@ -28,8 +28,8 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { } AzureUtils.ensureHostPresent( config, mapper, AzureServicePrincipalCredentialsProvider::tokenSourceFor); - RefreshableTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); - RefreshableTokenSource cloud = + CachedTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); + CachedTokenSource cloud = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); return OAuthHeaderFactory.fromSuppliers( @@ -44,29 +44,32 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { } /** - * Creates a RefreshableTokenSource for the specified Azure resource. + * Creates a CachedTokenSource for the specified Azure resource. * - *

This function constructs a RefreshableTokenSource instance that fetches OAuth tokens for the + *

This function constructs a CachedTokenSource instance that fetches OAuth tokens for the * given Azure resource. It uses the authentication parameters provided by the DatabricksConfig * instance to generate the tokens. * * @param config The DatabricksConfig instance containing the required authentication parameters. * @param resource The Azure resource for which OAuth tokens need to be fetched. - * @return A RefreshableTokenSource instance capable of fetching OAuth tokens for the specified - * Azure resource. + * @return A CachedTokenSource instance capable of fetching OAuth tokens for the specified Azure + * resource. */ - private static RefreshableTokenSource tokenSourceFor(DatabricksConfig config, String resource) { + private static CachedTokenSource tokenSourceFor(DatabricksConfig config, String resource) { String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint(); String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token"; Map endpointParams = new HashMap<>(); endpointParams.put("resource", resource); - return new ClientCredentials.Builder() - .withHttpClient(config.getHttpClient()) - .withClientId(config.getAzureClientId()) - .withClientSecret(config.getAzureClientSecret()) - .withTokenUrl(tokenUrl) - .withEndpointParametersSupplier(() -> endpointParams) - .withAuthParameterPosition(AuthParameterPosition.BODY) - .build(); + + ClientCredentials clientCredentials = + new ClientCredentials.Builder() + .withHttpClient(config.getHttpClient()) + .withClientId(config.getAzureClientId()) + .withClientSecret(config.getAzureClientSecret()) + .withTokenUrl(tokenUrl) + .withEndpointParametersSupplier(() -> endpointParams) + .withAuthParameterPosition(AuthParameterPosition.BODY) + .build(); + return new CachedTokenSource.Builder(clientCredentials).build(); } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java similarity index 52% rename from databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java rename to databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java index c4afd0859..0dbabbb1a 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/RefreshableTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java @@ -1,18 +1,11 @@ package com.databricks.sdk.core.oauth; -import com.databricks.sdk.core.ApiClient; -import com.databricks.sdk.core.DatabricksException; -import com.databricks.sdk.core.http.FormRequest; -import com.databricks.sdk.core.http.HttpClient; -import com.databricks.sdk.core.http.Request; import com.databricks.sdk.core.utils.ClockSupplier; import com.databricks.sdk.core.utils.UtcClockSupplier; import java.time.Duration; import java.time.Instant; -import java.util.Base64; -import java.util.Map; +import java.util.Objects; import java.util.concurrent.CompletableFuture; -import org.apache.http.HttpHeaders; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -23,102 +16,161 @@ * stale tokens will trigger a background refresh, while expired tokens will block until a new token * is fetched. */ -public abstract class RefreshableTokenSource implements TokenSource { +public class CachedTokenSource implements TokenSource { /** * Enum representing the state of the token. FRESH: Token is valid and not close to expiry. STALE: * Token is valid but will expire soon - an async refresh will be triggered if enabled. EXPIRED: * Token has expired and must be refreshed using a blocking call. */ - protected enum TokenState { + private enum TokenState { FRESH, STALE, EXPIRED } - private static final Logger logger = LoggerFactory.getLogger(RefreshableTokenSource.class); + private static final Logger logger = LoggerFactory.getLogger(CachedTokenSource.class); // Default duration before expiry to consider a token as 'stale'. private static final Duration DEFAULT_STALE_DURATION = Duration.ofMinutes(3); // Default additional buffer before expiry to consider a token as expired. + // This is 40 seconds by default since Azure Databricks rejects tokens that are within 30 seconds + // of expiry. private static final Duration DEFAULT_EXPIRY_BUFFER = Duration.ofSeconds(40); - // The current OAuth token. May be null if not yet fetched. - protected volatile Token token; + // The token source to use for refreshing the token. + private final TokenSource tokenSource; // Whether asynchronous refresh is enabled. private boolean asyncEnabled = Boolean.parseBoolean(System.getenv("DATABRICKS_ENABLE_EXPERIMENTAL_ASYNC_TOKEN_REFRESH")); // Duration before expiry to consider a token as 'stale'. - private Duration staleDuration = DEFAULT_STALE_DURATION; + private final Duration staleDuration; // Additional buffer before expiry to consider a token as expired. - private Duration expiryBuffer = DEFAULT_EXPIRY_BUFFER; + private final Duration expiryBuffer; + // Clock supplier for current time. + private final ClockSupplier clockSupplier; + + // The current OAuth token. May be null if not yet fetched. + private volatile Token token; // Whether a refresh is currently in progress (for async refresh). private boolean refreshInProgress = false; // Whether the last refresh attempt succeeded. private boolean lastRefreshSucceeded = true; - // Clock supplier for current time. - private ClockSupplier clockSupplier = new UtcClockSupplier(); - - /** Constructs a new {@code RefreshableTokenSource} with no initial token. */ - public RefreshableTokenSource() {} - /** - * Constructor with initial token. - * - * @param token The initial token to use. - */ - public RefreshableTokenSource(Token token) { - this.token = token; + private CachedTokenSource(Builder builder) { + this.tokenSource = builder.tokenSource; + this.asyncEnabled = builder.asyncEnabled; + this.staleDuration = builder.staleDuration; + this.expiryBuffer = builder.expiryBuffer; + this.clockSupplier = builder.clockSupplier; + this.token = builder.token; } /** - * Set the clock supplier for current time. + * Builder for creating instances of {@link CachedTokenSource}. * - *

Experimental: This method may change or be removed in future releases. - * - * @param clockSupplier The clock supplier to use. - * @return this instance for chaining + *

This builder allows configuration of various aspects of token caching behavior, including + * asynchronous refresh, timing parameters, and initial token state. */ - public RefreshableTokenSource withClockSupplier(ClockSupplier clockSupplier) { - this.clockSupplier = clockSupplier; - return this; - } + public static class Builder { + private final TokenSource tokenSource; + private boolean asyncEnabled = false; + private Duration staleDuration = DEFAULT_STALE_DURATION; + private Duration expiryBuffer = DEFAULT_EXPIRY_BUFFER; + private ClockSupplier clockSupplier = new UtcClockSupplier(); + private Token token; - /** - * Enable or disable asynchronous token refresh. - * - *

Experimental: This method may change or be removed in future releases. - * - * @param enabled true to enable async refresh, false to disable - * @return this instance for chaining - */ - public RefreshableTokenSource withAsyncRefresh(boolean enabled) { - this.asyncEnabled = enabled; - return this; - } + /** + * Creates a new builder with the specified token source. + * + * @param tokenSource The underlying token source to use for refreshing tokens. + * @throws NullPointerException If the token source is null. + */ + public Builder(TokenSource tokenSource) { + this.tokenSource = Objects.requireNonNull(tokenSource); + } - /** - * Set the expiry buffer. If the token's lifetime is less than this buffer, it is considered - * expired. - * - *

Experimental: This method may change or be removed in future releases. - * - * @param buffer the expiry buffer duration - * @return this instance for chaining - */ - public RefreshableTokenSource withExpiryBuffer(Duration buffer) { - this.expiryBuffer = buffer; - return this; - } + /** + * Sets an initial token to use in the cache. + * + *

If provided, this token will be used immediately without requiring an initial refresh from + * the underlying token source. + * + * @param token The initial token to cache. + * @return This builder instance for method chaining. + */ + public Builder setToken(Token token) { + this.token = token; + return this; + } - /** - * Refresh the OAuth token. Subclasses must implement this to define how the token is refreshed. - * - *

This method may throw an exception if the token cannot be refreshed. The specific exception - * type depends on the implementation. - * - * @return The newly refreshed Token. - */ - protected abstract Token refresh(); + /** + * Enables or disables asynchronous token refresh. + * + *

When enabled, stale tokens will trigger a background refresh while continuing to serve the + * current token. When disabled, all refreshes are performed synchronously and will block the + * calling thread. + * + * @param asyncEnabled True to enable asynchronous refresh, false to disable. + * @return This builder instance for method chaining. + */ + public Builder setAsyncEnabled(boolean asyncEnabled) { + this.asyncEnabled = asyncEnabled; + return this; + } + + /** + * Sets the duration before token expiry at which the token is considered stale. + * + *

When asynchronous refresh is enabled, tokens that are stale but not yet expired will + * trigger a background refresh while continuing to serve the current token. + * + * @param staleDuration The duration before expiry to consider a token stale. Must be greater + * than the expiry buffer duration. + * @return This builder instance for method chaining. + */ + public Builder setStaleDuration(Duration staleDuration) { + this.staleDuration = staleDuration; + return this; + } + + /** + * Sets the buffer duration before token expiry at which the token is considered expired. + * + *

Tokens within this buffer of their expiry time will be considered expired and require + * synchronous refresh. + * + * @param expiryBuffer The buffer duration before expiry to consider a token expired. Must be + * less than the stale duration. + * @return This builder instance for method chaining. + */ + public Builder setExpiryBuffer(Duration expiryBuffer) { + this.expiryBuffer = expiryBuffer; + return this; + } + + /** + * Sets the clock supplier to use for time-based operations. + * + *

This is primarily useful for testing scenarios where you need to control the current time. + * In production, the default UTC clock supplier should be used. + * + * @param clockSupplier The clock supplier to use for determining current time. + * @return This builder instance for method chaining. + */ + public Builder setClockSupplier(ClockSupplier clockSupplier) { + this.clockSupplier = clockSupplier; + return this; + } + + /** + * Builds and returns a new {@link CachedTokenSource} instance with the configured parameters. + * + * @return A new CachedTokenSource instance. + */ + public CachedTokenSource build() { + return new CachedTokenSource(this); + } + } /** * Gets the current token, refreshing if necessary. If async refresh is enabled, may return a @@ -177,7 +229,7 @@ protected Token getTokenBlocking() { } lastRefreshSucceeded = false; try { - token = refresh(); + token = tokenSource.getToken(); } catch (Exception e) { logger.error("Failed to refresh token synchronously", e); throw e; @@ -224,8 +276,8 @@ private synchronized void triggerAsyncRefresh() { CompletableFuture.runAsync( () -> { try { - // Attempt to refresh the token in the background - Token newToken = refresh(); + // Attempt to refresh the token in the background. + Token newToken = tokenSource.getToken(); synchronized (this) { token = newToken; refreshInProgress = false; @@ -240,58 +292,4 @@ private synchronized void triggerAsyncRefresh() { }); } } - - /** - * Helper method implementing OAuth token refresh. - * - * @param hc The HTTP client to use for the request. - * @param clientId The client ID to authenticate with. - * @param clientSecret The client secret to authenticate with. - * @param tokenUrl The authorization URL for fetching tokens. - * @param params Additional request parameters. - * @param headers Additional headers. - * @param position The position of the authentication parameters in the request. - * @return The newly fetched Token. - * @throws DatabricksException if the refresh fails - * @throws IllegalArgumentException if the OAuth response contains an error - */ - protected static Token retrieveToken( - HttpClient hc, - String clientId, - String clientSecret, - String tokenUrl, - Map params, - Map headers, - AuthParameterPosition position) { - switch (position) { - case BODY: - if (clientId != null) { - params.put("client_id", clientId); - } - if (clientSecret != null) { - params.put("client_secret", clientSecret); - } - break; - case HEADER: - String authHeaderValue = - "Basic " - + Base64.getEncoder().encodeToString((clientId + ":" + clientSecret).getBytes()); - headers.put(HttpHeaders.AUTHORIZATION, authHeaderValue); - break; - } - headers.put("Content-Type", "application/x-www-form-urlencoded"); - Request req = new Request("POST", tokenUrl, FormRequest.wrapValuesInList(params)); - req.withHeaders(headers); - try { - ApiClient apiClient = new ApiClient.Builder().withHttpClient(hc).build(); - OAuthResponse resp = apiClient.execute(req, OAuthResponse.class); - if (resp.getErrorCode() != null) { - throw new IllegalArgumentException(resp.getErrorCode() + ": " + resp.getErrorSummary()); - } - Instant expiry = Instant.now().plusSeconds(resp.getExpiresIn()); - return new Token(resp.getAccessToken(), resp.getTokenType(), resp.getRefreshToken(), expiry); - } catch (Exception e) { - throw new DatabricksException("Failed to refresh credentials: " + e.getMessage(), e); - } - } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java index 1c4b7d6de..8cee3ef29 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java @@ -13,7 +13,7 @@ * support all OAuth endpoints, authentication parameters can be passed in the request body or in * the Authorization header. */ -public class ClientCredentials extends RefreshableTokenSource { +public class ClientCredentials implements TokenSource { public static class Builder { private String clientId; private String clientSecret; @@ -97,7 +97,7 @@ private ClientCredentials( } @Override - protected Token refresh() { + public Token getToken() { Map params = new HashMap<>(); params.put("grant_type", "client_credentials"); if (scopes != null) { @@ -106,6 +106,7 @@ protected Token refresh() { if (endpointParamsSupplier != null) { params.putAll(endpointParamsSupplier.get()); } - return retrieveToken(hc, clientId, clientSecret, tokenUrl, params, new HashMap<>(), position); + return TokenEndpointClient.retrieveToken( + hc, clientId, clientSecret, tokenUrl, params, new HashMap<>(), position); } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java index 57c97a2ff..aee9fe50f 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Consent.java @@ -269,7 +269,7 @@ private Token exchange(String code, String state) { headers.put("Origin", this.redirectUrl); } Token token = - RefreshableTokenSource.retrieveToken( + TokenEndpointClient.retrieveToken( this.hc, this.clientId, this.clientSecret, diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java index 01bd05d2c..37e7d707e 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java @@ -8,17 +8,16 @@ /** * Manages and provides Databricks data plane tokens. This class is responsible for acquiring and * caching OAuth tokens that are specific to a particular Databricks data plane service endpoint and - * a set of authorization details. It utilizes a {@link DatabricksOAuthTokenSource} for obtaining - * control plane tokens, which may then be exchanged or used to authorize requests for data plane - * tokens. Cached {@link EndpointTokenSource} instances are used to efficiently reuse tokens for - * repeated requests to the same endpoint with the same authorization context. + * a set of authorization details. It utilizes a {@link TokenSource} for obtaining control plane + * tokens, which may then be exchanged or used to authorize requests for data plane tokens. Cached + * {@link EndpointTokenSource} instances are used to efficiently reuse tokens for repeated requests + * to the same endpoint with the same authorization context. */ public class DataPlaneTokenSource { private final HttpClient httpClient; private final TokenSource cpTokenSource; private final String host; - private final ConcurrentHashMap sourcesCache; - + private final ConcurrentHashMap sourcesCache; /** * Caching key for {@link EndpointTokenSource}, based on endpoint and authorization details. This * is a value object that uniquely identifies a token source configuration. @@ -79,12 +78,14 @@ public Token getToken(String endpoint, String authDetails) { TokenSourceKey key = TokenSourceKey.create(endpoint, authDetails); - EndpointTokenSource specificSource = + CachedTokenSource specificSource = sourcesCache.computeIfAbsent( key, k -> - new EndpointTokenSource( - this.cpTokenSource, k.authDetails(), this.httpClient, this.host)); + new CachedTokenSource.Builder( + new EndpointTokenSource( + this.cpTokenSource, k.authDetails(), this.httpClient, this.host)) + .build()); return specificSource.getToken(); } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java index 484e0712e..627f238a5 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -14,7 +14,7 @@ * Implementation of TokenSource that handles OAuth token exchange for Databricks authentication. * This class manages the OAuth token exchange flow using ID tokens to obtain access tokens. */ -public class DatabricksOAuthTokenSource extends RefreshableTokenSource { +public class DatabricksOAuthTokenSource implements TokenSource { private static final Logger LOG = LoggerFactory.getLogger(DatabricksOAuthTokenSource.class); /** OAuth client ID used for token exchange. */ @@ -128,7 +128,7 @@ public DatabricksOAuthTokenSource build() { * @throws NullPointerException when any of the required parameters are null. */ @Override - public Token refresh() { + public Token getToken() { Objects.requireNonNull(clientId, "ClientID cannot be null"); Objects.requireNonNull(host, "Host cannot be null"); Objects.requireNonNull(endpoints, "Endpoints cannot be null"); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java index ed08f57d6..dfd601695 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java @@ -13,7 +13,7 @@ * Represents a token source that exchanges a control plane token for an endpoint-specific dataplane * token. It utilizes an underlying {@link TokenSource} to obtain the initial control plane token. */ -public class EndpointTokenSource extends RefreshableTokenSource { +public class EndpointTokenSource implements TokenSource { private static final Logger LOG = LoggerFactory.getLogger(EndpointTokenSource.class); private static final String JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"; private static final String GRANT_TYPE_PARAM = "grant_type"; @@ -67,7 +67,7 @@ public EndpointTokenSource( * @throws NullPointerException if any of the parameters are null. */ @Override - protected Token refresh() { + public Token getToken() { Token cpToken = cpTokenSource.getToken(); Map params = new HashMap<>(); params.put(GRANT_TYPE_PARAM, JWT_GRANT_TYPE); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java index af67daeba..0766d8a1b 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java @@ -78,9 +78,10 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { Optional.of(config.getEffectiveOAuthRedirectUrl()), Optional.of(tokenCache)); + CachedTokenSource cachedTokenSource = new CachedTokenSource.Builder(tokenSource).build(); LOGGER.debug("Using cached token, will immediately refresh"); - tokenSource.token = tokenSource.refresh(); - return OAuthHeaderFactory.fromTokenSource(tokenSource); + cachedTokenSource.getToken(); + return OAuthHeaderFactory.fromTokenSource(cachedTokenSource); } catch (Exception e) { // If token refresh fails, log and continue to browser auth LOGGER.info("Token refresh failed: {}, falling back to browser auth", e.getMessage()); @@ -88,17 +89,17 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { } // If no cached token or refresh failed, perform browser auth - SessionCredentialsTokenSource tokenSource = + CachedTokenSource cachedTokenSource = performBrowserAuth(config, clientId, clientSecret, tokenCache); - tokenCache.save(tokenSource.getToken()); - return OAuthHeaderFactory.fromTokenSource(tokenSource); + tokenCache.save(cachedTokenSource.getToken()); + return OAuthHeaderFactory.fromTokenSource(cachedTokenSource); } catch (IOException | DatabricksException e) { LOGGER.error("Failed to authenticate: {}", e.getMessage()); return null; } } - SessionCredentialsTokenSource performBrowserAuth( + CachedTokenSource performBrowserAuth( DatabricksConfig config, String clientId, String clientSecret, TokenCache tokenCache) throws IOException { LOGGER.debug("Performing browser authentication"); @@ -117,13 +118,16 @@ SessionCredentialsTokenSource performBrowserAuth( Token token = consent.getTokenFromExternalBrowser(); // Create a SessionCredentialsTokenSource with the token from browser auth. - return new SessionCredentialsTokenSource( - token, - config.getHttpClient(), - config.getOidcEndpoints().getTokenEndpoint(), - config.getClientId(), - config.getClientSecret(), - Optional.ofNullable(config.getEffectiveOAuthRedirectUrl()), - Optional.ofNullable(tokenCache)); + SessionCredentialsTokenSource tokenSource = + new SessionCredentialsTokenSource( + token, + config.getHttpClient(), + config.getOidcEndpoints().getTokenEndpoint(), + config.getClientId(), + config.getClientSecret(), + Optional.ofNullable(config.getEffectiveOAuthRedirectUrl()), + Optional.ofNullable(tokenCache)); + + return new CachedTokenSource.Builder(tokenSource).setToken(token).build(); } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubOidcCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubOidcCredentialsProvider.java index eeb70797e..c517b8a3e 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubOidcCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubOidcCredentialsProvider.java @@ -58,9 +58,11 @@ public HeaderFactory configure(DatabricksConfig config) throws DatabricksExcepti .build()) .build(); + CachedTokenSource cachedTokenSource = new CachedTokenSource.Builder(clientCredentials).build(); + return () -> { Map headers = new HashMap<>(); - headers.put("Authorization", "Bearer " + clientCredentials.getToken().getAccessToken()); + headers.put("Authorization", "Bearer " + cachedTokenSource.getToken().getAccessToken()); return headers; }; } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthM2MServicePrincipalCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthM2MServicePrincipalCredentialsProvider.java index 058fc268c..ec94c014b 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthM2MServicePrincipalCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthM2MServicePrincipalCredentialsProvider.java @@ -27,7 +27,7 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { // https://login.microsoftonline.com/{cfg.azure_tenant_id}/.well-known/oauth-authorization-server try { OpenIDConnectEndpoints jsonResponse = config.getOidcEndpoints(); - ClientCredentials tokenSource = + ClientCredentials clientCredentials = new ClientCredentials.Builder() .withHttpClient(config.getHttpClient()) .withClientId(config.getClientId()) @@ -37,7 +37,10 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { .withAuthParameterPosition(AuthParameterPosition.HEADER) .build(); - return OAuthHeaderFactory.fromTokenSource(tokenSource); + CachedTokenSource cachedTokenSource = + new CachedTokenSource.Builder(clientCredentials).build(); + + return OAuthHeaderFactory.fromTokenSource(cachedTokenSource); } catch (IOException e) { // TODO: Log exception throw new DatabricksException("Unable to fetch OIDC endpoint: " + e.getMessage(), e); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java index b15f55ded..901b6bcc4 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OidcTokenSource.java @@ -15,7 +15,7 @@ * protocol. It communicates with an OAuth server to request access tokens using the client * credentials grant type instead of a client secret. */ -class OidcTokenSource extends RefreshableTokenSource { +class OidcTokenSource implements TokenSource { private final HttpClient httpClient; private final String tokenUrl; @@ -58,7 +58,8 @@ private static void putIfDefined( } } - protected Token refresh() { + @Override + public Token getToken() { Response rawResp; try { rawResp = httpClient.execute(new FormRequest(tokenUrl, params)); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java index 5a05b8751..165b504f0 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java @@ -40,7 +40,9 @@ public String authType() { @Override public OAuthHeaderFactory configure(DatabricksConfig config) { - return OAuthHeaderFactory.fromTokenSource(tokenSource); + CachedTokenSource cachedTokenSource = + new CachedTokenSource.Builder(tokenSource).setToken(tokenSource.getToken()).build(); + return OAuthHeaderFactory.fromTokenSource(cachedTokenSource); } static class Builder { diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentialsTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentialsTokenSource.java index 8c71f2eb3..d1552f35a 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentialsTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentialsTokenSource.java @@ -13,9 +13,11 @@ * *

Implements the refresh_token OAuth grant type with optional token caching. */ -public class SessionCredentialsTokenSource extends RefreshableTokenSource { +public class SessionCredentialsTokenSource implements TokenSource { private static final Logger LOGGER = LoggerFactory.getLogger(SessionCredentialsTokenSource.class); + // The token to use for the session + private Token token; // HTTP client for making token refresh requests private final HttpClient hc; // OAuth token endpoint URL for refresh requests @@ -48,7 +50,7 @@ public SessionCredentialsTokenSource( String clientSecret, Optional redirectUrl, Optional tokenCache) { - super(token); + this.token = token; this.hc = hc; this.tokenUrl = tokenUrl; this.clientId = clientId; @@ -69,7 +71,7 @@ public SessionCredentialsTokenSource( * request fails. */ @Override - protected Token refresh() { + public Token getToken() { if (this.token == null) { throw new DatabricksException("oauth2: token is not set"); } @@ -87,16 +89,16 @@ protected Token refresh() { // cross-origin requests redirectUrl.ifPresent(url -> headers.put("Origin", url)); } - Token newToken = - retrieveToken( + this.token = + TokenEndpointClient.retrieveToken( hc, clientId, clientSecret, tokenUrl, params, headers, AuthParameterPosition.BODY); // Save the refreshed token directly to cache tokenCache.ifPresent( cache -> { - cache.save(newToken); + cache.save(this.token); LOGGER.debug("Saved refreshed token to cache"); }); - return newToken; + return this.token; } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java index 69883dd24..3e8b13a4d 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java @@ -1,13 +1,18 @@ package com.databricks.sdk.core.oauth; +import com.databricks.sdk.core.ApiClient; import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.http.FormRequest; import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Request; import com.databricks.sdk.core.http.Response; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; +import java.time.Instant; +import java.util.Base64; import java.util.Map; import java.util.Objects; +import org.apache.http.HttpHeaders; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -88,4 +93,58 @@ public static OAuthResponse requestToken( LOG.debug("Successfully obtained token response from {}", tokenEndpointUrl); return response; } + + /** + * Helper method implementing OAuth token refresh. + * + * @param hc The {@link HttpClient} to use for making the request. + * @param clientId The client ID to authenticate with. + * @param clientSecret The client secret to authenticate with. + * @param tokenUrl The authorization URL for fetching tokens. + * @param params Additional request parameters. + * @param headers Additional headers. + * @param position The position of the authentication parameters in the request. + * @return The newly fetched Token. + * @throws DatabricksException if the refresh fails. + * @throws IllegalArgumentException if the OAuth response contains an error. + */ + public static Token retrieveToken( + HttpClient hc, + String clientId, + String clientSecret, + String tokenUrl, + Map params, + Map headers, + AuthParameterPosition position) { + switch (position) { + case BODY: + if (clientId != null) { + params.put("client_id", clientId); + } + if (clientSecret != null) { + params.put("client_secret", clientSecret); + } + break; + case HEADER: + String authHeaderValue = + "Basic " + + Base64.getEncoder().encodeToString((clientId + ":" + clientSecret).getBytes()); + headers.put(HttpHeaders.AUTHORIZATION, authHeaderValue); + break; + } + headers.put("Content-Type", "application/x-www-form-urlencoded"); + Request req = new Request("POST", tokenUrl, FormRequest.wrapValuesInList(params)); + req.withHeaders(headers); + try { + ApiClient apiClient = new ApiClient.Builder().withHttpClient(hc).build(); + OAuthResponse resp = apiClient.execute(req, OAuthResponse.class); + if (resp.getErrorCode() != null) { + throw new IllegalArgumentException(resp.getErrorCode() + ": " + resp.getErrorSummary()); + } + Instant expiry = Instant.now().plusSeconds(resp.getExpiresIn()); + return new Token(resp.getAccessToken(), resp.getTokenType(), resp.getRefreshToken(), expiry); + } catch (Exception e) { + throw new DatabricksException("Failed to refresh credentials: " + e.getMessage(), e); + } + } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java index 9a341b901..d3edae491 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java @@ -37,11 +37,16 @@ public TokenSourceCredentialsProvider(TokenSource tokenSource, String authType) */ @Override public OAuthHeaderFactory configure(DatabricksConfig config) { + // Check if the token source is already cached to prevent double caching + TokenSource cachedTokenSource = + (tokenSource instanceof CachedTokenSource) + ? tokenSource + : new CachedTokenSource.Builder(tokenSource).build(); + try { // Validate that we can get a token before returning a HeaderFactory - tokenSource.getToken().getAccessToken(); - - return OAuthHeaderFactory.fromTokenSource(tokenSource); + cachedTokenSource.getToken().getAccessToken(); + return OAuthHeaderFactory.fromTokenSource(cachedTokenSource); } catch (Exception e) { return null; } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java index 96dc116c2..09cea6e86 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java @@ -77,7 +77,7 @@ public static Map addWorkspaceResourceId( } public static Map addSpManagementToken( - RefreshableTokenSource tokenSource, Map headers) { + TokenSource tokenSource, Map headers) { headers.put("X-Databricks-Azure-SP-Management-Token", tokenSource.getToken().getAccessToken()); return headers; } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java index 4e8a57b06..051a87643 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/AzureCliCredentialsProviderTest.java @@ -5,8 +5,10 @@ import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.times; +import com.databricks.sdk.core.oauth.CachedTokenSource; import com.databricks.sdk.core.oauth.Token; import com.databricks.sdk.core.oauth.TokenSource; +import java.time.Instant; import java.util.Arrays; import java.util.List; import org.junit.jupiter.api.Test; @@ -21,18 +23,17 @@ class AzureCliCredentialsProviderTest { private static final String TOKEN = "t-123"; private static final String TOKEN_TYPE = "token-type"; - private static CliTokenSource mockTokenSource() { - CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class); - Mockito.when(tokenSource.getToken()) - .thenReturn(new Token(TOKEN, TOKEN_TYPE, java.time.Instant.now())); - return tokenSource; + private static CachedTokenSource mockTokenSource() { + CliTokenSource cliTokenSource = Mockito.mock(CliTokenSource.class); + Mockito.when(cliTokenSource.getToken()).thenReturn(new Token(TOKEN, TOKEN_TYPE, Instant.now())); + return new CachedTokenSource.Builder(cliTokenSource).build(); } private static AzureCliCredentialsProvider getAzureCliCredentialsProvider( TokenSource tokenSource) { AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider()); - Mockito.doReturn(tokenSource).when(provider).getToken(any(), anyList()); + Mockito.doReturn(tokenSource).when(provider).getTokenSource(any(), anyList()); return provider; } @@ -51,7 +52,7 @@ void testWorkSpaceIDUsage() { String token = header.headers().get("Authorization"); assertEquals(token, TOKEN_TYPE + " " + TOKEN); - Mockito.verify(provider, times(2)).getToken(any(), argument.capture()); + Mockito.verify(provider, times(2)).getTokenSource(any(), argument.capture()); List value = argument.getValue(); value = value.subList(value.size() - 2, value.size()); @@ -61,11 +62,13 @@ void testWorkSpaceIDUsage() { @Test void testFallbackWhenTailsToGetTokenForSubscription() { - CliTokenSource tokenSource = mockTokenSource(); + CachedTokenSource tokenSource = mockTokenSource(); AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider()); - Mockito.doThrow(new DatabricksException("error")).when(provider).getToken(any(), anyList()); - Mockito.doReturn(tokenSource).when(provider).getToken(any(), anyList()); + Mockito.doThrow(new DatabricksException("error")) + .when(provider) + .getTokenSource(any(), anyList()); + Mockito.doReturn(tokenSource).when(provider).getTokenSource(any(), anyList()); DatabricksConfig config = new DatabricksConfig() @@ -93,7 +96,7 @@ void testGetTokenWithoutWorkspaceResourceID() { String token = header.headers().get("Authorization"); assertEquals(token, TOKEN_TYPE + " " + TOKEN); - Mockito.verify(provider, times(2)).getToken(any(), argument.capture()); + Mockito.verify(provider, times(2)).getTokenSource(any(), argument.capture()); List value = argument.getValue(); assertFalse(value.contains("--subscription")); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 05ffd805d..16f110991 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -151,7 +151,7 @@ public void testRefreshWithExpiry( when(mock.start()).thenReturn(process); })) { // Test refresh. - Token token = tokenSource.refresh(); + Token token = tokenSource.getToken(); assertEquals("Bearer", token.getTokenType()); assertEquals("test-token", token.getAccessToken()); assertEquals(shouldBeExpired, token.getExpiry().isBefore(Instant.now())); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/CachedTokenSourceTest.java similarity index 69% rename from databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java rename to databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/CachedTokenSourceTest.java index 194c3a2ec..0a29f5a7c 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/RefreshableTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/CachedTokenSourceTest.java @@ -13,7 +13,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -public class RefreshableTokenSourceTest { +public class CachedTokenSourceTest { private static final String TOKEN_TYPE = "Bearer"; private static final String INITIAL_TOKEN = "initial-token"; private static final String REFRESH_TOKEN = "refreshed-token"; @@ -51,10 +51,10 @@ void testAsyncRefreshParametrized( new Token(REFRESH_TOKEN, TOKEN_TYPE, null, Instant.now().plus(Duration.ofMinutes(10))); CountDownLatch refreshCalled = new CountDownLatch(1); - RefreshableTokenSource source = - new RefreshableTokenSource(initialToken) { + TokenSource tokenSource = + new TokenSource() { @Override - protected Token refresh() { + public Token getToken() { refreshCalled.countDown(); try { Thread.sleep(500); @@ -63,7 +63,13 @@ protected Token refresh() { } return refreshedToken; } - }.withAsyncRefresh(asyncEnabled); + }; + + CachedTokenSource source = + new CachedTokenSource.Builder(tokenSource) + .setAsyncEnabled(asyncEnabled) + .setToken(initialToken) + .build(); Token token = source.getToken(); @@ -81,10 +87,10 @@ protected Token refresh() { */ @Test void testAsyncRefreshFailureFallback() throws Exception { - // Create a test clock starting at current time + // Create a mutable clock supplier that we can control TestClockSupplier clockSupplier = new TestClockSupplier(Instant.now()); - // Create a token that expires in 2 minutes from the initial clock time + // Create a token that will be stale (2 minutes until expiry) Token staleToken = new Token( INITIAL_TOKEN, @@ -92,67 +98,83 @@ void testAsyncRefreshFailureFallback() throws Exception { null, Instant.now(clockSupplier.getClock()).plus(Duration.ofMinutes(2))); - class TestSource extends RefreshableTokenSource { + class TestSource implements TokenSource { int refreshCallCount = 0; boolean isFirstRefresh = true; - TestSource(Token token) { - super(token); - } - @Override - protected Token refresh() { + public Token getToken() { refreshCallCount++; + try { + // Sleep to simulate token fetch delay + Thread.sleep(500); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } if (isFirstRefresh) { isFirstRefresh = false; throw new RuntimeException("Simulated async failure"); } + // Return a token that expires in 10 minutes from current time return new Token( - REFRESH_TOKEN, + REFRESH_TOKEN + "-" + refreshCallCount, TOKEN_TYPE, null, Instant.now(clockSupplier.getClock()).plus(Duration.ofMinutes(10))); } } - TestSource source = new TestSource(staleToken); - source.withAsyncRefresh(true); - source.withClockSupplier(clockSupplier); + TestSource testSource = new TestSource(); + CachedTokenSource source = + new CachedTokenSource.Builder(testSource) + .setAsyncEnabled(true) + .setToken(staleToken) + .setClockSupplier(clockSupplier) + .build(); // First call triggers async refresh, which fails - source.getToken(); - Thread.sleep(300); + // Should return stale token immediately (async refresh) + Token token = source.getToken(); + assertEquals(INITIAL_TOKEN, token.getAccessToken(), "Should return stale token immediately"); + Thread.sleep(600); assertEquals( - 1, source.refreshCallCount, "refresh() should have been called once (async, failed)"); + 1, testSource.refreshCallCount, "refresh() should have been called once (async, failed)"); // Token is still stale, so next call should NOT trigger another refresh since the last refresh // failed - source.getToken(); - Thread.sleep(200); + token = source.getToken(); + assertEquals(INITIAL_TOKEN, token.getAccessToken(), "Should still return stale token"); + Thread.sleep(600); assertEquals( 1, - source.refreshCallCount, + testSource.refreshCallCount, "refresh() should NOT be called again while stale after async failure"); - // Advance the clock by 3 minutes to make the token expired + // Advance time by 3 minutes to make the token expired clockSupplier.advanceTime(Duration.ofMinutes(3)); // Now getToken() should call refresh synchronously and return the refreshed token - Token token = source.getToken(); + token = source.getToken(); assertEquals( - REFRESH_TOKEN, + REFRESH_TOKEN + "-2", token.getAccessToken(), "Should return the refreshed token after sync refresh"); + Thread.sleep(600); assertEquals( - 2, source.refreshCallCount, "refresh() should have been called synchronously after expiry"); + 2, + testSource.refreshCallCount, + "refresh() should have been called synchronously after expiry"); // Advance time by 8 minutes to make the token stale again clockSupplier.advanceTime(Duration.ofMinutes(8)); - source.getToken(); - Thread.sleep(300); + // Should return stale token immediately (async refresh) + token = source.getToken(); + assertEquals( + REFRESH_TOKEN + "-2", token.getAccessToken(), "Should return stale token immediately"); + Thread.sleep(600); assertEquals( 3, - source.refreshCallCount, + testSource.refreshCallCount, "refresh() should have been called again asynchronously after making token stale"); } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java index b7e237ddd..968cde75b 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java @@ -203,7 +203,7 @@ void clientCredentials() throws IOException { .withClientSecret("abc") .withTokenUrl("https://tokenUrl") .build(); - Token token = clientCredentials.refresh(); + Token token = clientCredentials.getToken(); assertEquals("accessTokenFromServer", token.getAccessToken()); assertEquals("refreshTokenFromServer", token.getRefreshToken()); } @@ -228,7 +228,7 @@ void sessionCredentials() throws IOException { "abc", Optional.empty(), Optional.empty()); - Token token = sessionCredentialsTokenSource.refresh(); + Token token = sessionCredentialsTokenSource.getToken(); // We check that we are actually getting the token from server response (that is defined // above) rather than what was given while creating session credentials @@ -434,6 +434,9 @@ void cacheWithInvalidAccessTokenRefreshFailingTest() throws IOException { Optional.empty(), Optional.empty()); + CachedTokenSource cachedTokenSource = + new CachedTokenSource.Builder(browserAuthTokenSource).setToken(browserAuthToken).build(); + // Create config with failing HTTP client and mock token cache DatabricksConfig config = new DatabricksConfig() @@ -450,7 +453,7 @@ void cacheWithInvalidAccessTokenRefreshFailingTest() throws IOException { // Create our provider and mock the browser auth method ExternalBrowserCredentialsProvider provider = Mockito.spy(new ExternalBrowserCredentialsProvider(mockTokenCache)); - Mockito.doReturn(browserAuthTokenSource) + Mockito.doReturn(cachedTokenSource) .when(provider) .performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class)); @@ -460,6 +463,7 @@ void cacheWithInvalidAccessTokenRefreshFailingTest() throws IOException { // Configure provider HeaderFactory headerFactory = provider.configure(spyConfig); + assertNotNull(headerFactory); // Verify headers contain the browser auth token (fallback) Map headers = headerFactory.headers(); @@ -507,6 +511,9 @@ void cacheWithInvalidTokensTest() throws IOException { Optional.empty(), Optional.empty()); + CachedTokenSource cachedTokenSource = + new CachedTokenSource.Builder(browserAuthTokenSource).setToken(browserAuthToken).build(); + // Create simple config DatabricksConfig config = new DatabricksConfig() @@ -517,12 +524,13 @@ void cacheWithInvalidTokensTest() throws IOException { // Create our provider and mock the browser auth method ExternalBrowserCredentialsProvider provider = Mockito.spy(new ExternalBrowserCredentialsProvider(mockTokenCache)); - Mockito.doReturn(browserAuthTokenSource) + Mockito.doReturn(cachedTokenSource) .when(provider) .performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class)); // Configure provider HeaderFactory headerFactory = provider.configure(config); + assertNotNull(headerFactory); // Verify headers contain the browser auth token (fallback) Map headers = headerFactory.headers(); assertEquals("Bearer browser_access_token", headers.get("Authorization"));