Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GS workspaces can be bucket subfolders #604

Merged
merged 4 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- The `GSWorkspace()` can now be initialized with google cloud bucket subfolders.

### Fixed

- Removed unnecessary code coverage dev requirements.
Expand Down
109 changes: 72 additions & 37 deletions tango/integrations/gs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

import google.auth
from google.api_core import exceptions
Expand All @@ -27,33 +27,50 @@
logger = logging.getLogger(__name__)


def empty_bucket(bucket_name: str):
def get_bucket_and_prefix(folder_name: str) -> Tuple[str, str]:
"""
Removes all the tango-related blobs from the specified bucket.
Split bucket name and subfolder name, if present.
"""
split = folder_name.split("/")
return split[0], "/".join(split[1:])


def empty_bucket_folder(folder_name: str):
"""
Removes all the tango-related blobs from the specified bucket folder.
Used for testing.
"""
credentials, project = google.auth.default()
client = storage.Client(project=project, credentials=credentials)
bucket_name, prefix = get_bucket_and_prefix(folder_name)

prefix = prefix + "/tango-" if prefix else "tango-"

bucket = client.bucket(bucket_name)
try:
bucket.delete_blobs(list(bucket.list_blobs(prefix="tango-")))
bucket.delete_blobs(list(bucket.list_blobs(prefix=prefix)))
except exceptions.NotFound:
pass


def empty_datastore(namespace: str):
def empty_datastore(folder_name: str):
"""
Removes all the tango-related entities from the specified namespace in datastore.
Removes all the tango-related entities from the specified namespace subfolder in datastore.
Used for testing.
"""
from google.cloud import datastore

credentials, project = google.auth.default()
namespace, prefix = get_bucket_and_prefix(folder_name)

run_kind = prefix + "/run" if prefix else "run"
stepinfo_kind = prefix + "/stepinfo" if prefix else "stepinfo"

client = datastore.Client(project=project, credentials=credentials, namespace=namespace)
run_query = client.query(kind="run")
run_query = client.query(kind=run_kind)
run_query.keys_only()
keys = [entity.key for entity in run_query.fetch()]
stepinfo_query = client.query(kind="stepinfo")
stepinfo_query = client.query(kind=stepinfo_kind)
stepinfo_query.keys_only()
keys += [entity.key for entity in stepinfo_query.fetch()]
client.delete_multi(keys)
Expand Down Expand Up @@ -108,12 +125,19 @@ class GSArtifactWriteError(TangoError):
pass


def join_path(*args) -> str:
"""
We use this since we cannot use `os.path.join` for cloud storage paths.
"""
return "/".join(args).strip("/")


class GSClient:
"""
A client for interacting with Google Cloud Storage. The authorization works by
providing OAuth2 credentials.

:param bucket_name: The name of the Google Cloud bucket to use.
:param folder_name: The name of the Google Cloud bucket folder to use.
:param credentials: OAuth2 credentials can be provided. If not provided, default
gcloud credentials are inferred.
:param project: Optionally, the project ID can be provided. This is not essential
Expand All @@ -123,7 +147,7 @@ class GSClient:

placeholder_file = ".placeholder"
"""
The placeholder file is used for creation of a folder in the cloud bucket,
The placeholder file is used for creation of a folder in the cloud bucket folder,
as empty folders are not allowed. It is also used as a marker for the creation
time of the folder, hence we use a separate file to mark the artifact as
uncommitted.
Expand All @@ -143,17 +167,20 @@ class GSClient:

def __init__(
self,
bucket_name: str,
folder_name: str,
credentials: Optional[Credentials] = None,
project: Optional[str] = None,
):
if not credentials:
credentials, project = google.auth.default()

self.storage = storage.Client(project=project, credentials=credentials)
self.bucket_name = bucket_name
self.folder_name = folder_name

self.bucket_name, self.prefix = get_bucket_and_prefix(folder_name)
settings_file = self._gs_path(self.settings_file)

