Skip to content

Refactor SessionCredentials #470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,131 @@ public String getClientSecret() {
return clientSecret;
}

/**
* Launch a browser to collect an authorization code and exchange the code for an OAuth token.
*
* @return A {@code SessionCredentials} instance representing the retrieved OAuth token.
* @throws IOException if the webserver cannot be started, or if the browser cannot be opened.
*/
public SessionCredentials launchExternalBrowser() throws IOException {
Map<String, String> params = getOAuthCallbackParameters();
return exchangeCallbackParameters(params);
}

/**
* Exchange callback parameters for OAuth credentials.
*
* @param query The callback parameters from the OAuth flow
* @return A {@code SessionCredentials} instance representing the retrieved OAuth token.
*/
public SessionCredentials exchangeCallbackParameters(Map<String, String> query) {
validateCallbackParameters(query);
Token token = exchange(query.get("code"), query.get("state"));
return new SessionCredentials.Builder()
.withHttpClient(this.hc)
.withClientId(this.clientId)
.withClientSecret(this.clientSecret)
.withTokenUrl(this.tokenUrl)
.withToken(token)
.build();
}

/**
* Launches an external browser to collect OAuth callback parameters and exchanges them for an
* OAuth token.
*
* @return A {@code Token} instance containing the OAuth access token and related credentials
* @throws IOException if the local HTTP server cannot be started, the browser cannot be opened,
* or there are network issues during the token exchange
* @throws DatabricksException if the OAuth callback contains an error, missing required
* parameters, or if there's a state mismatch during the token exchange.
*/
Token getTokenFromExternalBrowser() throws IOException {
Map<String, String> params = getOAuthCallbackParameters();
validateCallbackParameters(params);
return exchange(params.get("code"), params.get("state"));
}

protected void desktopBrowser() throws IOException {
Desktop.getDesktop().browse(URI.create(this.authUrl));
}

/**
* Handles the OAuth callback by setting up a local HTTP server, launching the browser, and
* collecting the callback parameters.
*
* @return A map containing the callback parameters from the OAuth flow.
* @throws IOException if the webserver cannot be started, or if the browser cannot be opened.
*/
private Map<String, String> getOAuthCallbackParameters() throws IOException {
URL redirect = new URL(getRedirectUrl());
if (!Arrays.asList("localhost", "127.0.0.1").contains(redirect.getHost())) {
throw new IllegalArgumentException(
"cannot listen on "
+ redirect.getHost()
+ ", redirectUrl host must be one of: localhost, 127.0.0.1");
}
CallbackResponseHandler handler = new CallbackResponseHandler();
HttpServer httpServer =
HttpServer.create(new InetSocketAddress(redirect.getHost(), redirect.getPort()), 0);
httpServer.createContext("/", handler);
httpServer.start();
desktopBrowser();
Map<String, String> params = handler.getParams();
httpServer.stop(0);
return params;
}

/**
* Validates the OAuth callback parameters to ensure they contain the required fields and no error
* conditions.
*
* @param query The callback parameters to validate
* @throws DatabricksException if validation fails due to error conditions or missing required
* parameters
*/
private void validateCallbackParameters(Map<String, String> query) {
if (query.containsKey("error")) {
throw new DatabricksException(query.get("error") + ": " + query.get("error_description"));
}
if (!query.containsKey("code") || !query.containsKey("state")) {
throw new DatabricksException("No code returned in callback");
}
}

/**
* Exchange authorization code for OAuth token.
*
* @param code The authorization code from the OAuth callback
* @param state The state parameter from the OAuth callback
* @return A {@code Token} instance representing the OAuth token
*/
private Token exchange(String code, String state) {
if (!this.state.equals(state)) {
throw new DatabricksException(
"state mismatch: original state: " + this.state + "; retrieved state: " + state);
}
Map<String, String> params = new HashMap<>();
params.put("grant_type", "authorization_code");
params.put("code", code);
params.put("code_verifier", this.verifier);
params.put("redirect_uri", this.redirectUrl);
Map<String, String> headers = new HashMap<>();
if (this.tokenUrl.contains("microsoft")) {
headers.put("Origin", this.redirectUrl);
}
Token token =
RefreshableTokenSource.retrieveToken(
this.hc,
this.clientId,
this.clientSecret,
this.tokenUrl,
params,
headers,
AuthParameterPosition.BODY);
return token;
}

static class CallbackResponseHandler implements HttpHandler {
private final Logger LOG = LoggerFactory.getLogger(getClass().getName());
// Protects params
Expand Down Expand Up @@ -258,75 +383,4 @@ public Map<String, String> getParams() {
}
}
}

/**
* Launch a browser to collect an authorization code and exchange the code for an OAuth token.
*
* @return A {@code SessionCredentials} instance representing the retrieved OAuth token.
* @throws IOException if the webserver cannot be started, or if the browser cannot be opened
*/
public SessionCredentials launchExternalBrowser() throws IOException {
URL redirect = new URL(getRedirectUrl());
if (!Arrays.asList("localhost", "127.0.0.1").contains(redirect.getHost())) {
throw new IllegalArgumentException(
"cannot listen on "
+ redirect.getHost()
+ ", redirectUrl host must be one of: localhost, 127.0.0.1");
}
CallbackResponseHandler handler = new CallbackResponseHandler();
HttpServer httpServer =
HttpServer.create(new InetSocketAddress(redirect.getHost(), redirect.getPort()), 0);
httpServer.createContext("/", handler);
httpServer.start();
desktopBrowser();
Map<String, String> params = handler.getParams();
httpServer.stop(0);
return exchangeCallbackParameters(params);
}

