diff --git a/velox/connectors/hive/storage_adapters/gcs/GCSFileSystem.cpp b/velox/connectors/hive/storage_adapters/gcs/GCSFileSystem.cpp index 583a25c291e8..0c1138f87135 100644 --- a/velox/connectors/hive/storage_adapters/gcs/GCSFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/gcs/GCSFileSystem.cpp @@ -28,6 +28,8 @@ #include #include +#include "TokenProvider.h" +#include "OutOfBandOAuth2Credentials.h" namespace facebook::velox { namespace { @@ -65,10 +67,13 @@ inline void checkGCSStatus( class GCSReadFile final : public ReadFile { public: - GCSReadFile(const std::string& path, std::shared_ptr client) - : client_(std::move(client)) { + GCSReadFile(const std::string& path, + std::shared_ptr client, + std::function&, const std::optional&, const std::optional>&)> refreshTokenFn) + : client_(std::move(client)), refreshTokenFn_(std::move(refreshTokenFn)) { // assumption it's a proper path setBucketAndKeyFromGCSPath(path, bucket_, key_); + refreshTokenFn_(client_, std::optional(path), std::nullopt); } // Gets the length of the file. @@ -175,6 +180,7 @@ class GCSReadFile final : public ReadFile { } std::shared_ptr client_; + std::function&, const std::optional&, const std::optional>&)> refreshTokenFn_; std::string bucket_; std::string key_; std::atomic length_ = -1; @@ -184,9 +190,11 @@ class GCSWriteFile final : public WriteFile { public: explicit GCSWriteFile( const std::string& path, - std::shared_ptr client) - : client_(client) { + std::shared_ptr client, + std::function&, const std::optional&, const std::optional>&)> refreshTokenFn) + : client_(client), refreshTokenFn_(std::move(refreshTokenFn)) { setBucketAndKeyFromGCSPath(path, bucket_, key_); + refreshTokenFn_(client_, std::optional(path), std::nullopt); } ~GCSWriteFile() { @@ -244,6 +252,7 @@ class GCSWriteFile final : public WriteFile { gcs::ObjectWriteStream stream_; std::shared_ptr client_; + std::function&, const std::optional&, const std::optional>&)> refreshTokenFn_; std::string bucket_; std::string key_; std::atomic size_{-1}; @@ -267,8 +276,12 @@ class GCSFileSystem::Impl { // Use the input Config parameters and initialize the GCSClient. void initializeClient() { auto options = gc::Options{}; + + std::string accessTokenEnabled = + hiveConfig_->config()->get("gcs.access_token_enabled", "false"); + auto scheme = hiveConfig_->gcsScheme(); - if (scheme == "https") { + if (scheme == "https" && accessTokenEnabled == "false") { options.set( gc::MakeGoogleDefaultCredentials()); } else { @@ -312,6 +325,16 @@ class GCSFileSystem::Impl { << "Config hive.gcs.json-key-file-path is empty or key file path not found"; } + if (accessTokenEnabled == "true") { + std::string providerClassName = + hiveConfig_->config()->get("gcs.access_token_provider", ""); + throw std::runtime_error("AccessTokenProvider not configured. Please implement an out of band access token provider."); + // Out of band access token provider can implement custom/internal integrations. + // Callers must provide their own provider to use the access tokens with this client. + // Set credentials using: options.set(credentials_); + + } + client_ = std::make_shared(options); } @@ -319,6 +342,14 @@ class GCSFileSystem::Impl { return client_; } + std::shared_ptr credentials_; + void setCredentialsContext(const std::string& path, const std::optional>& permissions) { + if (credentials_) { + credentials_->SetContext(path, permissions); + } + } + + private: const std::shared_ptr hiveConfig_; std::shared_ptr client_; @@ -337,7 +368,11 @@ std::unique_ptr GCSFileSystem::openFileForRead( std::string_view path, const FileOptions& options) { const auto gcspath = gcsPath(path); - auto gcsfile = std::make_unique(gcspath, impl_->getClient()); + auto refreshTokenLambda = [impl_ = this->impl_](std::shared_ptr& client, const std::optional& path, const std::optional>& permissions) -> void { + impl_->setCredentialsContext(path.value_or(""), "READ"); + }; + + auto gcsfile = std::make_unique(gcspath, impl_->getClient(), refreshTokenLambda); gcsfile->initialize(options); return gcsfile; } @@ -346,7 +381,11 @@ std::unique_ptr GCSFileSystem::openFileForWrite( std::string_view path, const FileOptions& /*unused*/) { const auto gcspath = gcsPath(path); - auto gcsfile = std::make_unique(gcspath, impl_->getClient()); + auto refreshTokenLambda = [impl_ = this->impl_](std::shared_ptr& client, const std::optional& path, const std::optional>& permissions) -> void { + impl_->setCredentialsContext(path.value_or(""), "WRITE"); + }; + + auto gcsfile = std::make_unique(gcspath, impl_->getClient(), refreshTokenLambda); gcsfile->initialize(); return gcsfile; } @@ -363,6 +402,7 @@ void GCSFileSystem::remove(std::string_view path) { setBucketAndKeyFromGCSPath(file, bucket, object); if (!object.empty()) { + impl_->setCredentialsContext(file, "DELETE"); auto stat = impl_->getClient()->GetObjectMetadata(bucket, object); if (!stat.ok()) { checkGCSStatus( @@ -390,6 +430,8 @@ bool GCSFileSystem::exists(std::string_view path) { std::string object; setBucketAndKeyFromGCSPath(file, bucket, object); using ::google::cloud::StatusOr; + + impl_->setCredentialsContext(file, "LIST"); StatusOr metadata = impl_->getClient()->GetBucketMetadata(bucket); @@ -406,6 +448,8 @@ std::vector GCSFileSystem::list(std::string_view path) { std::string bucket; std::string object; setBucketAndKeyFromGCSPath(file, bucket, object); + + impl_->setCredentialsContext(file, "LIST"); for (auto&& metadata : impl_->getClient()->ListObjects(bucket)) { if (!metadata.ok()) { checkGCSStatus( diff --git a/velox/connectors/hive/storage_adapters/gcs/OutOfBandOAuth2Credentials.cc b/velox/connectors/hive/storage_adapters/gcs/OutOfBandOAuth2Credentials.cc new file mode 100644 index 000000000000..a14776af0e34 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/gcs/OutOfBandOAuth2Credentials.cc @@ -0,0 +1,39 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OutOfBandOAuth2Credentials.h" + +thread_local std::string OutOfBandOAuth2Credentials::current_path_; +thread_local std::optional> OutOfBandOAuth2Credentials::current_operation_; + +OutOfBandOAuth2Credentials::OutOfBandOAuth2Credentials(std::shared_ptr token_provider) + : token_provider_(std::move(token_provider)) { +} + +void OutOfBandOAuth2Credentials::SetContext(const std::string& path, const std::optional>& operation) { + current_path_ = path; + current_operation_ = operation; +} + +google::cloud::StatusOr OutOfBandOAuth2Credentials::AuthorizationHeader() { + auto token_pair = token_provider_->getAccessToken(current_path_, std::nullopt); + std::string access_token = token_pair.first; + + if (access_token.empty()) { + return google::cloud::Status(google::cloud::StatusCode::kPermissionDenied, "Failed to fetch access token"); + } + return "Authorization: Bearer " + access_token; +} diff --git a/velox/connectors/hive/storage_adapters/gcs/OutOfBandOAuth2Credentials.h b/velox/connectors/hive/storage_adapters/gcs/OutOfBandOAuth2Credentials.h new file mode 100644 index 000000000000..6b757a9fa357 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/gcs/OutOfBandOAuth2Credentials.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OUT_OF_BAND_OAUTH2_CREDENTIALS_H +#define OUT_OF_BAND_OAUTH2_CREDENTIALS_H + +#include "google/cloud/storage/oauth2/credentials.h" +#include +#include +#include "TokenProvider.h" + +class OutOfBandOAuth2Credentials : public google::cloud::storage::oauth2::Credentials { +public: + OutOfBandOAuth2Credentials(std::shared_ptr token_provider); + virtual ~OutOfBandOAuth2Credentials() = default; + + void SetContext(const std::string& path, const std::optional>& operation); + + google::cloud::StatusOr AuthorizationHeader() override; + +private: + std::shared_ptr token_provider_; + static thread_local std::string current_path_; + static thread_local std::optional> current_operation_; +}; + +#endif diff --git a/velox/connectors/hive/storage_adapters/gcs/TokenProvider.h b/velox/connectors/hive/storage_adapters/gcs/TokenProvider.h new file mode 100644 index 000000000000..5abe10ca1de2 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/gcs/TokenProvider.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TOKEN_PROVIDER_H +#define TOKEN_PROVIDER_H + +#include +#include +#include + +class TokenProvider { +public: + virtual ~TokenProvider() = default; + + virtual std::pair getAccessToken( + const std::optional& targetPath = std::nullopt, + const std::optional>& permissions = std::nullopt + ) const = 0; +}; + +#endif