Skip to content

Commit

Permalink
Rename class of credential providers and optimize expired time.
Browse files Browse the repository at this point in the history
  • Loading branch information
yuqi1129 committed Jan 3, 2025
1 parent d2ba98b commit b34c526
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,18 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GravitinoOSSCredentialProvider implements CredentialsProvider {
public class OSSCredentialProvider implements CredentialsProvider {

private static final Logger LOG = LoggerFactory.getLogger(GravitinoOSSCredentialProvider.class);
private static final Logger LOG = LoggerFactory.getLogger(OSSCredentialProvider.class);
private Credentials basicCredentials;
private final String filesetIdentifier;
private long expirationTime;
private final GravitinoClient client;
private final Configuration configuration;

public GravitinoOSSCredentialProvider(URI uri, Configuration conf) {
private long expirationTime = Long.MAX_VALUE;
private static final double EXPIRATION_TIME_FACTOR = 0.9D;

public OSSCredentialProvider(URI uri, Configuration conf) {
this.filesetIdentifier =
conf.get(GravitinoVirtualFileSystemConfiguration.GVFS_FILESET_IDENTIFIER);
this.client = GravitinoVirtualFileSystemUtils.createClient(conf);
Expand All @@ -67,7 +69,7 @@ public void setCredentials(Credentials credentials) {}
@Override
public Credentials getCredentials() {
// If the credentials are null or about to expire, refresh the credentials.
if (basicCredentials == null || System.currentTimeMillis() > expirationTime - 5 * 60 * 1000) {
if (basicCredentials == null || System.currentTimeMillis() >= expirationTime) {
synchronized (this) {
refresh();
}
Expand Down Expand Up @@ -110,9 +112,12 @@ private void refresh() {
basicCredentials = new DefaultCredentials(accessKeyId, secretAccessKey);
}

expirationTime = credential.expireTimeInMs();
if (expirationTime <= 0) {
expirationTime = Long.MAX_VALUE;
if (credential.expireTimeInMs() > 0) {
expirationTime =
System.currentTimeMillis()
+ (long)
((credential.expireTimeInMs() - System.currentTimeMillis())
* EXPIRATION_TIME_FACTOR);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ public FileSystem getFileSystem(Path path, Map<String, String> config) throws IO
&& config.containsKey(
GravitinoVirtualFileSystemConfiguration.FS_GRAVITINO_SERVER_URI_KEY)) {
hadoopConfMap.put(
Constants.CREDENTIALS_PROVIDER_KEY,
GravitinoOSSCredentialProvider.class.getCanonicalName());
Constants.CREDENTIALS_PROVIDER_KEY, OSSCredentialProvider.class.getCanonicalName());
}

hadoopConfMap.forEach(configuration::set);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,18 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GravitinoS3CredentialProvider implements AWSCredentialsProvider {
public class S3CredentialProvider implements AWSCredentialsProvider {

private static final Logger LOG = LoggerFactory.getLogger(GravitinoS3CredentialProvider.class);
private static final Logger LOG = LoggerFactory.getLogger(S3CredentialProvider.class);
private final GravitinoClient client;
private final String filesetIdentifier;
private final Configuration configuration;

private AWSCredentials basicSessionCredentials;
private long expirationTime;
private long expirationTime = Long.MAX_VALUE;
private static final double EXPIRATION_TIME_FACTOR = 0.9D;

public GravitinoS3CredentialProvider(final URI uri, final Configuration conf) {
public S3CredentialProvider(final URI uri, final Configuration conf) {
this.filesetIdentifier =
conf.get(GravitinoVirtualFileSystemConfiguration.GVFS_FILESET_IDENTIFIER);
this.configuration = conf;
Expand All @@ -65,8 +66,7 @@ public GravitinoS3CredentialProvider(final URI uri, final Configuration conf) {
@Override
public AWSCredentials getCredentials() {
// Refresh credentials if they are null or about to expire in 5 minutes
if (basicSessionCredentials == null
|| System.currentTimeMillis() > expirationTime - 5 * 60 * 1000) {
if (basicSessionCredentials == null || System.currentTimeMillis() >= expirationTime) {
synchronized (this) {
refresh();
}
Expand Down Expand Up @@ -112,9 +112,12 @@ public void refresh() {
basicSessionCredentials = new BasicAWSCredentials(accessKeyId, secretAccessKey);
}

expirationTime = credential.expireTimeInMs();
if (expirationTime <= 0) {
expirationTime = Long.MAX_VALUE;
if (credential.expireTimeInMs() > 0) {
expirationTime =
System.currentTimeMillis()
+ (long)
((credential.expireTimeInMs() - System.currentTimeMillis())
* EXPIRATION_TIME_FACTOR);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ public FileSystem getFileSystem(Path path, Map<String, String> config) throws IO
// will have this key.
if (config.containsKey(GravitinoVirtualFileSystemConfiguration.FS_GRAVITINO_SERVER_URI_KEY)) {
configuration.set(
Constants.AWS_CREDENTIALS_PROVIDER,
GravitinoS3CredentialProvider.class.getCanonicalName());
Constants.AWS_CREDENTIALS_PROVIDER, S3CredentialProvider.class.getCanonicalName());
}

// Hadoop-aws 2 does not support IAMInstanceCredentialsProvider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,9 @@ public FileSystem getFileSystem(@Nonnull Path path, @Nonnull Map<String, String>
if (config.containsKey(GravitinoVirtualFileSystemConfiguration.FS_GRAVITINO_SERVER_URI_KEY)) {
// Test whether SAS works
try {
GravitinoAzureSasCredentialProvider gravitinoAzureSasCredentialProvider =
new GravitinoAzureSasCredentialProvider();
gravitinoAzureSasCredentialProvider.initialize(configuration, null);
String sas = gravitinoAzureSasCredentialProvider.getSASToken(null, null, null, null);
AzureSasCredentialProvider azureSasCredentialProvider = new AzureSasCredentialProvider();
azureSasCredentialProvider.initialize(configuration, null);
String sas = azureSasCredentialProvider.getSASToken(null, null, null, null);
if (sas != null) {
String accountName =
String.format(
Expand All @@ -92,15 +91,15 @@ public FileSystem getFileSystem(@Nonnull Path path, @Nonnull Map<String, String>
FS_AZURE_ACCOUNT_AUTH_TYPE_PROPERTY_NAME + "." + accountName, AuthType.SAS.name());
configuration.set(
FS_AZURE_SAS_TOKEN_PROVIDER_TYPE + "." + accountName,
GravitinoAzureSasCredentialProvider.class.getName());
AzureSasCredentialProvider.class.getName());
configuration.set(FS_AZURE_ACCOUNT_IS_HNS_ENABLED, "true");
} else if (gravitinoAzureSasCredentialProvider.getAzureStorageAccountKey() != null
&& gravitinoAzureSasCredentialProvider.getAzureStorageAccountName() != null) {
} else if (azureSasCredentialProvider.getAzureStorageAccountKey() != null
&& azureSasCredentialProvider.getAzureStorageAccountName() != null) {
configuration.set(
String.format(
"fs.azure.account.key.%s.dfs.core.windows.net",
gravitinoAzureSasCredentialProvider.getAzureStorageAccountName()),
gravitinoAzureSasCredentialProvider.getAzureStorageAccountKey());
azureSasCredentialProvider.getAzureStorageAccountName()),
azureSasCredentialProvider.getAzureStorageAccountKey());
}
} catch (Exception e) {
// Can't use SAS, use account key and account key instead
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GravitinoAzureSasCredentialProvider implements SASTokenProvider, Configurable {
public class AzureSasCredentialProvider implements SASTokenProvider, Configurable {

private static final Logger LOGGER =
LoggerFactory.getLogger(GravitinoAzureSasCredentialProvider.class);
private static final Logger LOGGER = LoggerFactory.getLogger(AzureSasCredentialProvider.class);

private Configuration configuration;

Expand All @@ -58,7 +57,8 @@ public class GravitinoAzureSasCredentialProvider implements SASTokenProvider, Co
private String azureStorageAccountName;
private String azureStorageAccountKey;

private long expirationTime;
private long expirationTime = Long.MAX_VALUE;
private static final double EXPIRATION_TIME_FACTOR = 0.9D;

public String getAzureStorageAccountName() {
return azureStorageAccountName;
Expand Down Expand Up @@ -88,7 +88,7 @@ public void initialize(Configuration conf, String accountName) throws IOExceptio
@Override
public String getSASToken(String account, String fileSystem, String path, String operation) {
// Refresh credentials if they are null or about to expire in 5 minutes
if (sasToken == null || System.currentTimeMillis() > expirationTime - 5 * 60 * 1000) {
if (sasToken == null || System.currentTimeMillis() >= expirationTime) {
synchronized (this) {
refresh();
}
Expand Down Expand Up @@ -121,9 +121,12 @@ private void refresh() {
azureStorageAccountKey = credentialMap.get(GRAVITINO_AZURE_STORAGE_ACCOUNT_KEY);
}

this.expirationTime = credential.expireTimeInMs();
if (expirationTime <= 0) {
expirationTime = Long.MAX_VALUE;
if (credential.expireTimeInMs() > 0) {
expirationTime =
System.currentTimeMillis()
+ (long)
((credential.expireTimeInMs() - System.currentTimeMillis())
* EXPIRATION_TIME_FACTOR);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

import com.google.cloud.hadoop.util.AccessTokenProvider;
import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.Optional;
import org.apache.gravitino.NameIdentifier;
import org.apache.gravitino.client.GravitinoClient;
import org.apache.gravitino.credential.Credential;
Expand All @@ -34,18 +36,19 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GravitinoGCSCredentialProvider implements AccessTokenProvider {
private static final Logger LOG = LoggerFactory.getLogger(GravitinoGCSCredentialProvider.class);
public class GCSCredentialProvider implements AccessTokenProvider {
private static final Logger LOG = LoggerFactory.getLogger(GCSCredentialProvider.class);
private Configuration configuration;
private GravitinoClient client;
private String filesetIdentifier;

private AccessToken accessToken;
private long expirationTime;
private long expirationTime = Long.MAX_VALUE;
private static final double EXPIRATION_TIME_FACTOR = 0.9D;

@Override
public AccessToken getAccessToken() {
if (accessToken == null || expirationTime < System.currentTimeMillis() + 5 * 60 * 1000) {
if (accessToken == null || System.currentTimeMillis() >= expirationTime) {
try {
refresh();
} catch (IOException e) {
Expand All @@ -67,24 +70,28 @@ public void refresh() throws IOException {
Fileset fileset = filesetCatalog.loadFileset(NameIdentifier.of(idents[2], idents[3]));
Credential[] credentials = fileset.supportsCredentials().getCredentials();

Optional<Credential> optionalCredential = getCredential(credentials);
// Can't find any credential, use the default one.
if (credentials.length == 0) {
if (!optionalCredential.isPresent()) {
LOG.warn(
"No credential found for fileset: {}, try to use static JSON file", filesetIdentifier);
return;
}

Credential credential = credentials[0];
Credential credential = optionalCredential.get();
Map<String, String> credentialMap = credential.toProperties();

if (GCSTokenCredential.GCS_TOKEN_CREDENTIAL_TYPE.equals(
credentialMap.get(Credential.CREDENTIAL_TYPE))) {
String sessionToken = credentialMap.get(GCSTokenCredential.GCS_TOKEN_NAME);
accessToken = new AccessToken(sessionToken, expirationTime);
accessToken = new AccessToken(sessionToken, credential.expireTimeInMs());

expirationTime = credential.expireTimeInMs();
if (expirationTime <= 0) {
expirationTime = Long.MAX_VALUE;
if (credential.expireTimeInMs() > 0) {
expirationTime =
System.currentTimeMillis()
+ (long)
((credential.expireTimeInMs() - System.currentTimeMillis())
* EXPIRATION_TIME_FACTOR);
}
}
}
Expand All @@ -101,4 +108,20 @@ public void setConf(Configuration configuration) {
public Configuration getConf() {
return this.configuration;
}

/**
* Get the credential from the credential array. Using dynamic credential first, if not found,
* uses static credential.
*
* @param credentials The credential array.
* @return An optional credential.
*/
private Optional<Credential> getCredential(Credential[] credentials) {
// Use dynamic credential if found.
return Arrays.stream(credentials)
.filter(
credential ->
credential.credentialType().equals(GCSTokenCredential.GCS_TOKEN_CREDENTIAL_TYPE))
.findFirst();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,13 @@ public FileSystem getFileSystem(Path path, Map<String, String> config) throws IO
.forEach(configuration::set);

if (config.containsKey(GravitinoVirtualFileSystemConfiguration.FS_GRAVITINO_SERVER_URI_KEY)) {
AccessTokenProvider accessTokenProvider = new GravitinoGCSCredentialProvider();
AccessTokenProvider accessTokenProvider = new GCSCredentialProvider();
accessTokenProvider.setConf(configuration);
// Why is this check necessary?, if Gravitino fails to get any credentials, we fall back to
// the default behavior of the GoogleHadoopFileSystem to use service account credentials.
if (accessTokenProvider.getAccessToken() != null) {
configuration.set(
"fs.gs.auth.access.token.provider.impl",
GravitinoGCSCredentialProvider.class.getName());
"fs.gs.auth.access.token.provider.impl", GCSCredentialProvider.class.getName());
}
}

Expand Down

0 comments on commit b34c526

Please sign in to comment.