blob = self.storage.bucket(bucket_name).blob(self.settings_file) # no HTTP request yet
blob = self.storage.bucket(self.bucket_name).blob(settings_file) # no HTTP request yet
try:
with blob.open("r") as file_ref:
json.load(file_ref)
Expand All @@ -166,13 +193,12 @@ def url(self, artifact: Optional[str] = None):
"""
Returns the remote url of the storage artifact.
"""
path = f"gs://{self.bucket_name}"
path = f"gs://{self.folder_name}"
if artifact is not None:
path = f"{path}/{artifact}"
return path

@classmethod
def _convert_blobs_to_artifact(cls, blobs: List[storage.Blob]) -> GSArtifact:
def _convert_blobs_to_artifact(self, blobs: List[storage.Blob]) -> GSArtifact:
"""
Converts a list of `google.cloud.storage.Blob` to a `GSArtifact`.
"""
Expand All @@ -182,22 +208,24 @@ def _convert_blobs_to_artifact(cls, blobs: List[storage.Blob]) -> GSArtifact:
committed: bool = True

for blob in blobs:
if blob.name.endswith(cls.placeholder_file):
if blob.name.endswith(self.placeholder_file):
created = blob.time_created
name = blob.name.replace("/" + cls.placeholder_file, "")
name = blob.name.replace("/" + self.placeholder_file, "")
if self.prefix:
name = name.replace(self.prefix + "/", "")
artifact_path = name # does not contain bucket info here.
elif blob.name.endswith(cls.uncommitted_file):
elif blob.name.endswith(self.uncommitted_file):
committed = False

assert name is not None, "Folder is not a GSArtifact, should not have happened."
return GSArtifact(name, artifact_path, created, committed)

@classmethod
def from_env(cls, bucket_name: str):
def from_env(cls, folder_name: str):
"""
Constructs the client object from the environment, using default credentials.
"""
return cls(bucket_name)
return cls(folder_name)

def get(self, artifact: Union[str, GSArtifact]) -> GSArtifact:
"""
Expand All @@ -210,42 +238,46 @@ def get(self, artifact: Union[str, GSArtifact]) -> GSArtifact:
# We have an artifact, and we recreate it with refreshed info.
path = artifact.artifact_path

blobs = list(self.storage.bucket(self.bucket_name).list_blobs(prefix=path))
prefix = self._gs_path(path)
blobs = list(self.storage.bucket(self.bucket_name).list_blobs(prefix=prefix))
if len(blobs) > 0:
return self._convert_blobs_to_artifact(blobs)
else:
raise GSArtifactNotFound()

@classmethod
def _gs_path(cls, *args):
def _gs_path(self, *args):
"""
Returns path within google cloud storage bucket. We use this since we cannot
use `os.path.join` for cloud storage paths.
Returns path within google cloud storage bucket.
"""
return "/".join(args)
return join_path(self.prefix, *args)

def create(self, artifact: str):
"""
Creates a new artifact in the remote location. By default, it is uncommitted.
"""
bucket = self.storage.bucket(self.bucket_name)
# gives refreshed information
if bucket.blob(self._gs_path(artifact, self.placeholder_file)).exists():

artifact_path = self._gs_path(artifact, self.placeholder_file)
if bucket.blob(artifact_path).exists():
raise GSArtifactConflict(f"{artifact} already exists!")
else:
# Additional safety check
if bucket.blob(self._gs_path(artifact, self.uncommitted_file)).exists():
raise GSArtifactConflict(f"{artifact} already exists!")
bucket.blob(self._gs_path(artifact, self.placeholder_file)).upload_from_string("")
bucket.blob(self._gs_path(artifact, self.uncommitted_file)).upload_from_string("")
return self._convert_blobs_to_artifact(list(bucket.list_blobs(prefix=artifact)))
return self._convert_blobs_to_artifact(
list(bucket.list_blobs(prefix=self._gs_path(artifact)))
)

def delete(self, artifact: GSArtifact):
"""
Removes the artifact from the remote location.
"""
bucket = self.storage.bucket(self.bucket_name)
blobs = list(bucket.list_blobs(prefix=artifact.artifact_path))
prefix = self._gs_path(artifact.artifact_path)
blobs = list(bucket.list_blobs(prefix=prefix))
bucket.delete_blobs(blobs)

def upload(self, artifact: Union[str, GSArtifact], objects_dir: Path):
Expand All @@ -260,7 +292,7 @@ def upload(self, artifact: Union[str, GSArtifact], objects_dir: Path):
source_path = str(objects_dir)

def _sync_blob(source_file_path: str, target_file_path: str):
blob = self.storage.bucket(self.bucket_name).blob(target_file_path)
blob = self.storage.bucket(self.bucket_name).blob(self._gs_path(target_file_path))
blob.upload_from_filename(source_file_path)

import concurrent.futures
Expand All @@ -277,7 +309,7 @@ def _sync_blob(source_file_path: str, target_file_path: str):
for dirpath, _, filenames in os.walk(source_path):
for filename in filenames:
source_file_path = os.path.join(dirpath, filename)
target_file_path = self._gs_path(
target_file_path = join_path(
folder_path, source_file_path.replace(source_path + "/", "")
)
upload_futures.append(
Expand Down Expand Up @@ -328,14 +360,16 @@ def _fetch_blob(blob: storage.Blob):
import concurrent.futures

bucket = self.storage.bucket(self.bucket_name)
bucket.update()
# We may not need updates that frequently, with list_blobs(prefix).
# bucket.update()

try:
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.NUM_CONCURRENT_WORKERS, thread_name_prefix="GSClient.download()-"
) as executor:
download_futures = []
for blob in bucket.list_blobs(prefix=artifact.artifact_path):
prefix = self._gs_path(artifact.artifact_path)
for blob in bucket.list_blobs(prefix=prefix):
download_futures.append(executor.submit(_fetch_blob, blob))
for future in concurrent.futures.as_completed(download_futures):
future.result()
Expand All @@ -348,6 +382,7 @@ def artifacts(self, prefix: str, uncommitted: bool = True) -> List[GSArtifact]:
`match` and `uncommitted` criteria. These can include steps and runs.
"""
list_of_artifacts = []
prefix = self._gs_path(prefix)
for folder_name in self.storage.list_blobs(
self.bucket_name, prefix=prefix, delimiter="/"
)._get_next_page_response()["prefixes"]:
Expand Down Expand Up @@ -405,15 +440,15 @@ def get_credentials(credentials: Optional[Union[str, Credentials]] = None) -> Cr