protected void desktopBrowser() throws IOException {
Desktop.getDesktop().browse(URI.create(this.authUrl));
}

public SessionCredentials exchangeCallbackParameters(Map<String, String> query) {
if (query.containsKey("error")) {
throw new DatabricksException(query.get("error") + ": " + query.get("error_description"));
}
if (!query.containsKey("code") || !query.containsKey("state")) {
throw new DatabricksException("No code returned in callback");
}
return exchange(query.get("code"), query.get("state"));
}

public SessionCredentials exchange(String code, String state) {
if (!this.state.equals(state)) {
throw new DatabricksException(
"state mismatch: original state: " + this.state + "; retrieved state: " + state);
}
Map<String, String> params = new HashMap<>();
params.put("grant_type", "authorization_code");
params.put("code", code);
params.put("code_verifier", this.verifier);
params.put("redirect_uri", this.redirectUrl);
Map<String, String> headers = new HashMap<>();
if (this.tokenUrl.contains("microsoft")) {
headers.put("Origin", this.redirectUrl);
}
Token token =
RefreshableTokenSource.retrieveToken(
this.hc,
this.clientId,
this.clientSecret,
this.tokenUrl,
params,
headers,
AuthParameterPosition.BODY);
return new SessionCredentials.Builder()
.withHttpClient(this.hc)
.withClientId(this.clientId)
.withClientSecret(this.clientSecret)
.withTokenUrl(this.tokenUrl)
.withToken(token)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.io.IOException;
import java.nio.file.Path;
import java.util.Objects;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -66,39 +67,38 @@ public OAuthHeaderFactory configure(DatabricksConfig config) {
LOGGER.debug("Found cached token for {}:{}", config.getHost(), clientId);

try {
// Create SessionCredentials with the cached token and try to refresh if needed
SessionCredentials cachedCreds =
new SessionCredentials.Builder()
.withToken(cachedToken)
.withHttpClient(config.getHttpClient())
.withClientId(clientId)
.withClientSecret(clientSecret)
.withTokenUrl(config.getOidcEndpoints().getTokenEndpoint())
.withRedirectUrl(config.getEffectiveOAuthRedirectUrl())
.withTokenCache(tokenCache)
.build();
// Create SessionCredentialsTokenSource with the cached token and try to refresh if needed
SessionCredentialsTokenSource tokenSource =
new SessionCredentialsTokenSource(
cachedToken,
config.getHttpClient(),
config.getOidcEndpoints().getTokenEndpoint(),
clientId,
clientSecret,
Optional.of(config.getEffectiveOAuthRedirectUrl()),
Optional.of(tokenCache));

LOGGER.debug("Using cached token, will immediately refresh");
cachedCreds.token = cachedCreds.refresh();
return cachedCreds.configure(config);
tokenSource.token = tokenSource.refresh();
return OAuthHeaderFactory.fromTokenSource(tokenSource);
} catch (Exception e) {
// If token refresh fails, log and continue to browser auth
LOGGER.info("Token refresh failed: {}, falling back to browser auth", e.getMessage());
}
}

// If no cached token or refresh failed, perform browser auth
SessionCredentials credentials =
SessionCredentialsTokenSource tokenSource =
performBrowserAuth(config, clientId, clientSecret, tokenCache);
tokenCache.save(credentials.getToken());
return credentials.configure(config);
tokenCache.save(tokenSource.getToken());
return OAuthHeaderFactory.fromTokenSource(tokenSource);
} catch (IOException | DatabricksException e) {
LOGGER.error("Failed to authenticate: {}", e.getMessage());
return null;
}
}

SessionCredentials performBrowserAuth(
SessionCredentialsTokenSource performBrowserAuth(
DatabricksConfig config, String clientId, String clientSecret, TokenCache tokenCache)
throws IOException {
LOGGER.debug("Performing browser authentication");
Expand All @@ -113,18 +113,17 @@ SessionCredentials performBrowserAuth(
.build();
Consent consent = client.initiateConsent();

// Use the existing browser flow to get credentials
SessionCredentials credentials = consent.launchExternalBrowser();

// Create a new SessionCredentials with the same token but with our token cache
return new SessionCredentials.Builder()
.withToken(credentials.getToken())
.withHttpClient(config.getHttpClient())
.withClientId(config.getClientId())
.withClientSecret(config.getClientSecret())
.withTokenUrl(config.getOidcEndpoints().getTokenEndpoint())
.withRedirectUrl(config.getEffectiveOAuthRedirectUrl())
.withTokenCache(tokenCache)
.build();
// Use the existing browser flow to get credentials.
Token token = consent.getTokenFromExternalBrowser();

// Create a SessionCredentialsTokenSource with the token from browser auth.
return new SessionCredentialsTokenSource(
token,
config.getHttpClient(),
config.getOidcEndpoints().getTokenEndpoint(),
config.getClientId(),
config.getClientSecret(),
Optional.ofNullable(config.getEffectiveOAuthRedirectUrl()),
Optional.ofNullable(tokenCache));
}
}
Loading
Loading