Skip to content

Commit

Permalink
CMCD-455-use custom zarr store to raise errors and add retries
Browse files Browse the repository at this point in the history
  • Loading branch information
renaudjester committed Mar 22, 2024
1 parent 095d99c commit 5374510
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 27 deletions.
143 changes: 143 additions & 0 deletions copernicusmarine/core_functions/custom_zarr_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import logging
import time
from collections.abc import MutableMapping
from typing import Optional

import botocore.config
import botocore.exceptions
import botocore.session

log = logging.getLogger("copernicus_marine_root_logger")

S3_NUM_RETRIES = 9
S3_INITIAL_RETRY_WAIT_S = 1


class CustomS3Store(MutableMapping):
def __init__(
self,
endpoint: str,
bucket: str,
root_path: str,
secret_key: Optional[str] = None,
access_key: Optional[str] = None,
):
self._root_path = root_path.lstrip("/")
self._bucket = bucket
session = botocore.session.get_session()
if secret_key is None and access_key is None:
self.client = session.create_client(
"s3",
endpoint_url=endpoint,
config=botocore.config.Config(
signature_version=botocore.UNSIGNED
),
)
else:
self.client = session.create_client(
"s3",
endpoint_url=endpoint,
aws_secret_access_key=secret_key,
aws_access_key_id=access_key,
)

def __getitem__(self, key):
def fn():
full_key = f"{self._root_path}/{key}"
try:
resp = self.client.get_object(
Bucket=self._bucket, Key=full_key
)
return resp["Body"].read()
except botocore.exceptions.ClientError as e:
raise KeyError(key) from e

return with_retries(fn)

def __contains__(self, key):
full_key = f"{self._root_path}/{key}"
try:
self.client.head_object(Bucket=self._bucket, Key=full_key)
return True
except botocore.exceptions.ClientError as e:
if "404" in str(e) or "403" in str(e):
return False
raise

def __setitem__(self, key, value, headers=None):
def fn():
full_key = f"{self._root_path}/{key}"
final_headers = headers if headers is not None else {}
self.client.put_object(
Bucket=self._bucket, Key=full_key, Body=value, **final_headers
)

return with_retries(fn)

def __delitem__(self, key):
def fn():
full_key = f"{self._root_path}/{key}"
self.client.delete_object(Bucket=self._bucket, Key=full_key)

return with_retries(fn)

# Example of headers: {"ContentType": "application/json", "ContentEncoding": "gzip"}
def set_item_with_headers(self, key, value, headers):
# pylint: disable=unnecessary-dunder-call
return self.__setitem__(key, value, headers)

def keys(self):
keys = []
cursor = self._root_path
while True:
resp = self.client.list_objects_v2(
Bucket=self._bucket, Prefix=self._root_path, StartAfter=cursor
)
entries = resp.get("Contents", [])
keys += [
o["Key"].removeprefix(self._root_path).lstrip("/")
for o in entries
]
if not resp["IsTruncated"]:
break
cursor = entries[-1]["Key"]
return keys

def __iter__(self):
keys = self.keys()
yield from keys

def __len__(self):
return len(self.keys())

def clear(self):
keys = self.keys()
idx = 0
while idx < len(keys):
some_keys = keys[idx : idx + 1000]
objects = list(
map(lambda k: {"Key": f"{self._root_path}/{k}"}, some_keys)
)
self.client.delete_objects(
Bucket=self._bucket, Delete={"Objects": objects}
)
idx += 1000


def with_retries(fn):
retry_delay = S3_INITIAL_RETRY_WAIT_S
for idx_try in range(S3_NUM_RETRIES):
try:
return fn()
# KeyError is a normal error that we want to propagate
# (e.g. if we try to get a chunk and it doesn't exist,
# we want the caller to know this has happened -- and not retry!)
except KeyError:
raise
except Exception as e:
if idx_try == S3_NUM_RETRIES - 1:
raise e
log.error(f"S3 error: {e}")
log.info(f"Retrying in {retry_delay} s...")
time.sleep(retry_delay)
retry_delay *= 2
16 changes: 14 additions & 2 deletions copernicusmarine/core_functions/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import requests
import xarray

from copernicusmarine.core_functions.custom_zarr_store import CustomS3Store
from copernicusmarine.core_functions.utils import (
construct_query_params_for_marine_data_store_monitoring,
parse_access_dataset_url,
)