def get_client(
bucket_name: str,
folder_name: str,
credentials: Optional[Union[str, Credentials]] = None,
project: Optional[str] = None,
) -> GSClient:
"""
Returns a `GSClient` object for a google cloud bucket.
Returns a `GSClient` object for a google cloud bucket folder.
"""
credentials = get_credentials(credentials)
return GSClient(bucket_name, credentials=credentials, project=project)
return GSClient(folder_name, credentials=credentials, project=project)


class Constants(RemoteConstants):
Expand Down
14 changes: 7 additions & 7 deletions tango/integrations/gs/step_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
GSArtifactNotFound,
GSArtifactWriteError,
GSClient,
get_bucket_and_prefix,
)
from tango.step import Step
from tango.step_cache import StepCache
Expand All @@ -32,24 +33,23 @@ class GSStepCache(RemoteStepCache):
.. tip::
Registered as a :class:`~tango.step_cache.StepCache` under the name "gs".

:param bucket_name: The name of the google cloud bucket to use.
:param folder_name: The name of the google cloud bucket folder to use.
:param client: The google cloud storage client to use.
"""

Constants = Constants

def __init__(self, bucket_name: str, client: Optional[GSClient] = None):
def __init__(self, folder_name: str, client: Optional[GSClient] = None):
if client is not None:
bucket_name, _ = get_bucket_and_prefix(folder_name)
assert (
bucket_name == client.bucket_name
), "Assert that bucket name is same as client bucket until we do better"
self.bucket_name = bucket_name
self.folder_name = folder_name
self._client = client
else:
self._client = GSClient(bucket_name)
super().__init__(
tango_cache_dir() / "gs_cache" / make_safe_filename(self._client.bucket_name)
)
self._client = GSClient(folder_name)
super().__init__(tango_cache_dir() / "gs_cache" / make_safe_filename(folder_name))

@property
def client(self):
Expand Down
Loading