Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for out of band access tokens in the GCS Client #11181

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 51 additions & 7 deletions velox/connectors/hive/storage_adapters/gcs/GCSFileSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <stdexcept>

#include <google/cloud/storage/client.h>
#include "TokenProvider.h"
#include "OutOfBandOAuth2Credentials.h"

namespace facebook::velox {
namespace {
Expand Down Expand Up @@ -65,10 +67,13 @@ inline void checkGCSStatus(

class GCSReadFile final : public ReadFile {
public:
GCSReadFile(const std::string& path, std::shared_ptr<gcs::Client> client)
: client_(std::move(client)) {
GCSReadFile(const std::string& path,
std::shared_ptr<gcs::Client> client,
std::function<void(std::shared_ptr<gcs::Client>&, const std::optional<std::string>&, const std::optional<std::vector<std::string>>&)> refreshTokenFn)
: client_(std::move(client)), refreshTokenFn_(std::move(refreshTokenFn)) {
// assumption it's a proper path
setBucketAndKeyFromGCSPath(path, bucket_, key_);
refreshTokenFn_(client_, std::optional<std::string>(path), std::nullopt);
}

// Gets the length of the file.
Expand Down Expand Up @@ -175,6 +180,7 @@ class GCSReadFile final : public ReadFile {
}

std::shared_ptr<gcs::Client> client_;
std::function<void(std::shared_ptr<gcs::Client>&, const std::optional<std::string>&, const std::optional<std::vector<std::string>>&)> refreshTokenFn_;
std::string bucket_;
std::string key_;
std::atomic<int64_t> length_ = -1;
Expand All @@ -184,9 +190,11 @@ class GCSWriteFile final : public WriteFile {
public:
explicit GCSWriteFile(
const std::string& path,
std::shared_ptr<gcs::Client> client)
: client_(client) {
std::shared_ptr<gcs::Client> client,
std::function<void(std::shared_ptr<gcs::Client>&, const std::optional<std::string>&, const std::optional<std::vector<std::string>>&)> refreshTokenFn)
: client_(client), refreshTokenFn_(std::move(refreshTokenFn)) {
setBucketAndKeyFromGCSPath(path, bucket_, key_);
refreshTokenFn_(client_, std::optional<std::string>(path), std::nullopt);
}

~GCSWriteFile() {
Expand Down Expand Up @@ -244,6 +252,7 @@ class GCSWriteFile final : public WriteFile {

gcs::ObjectWriteStream stream_;
std::shared_ptr<gcs::Client> client_;
std::function<void(std::shared_ptr<gcs::Client>&, const std::optional<std::string>&, const std::optional<std::vector<std::string>>&)> refreshTokenFn_;
std::string bucket_;
std::string key_;
std::atomic<int64_t> size_{-1};
Expand All @@ -267,8 +276,12 @@ class GCSFileSystem::Impl {
// Use the input Config parameters and initialize the GCSClient.
void initializeClient() {
auto options = gc::Options{};

std::string accessTokenEnabled =
hiveConfig_->config()->get<std::string>("gcs.access_token_enabled", "false");

auto scheme = hiveConfig_->gcsScheme();
if (scheme == "https") {
if (scheme == "https" && accessTokenEnabled == "false") {
options.set<gc::UnifiedCredentialsOption>(
gc::MakeGoogleDefaultCredentials());
} else {
Expand Down Expand Up @@ -312,13 +325,31 @@ class GCSFileSystem::Impl {
<< "Config hive.gcs.json-key-file-path is empty or key file path not found";
}

if (accessTokenEnabled == "true") {
std::string providerClassName =
hiveConfig_->config()->get<std::string>("gcs.access_token_provider", "");
throw std::runtime_error("AccessTokenProvider not configured. Please implement an out of band access token provider.");
// Out of band access token provider can implement custom/internal integrations.
// Callers must provide their own provider to use the access tokens with this client.
// Set credentials using: options.set<google::cloud::storage::Oauth2CredentialsOption>(credentials_);

}

client_ = std::make_shared<gcs::Client>(options);
}

std::shared_ptr<gcs::Client> getClient() const {
return client_;
}

std::shared_ptr<OutOfBandOAuth2Credentials> credentials_;
void setCredentialsContext(const std::string& path, const std::optional<std::vector<std::string>>& permissions) {
if (credentials_) {
credentials_->SetContext(path, permissions);
}
}


private:
const std::shared_ptr<HiveConfig> hiveConfig_;
std::shared_ptr<gcs::Client> client_;
Expand All @@ -337,7 +368,11 @@ std::unique_ptr<ReadFile> GCSFileSystem::openFileForRead(
std::string_view path,
const FileOptions& options) {
const auto gcspath = gcsPath(path);
auto gcsfile = std::make_unique<GCSReadFile>(gcspath, impl_->getClient());
auto refreshTokenLambda = [impl_ = this->impl_](std::shared_ptr<gcs::Client>& client, const std::optional<std::string>& path, const std::optional<std::vector<std::string>>& permissions) -> void {
impl_->setCredentialsContext(path.value_or(""), "READ");
};

auto gcsfile = std::make_unique<GCSReadFile>(gcspath, impl_->getClient(), refreshTokenLambda);
gcsfile->initialize(options);
return gcsfile;
}
Expand All @@ -346,7 +381,11 @@ std::unique_ptr<WriteFile> GCSFileSystem::openFileForWrite(
std::string_view path,
const FileOptions& /*unused*/) {
const auto gcspath = gcsPath(path);
auto gcsfile = std::make_unique<GCSWriteFile>(gcspath, impl_->getClient());
auto refreshTokenLambda = [impl_ = this->impl_](std::shared_ptr<gcs::Client>& client, const std::optional<std::string>& path, const std::optional<std::vector<std::string>>& permissions) -> void {
impl_->setCredentialsContext(path.value_or(""), "WRITE");
};

auto gcsfile = std::make_unique<GCSWriteFile>(gcspath, impl_->getClient(), refreshTokenLambda);
gcsfile->initialize();
return gcsfile;
}
Expand All @@ -363,6 +402,7 @@ void GCSFileSystem::remove(std::string_view path) {
setBucketAndKeyFromGCSPath(file, bucket, object);

if (!object.empty()) {
impl_->setCredentialsContext(file, "DELETE");
auto stat = impl_->getClient()->GetObjectMetadata(bucket, object);
if (!stat.ok()) {
checkGCSStatus(
Expand Down Expand Up @@ -390,6 +430,8 @@ bool GCSFileSystem::exists(std::string_view path) {
std::string object;
setBucketAndKeyFromGCSPath(file, bucket, object);
using ::google::cloud::StatusOr;

impl_->setCredentialsContext(file, "LIST");
StatusOr<gcs::BucketMetadata> metadata =
impl_->getClient()->GetBucketMetadata(bucket);

Expand All @@ -406,6 +448,8 @@ std::vector<std::string> GCSFileSystem::list(std::string_view path) {
std::string bucket;
std::string object;
setBucketAndKeyFromGCSPath(file, bucket, object);

impl_->setCredentialsContext(file, "LIST");
for (auto&& metadata : impl_->getClient()->ListObjects(bucket)) {
if (!metadata.ok()) {
checkGCSStatus(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "OutOfBandOAuth2Credentials.h"

thread_local std::string OutOfBandOAuth2Credentials::current_path_;
thread_local std::optional<std::vector<std::string>> OutOfBandOAuth2Credentials::current_operation_;

OutOfBandOAuth2Credentials::OutOfBandOAuth2Credentials(std::shared_ptr<TokenProvider> token_provider)
: token_provider_(std::move(token_provider)) {
}

void OutOfBandOAuth2Credentials::SetContext(const std::string& path, const std::optional<std::vector<std::string>>& operation) {
current_path_ = path;
current_operation_ = operation;
}

google::cloud::StatusOr<std::string> OutOfBandOAuth2Credentials::AuthorizationHeader() {
auto token_pair = token_provider_->getAccessToken(current_path_, std::nullopt);
std::string access_token = token_pair.first;

if (access_token.empty()) {
return google::cloud::Status(google::cloud::StatusCode::kPermissionDenied, "Failed to fetch access token");
}
return "Authorization: Bearer " + access_token;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef OUT_OF_BAND_OAUTH2_CREDENTIALS_H
#define OUT_OF_BAND_OAUTH2_CREDENTIALS_H

#include "google/cloud/storage/oauth2/credentials.h"
#include <memory>
#include <string>
#include "TokenProvider.h"

class OutOfBandOAuth2Credentials : public google::cloud::storage::oauth2::Credentials {
public:
OutOfBandOAuth2Credentials(std::shared_ptr<TokenProvider> token_provider);
virtual ~OutOfBandOAuth2Credentials() = default;

void SetContext(const std::string& path, const std::optional<std::vector<std::string>>& operation);

google::cloud::StatusOr<std::string> AuthorizationHeader() override;

private:
std::shared_ptr<TokenProvider> token_provider_;
static thread_local std::string current_path_;
static thread_local std::optional<std::vector<std::string>> current_operation_;
};

#endif
34 changes: 34 additions & 0 deletions velox/connectors/hive/storage_adapters/gcs/TokenProvider.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef TOKEN_PROVIDER_H
#define TOKEN_PROVIDER_H

#include <optional>
#include <string>
#include <vector>

class TokenProvider {
public:
virtual ~TokenProvider() = default;

virtual std::pair<std::string, int64_t> getAccessToken(
const std::optional<std::string>& targetPath = std::nullopt,
const std::optional<std::vector<std::string>>& permissions = std::nullopt
) const = 0;
};

#endif