TRUST_ENV = True
Expand All @@ -33,8 +35,18 @@ def get_configured_request_session() -> requests.Session:


def open_zarr(
*args, copernicus_marine_username: Optional[str] = None, **kwargs
dataset_url: str,
copernicus_marine_username: Optional[str] = None,
**kwargs,
) -> xarray.Dataset:
(
endpoint,
bucket,
root_path,
) = parse_access_dataset_url(dataset_url)
store = CustomS3Store(
endpoint=endpoint, bucket=bucket, root_path=root_path
)
kwargs.update(
{
"storage_options": {
Expand All @@ -46,4 +58,4 @@ def open_zarr(
}
}
)
return xarray.open_zarr(*args, **kwargs)
return xarray.open_zarr(store, **kwargs)
28 changes: 28 additions & 0 deletions copernicusmarine/core_functions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import pathlib
import re
from datetime import datetime
from importlib.metadata import version
from typing import (
Expand All @@ -13,6 +14,7 @@
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
)
Expand Down Expand Up @@ -173,3 +175,29 @@ async def worker():

results = await asyncio.gather(*[worker() for _ in range(per_batch)])
return [s for r in results for s in r]


# Example data_path
# https://s3.waw3-1.cloudferro.com/mdl-native-01/native/NWSHELF_MULTIYEAR_BGC_004_011/cmems_mod_nws_bgc-pft_myint_7km-3D-diato_P1M-m_202105
# https://s3.region.cloudferro.com/bucket/arco/product/dataset/geoChunked.zarr
# https://s3.region.cloudferro.com:443/bucket/arco/product/dataset/geoChunked.zarr
def parse_access_dataset_url(
data_path: str, only_dataset_root_path: bool = False
) -> Tuple[str, str, str]:

match = re.search(
r"^(http|https):\/\/([\w\-\.]+)(:[\d]+)?(\/.*)", data_path
)
if match:
endpoint_url = match.group(1) + "://" + match.group(2)
full_path = match.group(4)
segments = full_path.split("/")
bucket = segments[1]
path = (
"/".join(segments[2:])
if not only_dataset_root_path
else "/".join(segments[2:5]) + "/"
)
return endpoint_url, bucket, path
else:
raise Exception(f"Invalid data path: {data_path}")
28 changes: 3 additions & 25 deletions copernicusmarine/download_functions/download_original_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
construct_url_with_query_params,
flatten,
get_unique_filename,
parse_access_dataset_url,
)

logger = logging.getLogger("copernicus_marine_root_logger")
Expand Down Expand Up @@ -197,8 +198,8 @@ def _download_header(
directory_out: pathlib.Path,
only_list_root_path: bool = False,
) -> Tuple[str, Tuple[str, str], list[str], float, list[str]]:
(endpoint_url, bucket, path) = parse_original_files_dataset_url(
data_path, only_list_root_path
(endpoint_url, bucket, path) = parse_access_dataset_url(
data_path, only_dataset_root_path=only_list_root_path
)

filenames, sizes, total_size = [], [], 0.0
Expand Down Expand Up @@ -412,29 +413,6 @@ def _original_files_file_download(
# /////////////////////////////


# Example data_path
# https://s3.waw3-1.cloudferro.com/mdl-native-01/native/NWSHELF_MULTIYEAR_BGC_004_011/cmems_mod_nws_bgc-pft_myint_7km-3D-diato_P1M-m_202105
def parse_original_files_dataset_url(
data_path: str, only_dataset_root_path: bool
) -> Tuple[str, str, str]:
match = re.search(
r"^(http|https):\/\/([\w\-\.]+)(:[\d]+)?(\/.*)", data_path
)
if match:
endpoint_url = match.group(1) + "://" + match.group(2)
full_path = match.group(4)
segments = full_path.split("/")
bucket = segments[1]
path = (
"/".join(segments[2:])
if not only_dataset_root_path
else "/".join(segments[2:5]) + "/"
)
return endpoint_url, bucket, path
else:
raise Exception(f"Invalid data path: {data_path}")


def create_filenames_out(
filenames_in: list[str],
overwrite: bool,
Expand Down

0 comments on commit 5374510

Please sign in to comment.